Browse Source

实现c4.5算法

yufeng0528 4 years ago
parent
commit
b576a4ad2a
1 changed files with 43 additions and 2 deletions
  1. 43 2
      tree/my_tree.py

+ 43 - 2
tree/my_tree.py

@@ -131,6 +131,47 @@ def cal_ent_attr(Xtrain, Ytrain):
131 131
             min_mean = mean
132 132
     return min_i, min_mean, min_ent
133 133
 
134
+
135
+def cal_max_ent_attr_c45(Xtrain, Ytrain):
136
+    max_ent = 0
137
+    max_mean = 0
138
+    h = calc_ent(Ytrain)
139
+    for k in range(len(Xtrain) - 1):
140
+        left = Xtrain[:k + 1]
141
+        right = Xtrain[k + 1:]
142
+
143
+        left_ent = calc_ent(Ytrain[:k+1])*len(left)/len(Ytrain)
144
+        right_ent = calc_ent(Ytrain[k + 1:])*len(right)/len(Ytrain)
145
+
146
+        iv = -len(left) / len(Ytrain) * np.log2(len(left) / len(Ytrain))
147
+        iv -= len(right) / len(Ytrain) * np.log2(len(right) / len(Ytrain))
148
+
149
+        gain_ent = (h - left_ent - right_ent)/iv
150
+
151
+        if gain_ent > max_ent:
152
+            max_ent = gain_ent
153
+            max_mean = left[-1]
154
+    return  max_ent, max_mean
155
+
156
+
157
+# 计算某个属性的信息增益率
158
+def cal_ent_attr_c45(Xtrain, Ytrain):
159
+    # 对每个属性
160
+    max_ent = 0
161
+    max_i = 0
162
+    max_mean = 0
163
+    for i in range(Xtrain.shape[1]): #每个属性
164
+        argsort = Xtrain[:,i].argsort()
165
+        x,y = Xtrain[:,i][argsort], Ytrain[argsort]
166
+
167
+        gain_ent, mean = cal_max_ent_attr_c45(x, y)
168
+
169
+        if gain_ent > max_ent:
170
+            max_ent = gain_ent
171
+            max_i = i
172
+            max_mean = mean
173
+    return max_i, max_mean, max_ent
174
+
134 175
 # 计算某个属性的基尼指数
135 176
 def cal_gini_attr(Xtrain, Ytrain):
136 177
     print('sharp', Xtrain.shape)
@@ -203,7 +244,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
203 244
     if depth >= MAX_T:
204 245
         return leaf_node(Ytrain)
205 246
 
206
-    i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain)
247
+    i, mean, min_ent = cal_ent_attr_c45(Xtrain, Ytrain)
207 248
     total_ent = calc_ent(Ytrain)
208 249
     print("第", i, "个属性,mean:", mean)
209 250
     # 生成节点
@@ -265,7 +306,7 @@ if __name__ == '__main__':
265 306
 
266 307
     print("基尼指数", cal_gini(Ytrain))
267 308
 
268
-    print(cal_ent_attr(Xtrain, Ytrain))
309
+    print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain))
269 310
 
270 311
     node = fit(Xtrain, Ytrain, None, 0)
271 312
     print_width([node], 1)