yufeng 1 year ago
parent
commit
e7fd939b04
1 changed files with 15 additions and 16 deletions
  1. 15 16
      linear/train_jqxx2.py

+ 15 - 16
linear/train_jqxx2.py

@@ -38,28 +38,27 @@ def read_data(path):
38 38
 
39 39
 def demo(file, model_file):
40 40
     X_train,y_train=read_data(file)
41
-    # X_test,y_test=read_data(config.get('application', 'test_data_path'))
42
-    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
41
+    # Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
43 42
     # 一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
44 43
     model = LogisticRegression()
45 44
     # 一调用这个函数,就会不停地找合适的w和b 直到误差最小
46
-    model.fit(Xtrain, Ytrain)
45
+    model.fit(X_train, y_train)
47 46
     # 打印W
48 47
     # print(model.coef_)
49 48
     # 打印b
50 49
     print(model.intercept_)
51 50
     # 模型已经训练完毕,用模型看下在训练集的表现
52
-    y_pred_train = model.predict(Xtrain)
53
-    # sklearn 求解训练集的mse
54
-    # y_train 在训练集上 真实的y值
55
-    # y_pred_train 通过模型预测出来的y值
56
-    # 计算  (y_train-y_pred_train)^2/n
57
-    train_mse = metrics.mean_squared_error(Ytrain, y_pred_train)
58
-    print("train准确率:", accuracy_score(y_pred_train, Ytrain))
59
-
60
-    # 看下在测试集上的效果
61
-    y_pred_test = model.predict(Xtest)
62
-    print("test准确率:", accuracy_score(y_pred_test, Ytest))
51
+    # y_pred_train = model.predict(Xtrain)
52
+    # # sklearn 求解训练集的mse
53
+    # # y_train 在训练集上 真实的y值
54
+    # # y_pred_train 通过模型预测出来的y值
55
+    # # 计算  (y_train-y_pred_train)^2/n
56
+    # train_mse = metrics.mean_squared_error(Ytrain, y_pred_train)
57
+    # print("train准确率:", accuracy_score(y_pred_train, Ytrain))
58
+    #
59
+    # # 看下在测试集上的效果
60
+    # y_pred_test = model.predict(Xtest)
61
+    # print("test准确率:", accuracy_score(y_pred_test, Ytest))
63 62
     # 保存模型
64 63
     joblib.dump(model,  model_file)
65 64
 
@@ -127,6 +126,6 @@ def demo_3(file, model_file):
127 126
 if __name__ == '__main__':
128 127
     root_dir = 'D:\\data\\quantization\\jqxx2\\'
129 128
     model_dir = 'D:\\data\\quantization\\jqxx2_svm_model\\'
130
-    m = '399306.SZ.log'
131
-    demo(root_dir + m, model_dir + str(m)[:6] + '.pkl')
129
+    m = '399308.SZ.log'
130
+    demo_1(root_dir + m, model_dir + str(m)[:6] + '.pkl')
132 131