Browse Source

扩充节点属性

yufeng0528 4 years ago
parent
commit
1e849e2652
1 changed files with 30 additions and 14 deletions
  1. 30 14
      tree/my_tree.py

+ 30 - 14
tree/my_tree.py

@@ -4,23 +4,34 @@ from sklearn.datasets import load_wine
4
 from sklearn.model_selection import train_test_split
4
 from sklearn.model_selection import train_test_split
5
 import numpy as np
5
 import numpy as np
6
 
6
 
7
+feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮',
8
+                '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
9
+class_names=["琴酒", "雪莉", "贝尔摩德"]
10
+
7
 # 生成决策树的节点类型
11
 # 生成决策树的节点类型
8
 class TreeNode(object):
12
 class TreeNode(object):
9
-    idx = ''                 # 属性
13
+    idx = 0                  # 属性
10
     idx_value = 0.0          # 属性值
14
     idx_value = 0.0          # 属性值
11
     is_leaf = False
15
     is_leaf = False
12
     y = 0                   # 预测值
16
     y = 0                   # 预测值
13
-    next_list = []          # 分支
17
+    samples = 0             # 样本数
18
+    value = []              # 分布情况
19
+    left = None
20
+    right = None
14
 
21
 
15
-    def __init__(self, idx,idx_value, is_leaf, y, next_list):
22
+    def __init__(self, idx,idx_value, is_leaf, y, samples, value, left=None, right=None):
16
         self.idx = idx
23
         self.idx = idx
17
         self.idx_value = idx_value
24
         self.idx_value = idx_value
18
         self.is_leaf = is_leaf
25
         self.is_leaf = is_leaf
19
         self.y = y
26
         self.y = y
20
-        self.next_list = next_list
27
+        self.samples = samples
28
+        self.value = value
29
+        self.left = left
30
+        self.right = right
21
 
31
 
22
     def __str__(self):
32
     def __str__(self):
23
-        return "%s,%f,%f,%d" % (self.idx, self.idx_value, self.is_leaf, self.y)
33
+        return "%s,%f, leaf:%d, samples:%d, value:%s, %s" % \
34
+               (feature_name[self.idx], self.idx_value, self.is_leaf,self.samples, self.value, class_names[self.y])
24
 
35
 
25
 
36
 
26
 def read_data():
37
 def read_data():
@@ -101,6 +112,7 @@ def cal_ent_attr(Xtrain, Ytrain):
101
             min_mean = mean
112
             min_mean = mean
102
     return min_i, min_mean
113
     return min_i, min_mean
103
 
114
 
115
+
104
 MAX_T = 5
116
 MAX_T = 5
105
 
117
 
106
 
118
 
@@ -127,7 +139,7 @@ def leaf_node(Ytrain):
127
         if item[1] > max_item[1]:
139
         if item[1] > max_item[1]:
128
             max_item = item
140
             max_item = item
129
     print('这个是叶子节点,value:', max_item[0])
141
     print('这个是叶子节点,value:', max_item[0])
130
-    return TreeNode('', 0, True, max_item[0], [])
142
+    return TreeNode(0, 0, True, max_item[0], 0, 0)
131
 
143
 
132
 
144
 
133
 # def create_node()
145
 # def create_node()
@@ -137,7 +149,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
137
 
149
 
138
     if is_end(Ytrain):
150
     if is_end(Ytrain):
139
         print('这个是叶子节点')
151
         print('这个是叶子节点')
140
-        return TreeNode('', 0, True, 0, [])
152
+        return TreeNode(0, 0, True, 0, 0, 0)
141
 
153
 
142
     if depth > MAX_T:
154
     if depth > MAX_T:
143
         return leaf_node(Ytrain)
155
         return leaf_node(Ytrain)
@@ -145,22 +157,22 @@ def fit(Xtrain, Ytrain, parent_node, depth):
145
     i, mean = cal_ent_attr(Xtrain, Ytrain)
157
     i, mean = cal_ent_attr(Xtrain, Ytrain)
146
     print("第", i, "个属性,mean:", mean)
158
     print("第", i, "个属性,mean:", mean)
147
     # 生成节点
159
     # 生成节点
148
-    parent_node = TreeNode(i, mean, False, 0, [])
160
+    parent_node = TreeNode(i, mean, False, 0, 0, 0)
149
 
161
 
150
     # 切分数据
162
     # 切分数据
151
     right_Ytrain = Ytrain[Xtrain[:, i] > mean]
163
     right_Ytrain = Ytrain[Xtrain[:, i] > mean]
152
     right_Xtrain = Xtrain[Xtrain[:, i] > mean]
164
     right_Xtrain = Xtrain[Xtrain[:, i] > mean]
153
-    right_Xtrain = np.delete(right_Xtrain, i, axis=1)
165
+    # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
154
 
166
 
155
     right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1)
167
     right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1)
156
 
168
 
157
     left_Ytrain = Ytrain[Xtrain[:, i] <= mean]
169
     left_Ytrain = Ytrain[Xtrain[:, i] <= mean]
158
     left_Xtrain = Xtrain[Xtrain[:, i] <= mean]
170
     left_Xtrain = Xtrain[Xtrain[:, i] <= mean]
159
-    left_Xtrain = np.delete(left_Xtrain, i, axis=1)
171
+    # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
160
     left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1)
172
     left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1)
161
 
173
 
162
-    parent_node.next_list.append(left_node)
163
-    parent_node.next_list.append(right_node)
174
+    parent_node.left = left_node
175
+    parent_node.right = right_node
164
     return parent_node
176
     return parent_node
165
 
177
 
166
 
178
 
@@ -172,11 +184,15 @@ def print_width(nodes, depth):
172
     node_down = []
184
     node_down = []
173
     for node in nodes:
185
     for node in nodes:
174
         print(node)
186
         print(node)
175
-        if len(node.next_list) > 0:
176
-            node_down.extend(node.next_list)
187
+        if node.left is not None:
188
+            node_down.append(node.left)
189
+        if node.right is not None:
190
+            node_down.append(node.right)
191
+
177
 
192
 
178
     print_width(node_down, depth+1)
193
     print_width(node_down, depth+1)
179
 
194
 
195
+
180
 if __name__ == '__main__':
196
 if __name__ == '__main__':
181
     Xtrain, Xtest, Ytrain, Ytest = read_data()
197
     Xtrain, Xtest, Ytrain, Ytest = read_data()
182
     print(calc_ent1(Ytrain))
198
     print(calc_ent1(Ytrain))