|
@@ -44,6 +44,14 @@ class TreeNode(object):
|
44
|
44
|
return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
|
45
|
45
|
(feature_name[self.idx], self.idx_value, self.ent, self.is_leaf, self.samples, self.value, class_names[self.y])
|
46
|
46
|
|
|
47
|
+ def export(self, _feature_name, _class_names):
|
|
48
|
+ if self.idx == -1:
|
|
49
|
+ return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
|
|
50
|
+ ('', self.idx_value, self.ent, self.is_leaf,self.samples, self.value, _class_names[self.y])
|
|
51
|
+ else:
|
|
52
|
+ return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
|
|
53
|
+ (_feature_name[self.idx], self.idx_value, self.ent, self.is_leaf, self.samples, self.value, _class_names[self.y])
|
|
54
|
+
|
47
|
55
|
|
48
|
56
|
def read_data():
|
49
|
57
|
wine = load_wine()
|
|
@@ -259,21 +267,21 @@ def distrib(Ytrain):
|
259
|
267
|
return d_list
|
260
|
268
|
|
261
|
269
|
|
262
|
|
-def print_width(nodes, depth):
|
|
270
|
+def print_width(nodes, depth, _feature_name, _class_names):
|
263
|
271
|
if len(nodes) == 0:
|
264
|
272
|
return
|
265
|
273
|
|
266
|
274
|
print("--第", depth, "层--")
|
267
|
275
|
node_down = []
|
268
|
276
|
for node in nodes:
|
269
|
|
- print(node)
|
|
277
|
+ print(node.export(_feature_name, _class_names))
|
270
|
278
|
if node.left is not None:
|
271
|
279
|
node_down.append(node.left)
|
272
|
280
|
if node.right is not None:
|
273
|
281
|
node_down.append(node.right)
|
274
|
282
|
|
275
|
283
|
|
276
|
|
- print_width(node_down, depth+1)
|
|
284
|
+ print_width(node_down, depth+1, _feature_name, _class_names)
|
277
|
285
|
|
278
|
286
|
|
279
|
287
|
def predit_one(X, Y, node):
|
|
@@ -353,6 +361,12 @@ class MyDT(object):
|
353
|
361
|
result = self.predict(Xtest)
|
354
|
362
|
return sum(result == Ytest)/Ytest.shape[0]
|
355
|
363
|
|
|
364
|
+ @staticmethod
|
|
365
|
+ def export(tree, feature_names, class_names):
|
|
366
|
+ nodes = tree.root_node
|
|
367
|
+ print_width([nodes], 1, feature_names, class_names)
|
|
368
|
+
|
|
369
|
+
|
356
|
370
|
if __name__ == '__main__':
|
357
|
371
|
Xtrain, Xtest, Ytrain, Ytest = read_data()
|
358
|
372
|
print(calc_ent1(Ytrain))
|
|
@@ -365,11 +379,13 @@ if __name__ == '__main__':
|
365
|
379
|
|
366
|
380
|
print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain, weights))
|
367
|
381
|
|
368
|
|
- clf = MyDT(criterion="entropy", max_depth=3,)
|
|
382
|
+ clf = MyDT(criterion="entropy", max_depth=1,)
|
369
|
383
|
clf.fit(Xtrain, Ytrain, weights)
|
370
|
384
|
|
371
|
385
|
# print_width([node], 1)
|
372
|
386
|
|
373
|
387
|
print(clf.predict(Xtest))
|
374
|
388
|
|
375
|
|
- print(clf.score(Xtest, Ytest))
|
|
389
|
+ print(clf.score(Xtest, Ytest))
|
|
390
|
+ print(clf.score(Xtrain, Ytrain))
|
|
391
|
+ MyDT.export(clf, feature_name, class_names)
|