Browse Source

手写KL距离

yufeng0528 4 years ago
parent
commit
50e33848af
1 changed files with 60 additions and 0 deletions
  1. 60 0
      logistic/analy_mse_entropy.py

+ 60 - 0
logistic/analy_mse_entropy.py

@@ -0,0 +1,60 @@
1
+#!/usr/bin/python
2
+# -*- coding: UTF-8 -*-
3
+import sys
4
+reload(sys) 
5
+sys.setdefaultencoding('utf-8') 
6
+import random
7
+import math
8
+
9
+'''
10
+逻辑回归的mse即KL距离
11
+'''
12
+def logistic(x,w):
13
+	d=sum([ s1*s2 for [s1,s2] in zip(x,w)])
14
+	r=1.0/(1+math.exp(-1*d))
15
+	return r
16
+
17
+def cal_entropy(s1,s2):
18
+	s2=min(0.999,max(s2,0.001))
19
+	if s1==0:
20
+		return math.log(1.0/(1-s2))
21
+	else:
22
+		return math.log(1.0/s2)
23
+
24
+def cal_error(data,w):
25
+	y2=[ logistic(x,w) for [x,_] in data]
26
+	y=[s[1][0] for s in data]
27
+	mse=sum([   (s1-s2)*(s1-s2) for [s1,s2] in zip(y,y2)])/len(data)
28
+	entropy=sum([   cal_entropy(s1,s2) for [s1,s2] in zip(y,y2)])/len(data)
29
+	return  mse,entropy
30
+
31
+
32
+def read_data(path):
33
+	with open(path) as f :
34
+		lines=f.readlines()
35
+	lines=[eval(line.strip()) for line in lines]
36
+	return lines
37
+
38
+data=read_data("train_data")
39
+results=[]
40
+num=100
41
+step=0.2
42
+for i in range(-num,num):
43
+	w1=1.87+step*i
44
+	#for j in range(-num,num):
45
+	w2=-1.87#+step*j
46
+	e1,e2=cal_error(data,[w1,w2])
47
+	results.append("{},{},{}".format(w1,e1,e2))
48
+	print "{},{},{}".format(w1,e1,e2)
49
+
50
+# with open("mse_entropy.csv","w") as f  :
51
+# 	f.writelines("\n".join(results))
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+