|
@@ -131,6 +131,47 @@ def cal_ent_attr(Xtrain, Ytrain):
|
131
|
131
|
min_mean = mean
|
132
|
132
|
return min_i, min_mean, min_ent
|
133
|
133
|
|
|
134
|
+
|
|
135
|
+def cal_max_ent_attr_c45(Xtrain, Ytrain):
|
|
136
|
+ max_ent = 0
|
|
137
|
+ max_mean = 0
|
|
138
|
+ h = calc_ent(Ytrain)
|
|
139
|
+ for k in range(len(Xtrain) - 1):
|
|
140
|
+ left = Xtrain[:k + 1]
|
|
141
|
+ right = Xtrain[k + 1:]
|
|
142
|
+
|
|
143
|
+ left_ent = calc_ent(Ytrain[:k+1])*len(left)/len(Ytrain)
|
|
144
|
+ right_ent = calc_ent(Ytrain[k + 1:])*len(right)/len(Ytrain)
|
|
145
|
+
|
|
146
|
+ iv = -len(left) / len(Ytrain) * np.log2(len(left) / len(Ytrain))
|
|
147
|
+ iv -= len(right) / len(Ytrain) * np.log2(len(right) / len(Ytrain))
|
|
148
|
+
|
|
149
|
+ gain_ent = (h - left_ent - right_ent)/iv
|
|
150
|
+
|
|
151
|
+ if gain_ent > max_ent:
|
|
152
|
+ max_ent = gain_ent
|
|
153
|
+ max_mean = left[-1]
|
|
154
|
+ return max_ent, max_mean
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+# 计算某个属性的信息增益率
|
|
158
|
+def cal_ent_attr_c45(Xtrain, Ytrain):
|
|
159
|
+ # 对每个属性
|
|
160
|
+ max_ent = 0
|
|
161
|
+ max_i = 0
|
|
162
|
+ max_mean = 0
|
|
163
|
+ for i in range(Xtrain.shape[1]): #每个属性
|
|
164
|
+ argsort = Xtrain[:,i].argsort()
|
|
165
|
+ x,y = Xtrain[:,i][argsort], Ytrain[argsort]
|
|
166
|
+
|
|
167
|
+ gain_ent, mean = cal_max_ent_attr_c45(x, y)
|
|
168
|
+
|
|
169
|
+ if gain_ent > max_ent:
|
|
170
|
+ max_ent = gain_ent
|
|
171
|
+ max_i = i
|
|
172
|
+ max_mean = mean
|
|
173
|
+ return max_i, max_mean, max_ent
|
|
174
|
+
|
134
|
175
|
# 计算某个属性的基尼指数
|
135
|
176
|
def cal_gini_attr(Xtrain, Ytrain):
|
136
|
177
|
print('sharp', Xtrain.shape)
|
|
@@ -203,7 +244,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
203
|
244
|
if depth >= MAX_T:
|
204
|
245
|
return leaf_node(Ytrain)
|
205
|
246
|
|
206
|
|
- i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain)
|
|
247
|
+ i, mean, min_ent = cal_ent_attr_c45(Xtrain, Ytrain)
|
207
|
248
|
total_ent = calc_ent(Ytrain)
|
208
|
249
|
print("第", i, "个属性,mean:", mean)
|
209
|
250
|
# 生成节点
|
|
@@ -265,7 +306,7 @@ if __name__ == '__main__':
|
265
|
306
|
|
266
|
307
|
print("基尼指数", cal_gini(Ytrain))
|
267
|
308
|
|
268
|
|
- print(cal_ent_attr(Xtrain, Ytrain))
|
|
309
|
+ print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain))
|
269
|
310
|
|
270
|
311
|
node = fit(Xtrain, Ytrain, None, 0)
|
271
|
312
|
print_width([node], 1)
|