|
@@ -47,7 +47,7 @@ class TreeNode(object):
|
47
|
47
|
|
48
|
48
|
def read_data():
|
49
|
49
|
wine = load_wine()
|
50
|
|
- print(wine.data.shape) # 178*13
|
|
50
|
+ print("数组结构", wine.data.shape) # 178*13
|
51
|
51
|
print(wine.target)
|
52
|
52
|
print(wine.feature_names)
|
53
|
53
|
print(wine.target_names)
|
|
@@ -148,7 +148,7 @@ def cal_max_ent_attr_c45(Xtrain, Ytrain, weights=None):
|
148
|
148
|
weights = weights / sum(weights)
|
149
|
149
|
h = calc_ent(Ytrain, weights)
|
150
|
150
|
p = 0
|
151
|
|
- for k in range(len(Xtrain) - 1):
|
|
151
|
+ for k in range(0, len(Xtrain) - 1, 3):
|
152
|
152
|
left = Xtrain[:k + 1]
|
153
|
153
|
right = Xtrain[k + 1:]
|
154
|
154
|
|
|
@@ -259,40 +259,6 @@ def distrib(Ytrain):
|
259
|
259
|
return d_list
|
260
|
260
|
|
261
|
261
|
|
262
|
|
-def fit(Xtrain, Ytrain, parent_node, depth, weights):
|
263
|
|
-
|
264
|
|
- if is_end(Ytrain):
|
265
|
|
- # print('这个是叶子节点')
|
266
|
|
- return leaf_node(Ytrain, weights)
|
267
|
|
-
|
268
|
|
- if depth >= MAX_T:
|
269
|
|
- return leaf_node(Ytrain, weights)
|
270
|
|
-
|
271
|
|
- i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain, weights)
|
272
|
|
- total_ent = 0 # calc_ent(Ytrain)
|
273
|
|
- # print("第", i, "个属性,mean:", mean)
|
274
|
|
- # 生成节点
|
275
|
|
- parent_node = TreeNode(i, mean, total_ent - min_ent, False, -2, len(Ytrain), distrib(Ytrain))
|
276
|
|
-
|
277
|
|
- # 切分数据
|
278
|
|
- right_position = Xtrain[:, i] > mean
|
279
|
|
- right_Ytrain = Ytrain[right_position]
|
280
|
|
- right_Xtrain = Xtrain[right_position]
|
281
|
|
- # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
|
282
|
|
-
|
283
|
|
- right_node = fit(right_Xtrain, right_Ytrain, parent_node, depth+1, weights[right_position])
|
284
|
|
-
|
285
|
|
- left_position = Xtrain[:, i] <= mean
|
286
|
|
- left_Ytrain = Ytrain[left_position]
|
287
|
|
- left_Xtrain = Xtrain[left_position]
|
288
|
|
- # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
|
289
|
|
- left_node = fit(left_Xtrain, left_Ytrain, parent_node, depth + 1, weights[left_position])
|
290
|
|
-
|
291
|
|
- parent_node.left = left_node
|
292
|
|
- parent_node.right = right_node
|
293
|
|
- return parent_node
|
294
|
|
-
|
295
|
|
-
|
296
|
262
|
def print_width(nodes, depth):
|
297
|
263
|
if len(nodes) == 0:
|
298
|
264
|
return
|
|
@@ -313,8 +279,8 @@ def print_width(nodes, depth):
|
313
|
279
|
def predit_one(X, Y, node):
|
314
|
280
|
if node.is_leaf:
|
315
|
281
|
# print(class_names[node.y], class_names[Y])
|
316
|
|
- if node.y == 0:
|
317
|
|
- return -1
|
|
282
|
+ # if node.y == 0:
|
|
283
|
+ # return -1
|
318
|
284
|
return node.y
|
319
|
285
|
else:
|
320
|
286
|
if X[node.idx] <= node.idx_value:
|
|
@@ -323,12 +289,69 @@ def predit_one(X, Y, node):
|
323
|
289
|
return predit_one(X, Y, node.right)
|
324
|
290
|
|
325
|
291
|
|
326
|
|
-def predict(Xtest, Ytest, node):
|
327
|
|
- result = []
|
328
|
|
- for i in range(Xtest.shape[0]):
|
329
|
|
- result.append(predit_one(Xtest[i], None, node))
|
330
|
|
- return np.array(result)
|
|
292
|
+class MyDT(object):
|
|
293
|
+
|
|
294
|
+ criterion = None
|
|
295
|
+ max_depth = None
|
331
|
296
|
|
|
297
|
+ root_node = None
|
|
298
|
+
|
|
299
|
+ def __init__(self, criterion, max_depth):
|
|
300
|
+ self.criterion = criterion
|
|
301
|
+ self.max_depth = max_depth
|
|
302
|
+
|
|
303
|
+ def fit(self, Xtrain, Ytrain, sample_weight=None):
|
|
304
|
+ if sample_weight is None:
|
|
305
|
+ sample_weight = np.ones(Ytrain.shape[0]) / Ytrain.shape[0]
|
|
306
|
+ self.root_node = self.do_fit(Xtrain, Ytrain, 0, sample_weight)
|
|
307
|
+
|
|
308
|
+ def do_fit(self, Xtrain, Ytrain, depth, weights):
|
|
309
|
+
|
|
310
|
+ if is_end(Ytrain):
|
|
311
|
+ # print('这个是叶子节点')
|
|
312
|
+ return leaf_node(Ytrain, weights)
|
|
313
|
+
|
|
314
|
+ if depth >= self.max_depth:
|
|
315
|
+ return leaf_node(Ytrain, weights)
|
|
316
|
+
|
|
317
|
+ if self.criterion == 'entropy':
|
|
318
|
+ i, mean, min_ent = cal_ent_attr(Xtrain, Ytrain, weights)
|
|
319
|
+ elif self.criterion == 'C4.5':
|
|
320
|
+ i, mean, min_ent = cal_ent_attr_c45(Xtrain, Ytrain, weights)
|
|
321
|
+ else:
|
|
322
|
+ i, mean, min_ent = cal_gini_attr(Xtrain, Ytrain, weights)
|
|
323
|
+ total_ent = 0 # calc_ent(Ytrain)
|
|
324
|
+ # print("第", i, "个属性,mean:", mean)
|
|
325
|
+ # 生成节点
|
|
326
|
+ parent_node = TreeNode(i, mean, total_ent - min_ent, False, None, len(Ytrain), distrib(Ytrain))
|
|
327
|
+
|
|
328
|
+ # 切分数据
|
|
329
|
+ right_position = Xtrain[:, i] > mean
|
|
330
|
+ right_Ytrain = Ytrain[right_position]
|
|
331
|
+ right_Xtrain = Xtrain[right_position]
|
|
332
|
+ # right_Xtrain = np.delete(right_Xtrain, i, axis=1) # 这个属性还可以再被切分
|
|
333
|
+
|
|
334
|
+ right_node = self.do_fit(right_Xtrain, right_Ytrain, depth + 1, weights[right_position])
|
|
335
|
+
|
|
336
|
+ left_position = Xtrain[:, i] <= mean
|
|
337
|
+ left_Ytrain = Ytrain[left_position]
|
|
338
|
+ left_Xtrain = Xtrain[left_position]
|
|
339
|
+ # left_Xtrain = np.delete(left_Xtrain, i, axis=1)
|
|
340
|
+ left_node = self.do_fit(left_Xtrain, left_Ytrain, depth + 1, weights[left_position])
|
|
341
|
+
|
|
342
|
+ parent_node.left = left_node
|
|
343
|
+ parent_node.right = right_node
|
|
344
|
+ return parent_node
|
|
345
|
+
|
|
346
|
+ def predict(self, Xtest):
|
|
347
|
+ result = []
|
|
348
|
+ for i in range(Xtest.shape[0]):
|
|
349
|
+ result.append(predit_one(Xtest[i], None, self.root_node))
|
|
350
|
+ return np.array(result)
|
|
351
|
+
|
|
352
|
+ def score(self, Xtest, Ytest):
|
|
353
|
+ result = self.predict(Xtest)
|
|
354
|
+ return sum(result == Ytest)/Ytest.shape[0]
|
332
|
355
|
|
333
|
356
|
if __name__ == '__main__':
|
334
|
357
|
Xtrain, Xtest, Ytrain, Ytest = read_data()
|
|
@@ -342,7 +365,11 @@ if __name__ == '__main__':
|
342
|
365
|
|
343
|
366
|
print("信息增益率", cal_ent_attr_c45(Xtrain, Ytrain, weights))
|
344
|
367
|
|
345
|
|
- node = fit(Xtrain, Ytrain, None, 0, weights)
|
346
|
|
- print_width([node], 1)
|
|
368
|
+ clf = MyDT(criterion="entropy", max_depth=3,)
|
|
369
|
+ clf.fit(Xtrain, Ytrain, weights)
|
|
370
|
+
|
|
371
|
+ # print_width([node], 1)
|
|
372
|
+
|
|
373
|
+ print(clf.predict(Xtest))
|
347
|
374
|
|
348
|
|
- print(predict(Xtest, Ytest, node))
|
|
375
|
+ print(clf.score(Xtest, Ytest))
|