Browse Source

和sklearn的比较下

yufeng0528 4 years ago
parent
commit
8144e8526e
2 changed files with 23 additions and 3 deletions
  1. 17 2
      integr/my_gbdt.py
  2. 6 1
      integr/my_gbdt_classic.py

+ 17 - 2
integr/my_gbdt.py

@@ -31,7 +31,7 @@ def fit(Xtrain, Ytrain):
31 31
         # 求残差
32 32
         gx = gx - fx0
33 33
         print("第", i, '轮 残差', gx[:10])
34
-        clf = tree.DecisionTreeRegressor(criterion="mse", max_features=5, max_depth=10)
34
+        clf = tree.DecisionTreeRegressor(criterion="mse", max_features=5, max_depth=10, random_state=10)
35 35
         clf.fit(Xtrain, gx)
36 36
         trees.append(clf)
37 37
 
@@ -74,7 +74,22 @@ def score(Xtest, Ytest, trees, fx0):
74 74
         sum = sum + (gx[i] - Ytest[i]) ** 2
75 75
     print("test mse0", sum / Ytest.shape[0])
76 76
 
77
+
77 78
 if __name__ == '__main__':
78 79
     Xtrain, Xtest, Ytrain, Ytest = read_data()
79 80
     trees, fx0 = fit(Xtrain, Ytrain)
80
-    score(Xtest, Ytest, trees, fx0)
81
+    score(Xtest, Ytest, trees, fx0)
82
+
83
+    gbm2 = GradientBoostingRegressor(n_estimators=55, max_depth=10, learning_rate=0.7,
84
+                                     max_features='sqrt', random_state=10)
85
+    gbm2.fit(Xtrain, Ytrain)  # 分数越高越好
86
+    print("gbdt1", gbm2.score(Xtest, Ytest))
87
+
88
+    gx = gbm2.predict(Xtest)
89
+    sum = 0
90
+    for i in range(Ytest.shape[0]):
91
+        sum = sum + (gx[i] - Ytest[i]) ** 2
92
+    print(gx[:10])
93
+    print(Ytest[:10])
94
+    print("gbdt mse", sum / Ytest.shape[0])
95
+

+ 6 - 1
integr/my_gbdt_classic.py

@@ -87,4 +87,9 @@ def score(Xtest, Ytest, trees, fx0):
87 87
 if __name__ == '__main__':
88 88
     Xtrain, Xtest, Ytrain, Ytest = read_data()
89 89
     trees,fx0 = fit(Xtrain, Ytrain)
90
-    score(Xtest, Ytest, trees, fx0)
90
+    score(Xtest, Ytest, trees, fx0)
91
+
92
+    gbm1 = GradientBoostingClassifier(n_estimators=10, max_depth=1, learning_rate=0.7,
93
+                                      max_features='sqrt', random_state=10)
94
+    gbm1.fit(Xtrain, Ytrain)
95
+    print("gbdt", gbm1.score(Xtest, Ytest))