gradient_linear.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. reload(sys)
  8. sys.setdefaultencoding('utf-8')
  9. import random
  10. import numpy as np
  11. from sklearn.linear_model import LinearRegression
  12. from sklearn import metrics
  13. from draw import draw_util
  14. def read_data(path):
  15. with open(path) as f :
  16. lines=f.readlines()
  17. lines=[eval(line.strip()) for line in lines]
  18. return lines
  19. def cal_step_pow(data, w, b=3):
  20. p = [(w*item[0][0] + b - item[1][0])*item[0][0]*2 for item in data]
  21. sum_p = sum(p)
  22. return sum_p/len(data)
  23. def cal_step_pow_b(data, w, b=1):
  24. p = [(w*item[0][0] + b - item[1][0])*2 for item in data]
  25. sum_p = sum(p)
  26. return sum_p/len(data)
  27. def cal_mse(data, w, b=1):
  28. sum_p = sum([(w * item[0][0] + b - item[1][0]) * (w * item[0][0] + b - item[1][0]) for item in data])
  29. return sum_p / len(data)
  30. def train():
  31. train_data = read_data('train_data')
  32. w = random.uniform(-50, 50)
  33. for i in range(50):
  34. step = cal_step_pow(train_data, w)*0.01
  35. mse = cal_mse(train_data, w)
  36. print w, step, mse
  37. w = w - step
  38. return w
  39. def train_b(w):
  40. train_data = read_data('train_data')
  41. b = random.uniform(-50, 50)
  42. for i in range(1000):
  43. step = cal_step_pow_b(train_data, w, b)*0.01
  44. mse = cal_mse(train_data, w, b)
  45. print b, step, mse
  46. b = b - step
  47. return b
  48. if __name__ == '__main__':
  49. w = train()
  50. print "__________"
  51. b = train_b(w)
  52. print "__________"
  53. print w,b
  54. train_data = read_data('train_data')
  55. X, y = zip(*train_data)
  56. X = np.array(X)
  57. y = np.array(y)
  58. model = LinearRegression()
  59. # 一调用这个函数,就会不停地找合适的w和b 直到误差最小
  60. model.fit(X, y)
  61. print model.coef_, model.intercept_