my_tree.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. from sklearn.datasets import load_wine
  4. from sklearn.model_selection import train_test_split
  5. import numpy as np
  6. feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮',
  7. '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸'
  8. , 'A', 'B', 'c', 'D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T']
  9. class_names=["琴酒", "雪莉", "贝尔摩德"]
  10. # 生成决策树的节点类型
  11. class TreeNode(object):
  12. idx = 0 # 属性
  13. idx_value = 0.0 # 属性值
  14. ent = 0.0 # 信息增益
  15. is_leaf = False
  16. y = 0 # 预测值
  17. samples = 0 # 样本数
  18. value = [] # 分布情况
  19. left = None
  20. right = None
  21. def __init__(self, idx, idx_value, ent, is_leaf, y, samples, value, left=None, right=None):
  22. self.idx = idx
  23. self.idx_value = idx_value
  24. self.is_leaf = is_leaf
  25. self.ent = ent
  26. self.y = y
  27. self.samples = samples
  28. self.value = value
  29. self.left = left
  30. self.right = right
  31. if self.y is None:
  32. self.y = np.where(value == np.max(value))[0][0] ## TODO
  33. # print(self.y, self.value)
  34. def __str__(self):
  35. if self.idx == -1:
  36. return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
  37. ('', self.idx_value, self.ent, self.is_leaf,self.samples, self.value, class_names[self.y])
  38. else:
  39. return "%s,%f, ent:%f, leaf:%d, samples:%d, value:%s, %s" % \
  40. (feature_name[self.idx], self.idx_value, self.ent, self.is_leaf, self.samples, self.value, class_names[self.y])
  41. def read_data():
  42. wine = load_wine()
  43. print(wine.data.shape) # 178*13
  44. print(wine.target)
  45. print(wine.feature_names)
  46. print(wine.target_names)
  47. # 如果wine是一张表,应该长这样:
  48. import pandas as pd
  49. # pdata = pd.concat([pd.DataFrame(wine.data), pd.DataFrame(wine.target)], axis=1)
  50. Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
  51. return Xtrain, Xtest, Ytrain, Ytest
  52. def calc_ent(x, weights=None):
  53. """
  54. calculate shanno ent of x
  55. """
  56. x_value_list = set([x[i] for i in range(x.shape[0])])
  57. ent = 0.0
  58. for x_value in x_value_list:
  59. if weights is None:
  60. p = float(x[x == x_value].shape[0]) / x.shape[0]
  61. else:
  62. weights = weights/sum(weights)
  63. p = sum(sum([x == x_value]*weights))
  64. logp = np.log2(p)
  65. ent -= p * logp
  66. return ent
  67. def cal_gini(x):
  68. x_value_list = set([x[i] for i in range(x.shape[0])])
  69. ent = 0.0
  70. for x_value in x_value_list:
  71. p = float(x[x == x_value].shape[0]) / x.shape[0]
  72. ent += p*p
  73. return 1-ent
  74. def calc_ent1(x):
  75. """
  76. calculate shanno ent of x
  77. """
  78. p_set = []
  79. for item in x:
  80. for i in p_set:
  81. if i[0] == item:
  82. i[1] = i[1] + 1
  83. break
  84. else:
  85. i = [item, 1]
  86. p_set.append(i)
  87. pro_list = []
  88. size = len(x)
  89. for item in p_set:
  90. pro_list.append(item[1]/size)
  91. ent = 0.0
  92. for p in pro_list:
  93. logp = np.log2(p)
  94. ent -= p * logp
  95. return ent
  96. # 计算某个属性的信息增益
  97. def cal_ent_attr(Xtrain, Ytrain, weights):
  98. # print('sharp', Xtrain.shape)
  99. weights = weights / sum(weights)
  100. # 对每个属性
  101. min_ent = 100
  102. min_i = 0
  103. min_mean = 0
  104. for i in range(Xtrain.shape[1]):
  105. x_value_list = set([Xtrain[j][i] for j in range(Xtrain.shape[0])])
  106. mean = sum(x_value_list)/len(x_value_list)
  107. sum_ent = 0
  108. # 二叉树
  109. p = Ytrain[Xtrain[:, i] > mean]
  110. p0 = sum(weights[Xtrain[:, i] > mean])
  111. sum_ent = sum_ent + calc_ent(p, weights[Xtrain[:, i] > mean])*p0
  112. p = Ytrain[Xtrain[:, i] <= mean]
  113. sum_ent = sum_ent + calc_ent(p, weights[Xtrain[:, i] <= mean])*(1-p0)
  114. if sum_ent <= min_ent:
  115. min_ent = sum_ent
  116. min_i = i
  117. min_mean = mean
  118. return min_i, min_mean, min_ent
  119. def cal_max_ent_attr_c45(Xtrain, Ytrain, weights=None):
  120. max_ent = 0
  121. max_mean = 0
  122. weights = weights / sum(weights)
  123. h = calc_ent(Ytrain)
  124. for k in range(len(Xtrain) - 1):
  125. left = Xtrain[:k + 1]
  126. right = Xtrain[k + 1:]
  127. if weights is None:
  128. left_ent = calc_ent(Ytrain[:k+1])*len(left)/len(Ytrain)
  129. right_ent = calc_ent(Ytrain[k + 1:])*len(right)/len(Ytrain)
  130. iv = -len(left) / len(Ytrain) * np.log2(len(left) / len(Ytrain))
  131. iv -= len(right) / len(Ytrain) * np.log2(len(right) / len(Ytrain))
  132. else:
  133. p = sum(weights[:k+1])
  134. left_ent = calc_ent(Ytrain[:k + 1], weights[:k+1]) * p
  135. right_ent = calc_ent(Ytrain[k + 1:], weights[k+1:]) * (1-p)
  136. iv = -p * np.log2(p)
  137. iv -= (1-p) * np.log2(1-p)
  138. gain_ent = (h - left_ent - right_ent)/iv
  139. if gain_ent > max_ent:
  140. max_ent = gain_ent
  141. max_mean = left[-1]
  142. return max_ent, max_mean
  143. # 样本权重
  144. weights = []
  145. # 计算某个属性的信息增益率
  146. def cal_ent_attr_c45(Xtrain, Ytrain, weights):
  147. # 对每个属性
  148. max_ent = 0
  149. max_i = 0
  150. max_mean = 0
  151. weights = weights / sum(weights)
  152. for i in range(Xtrain.shape[1]): #每个属性
  153. argsort = Xtrain[:,i].argsort()
  154. x,y,w = Xtrain[:,i][argsort], Ytrain[argsort], weights[argsort]
  155. gain_ent, mean = cal_max_ent_attr_c45(x, y, w)
  156. if gain_ent > max_ent:
  157. max_ent = gain_ent
  158. max_i = i
  159. max_mean = mean
  160. return max_i, max_mean, max_ent
  161. # 计算某个属性的基尼指数
  162. def cal_gini_attr(Xtrain, Ytrain):
  163. # print('sharp', Xtrain.shape)
  164. # 对每个属性
  165. min_ent = 100
  166. min_i = 0
  167. min_mean = 0
  168. for i in range(Xtrain.shape[1]):
  169. x_value_list = set([Xtrain[j][i] for j in range(Xtrain.shape[0])])
  170. mean = sum(x_value_list)/len(x_value_list)
  171. sum_ent = 0
  172. # 二叉树
  173. p = Ytrain[Xtrain[:, i] > mean]
  174. sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
  175. p = Ytrain[Xtrain[:, i] <= mean]
  176. sum_ent = sum_ent + cal_gini(p)*len(p)/len(Ytrain)
  177. if sum_ent < min_ent:
  178. min_ent = sum_ent
  179. min_i = i
  180. min_mean = mean
  181. return min_i, min_mean, min_ent
  182. MAX_T = 1
  183. def is_end(Ytrain):
  184. if len(Ytrain) == 0:
  185. return True
  186. if len(set(Ytrain)) == 1: # 只有一个分类
  187. return True
  188. # 强行划分为叶子节点
  189. def leaf_node(Ytrain, weights):
  190. p_set = []
  191. k = 0
  192. for item in Ytrain:
  193. for i in p_set:
  194. if i[0] == item:
  195. i[1] = i[1] + weights[k]
  196. break
  197. else:
  198. i = [item, weights[k]]
  199. p_set.append(i)
  200. k = k + 1
  201. max_item = [0, 0]
  202. for item in p_set:
  203. if item[1] > max_item[1]:
  204. max_item = item
  205. # print('这个是叶子节点,value:', max_item[0])
  206. return TreeNode(-1, 0, 0, True, max_item[0], len(Ytrain), distrib(Ytrain))
  207. def distrib(Ytrain):
  208. x_value_list = set([Ytrain[i] for i in range(Ytrain.shape[0])])
  209. ent = 0.0
  210. d_list = np.zeros(3, dtype=int)
  211. for x_value in x_value_list:
  212. d_list[x_value] = len([1 for i in Ytrain == x_value if i])
  213. return d_list
  214. def fit(Xtrain, Ytrain, parent_node, depth, weights):
  215. if is_end(Ytrain):
  216. # print('这个是叶子节点')
  217. return leaf_node(Ytrain, weights)
  218. if depth >= MAX_T:
  219. return leaf_node(Ytrain, weights)
  220. i, mean, min_ent = cal_ent_attr_c45(Xtrain, Ytrain, weights)
  221. total_ent = calc_ent(Ytrain)
  222. # print("第", i, "个属性,mean:", mean)
  223. # 生成节点
  224. parent_node = TreeNode(i, mean, total_ent - min_ent, False, -2, len(Ytrain), distrib(Ytrain))
  225. # 切分数据
  226. right_Ytrain = Ytrain[Xtrain[:, i] > mean]
  227. right_Xtrain = Xtrain[Xtrain[:, i] > mean]
  228. # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
  229. right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1, weights[Xtrain[:, i] > mean])
  230. left_Ytrain = Ytrain[Xtrain[:, i] <= mean]
  231. left_Xtrain = Xtrain[Xtrain[:, i] <= mean]
  232. # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
  233. left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1, weights[Xtrain[:, i] <= mean])
  234. parent_node.left = left_node
  235. parent_node.right = right_node
  236. return parent_node
  237. def print_width(nodes, depth):
  238. if len(nodes) == 0:
  239. return
  240. print("--第", depth, "层--")
  241. node_down = []
  242. for node in nodes:
  243. print(node)
  244. if node.left is not None:
  245. node_down.append(node.left)
  246. if node.right is not None:
  247. node_down.append(node.right)
  248. print_width(node_down, depth+1)
  249. def predit_one(X, Y, node):
  250. if node.is_leaf:
  251. # print(class_names[node.y], class_names[Y])
  252. if node.y == 0:
  253. return -1
  254. return node.y
  255. else:
  256. if X[node.idx] <= node.idx_value:
  257. return predit_one(X,Y,node.left)
  258. else:
  259. return predit_one(X, Y, node.right)
  260. def predict(Xtest, Ytest, node):
  261. result = []
  262. for i in range(Xtest.shape[0]):
  263. result.append(predit_one(Xtest[i], None, node))
  264. return np.array(result)
  265. if __name__ == '__main__':
  266. Xtrain, Xtest, Ytrain, Ytest = read_data()
  267. print(calc_ent1(Ytrain))
  268. weights = np.ones(len(Ytrain))/Ytrain.shape[0]
  269. print("熵值", calc_ent(Ytrain))
  270. print("熵值", calc_ent(Ytrain, weights))
  271. print("基尼指数", cal_gini(Ytrain))
  272. print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain))
  273. node = fit(Xtrain, Ytrain, None, 0, weights)
  274. print_width([node], 1)
  275. print(predict(Xtest, Ytest, node))