|
@@ -131,6 +131,29 @@ def cal_ent_attr(Xtrain, Ytrain):
|
131
|
131
|
min_mean = mean
|
132
|
132
|
return min_i, min_mean, min_ent
|
133
|
133
|
|
|
134
|
+# 计算某个属性的基尼指数
|
|
135
|
+def cal_gini_attr(Xtrain, Ytrain):
|
|
136
|
+ print('sharp', Xtrain.shape)
|
|
137
|
+
|
|
138
|
+ # 对每个属性
|
|
139
|
+ min_ent = 100
|
|
140
|
+ min_i = 0
|
|
141
|
+ min_mean = 0
|
|
142
|
+ for i in range(Xtrain.shape[1]):
|
|
143
|
+ x_value_list = set([Xtrain[j][i] for j in range(Xtrain.shape[0])])
|
|
144
|
+ mean = sum(x_value_list)/len(x_value_list)
|
|
145
|
+ sum_ent = 0
|
|
146
|
+ # 二叉树
|
|
147
|
+ p = Ytrain[Xtrain[:, i] > mean]
|
|
148
|
+ sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
|
|
149
|
+ p = Ytrain[Xtrain[:, i] <= mean]
|
|
150
|
+ sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
|
|
151
|
+
|
|
152
|
+ if sum_ent < min_ent:
|
|
153
|
+ min_ent = sum_ent
|
|
154
|
+ min_i = i
|
|
155
|
+ min_mean = mean
|
|
156
|
+ return min_i, min_mean, min_ent
|
134
|
157
|
|
135
|
158
|
MAX_T = 5
|
136
|
159
|
|
|
@@ -180,7 +203,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
180
|
203
|
if depth >= MAX_T:
|
181
|
204
|
return leaf_node(Ytrain)
|
182
|
205
|
|
183
|
|
- i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain)
|
|
206
|
+ i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain)
|
184
|
207
|
total_ent = calc_ent(Ytrain)
|
185
|
208
|
print("第", i, "个属性,mean:", mean)
|
186
|
209
|
# 生成节点
|