|
@@ -71,6 +71,15 @@ def calc_ent(x):
|
71
|
71
|
return ent
|
72
|
72
|
|
73
|
73
|
|
|
74
|
+def cal_gini(x):
|
|
75
|
+ x_value_list = set([x[i] for i in range(x.shape[0])])
|
|
76
|
+ ent = 0.0
|
|
77
|
+ for x_value in x_value_list:
|
|
78
|
+ p = float(x[x == x_value].shape[0]) / x.shape[0]
|
|
79
|
+ ent += p*p
|
|
80
|
+ return ent
|
|
81
|
+
|
|
82
|
+
|
74
|
83
|
def calc_ent1(x):
|
75
|
84
|
"""
|
76
|
85
|
calculate shanno ent of x
|
|
@@ -231,6 +240,8 @@ if __name__ == '__main__':
|
231
|
240
|
print(calc_ent1(Ytrain))
|
232
|
241
|
print(calc_ent(Ytrain))
|
233
|
242
|
|
|
243
|
+ print("基尼指数", cal_gini(Ytrain))
|
|
244
|
+
|
234
|
245
|
print(cal_ent_attr(Xtrain, Ytrain))
|
235
|
246
|
|
236
|
247
|
node = fit(Xtrain, Ytrain, None, 0)
|