|
@@ -19,7 +19,7 @@ class TreeNode(object):
|
19
|
19
|
left = None
|
20
|
20
|
right = None
|
21
|
21
|
|
22
|
|
- def __init__(self, idx,idx_value, is_leaf, y, samples, value, left=None, right=None):
|
|
22
|
+ def __init__(self, idx, idx_value, is_leaf, y, samples, value, left=None, right=None):
|
23
|
23
|
self.idx = idx
|
24
|
24
|
self.idx_value = idx_value
|
25
|
25
|
self.is_leaf = is_leaf
|
|
@@ -29,6 +29,10 @@ class TreeNode(object):
|
29
|
29
|
self.left = left
|
30
|
30
|
self.right = right
|
31
|
31
|
|
|
32
|
+ if self.y == -1:
|
|
33
|
+ self.y = np.where(value == np.max(value))[0][0]
|
|
34
|
+ print(self.y, self.value)
|
|
35
|
+
|
32
|
36
|
def __str__(self):
|
33
|
37
|
return "%s,%f, leaf:%d, samples:%d, value:%s, %s" % \
|
34
|
38
|
(feature_name[self.idx], self.idx_value, self.is_leaf,self.samples, self.value, class_names[self.y])
|
|
@@ -139,17 +143,24 @@ def leaf_node(Ytrain):
|
139
|
143
|
if item[1] > max_item[1]:
|
140
|
144
|
max_item = item
|
141
|
145
|
print('这个是叶子节点,value:', max_item[0])
|
142
|
|
- return TreeNode(0, 0, True, max_item[0], 0, 0)
|
|
146
|
+ return TreeNode(0, 0, True, max_item[0], len(Ytrain), distrib(Ytrain))
|
|
147
|
+
|
143
|
148
|
|
|
149
|
+def distrib(Ytrain):
|
|
150
|
+ x_value_list = set([Ytrain[i] for i in range(Ytrain.shape[0])])
|
|
151
|
+ ent = 0.0
|
|
152
|
+ d_list = np.zeros(3, dtype=int)
|
|
153
|
+ for x_value in x_value_list:
|
|
154
|
+ d_list[x_value] = len([1 for i in Ytrain == x_value if i])
|
144
|
155
|
|
145
|
|
-# def create_node()
|
|
156
|
+ return d_list
|
146
|
157
|
|
147
|
158
|
|
148
|
159
|
def fit(Xtrain, Ytrain, parent_node, depth):
|
149
|
160
|
|
150
|
161
|
if is_end(Ytrain):
|
151
|
162
|
print('这个是叶子节点')
|
152
|
|
- return TreeNode(0, 0, True, 0, 0, 0)
|
|
163
|
+ return TreeNode(0, 0, True, -1, len(Ytrain), distrib(Ytrain))
|
153
|
164
|
|
154
|
165
|
if depth > MAX_T:
|
155
|
166
|
return leaf_node(Ytrain)
|
|
@@ -157,7 +168,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
157
|
168
|
i, mean = cal_ent_attr(Xtrain, Ytrain)
|
158
|
169
|
print("第", i, "个属性,mean:", mean)
|
159
|
170
|
# 生成节点
|
160
|
|
- parent_node = TreeNode(i, mean, False, 0, 0, 0)
|
|
171
|
+ parent_node = TreeNode(i, mean, False, -1, len(Ytrain), distrib(Ytrain))
|
161
|
172
|
|
162
|
173
|
# 切分数据
|
163
|
174
|
right_Ytrain = Ytrain[Xtrain[:, i] > mean]
|