train_5d.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. import os
  8. sys.path.append(os.path.abspath('..'))
  9. from util.config import config
  10. import numpy as np
  11. from sklearn.linear_model import LinearRegression
  12. from sklearn import metrics
  13. from sklearn.model_selection import train_test_split
  14. from draw import draw_util
  15. import joblib
  16. def curce_data(x,y,y_pred):
  17. x=x.tolist()
  18. y=y.tolist()
  19. y_pred=y_pred.tolist()
  20. results=zip(x,y,y_pred)
  21. results=["{},{},{}".format(s[0],s[1][0],s[2][0]) for s in results ]
  22. return results
  23. def read_data(path):
  24. with open(path) as f :
  25. lines=f.readlines()
  26. lines=[eval(line.strip()) for line in lines]
  27. X,z,y=zip(*lines)
  28. X=np.array(X)
  29. y=np.array(y)
  30. return X,y
  31. def demo(file, model_file):
  32. X_train,y_train=read_data(file)
  33. # X_test,y_test=read_data(config.get('application', 'test_data_path'))
  34. Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
  35. # 一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
  36. model = LinearRegression()
  37. # 一调用这个函数,就会不停地找合适的w和b 直到误差最小
  38. model.fit(Xtrain, Ytrain)
  39. # 打印W
  40. # print(model.coef_)
  41. # 打印b
  42. print(model.intercept_)
  43. # 模型已经训练完毕,用模型看下在训练集的表现
  44. y_pred_train = model.predict(Xtrain)
  45. # sklearn 求解训练集的mse
  46. # y_train 在训练集上 真实的y值
  47. # y_pred_train 通过模型预测出来的y值
  48. # 计算 (y_train-y_pred_train)^2/n
  49. train_mse = metrics.mean_squared_error(Ytrain, y_pred_train)
  50. print("训练集MSE:", train_mse)
  51. # 看下在测试集上的效果
  52. y_pred_test = model.predict(Xtest)
  53. # print(y_pred_test)
  54. test_mse = metrics.mean_squared_error(Ytest, y_pred_test)
  55. print("测试集MSE:",test_mse)
  56. # 保存模型
  57. joblib.dump(model, model_file)
  58. def draw_line():
  59. x_train, y_train = read_data("../bbztx/train_data")
  60. print(x_train.tolist())
  61. print(y_train.tolist())
  62. draw_util.drawScatter(x_train.tolist(), y_train.tolist())
  63. if __name__ == '__main__':
  64. root_dir = 'D:\\data\\quantization\\5d\\'
  65. model_dir = 'D:\\data\\quantization\\5d_lr_model\\'
  66. list = os.listdir(root_dir)
  67. for f in list:
  68. print(f)
  69. m = f.split('.')[0][-6:]
  70. demo(root_dir + str(f), model_dir + '' + str(m) + '.pkl')