Kaynağa Gözat

扩充节点属性

yufeng0528 4 yıl önce
ebeveyn
işleme
1e849e2652
1 değiştirilmiş dosya ile 30 ekleme ve 14 silme
  1. 30 14
      tree/my_tree.py

+ 30 - 14
tree/my_tree.py

@@ -4,23 +4,34 @@ from sklearn.datasets import load_wine
4 4
 from sklearn.model_selection import train_test_split
5 5
 import numpy as np
6 6
 
7
+feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮',
8
+                '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
9
+class_names=["琴酒", "雪莉", "贝尔摩德"]
10
+
7 11
 # 生成决策树的节点类型
8 12
 class TreeNode(object):
9
-    idx = ''                 # 属性
13
+    idx = 0                  # 属性
10 14
     idx_value = 0.0          # 属性值
11 15
     is_leaf = False
12 16
     y = 0                   # 预测值
13
-    next_list = []          # 分支
17
+    samples = 0             # 样本数
18
+    value = []              # 分布情况
19
+    left = None
20
+    right = None
14 21
 
15
-    def __init__(self, idx,idx_value, is_leaf, y, next_list):
22
+    def __init__(self, idx,idx_value, is_leaf, y, samples, value, left=None, right=None):
16 23
         self.idx = idx
17 24
         self.idx_value = idx_value
18 25
         self.is_leaf = is_leaf
19 26
         self.y = y
20
-        self.next_list = next_list
27
+        self.samples = samples
28
+        self.value = value
29
+        self.left = left
30
+        self.right = right
21 31
 
22 32
     def __str__(self):
23
-        return "%s,%f,%f,%d" % (self.idx, self.idx_value, self.is_leaf, self.y)
33
+        return "%s,%f, leaf:%d, samples:%d, value:%s, %s" % \
34
+               (feature_name[self.idx], self.idx_value, self.is_leaf,self.samples, self.value, class_names[self.y])
24 35
 
25 36
 
26 37
 def read_data():
@@ -101,6 +112,7 @@ def cal_ent_attr(Xtrain, Ytrain):
101 112
             min_mean = mean
102 113
     return min_i, min_mean
103 114
 
115
+
104 116
 MAX_T = 5
105 117
 
106 118
 
@@ -127,7 +139,7 @@ def leaf_node(Ytrain):
127 139
         if item[1] > max_item[1]:
128 140
             max_item = item
129 141
     print('这个是叶子节点,value:', max_item[0])
130
-    return TreeNode('', 0, True, max_item[0], [])
142
+    return TreeNode(0, 0, True, max_item[0], 0, 0)
131 143
 
132 144
 
133 145
 # def create_node()
@@ -137,7 +149,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
137 149
 
138 150
     if is_end(Ytrain):
139 151
         print('这个是叶子节点')
140
-        return TreeNode('', 0, True, 0, [])
152
+        return TreeNode(0, 0, True, 0, 0, 0)
141 153
 
142 154
     if depth > MAX_T:
143 155
         return leaf_node(Ytrain)
@@ -145,22 +157,22 @@ def fit(Xtrain, Ytrain, parent_node, depth):
145 157
     i, mean = cal_ent_attr(Xtrain, Ytrain)
146 158
     print("第", i, "个属性,mean:", mean)
147 159
     # 生成节点
148
-    parent_node = TreeNode(i, mean, False, 0, [])
160
+    parent_node = TreeNode(i, mean, False, 0, 0, 0)
149 161
 
150 162
     # 切分数据
151 163
     right_Ytrain = Ytrain[Xtrain[:, i] > mean]
152 164
     right_Xtrain = Xtrain[Xtrain[:, i] > mean]
153
-    right_Xtrain = np.delete(right_Xtrain, i, axis=1)
165
+    # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
154 166
 
155 167
     right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1)
156 168
 
157 169
     left_Ytrain = Ytrain[Xtrain[:, i] <= mean]
158 170
     left_Xtrain = Xtrain[Xtrain[:, i] <= mean]
159
-    left_Xtrain = np.delete(left_Xtrain, i, axis=1)
171
+    # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
160 172
     left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1)
161 173
 
162
-    parent_node.next_list.append(left_node)
163
-    parent_node.next_list.append(right_node)
174
+    parent_node.left = left_node
175
+    parent_node.right = right_node
164 176
     return parent_node
165 177
 
166 178
 
@@ -172,11 +184,15 @@ def print_width(nodes, depth):
172 184
     node_down = []
173 185
     for node in nodes:
174 186
         print(node)
175
-        if len(node.next_list) > 0:
176
-            node_down.extend(node.next_list)
187
+        if node.left is not None:
188
+            node_down.append(node.left)
189
+        if node.right is not None:
190
+            node_down.append(node.right)
191
+
177 192
 
178 193
     print_width(node_down, depth+1)
179 194
 
195
+
180 196
 if __name__ == '__main__':
181 197
     Xtrain, Xtest, Ytrain, Ytest = read_data()
182 198
     print(calc_ent1(Ytrain))