example.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. from sklearn import tree
  4. from sklearn.datasets import load_wine
  5. from sklearn.model_selection import train_test_split
  6. import numpy
  7. import graphviz
  8. wine = load_wine()
  9. print(wine.data.shape) #178*13
  10. print(wine.target)
  11. #如果wine是一张表,应该长这样:
  12. import pandas as pd
  13. pdata = pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
  14. print(wine.feature_names)
  15. print(wine.target_names)
  16. Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
  17. numpy.savetxt("foo.csv", Xtrain, delimiter=",")
  18. clf = tree.DecisionTreeClassifier(criterion="entropy", max_features=1, max_depth=1)#实例化,criterion不写的话默认是基尼系数
  19. # clf.n_features_ = 2
  20. clf = clf.fit(Xtrain, Ytrain)
  21. score = clf.score(Xtest, Ytest) #返回预测的准确度
  22. print("score:", score)
  23. feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮',
  24. '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
  25. dot_data = tree.export_graphviz(clf
  26. # ,out_file = None
  27. , feature_names=feature_name
  28. , class_names=["琴酒", "雪莉", "贝尔摩德"]
  29. , filled=True # 让树的每一块有颜色,颜色越浅,表示不纯度越高
  30. , rounded=True # 树的块的形状
  31. )
  32. dot_data = dot_data.replace('helvetica', '"Microsoft YaHei"')
  33. graph = graphviz.Source(dot_data)
  34. graph.render("Tree1")
  35. graph # graph.view()