analy_mse_entropy.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. import sys
  4. reload(sys)
  5. sys.setdefaultencoding('utf-8')
  6. import random
  7. import math
  8. '''
  9. 逻辑回归的mse即KL距离
  10. '''
  11. def logistic(x,w):
  12. d=sum([ s1*s2 for [s1,s2] in zip(x,w)])
  13. r=1.0/(1+math.exp(-1*d))
  14. return r
  15. def cal_entropy(s1,s2):
  16. s2=min(0.999,max(s2,0.001))
  17. if s1==0:
  18. return math.log(1.0/(1-s2))
  19. else:
  20. return math.log(1.0/s2)
  21. def cal_error(data,w):
  22. y2=[ logistic(x,w) for [x,_] in data]
  23. y=[s[1][0] for s in data]
  24. mse=sum([ (s1-s2)*(s1-s2) for [s1,s2] in zip(y,y2)])/len(data)
  25. entropy=sum([ cal_entropy(s1,s2) for [s1,s2] in zip(y,y2)])/len(data)
  26. return mse,entropy
  27. def read_data(path):
  28. with open(path) as f :
  29. lines=f.readlines()
  30. lines=[eval(line.strip()) for line in lines]
  31. return lines
  32. data=read_data("train_data")
  33. results=[]
  34. num=100
  35. step=0.2
  36. for i in range(-num,num):
  37. w1=1.87+step*i
  38. #for j in range(-num,num):
  39. w2=-1.87#+step*j
  40. e1,e2=cal_error(data,[w1,w2])
  41. results.append("{},{},{}".format(w1,e1,e2))
  42. print "{},{},{}".format(w1,e1,e2)
  43. # with open("mse_entropy.csv","w") as f :
  44. # f.writelines("\n".join(results))