|
@@ -0,0 +1,41 @@
|
|
1
|
+#!/usr/bin/python
|
|
2
|
+# -*- coding: UTF-8 -*-
|
|
3
|
+'''
|
|
4
|
+最简单的mse
|
|
5
|
+'''
|
|
6
|
+import sys
|
|
7
|
+import os
|
|
8
|
+sys.path.append(os.path.abspath('..'))
|
|
9
|
+
|
|
10
|
+import numpy as np
|
|
11
|
+from sklearn.tree import DecisionTreeRegressor
|
|
12
|
+from sklearn import metrics
|
|
13
|
+
|
|
14
|
+def read_data(path):
|
|
15
|
+ with open(path) as f :
|
|
16
|
+ lines=f.readlines()
|
|
17
|
+ lines=[eval(line.strip()) for line in lines]
|
|
18
|
+ X,y=zip(*lines)
|
|
19
|
+ X=np.array(X)
|
|
20
|
+ y=np.array(y)
|
|
21
|
+ return X,y
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+def demo():
|
|
25
|
+ X_train,y_train=read_data("../bbztx/train_data")
|
|
26
|
+ X_test,y_test=read_data("../bbztx/test_data")
|
|
27
|
+
|
|
28
|
+ dt1 = DecisionTreeRegressor(max_depth=10)
|
|
29
|
+ dt1.fit(X_train, y_train)
|
|
30
|
+
|
|
31
|
+ y_pred_train = dt1.predict(X_train)
|
|
32
|
+ train_mse = metrics.mean_squared_error(y_train, y_pred_train)
|
|
33
|
+ print("训练集MSE:", train_mse)
|
|
34
|
+
|
|
35
|
+ y_pred_test = dt1.predict(X_test)
|
|
36
|
+ test_mse = metrics.mean_squared_error(y_test, y_pred_test)
|
|
37
|
+ print("测试集MSE:", test_mse)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+if __name__ == '__main__':
|
|
41
|
+ demo()
|