train_feature.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. import sys
  4. reload(sys)
  5. sys.setdefaultencoding('utf-8')
  6. import numpy as np
  7. from sklearn.linear_model import LinearRegression
  8. from sklearn import metrics
  9. '''
  10. 随机特征
  11. '''
  12. def extend_feature_repeat(x):
  13. return [x[0],x[0]]
  14. '''
  15. 重复特征
  16. '''
  17. def extend_feature_random(x):
  18. import random
  19. return [x[0],random.uniform(-10,10)]
  20. def read_data(path, fun):
  21. with open(path) as f :
  22. lines=f.readlines()
  23. lines=[eval(line.strip()) for line in lines]
  24. X,y=zip(*lines)
  25. X=[fun(x) for x in X]
  26. X=np.array(X)
  27. y=np.array(y)
  28. return X,y
  29. def repeat():
  30. X_train,y_train=read_data("train_data", extend_feature_repeat)
  31. X_test,y_test=read_data("test_data", extend_feature_repeat)
  32. model = LinearRegression()
  33. model.fit(X_train, y_train)
  34. print model.coef_
  35. print model.intercept_
  36. y_pred_train = model.predict(X_train)
  37. train_mse=metrics.mean_squared_error(y_train, y_pred_train)
  38. print "重复特征"
  39. print "MSE:", train_mse
  40. y_pred_test = model.predict(X_test)
  41. test_mse=metrics.mean_squared_error(y_test, y_pred_test)
  42. print "MSE:",test_mse
  43. print "推广mse差", test_mse-train_mse
  44. def random():
  45. X_train, y_train = read_data("train_data", extend_feature_random)
  46. X_test, y_test = read_data("test_data", extend_feature_random)
  47. model = LinearRegression()
  48. model.fit(X_train, y_train)
  49. print model.coef_
  50. print model.intercept_
  51. y_pred_train = model.predict(X_train)
  52. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  53. print "+随机特征"
  54. print "MSE:", train_mse
  55. y_pred_test = model.predict(X_test)
  56. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  57. print "MSE:", test_mse
  58. print "推广mse差", test_mse - train_mse
  59. if __name__ == '__main__':
  60. repeat()
  61. print('-------------------')
  62. random()