Browse Source

print(y_pred_test)

yufeng0528 4 years ago
parent
commit
4ea27ba0f5
1 changed files with 1 additions and 1 deletions
  1. 1 1
      linear/train.py

+ 1 - 1
linear/train.py

@@ -55,6 +55,7 @@ def demo():
55 55
 
56 56
     #看下在测试集上的效果
57 57
     y_pred_test = model.predict(X_test)
58
+    print(y_pred_test)
58 59
     test_mse = metrics.mean_squared_error(y_test, y_pred_test)
59 60
     print("测试集MSE:",test_mse)
60 61
 
@@ -81,7 +82,6 @@ def draw_line():
81 82
     print(y_train.tolist())
82 83
     draw_util.drawScatter(x_train.tolist(), y_train.tolist())
83 84
 
84
-
85 85
 if __name__ == '__main__':
86 86
 	# draw_line()
87 87
 	p, q, w,b = demo()