|
@@ -38,28 +38,27 @@ def read_data(path):
|
38
|
38
|
|
39
|
39
|
def demo(file, model_file):
|
40
|
40
|
X_train,y_train=read_data(file)
|
41
|
|
- # X_test,y_test=read_data(config.get('application', 'test_data_path'))
|
42
|
|
- Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
|
|
41
|
+ # Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
|
43
|
42
|
# 一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
|
44
|
43
|
model = LogisticRegression()
|
45
|
44
|
# 一调用这个函数,就会不停地找合适的w和b 直到误差最小
|
46
|
|
- model.fit(Xtrain, Ytrain)
|
|
45
|
+ model.fit(X_train, y_train)
|
47
|
46
|
# 打印W
|
48
|
47
|
# print(model.coef_)
|
49
|
48
|
# 打印b
|
50
|
49
|
print(model.intercept_)
|
51
|
50
|
# 模型已经训练完毕,用模型看下在训练集的表现
|
52
|
|
- y_pred_train = model.predict(Xtrain)
|
53
|
|
- # sklearn 求解训练集的mse
|
54
|
|
- # y_train 在训练集上 真实的y值
|
55
|
|
- # y_pred_train 通过模型预测出来的y值
|
56
|
|
- # 计算 (y_train-y_pred_train)^2/n
|
57
|
|
- train_mse = metrics.mean_squared_error(Ytrain, y_pred_train)
|
58
|
|
- print("train准确率:", accuracy_score(y_pred_train, Ytrain))
|
59
|
|
-
|
60
|
|
- # 看下在测试集上的效果
|
61
|
|
- y_pred_test = model.predict(Xtest)
|
62
|
|
- print("test准确率:", accuracy_score(y_pred_test, Ytest))
|
|
51
|
+ # y_pred_train = model.predict(Xtrain)
|
|
52
|
+ # # sklearn 求解训练集的mse
|
|
53
|
+ # # y_train 在训练集上 真实的y值
|
|
54
|
+ # # y_pred_train 通过模型预测出来的y值
|
|
55
|
+ # # 计算 (y_train-y_pred_train)^2/n
|
|
56
|
+ # train_mse = metrics.mean_squared_error(Ytrain, y_pred_train)
|
|
57
|
+ # print("train准确率:", accuracy_score(y_pred_train, Ytrain))
|
|
58
|
+ #
|
|
59
|
+ # # 看下在测试集上的效果
|
|
60
|
+ # y_pred_test = model.predict(Xtest)
|
|
61
|
+ # print("test准确率:", accuracy_score(y_pred_test, Ytest))
|
63
|
62
|
# 保存模型
|
64
|
63
|
joblib.dump(model, model_file)
|
65
|
64
|
|
|
@@ -127,6 +126,6 @@ def demo_3(file, model_file):
|
127
|
126
|
if __name__ == '__main__':
|
128
|
127
|
root_dir = 'D:\\data\\quantization\\jqxx2\\'
|
129
|
128
|
model_dir = 'D:\\data\\quantization\\jqxx2_svm_model\\'
|
130
|
|
- m = '399306.SZ.log'
|
131
|
|
- demo(root_dir + m, model_dir + str(m)[:6] + '.pkl')
|
|
129
|
+ m = '399308.SZ.log'
|
|
130
|
+ demo_1(root_dir + m, model_dir + str(m)[:6] + '.pkl')
|
132
|
131
|
|