yufeng0528 4 年 前
コミット
7e4f7cb319
共有1 個のファイルを変更した16 個の追加5 個の削除を含む
  1. 16 5
      tree/my_tree.py

+ 16 - 5
tree/my_tree.py

@@ -19,7 +19,7 @@ class TreeNode(object):
19 19
     left = None
20 20
     right = None
21 21
 
22
-    def __init__(self, idx,idx_value, is_leaf, y, samples, value, left=None, right=None):
22
+    def __init__(self, idx, idx_value, is_leaf, y, samples, value, left=None, right=None):
23 23
         self.idx = idx
24 24
         self.idx_value = idx_value
25 25
         self.is_leaf = is_leaf
@@ -29,6 +29,10 @@ class TreeNode(object):
29 29
         self.left = left
30 30
         self.right = right
31 31
 
32
+        if self.y == -1:
33
+            self.y = np.where(value == np.max(value))[0][0]
34
+            print(self.y, self.value)
35
+
32 36
     def __str__(self):
33 37
         return "%s,%f, leaf:%d, samples:%d, value:%s, %s" % \
34 38
                (feature_name[self.idx], self.idx_value, self.is_leaf,self.samples, self.value, class_names[self.y])
@@ -139,17 +143,24 @@ def leaf_node(Ytrain):
139 143
         if item[1] > max_item[1]:
140 144
             max_item = item
141 145
     print('这个是叶子节点,value:', max_item[0])
142
-    return TreeNode(0, 0, True, max_item[0], 0, 0)
146
+    return TreeNode(0, 0, True, max_item[0], len(Ytrain), distrib(Ytrain))
147
+
143 148
 
149
+def distrib(Ytrain):
150
+    x_value_list = set([Ytrain[i] for i in range(Ytrain.shape[0])])
151
+    ent = 0.0
152
+    d_list = np.zeros(3, dtype=int)
153
+    for x_value in x_value_list:
154
+        d_list[x_value] = len([1 for i in Ytrain == x_value if i])
144 155
 
145
-# def create_node()
156
+    return d_list
146 157
 
147 158
 
148 159
 def fit(Xtrain, Ytrain, parent_node, depth):
149 160
 
150 161
     if is_end(Ytrain):
151 162
         print('这个是叶子节点')
152
-        return TreeNode(0, 0, True, 0, 0, 0)
163
+        return TreeNode(0, 0, True, -1, len(Ytrain), distrib(Ytrain))
153 164
 
154 165
     if depth > MAX_T:
155 166
         return leaf_node(Ytrain)
@@ -157,7 +168,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
157 168
     i, mean = cal_ent_attr(Xtrain, Ytrain)
158 169
     print("第", i, "个属性,mean:", mean)
159 170
     # 生成节点
160
-    parent_node = TreeNode(i, mean, False, 0, 0, 0)
171
+    parent_node = TreeNode(i, mean, False, -1, len(Ytrain), distrib(Ytrain))
161 172
 
162 173
     # 切分数据
163 174
     right_Ytrain = Ytrain[Xtrain[:, i] > mean]