|
@@ -127,7 +127,7 @@ MAX_T = 5
|
127
|
127
|
def is_end(Ytrain):
|
128
|
128
|
if len(Ytrain) == 0:
|
129
|
129
|
return True
|
130
|
|
- if len(set(Ytrain)) == 1:
|
|
130
|
+ if len(set(Ytrain)) == 1: # 只有一个分类
|
131
|
131
|
return True
|
132
|
132
|
|
133
|
133
|
# 强行划分为叶子节点
|
|
@@ -166,7 +166,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
166
|
166
|
print('这个是叶子节点')
|
167
|
167
|
return TreeNode(-1, 0, True, -1, len(Ytrain), distrib(Ytrain))
|
168
|
168
|
|
169
|
|
- if depth > MAX_T:
|
|
169
|
+ if depth >= MAX_T:
|
170
|
170
|
return leaf_node(Ytrain)
|
171
|
171
|
|
172
|
172
|
i, mean = cal_ent_attr(Xtrain, Ytrain)
|