Browse Source

完善类

yufeng0528 4 years ago
parent
commit
cfaf9b76b6
1 changed files with 21 additions and 5 deletions
  1. 21 5
      tree/my_tree.py

+ 21 - 5
tree/my_tree.py

@@ -44,6 +44,14 @@ class TreeNode(object):
44 44
             return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
45 45
                    (feature_name[self.idx], self.idx_value, self.ent, self.is_leaf, self.samples, self.value, class_names[self.y])
46 46
 
47
+    def export(self, _feature_name, _class_names):
48
+        if self.idx == -1:
49
+            return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
50
+               ('', self.idx_value, self.ent, self.is_leaf,self.samples, self.value, _class_names[self.y])
51
+        else:
52
+            return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
53
+                   (_feature_name[self.idx], self.idx_value, self.ent, self.is_leaf, self.samples, self.value, _class_names[self.y])
54
+
47 55
 
48 56
 def read_data():
49 57
     wine = load_wine()
@@ -259,21 +267,21 @@ def distrib(Ytrain):
259 267
     return d_list
260 268
 
261 269
 
262
-def print_width(nodes, depth):
270
+def print_width(nodes, depth, _feature_name, _class_names):
263 271
     if len(nodes) == 0:
264 272
         return
265 273
 
266 274
     print("--第", depth, "层--")
267 275
     node_down = []
268 276
     for node in nodes:
269
-        print(node)
277
+        print(node.export(_feature_name, _class_names))
270 278
         if node.left is not None:
271 279
             node_down.append(node.left)
272 280
         if node.right is not None:
273 281
             node_down.append(node.right)
274 282
 
275 283
 
276
-    print_width(node_down, depth+1)
284
+    print_width(node_down, depth+1, _feature_name, _class_names)
277 285
 
278 286
 
279 287
 def predit_one(X, Y, node):
@@ -353,6 +361,12 @@ class MyDT(object):
353 361
         result = self.predict(Xtest)
354 362
         return sum(result == Ytest)/Ytest.shape[0]
355 363
 
364
+    @staticmethod
365
+    def export(tree, feature_names, class_names):
366
+        nodes = tree.root_node
367
+        print_width([nodes], 1, feature_names, class_names)
368
+
369
+
356 370
 if __name__ == '__main__':
357 371
     Xtrain, Xtest, Ytrain, Ytest = read_data()
358 372
     print(calc_ent1(Ytrain))
@@ -365,11 +379,13 @@ if __name__ == '__main__':
365 379
 
366 380
     print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain, weights))
367 381
 
368
-    clf = MyDT(criterion="entropy", max_depth=3,)
382
+    clf = MyDT(criterion="entropy", max_depth=1,)
369 383
     clf.fit(Xtrain, Ytrain, weights)
370 384
 
371 385
     # print_width([node], 1)
372 386
 
373 387
     print(clf.predict(Xtest))
374 388
 
375
-    print(clf.score(Xtest, Ytest))
389
+    print(clf.score(Xtest, Ytest))
390
+    print(clf.score(Xtrain, Ytrain))
391
+    MyDT.export(clf, feature_name, class_names)