train.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. reload(sys)
  8. sys.setdefaultencoding('utf-8')
  9. import numpy as np
  10. from sklearn.linear_model import LinearRegression
  11. from sklearn import metrics
  12. from draw import draw_util
  13. def curce_data(x,y,y_pred):
  14. x=x.tolist()
  15. y=y.tolist()
  16. y_pred=y_pred.tolist()
  17. results=zip(x,y,y_pred)
  18. results=["{},{},{}".format(s[0][0],s[1][0],s[2][0]) for s in results ]
  19. return results
  20. def read_data(path):
  21. with open(path) as f :
  22. lines=f.readlines()
  23. lines=[eval(line.strip()) for line in lines]
  24. X,y=zip(*lines)
  25. X=np.array(X)
  26. y=np.array(y)
  27. return X,y
  28. def test():
  29. X_train,y_train=read_data("train_data")
  30. X_test,y_test=read_data("test_data")
  31. #一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
  32. model = LinearRegression()
  33. #一调用这个函数,就会不停地找合适的w和b 直到误差最小
  34. model.fit(X_train, y_train)
  35. #打印W
  36. print model.coef_
  37. #打印b
  38. print model.intercept_
  39. #模型已经训练完毕,用模型看下在训练集的表现
  40. y_pred_train = model.predict(X_train)
  41. #sklearn 求解训练集的mse
  42. # y_train 在训练集上 真实的y值
  43. # y_pred_train 通过模型预测出来的y值
  44. #计算 (y_train-y_pred_train)^2/n
  45. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  46. print "训练集MSE:".decode('utf-8'), train_mse
  47. #看下在测试集上的效果
  48. y_pred_test = model.predict(X_test)
  49. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  50. print "测试集MSE:".decode('utf-8'),test_mse
  51. # train_curve = curce_data(X_train,y_train,y_pred_train)
  52. # test_curve = curce_data(X_test,y_test,y_pred_test)
  53. print "推广mse差".decode('utf-8'), test_mse-train_mse
  54. '''
  55. with open("train_curve.csv","w") as f :
  56. f.writelines("\n".join(train_curve))
  57. with open("test_curve.csv","w") as f :
  58. f.writelines("\n".join(test_curve))
  59. '''
  60. return X_train,y_train, model.coef_, model.intercept_
  61. def draw_line():
  62. x_train, y_train = read_data("train_data")
  63. print(x_train.tolist())
  64. print(y_train.tolist())
  65. draw_util.drawScatter(x_train.tolist(), y_train.tolist())
  66. if __name__ == '__main__':
  67. # draw_line()
  68. p, q, w,b = test()
  69. p = [i[0] for i in p.tolist()]
  70. q = [i[0] for i in q.tolist()]
  71. w = w[0]
  72. b = b[0]
  73. draw_util.drawScatterAndLine(p, q, w, b)