yufeng0528 4 vuotta sitten
vanhempi
commit
2f08126ed5
1 muutettua tiedostoa jossa 2 lisäystä ja 2 poistoa
  1. 2 2
      tree/my_tree.py

+ 2 - 2
tree/my_tree.py

@@ -127,7 +127,7 @@ MAX_T = 5
127 127
 def is_end(Ytrain):
128 128
     if len(Ytrain) == 0:
129 129
         return True
130
-    if len(set(Ytrain)) == 1:
130
+    if len(set(Ytrain)) == 1: # 只有一个分类
131 131
         return True
132 132
 
133 133
 # 强行划分为叶子节点
@@ -166,7 +166,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
166 166
         print('这个是叶子节点')
167 167
         return TreeNode(-1, 0, True, -1, len(Ytrain), distrib(Ytrain))
168 168
 
169
-    if depth > MAX_T:
169
+    if depth >= MAX_T:
170 170
         return leaf_node(Ytrain)
171 171
 
172 172
     i, mean = cal_ent_attr(Xtrain, Ytrain)