train.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. import os
  8. sys.path.append(os.path.abspath('..'))
  9. from util.config import config
  10. import numpy as np
  11. from sklearn.linear_model import LinearRegression
  12. from sklearn import metrics
  13. from draw import draw_util
  14. def curce_data(x,y,y_pred):
  15. x=x.tolist()
  16. y=y.tolist()
  17. y_pred=y_pred.tolist()
  18. results=zip(x,y,y_pred)
  19. results=["{},{},{}".format(s[0],s[1][0],s[2][0]) for s in results ]
  20. return results
  21. def read_data(path):
  22. with open(path) as f :
  23. lines=f.readlines()
  24. lines=[eval(line.strip()) for line in lines]
  25. X,y=zip(*lines)
  26. X=np.array(X)
  27. y=np.array(y)
  28. return X,y
  29. def demo():
  30. X_train,y_train=read_data(config.get('application', 'train_data_path'))
  31. X_test,y_test=read_data(config.get('application', 'test_data_path'))
  32. #一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
  33. model = LinearRegression()
  34. #一调用这个函数,就会不停地找合适的w和b 直到误差最小
  35. model.fit(X_train, y_train)
  36. #打印W
  37. print(model.coef_)
  38. #打印b
  39. print(model.intercept_)
  40. #模型已经训练完毕,用模型看下在训练集的表现
  41. y_pred_train = model.predict(X_train)
  42. #sklearn 求解训练集的mse
  43. # y_train 在训练集上 真实的y值
  44. # y_pred_train 通过模型预测出来的y值
  45. #计算 (y_train-y_pred_train)^2/n
  46. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  47. print("训练集MSE:", train_mse)
  48. #看下在测试集上的效果
  49. y_pred_test = model.predict(X_test)
  50. print(y_pred_test)
  51. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  52. print("测试集MSE:",test_mse)
  53. # train_curve = curce_data(X_train,y_train,y_pred_train)
  54. test_curve = curce_data(X_test,y_test,y_pred_test)
  55. print("推广mse差", test_mse-train_mse)
  56. '''
  57. with open("train_curve.csv","w") as f :
  58. f.writelines("\n".join(train_curve))
  59. 保存数据
  60. with open("test_curve.csv","w") as f :
  61. f.writelines("\n".join(test_curve))
  62. '''
  63. for x in test_curve:
  64. print(x)
  65. return X_train,y_train, model.coef_, model.intercept_
  66. def draw_line():
  67. x_train, y_train = read_data("../bbztx/train_data")
  68. print(x_train.tolist())
  69. print(y_train.tolist())
  70. draw_util.drawScatter(x_train.tolist(), y_train.tolist())
  71. if __name__ == '__main__':
  72. # draw_line()
  73. p, q, w,b = demo()
  74. p = [i[0] for i in p.tolist()]
  75. q = [i[0] for i in q.tolist()]
  76. w = w[0]
  77. b = b[0]
  78. # draw_util.drawScatterAndLine(p, q, w, b)