Browse Source

test_data_path

yufeng0528 4 years ago
parent
commit
53ea33daf0
2 changed files with 2 additions and 2 deletions
  1. 1 1
      integr/random_forest.py
  2. 1 1
      linear/train.py

+ 1 - 1
integr/random_forest.py

@@ -41,7 +41,7 @@ def read_data(path):
41 41
 
42 42
 def demo():
43 43
     X_train, y_train = read_data(config.get('application', 'train_data_path'))
44
-    X_test, y_test = read_data(config.get('application', 'train_data_path'))
44
+    X_test, y_test = read_data(config.get('application', 'test_data_path'))
45 45
     Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
46 46
     rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
47 47
     rfc = rfc.fit(Xtrain, Ytrain)

+ 1 - 1
linear/train.py

@@ -34,7 +34,7 @@ def read_data(path):
34 34
 
35 35
 def demo():
36 36
     X_train,y_train=read_data(config.get('application', 'train_data_path'))
37
-    X_test,y_test=read_data(config.get('application', 'train_data_path'))
37
+    X_test,y_test=read_data(config.get('application', 'test_data_path'))
38 38
 
39 39
 	#一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
40 40
     model = LinearRegression()