Explorar o código

使用基尼指数进行属性划分

yufeng0528 %!s(int64=4) %!d(string=hai) anos
pai
achega
aee3214dfc
Modificáronse 1 ficheiros con 24 adicións e 1 borrados
  1. 24 1
      tree/my_tree.py

+ 24 - 1
tree/my_tree.py

@@ -131,6 +131,29 @@ def cal_ent_attr(Xtrain, Ytrain):
131 131
             min_mean = mean
132 132
     return min_i, min_mean, min_ent
133 133
 
134
+# 计算某个属性的基尼指数
135
+def cal_gini_attr(Xtrain, Ytrain):
136
+    print('sharp', Xtrain.shape)
137
+
138
+    # 对每个属性
139
+    min_ent = 100
140
+    min_i = 0
141
+    min_mean = 0
142
+    for i in range(Xtrain.shape[1]):
143
+        x_value_list = set([Xtrain[j][i] for j in range(Xtrain.shape[0])])
144
+        mean = sum(x_value_list)/len(x_value_list)
145
+        sum_ent = 0
146
+        # 二叉树
147
+        p = Ytrain[Xtrain[:, i] > mean]
148
+        sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
149
+        p = Ytrain[Xtrain[:, i] <= mean]
150
+        sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
151
+
152
+        if sum_ent < min_ent:
153
+            min_ent = sum_ent
154
+            min_i = i
155
+            min_mean = mean
156
+    return min_i, min_mean, min_ent
134 157
 
135 158
 MAX_T = 5
136 159
 
@@ -180,7 +203,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
180 203
     if depth >= MAX_T:
181 204
         return leaf_node(Ytrain)
182 205
 
183
-    i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain)
206
+    i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain)
184 207
     total_ent = calc_ent(Ytrain)
185 208
     print("第", i, "个属性,mean:", mean)
186 209
     # 生成节点