Browse Source

增加配置

yufeng0528 2 years ago
parent
commit
5803ba24ab
3 changed files with 9 additions and 8 deletions
  1. 3 2
      integr/random_forest.py
  2. 3 3
      linear/train.py
  3. 3 3
      tree/train.py

+ 3 - 2
integr/random_forest.py

@@ -6,6 +6,7 @@ from sklearn.ensemble import RandomForestRegressor
6 6
 from sklearn.datasets import load_wine
7 7
 from sklearn.model_selection import train_test_split
8 8
 import numpy as np
9
+from util.config import config
9 10
 from sklearn.tree import DecisionTreeRegressor
10 11
 from sklearn import metrics
11 12
 
@@ -39,8 +40,8 @@ def read_data(path):
39 40
 
40 41
 
41 42
 def demo():
42
-    X_train, y_train = read_data("../bbztx/train_data")
43
-    X_test, y_test = read_data("../bbztx/test_data")
43
+    X_train, y_train = read_data(config.get('application', 'train_data_path'))
44
+    X_test, y_test = read_data(config.get('application', 'train_data_path'))
44 45
     Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
45 46
     rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
46 47
     rfc = rfc.fit(Xtrain, Ytrain)

+ 3 - 3
linear/train.py

@@ -6,7 +6,7 @@
6 6
 import sys
7 7
 import os
8 8
 sys.path.append(os.path.abspath('..'))
9
-
9
+from util.config import config
10 10
 import numpy as np
11 11
 from sklearn.linear_model import LinearRegression
12 12
 from sklearn import metrics
@@ -33,8 +33,8 @@ def read_data(path):
33 33
 
34 34
 
35 35
 def demo():
36
-    X_train,y_train=read_data("../bbztx/train_data")
37
-    X_test,y_test=read_data("../bbztx/test_data")
36
+    X_train,y_train=read_data(config.get('application', 'train_data_path'))
37
+    X_test,y_test=read_data(config.get('application', 'train_data_path'))
38 38
 
39 39
 	#一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
40 40
     model = LinearRegression()

+ 3 - 3
tree/train.py

@@ -6,10 +6,10 @@
6 6
 import sys
7 7
 import os
8 8
 sys.path.append(os.path.abspath('..'))
9
-
10 9
 import numpy as np
11 10
 from sklearn.tree import DecisionTreeRegressor
12 11
 from sklearn import metrics
12
+from util.config import config
13 13
 
14 14
 def read_data(path):
15 15
     with open(path) as f :
@@ -22,8 +22,8 @@ def read_data(path):
22 22
 
23 23
 
24 24
 def demo():
25
-    X_train,y_train=read_data("../bbztx/train_data")
26
-    X_test,y_test=read_data("../bbztx/test_data")
25
+    X_train,y_train=read_data(config.get('application', 'train_data_path'))
26
+    X_test,y_test=read_data(config.get('application', 'test_data_path'))
27 27
 
28 28
     dt1 = DecisionTreeRegressor(max_depth=10)
29 29
     dt1.fit(X_train, y_train)