train_xsquare.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. import sys
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. reload(sys)
  7. sys.setdefaultencoding('utf-8')
  8. import numpy as np
  9. from sklearn.linear_model import LinearRegression
  10. from sklearn import metrics
  11. '''
  12. lesson4
  13. 把特征性扩展平方
  14. '''
  15. def extend_feature(x):
  16. return [x[0], x[0] * x[0]]
  17. def read_data(path):
  18. with open(path) as f:
  19. lines = f.readlines()
  20. lines = [eval(line.strip()) for line in lines]
  21. X, y = zip(*lines)
  22. X = np.array(X)
  23. y = np.array(y)
  24. return X, y
  25. def read_data2(path):
  26. with open(path) as f:
  27. lines = f.readlines()
  28. lines = [eval(line.strip()) for line in lines]
  29. X, y = zip(*lines)
  30. X = [extend_feature(x) for x in X]
  31. X = np.array(X)
  32. y = np.array(y)
  33. return X, y
  34. def drawScatterAndLine(p, q, w, b):
  35. plt.scatter(p, q)
  36. plt.xlabel('p')
  37. plt.ylabel('q')
  38. plt.title('line regesion')
  39. x = np.arange(-11, 11)
  40. y = w * x + b
  41. plt.plot(x, y, color='red')
  42. plt.show()
  43. def drawScatterAndLine2(p, q, w, b):
  44. plt.scatter(p, q)
  45. plt.xlabel('p')
  46. plt.ylabel('q')
  47. plt.title('line regesion')
  48. x = np.arange(-11, 11)
  49. y = w[0] * x + w[1]*x*x + b
  50. plt.plot(x, y, color='red')
  51. plt.show()
  52. def test1():
  53. X_train, y_train = read_data("train_paracurve_data")
  54. X_test, y_test = read_data("test_paracurve_data")
  55. model = LinearRegression()
  56. model.fit(X_train, y_train)
  57. print model.coef_
  58. print model.intercept_
  59. y_pred_train = model.predict(X_train)
  60. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  61. print "特征+平方非线性"
  62. print "MSE:", train_mse
  63. y_pred_test = model.predict(X_test)
  64. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  65. print "MSE:", test_mse
  66. print "推广mse差", test_mse - train_mse
  67. return X_train, y_train, model.coef_, model.intercept_
  68. def test2():
  69. print("---------特征性修改平方------------")
  70. X_train, y_train = read_data2("train_paracurve_data")
  71. X_test, y_test = read_data2("test_paracurve_data")
  72. model = LinearRegression()
  73. model.fit(X_train, y_train)
  74. print model.coef_
  75. print model.intercept_
  76. y_pred_train = model.predict(X_train)
  77. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  78. print "特征+平方非线性"
  79. print "MSE:", train_mse
  80. y_pred_test = model.predict(X_test)
  81. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  82. print "MSE:", test_mse
  83. print "推广mse差", test_mse - train_mse
  84. return X_train, y_train, model.coef_, model.intercept_
  85. if __name__ == '__main__':
  86. p,q,w,b = test1()
  87. p = [i[0] for i in p.tolist()]
  88. q = [i[0] for i in q.tolist()]
  89. w = w[0]
  90. b = b[0]
  91. drawScatterAndLine(p, q, w, b)
  92. p,q,w,b = test2()
  93. p = [i[0] for i in p.tolist()]
  94. q = [i[0] for i in q.tolist()]
  95. w = w[0]
  96. b = b[0]
  97. drawScatterAndLine2(p, q, w, b)