|
@@ -0,0 +1,61 @@
|
|
1
|
+
|
|
2
|
+# -*- encoding:utf-8 -*-
|
|
3
|
+from sklearn.tree import DecisionTreeClassifier
|
|
4
|
+from sklearn.ensemble import RandomForestClassifier
|
|
5
|
+from sklearn.ensemble import RandomForestRegressor
|
|
6
|
+from sklearn.datasets import load_wine
|
|
7
|
+from sklearn.model_selection import train_test_split
|
|
8
|
+import numpy as np
|
|
9
|
+from sklearn.tree import DecisionTreeRegressor
|
|
10
|
+from sklearn import metrics
|
|
11
|
+
|
|
12
|
+'''
|
|
13
|
+参数 含义
|
|
14
|
+criterion 不纯度的衡量指标,有基尼系数和信息熵两种选择
|
|
15
|
+max_depth 树的最大深度,超过最大深度的树枝都会被剪掉
|
|
16
|
+min_samples_leaf 一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
|
|
17
|
+min_samples_split 一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
|
|
18
|
+max_features max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃,默认值为总特征个数开平方取整
|
|
19
|
+min_impurity_decrease 限制信息增益的大小,信息增益小于设定数值的分枝不会发生
|
|
20
|
+'''
|
|
21
|
+def demo_wine():
|
|
22
|
+ wine = load_wine()
|
|
23
|
+ #print wine.data
|
|
24
|
+ #print wine.target
|
|
25
|
+ Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
|
|
26
|
+ rfc = RandomForestClassifier(random_state=0,n_estimators=10)
|
|
27
|
+ rfc = rfc.fit(Xtrain,Ytrain)
|
|
28
|
+ print(rfc.score(Xtest,Ytest))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+def read_data(path):
|
|
32
|
+ with open(path) as f :
|
|
33
|
+ lines=f.readlines()
|
|
34
|
+ lines=[eval(line.strip()) for line in lines]
|
|
35
|
+ X,y=zip(*lines)
|
|
36
|
+ X=np.array(X)
|
|
37
|
+ y=np.array(y)
|
|
38
|
+ return X,y
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+def demo():
|
|
42
|
+ X_train, y_train = read_data("../bbztx/train_data")
|
|
43
|
+ X_test, y_test = read_data("../bbztx/test_data")
|
|
44
|
+ Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
|
|
45
|
+ rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
|
|
46
|
+ rfc = rfc.fit(Xtrain, Ytrain)
|
|
47
|
+ print(rfc.score(Xtest, Ytest))
|
|
48
|
+
|
|
49
|
+ print(rfc.predict(X_test))
|
|
50
|
+ # print(rfc.score(Xtest, y_test))
|
|
51
|
+
|
|
52
|
+ y_pred_test = rfc.predict(X_test)
|
|
53
|
+ test_mse = metrics.mean_squared_error(y_test, y_pred_test)
|
|
54
|
+ print("测试集MSE:", test_mse)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+if __name__ == '__main__':
|
|
58
|
+ demo()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
|