Browse Source

线性不可分

yufeng0528 4 years ago
parent
commit
a003ef848a
1 changed files with 65 additions and 0 deletions
  1. 65 0
      logistic/train_noseparability.py

+ 65 - 0
logistic/train_noseparability.py

@@ -0,0 +1,65 @@
1
+# -*- encoding:utf-8 -*-
2
+from sklearn import datasets
3
+from sklearn.model_selection import train_test_split 
4
+from sklearn.linear_model import LogisticRegression
5
+from sklearn.model_selection import cross_val_predict
6
+from numpy import shape
7
+from sklearn import metrics
8
+import numpy as np
9
+import random
10
+'''
11
+线性不可分
12
+'''
13
+def curve(x_train,w,w0):
14
+	results=x_train.tolist()
15
+	results=[x[0:2] for x in results]
16
+	step=0.0001
17
+	for i in np.arange(-0.2,1.2,step):
18
+		x1=i+step
19
+		x2=-1*(w[0]*x1+w0)/(w[1]+w[2]*x1) # 计算mse
20
+		if abs(x2)>5.0:
21
+			continue
22
+		results.append([x1,x2])
23
+	results=["{},{}".format(x1,x2) for [x1,x2] in results]
24
+	return results
25
+
26
+
27
+def get_data(center_label,num=100):
28
+	X_train=[]
29
+	y_train=[]
30
+	sigma=0.01
31
+	for point,label in center_label:
32
+		c1,c2=point
33
+		for _ in range(0,num):
34
+			x1=c1+random.uniform(-sigma,sigma)
35
+			x2=c2+random.uniform(-sigma,sigma)
36
+			X_train.append([x1,x2])
37
+			y_train.append([label])
38
+	return X_train,y_train
39
+
40
+
41
+center_label=[[[0,0],1],[[1,1],1],[[0,1],0],[[1,0],0]]
42
+X_train,y_train=get_data(center_label)
43
+#X_train=10*[[0,0],[1,1],[1,0],[0,1]]
44
+X_train=[ x+[x[0]*x[1]] for x in X_train]
45
+X_train=np.array(X_train)
46
+ 
47
+#model = LogisticRegression(penalty="l2")
48
+model = LogisticRegression()
49
+model.fit(X_train, y_train)
50
+ 
51
+print (model.coef_)
52
+print (model.intercept_)
53
+curve_results=curve(X_train,model.coef_.tolist()[0],model.intercept_.tolist()[0])
54
+
55
+'''
56
+with open("no_separa_traindata.csv","w") as f :
57
+	f.writelines("\n".join(curve_results[0:400]))
58
+with open("no_separa_train_with_splitline.csv","w") as f :
59
+	f.writelines("\n".join(curve_results))
60
+	'''
61
+
62
+
63
+
64
+ 
65
+