yufeng0528 пре 4 година
родитељ
комит
f95423fc8e
1 измењених фајлова са 73 додато и 46 уклоњено
  1. 73 46
      tree/my_tree.py

+ 73 - 46
tree/my_tree.py

@@ -47,7 +47,7 @@ class TreeNode(object):
47 47
 
48 48
 def read_data():
49 49
     wine = load_wine()
50
-    print(wine.data.shape)  # 178*13
50
+    print("数组结构", wine.data.shape)  # 178*13
51 51
     print(wine.target)
52 52
     print(wine.feature_names)
53 53
     print(wine.target_names)
@@ -148,7 +148,7 @@ def cal_max_ent_attr_c45(Xtrain, Ytrain, weights=None):
148 148
     weights = weights / sum(weights)
149 149
     h = calc_ent(Ytrain, weights)
150 150
     p = 0
151
-    for k in range(len(Xtrain) - 1):
151
+    for k in range(0, len(Xtrain) - 1, 3):
152 152
         left = Xtrain[:k + 1]
153 153
         right = Xtrain[k + 1:]
154 154
 
@@ -259,40 +259,6 @@ def distrib(Ytrain):
259 259
     return d_list
260 260
 
261 261
 
262
-def fit(Xtrain, Ytrain, parent_node, depth, weights):
263
-
264
-    if is_end(Ytrain):
265
-        # print('这个是叶子节点')
266
-        return leaf_node(Ytrain, weights)
267
-
268
-    if depth >= MAX_T:
269
-        return leaf_node(Ytrain, weights)
270
-
271
-    i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain, weights)
272
-    total_ent = 0 # calc_ent(Ytrain)
273
-    # print("第", i, "个属性,mean:", mean)
274
-    # 生成节点
275
-    parent_node = TreeNode(i, mean, total_ent - min_ent, False, -2, len(Ytrain), distrib(Ytrain))
276
-
277
-    # 切分数据
278
-    right_position = Xtrain[:, i] > mean
279
-    right_Ytrain = Ytrain[right_position]
280
-    right_Xtrain = Xtrain[right_position]
281
-    # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
282
-
283
-    right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1, weights[right_position])
284
-
285
-    left_position = Xtrain[:, i] <= mean
286
-    left_Ytrain = Ytrain[left_position]
287
-    left_Xtrain = Xtrain[left_position]
288
-    # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
289
-    left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1, weights[left_position])
290
-
291
-    parent_node.left = left_node
292
-    parent_node.right = right_node
293
-    return parent_node
294
-
295
-
296 262
 def print_width(nodes, depth):
297 263
     if len(nodes) == 0:
298 264
         return
@@ -313,8 +279,8 @@ def print_width(nodes, depth):
313 279
 def predit_one(X, Y, node):
314 280
     if node.is_leaf:
315 281
         # print(class_names[node.y], class_names[Y])
316
-        if node.y == 0:
317
-            return -1
282
+        # if node.y == 0:
283
+        #     return -1
318 284
         return node.y
319 285
     else:
320 286
         if X[node.idx] <= node.idx_value:
@@ -323,12 +289,69 @@ def predit_one(X, Y, node):
323 289
             return predit_one(X, Y, node.right)
324 290
 
325 291
 
326
-def predict(Xtest, Ytest, node):
327
-    result = []
328
-    for i in range(Xtest.shape[0]):
329
-        result.append(predit_one(Xtest[i], None, node))
330
-    return np.array(result)
292
+class MyDT(object):
293
+
294
+    criterion = None
295
+    max_depth = None
331 296
 
297
+    root_node = None
298
+
299
+    def __init__(self, criterion, max_depth):
300
+        self.criterion = criterion
301
+        self.max_depth = max_depth
302
+
303
+    def fit(self, Xtrain, Ytrain, sample_weight=None):
304
+        if sample_weight is None:
305
+            sample_weight = np.ones(Ytrain.shape[0]) / Ytrain.shape[0]
306
+        self.root_node = self.do_fit(Xtrain, Ytrain, 0, sample_weight)
307
+
308
+    def do_fit(self, Xtrain, Ytrain, depth, weights):
309
+
310
+        if is_end(Ytrain):
311
+            # print('这个是叶子节点')
312
+            return leaf_node(Ytrain, weights)
313
+
314
+        if depth >= self.max_depth:
315
+            return leaf_node(Ytrain, weights)
316
+
317
+        if self.criterion == 'entropy':
318
+            i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain, weights)
319
+        elif self.criterion == 'C4.5':
320
+            i, mean, min_ent = cal_ent_attr_c45(Xtrain, Ytrain, weights)
321
+        else:
322
+            i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain, weights)
323
+        total_ent = 0  # calc_ent(Ytrain)
324
+        # print("第", i, "个属性,mean:", mean)
325
+        # 生成节点
326
+        parent_node = TreeNode(i, mean, total_ent - min_ent, False, None, len(Ytrain), distrib(Ytrain))
327
+
328
+        # 切分数据
329
+        right_position = Xtrain[:, i] > mean
330
+        right_Ytrain = Ytrain[right_position]
331
+        right_Xtrain = Xtrain[right_position]
332
+        # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
333
+
334
+        right_node = self.do_fit(right_Xtrain, right_Ytrain, depth + 1, weights[right_position])
335
+
336
+        left_position = Xtrain[:, i] <= mean
337
+        left_Ytrain = Ytrain[left_position]
338
+        left_Xtrain = Xtrain[left_position]
339
+        # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
340
+        left_node = self.do_fit(left_Xtrain, left_Ytrain, depth + 1, weights[left_position])
341
+
342
+        parent_node.left = left_node
343
+        parent_node.right = right_node
344
+        return parent_node
345
+
346
+    def predict(self, Xtest):
347
+        result = []
348
+        for i in range(Xtest.shape[0]):
349
+            result.append(predit_one(Xtest[i], None, self.root_node))
350
+        return np.array(result)
351
+
352
+    def score(self, Xtest, Ytest):
353
+        result = self.predict(Xtest)
354
+        return sum(result == Ytest)/Ytest.shape[0]
332 355
 
333 356
 if __name__ == '__main__':
334 357
     Xtrain, Xtest, Ytrain, Ytest = read_data()
@@ -342,7 +365,11 @@ if __name__ == '__main__':
342 365
 
343 366
     print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain, weights))
344 367
 
345
-    node = fit(Xtrain, Ytrain, None, 0, weights)
346
-    print_width([node], 1)
368
+    clf = MyDT(criterion="entropy", max_depth=3,)
369
+    clf.fit(Xtrain, Ytrain, weights)
370
+
371
+    # print_width([node], 1)
372
+
373
+    print(clf.predict(Xtest))
347 374
 
348
-    print(predict(Xtest, Ytest, node))
375
+    print(clf.score(Xtest, Ytest))