analy_mse_entropy.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. import sys
  4. reload(sys)
  5. sys.setdefaultencoding('utf-8')
  6. import random
  7. import math
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. def read_data(path):
  11. with open(path) as f :
  12. lines=f.readlines()
  13. lines=[eval(line.strip()) for line in lines]
  14. def cal_mse(data,w,b=3):
  15. y2=[ w*x[0]+b for [x,_] in data]
  16. y=[s[1][0] for s in data]
  17. mse=sum([(s1-s2)*(s1-s2) for [s1,s2] in zip(y,y2)])/len(data)
  18. return mse
  19. with open("train_data") as f :
  20. lines=f.readlines()
  21. data=[eval(line.strip()) for line in lines]
  22. sub_data1=random.sample(data,10)
  23. sub_data2=random.sample(data,10)
  24. sub_data3=random.sample(data,50)
  25. sub_data4=random.sample(data,100)
  26. # results=["w,all_sample,one_sample1,one_sample2,50_sample,100_sampe"]
  27. results = []
  28. x= []
  29. y= []
  30. def drawLine(x, y):
  31. plt.xlabel("w")
  32. plt.ylabel("mse")
  33. plt.plot(x, [i[0] for i in y], color="black")
  34. plt.plot(x, [i[1] for i in y], color="red")
  35. plt.plot(x, [i[2] for i in y], color="green")
  36. plt.plot(x, [i[3] for i in y], color="blue")
  37. plt.plot(x, [i[4] for i in y], color="yellow")
  38. plt.show()
  39. '''
  40. 计算w
  41. '''
  42. for i in range(-2000,2000):
  43. w=5 + 1.0*i/500
  44. mse=cal_mse(data,w)
  45. sub_mse1=cal_mse(sub_data1,w)
  46. sub_mse2=cal_mse(sub_data2,w)
  47. sub_mse3=cal_mse(sub_data3,w)
  48. sub_mse4=cal_mse(sub_data4,w)
  49. results.append("{},{},{},{},{},{}".format(w,mse,sub_mse1,sub_mse2,sub_mse3,sub_mse4))
  50. # print "{},{},{},{},{},{}".format(w, mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4)
  51. x.append(w)
  52. y.append((mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4))
  53. drawLine(x, y)
  54. x = []
  55. y = []
  56. '''
  57. 计算b
  58. '''
  59. for i in range(-2000,2000):
  60. b=1 + 1.0*i/500
  61. mse=cal_mse(data,5, b)
  62. sub_mse1=cal_mse(sub_data1,5, b)
  63. sub_mse2=cal_mse(sub_data2,5, b)
  64. sub_mse3=cal_mse(sub_data3,5, b)
  65. sub_mse4=cal_mse(sub_data4,5, b)
  66. results.append("{},{},{},{},{},{}".format(w,mse,sub_mse1,sub_mse2,sub_mse3,sub_mse4))
  67. # print "{},{},{},{},{},{}".format(w, mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4)
  68. x.append(b)
  69. y.append((mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4))
  70. drawLine(x, y)
  71. '''
  72. with open("mse_curve.csv","w") as f :
  73. f.writelines("\n".join(results))
  74. '''