|
@@ -4,23 +4,34 @@ from sklearn.datasets import load_wine
|
4
|
4
|
from sklearn.model_selection import train_test_split
|
5
|
5
|
import numpy as np
|
6
|
6
|
|
|
7
|
+feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮',
|
|
8
|
+ '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
|
|
9
|
+class_names=["琴酒", "雪莉", "贝尔摩德"]
|
|
10
|
+
|
7
|
11
|
# 生成决策树的节点类型
|
8
|
12
|
class TreeNode(object):
|
9
|
|
- idx = '' # 属性
|
|
13
|
+ idx = 0 # 属性
|
10
|
14
|
idx_value = 0.0 # 属性值
|
11
|
15
|
is_leaf = False
|
12
|
16
|
y = 0 # 预测值
|
13
|
|
- next_list = [] # 分支
|
|
17
|
+ samples = 0 # 样本数
|
|
18
|
+ value = [] # 分布情况
|
|
19
|
+ left = None
|
|
20
|
+ right = None
|
14
|
21
|
|
15
|
|
- def __init__(self, idx,idx_value, is_leaf, y, next_list):
|
|
22
|
+ def __init__(self, idx,idx_value, is_leaf, y, samples, value, left=None, right=None):
|
16
|
23
|
self.idx = idx
|
17
|
24
|
self.idx_value = idx_value
|
18
|
25
|
self.is_leaf = is_leaf
|
19
|
26
|
self.y = y
|
20
|
|
- self.next_list = next_list
|
|
27
|
+ self.samples = samples
|
|
28
|
+ self.value = value
|
|
29
|
+ self.left = left
|
|
30
|
+ self.right = right
|
21
|
31
|
|
22
|
32
|
def __str__(self):
|
23
|
|
- return "%s,%f,%f,%d" % (self.idx, self.idx_value, self.is_leaf, self.y)
|
|
33
|
+ return "%s,%f, leaf:%d, samples:%d, value:%s, %s" % \
|
|
34
|
+ (feature_name[self.idx], self.idx_value, self.is_leaf,self.samples, self.value, class_names[self.y])
|
24
|
35
|
|
25
|
36
|
|
26
|
37
|
def read_data():
|
|
@@ -101,6 +112,7 @@ def cal_ent_attr(Xtrain, Ytrain):
|
101
|
112
|
min_mean = mean
|
102
|
113
|
return min_i, min_mean
|
103
|
114
|
|
|
115
|
+
|
104
|
116
|
MAX_T = 5
|
105
|
117
|
|
106
|
118
|
|
|
@@ -127,7 +139,7 @@ def leaf_node(Ytrain):
|
127
|
139
|
if item[1] > max_item[1]:
|
128
|
140
|
max_item = item
|
129
|
141
|
print('这个是叶子节点,value:', max_item[0])
|
130
|
|
- return TreeNode('', 0, True, max_item[0], [])
|
|
142
|
+ return TreeNode(0, 0, True, max_item[0], 0, 0)
|
131
|
143
|
|
132
|
144
|
|
133
|
145
|
# def create_node()
|
|
@@ -137,7 +149,7 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
137
|
149
|
|
138
|
150
|
if is_end(Ytrain):
|
139
|
151
|
print('这个是叶子节点')
|
140
|
|
- return TreeNode('', 0, True, 0, [])
|
|
152
|
+ return TreeNode(0, 0, True, 0, 0, 0)
|
141
|
153
|
|
142
|
154
|
if depth > MAX_T:
|
143
|
155
|
return leaf_node(Ytrain)
|
|
@@ -145,22 +157,22 @@ def fit(Xtrain, Ytrain, parent_node, depth):
|
145
|
157
|
i, mean = cal_ent_attr(Xtrain, Ytrain)
|
146
|
158
|
print("第", i, "个属性,mean:", mean)
|
147
|
159
|
# 生成节点
|
148
|
|
- parent_node = TreeNode(i, mean, False, 0, [])
|
|
160
|
+ parent_node = TreeNode(i, mean, False, 0, 0, 0)
|
149
|
161
|
|
150
|
162
|
# 切分数据
|
151
|
163
|
right_Ytrain = Ytrain[Xtrain[:, i] > mean]
|
152
|
164
|
right_Xtrain = Xtrain[Xtrain[:, i] > mean]
|
153
|
|
- right_Xtrain = np.delete(right_Xtrain, i, axis=1)
|
|
165
|
+ # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
|
154
|
166
|
|
155
|
167
|
right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1)
|
156
|
168
|
|
157
|
169
|
left_Ytrain = Ytrain[Xtrain[:, i] <= mean]
|
158
|
170
|
left_Xtrain = Xtrain[Xtrain[:, i] <= mean]
|
159
|
|
- left_Xtrain = np.delete(left_Xtrain, i, axis=1)
|
|
171
|
+ # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
|
160
|
172
|
left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1)
|
161
|
173
|
|
162
|
|
- parent_node.next_list.append(left_node)
|
163
|
|
- parent_node.next_list.append(right_node)
|
|
174
|
+ parent_node.left = left_node
|
|
175
|
+ parent_node.right = right_node
|
164
|
176
|
return parent_node
|
165
|
177
|
|
166
|
178
|
|
|
@@ -172,11 +184,15 @@ def print_width(nodes, depth):
|
172
|
184
|
node_down = []
|
173
|
185
|
for node in nodes:
|
174
|
186
|
print(node)
|
175
|
|
- if len(node.next_list) > 0:
|
176
|
|
- node_down.extend(node.next_list)
|
|
187
|
+ if node.left is not None:
|
|
188
|
+ node_down.append(node.left)
|
|
189
|
+ if node.right is not None:
|
|
190
|
+ node_down.append(node.right)
|
|
191
|
+
|
177
|
192
|
|
178
|
193
|
print_width(node_down, depth+1)
|
179
|
194
|
|
|
195
|
+
|
180
|
196
|
if __name__ == '__main__':
|
181
|
197
|
Xtrain, Xtest, Ytrain, Ytest = read_data()
|
182
|
198
|
print(calc_ent1(Ytrain))
|