Browse Source

随机森林

yufeng0528 4 years ago
parent
commit
151be3f0fb
1 changed files with 61 additions and 0 deletions
  1. 61 0
      integr/random_forest.py

+ 61 - 0
integr/random_forest.py

@@ -0,0 +1,61 @@
1
+
2
+# -*- encoding:utf-8 -*-
3
+from sklearn.tree import DecisionTreeClassifier
4
+from sklearn.ensemble import RandomForestClassifier
5
+from sklearn.ensemble import RandomForestRegressor
6
+from sklearn.datasets import load_wine
7
+from sklearn.model_selection import train_test_split
8
+import numpy as np
9
+from sklearn.tree import DecisionTreeRegressor
10
+from sklearn import metrics
11
+
12
+'''
13
+参数	含义
14
+criterion	不纯度的衡量指标,有基尼系数和信息熵两种选择
15
+max_depth	树的最大深度,超过最大深度的树枝都会被剪掉
16
+min_samples_leaf	一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
17
+min_samples_split	一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
18
+max_features	max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃,默认值为总特征个数开平方取整
19
+min_impurity_decrease	限制信息增益的大小,信息增益小于设定数值的分枝不会发生
20
+'''
21
+def demo_wine():
22
+    wine = load_wine()
23
+    #print wine.data
24
+    #print wine.target
25
+    Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
26
+    rfc = RandomForestClassifier(random_state=0,n_estimators=10)
27
+    rfc = rfc.fit(Xtrain,Ytrain)
28
+    print(rfc.score(Xtest,Ytest))
29
+
30
+
31
+def read_data(path):
32
+    with open(path) as f :
33
+        lines=f.readlines()
34
+    lines=[eval(line.strip()) for line in lines]
35
+    X,y=zip(*lines)
36
+    X=np.array(X)
37
+    y=np.array(y)
38
+    return X,y
39
+
40
+
41
+def demo():
42
+    X_train, y_train = read_data("../bbztx/train_data")
43
+    X_test, y_test = read_data("../bbztx/test_data")
44
+    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
45
+    rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
46
+    rfc = rfc.fit(Xtrain, Ytrain)
47
+    print(rfc.score(Xtest, Ytest))
48
+
49
+    print(rfc.predict(X_test))
50
+    # print(rfc.score(Xtest, y_test))
51
+
52
+    y_pred_test = rfc.predict(X_test)
53
+    test_mse = metrics.mean_squared_error(y_test, y_pred_test)
54
+    print("测试集MSE:", test_mse)
55
+
56
+
57
+if __name__ == '__main__':
58
+	demo()
59
+
60
+
61
+