yufeng0528 4 years ago
parent
commit
930273314f
1 changed files with 38 additions and 38 deletions
  1. 38 38
      linear/train.py

+ 38 - 38
linear/train.py

@@ -14,17 +14,17 @@ from draw import draw_util
14 14
 
15 15
 
16 16
 def curce_data(x,y,y_pred):
17
-	x=x.tolist()
18
-	y=y.tolist()
19
-	y_pred=y_pred.tolist()
20
-	results=zip(x,y,y_pred)
21
-	results=["{},{},{}".format(s[0][0],s[1][0],s[2][0]) for s in results ]
22
-	return results
17
+    x=x.tolist()
18
+    y=y.tolist()
19
+    y_pred=y_pred.tolist()
20
+    results=zip(x,y,y_pred)
21
+    results=["{},{},{}".format(s[0][0],s[1][0],s[2][0]) for s in results ]
22
+    return results
23 23
 
24 24
 
25 25
 def read_data(path):
26
-	with open(path) as f :
27
-		lines=f.readlines()
26
+    with open(path) as f :
27
+        lines=f.readlines()
28 28
 	lines=[eval(line.strip()) for line in lines]
29 29
 	X,y=zip(*lines)
30 30
 	X=np.array(X)
@@ -33,44 +33,44 @@ def read_data(path):
33 33
 
34 34
 
35 35
 def demo():
36
-	X_train,y_train=read_data("../bbztx/train_data")
37
-	X_test,y_test=read_data("../bbztx/test_data")
36
+    X_train,y_train=read_data("../bbztx/train_data")
37
+    X_test,y_test=read_data("../bbztx/test_data")
38 38
 
39 39
 	#一个对象,它代表的线性回归模型,它的成员变量,就已经有了w,b. 刚生成w和b的时候 是随机的
40
-	model = LinearRegression()
41
-	#一调用这个函数,就会不停地找合适的w和b 直到误差最小
42
-	model.fit(X_train, y_train)
43
-	#打印W
44
-	print(model.coef_)
45
-	#打印b
46
-	print(model.intercept_)
47
-	#模型已经训练完毕,用模型看下在训练集的表现
48
-	y_pred_train = model.predict(X_train)
49
-	#sklearn 求解训练集的mse
50
-	# y_train 在训练集上 真实的y值
51
-	# y_pred_train 通过模型预测出来的y值
52
-	#计算  (y_train-y_pred_train)^2/n
53
-	train_mse = metrics.mean_squared_error(y_train, y_pred_train)
54
-	print("训练集MSE:", train_mse)
55
-
56
-	#看下在测试集上的效果
57
-	y_pred_test = model.predict(X_test)
58
-	test_mse = metrics.mean_squared_error(y_test, y_pred_test)
59
-	print("测试集MSE:",test_mse)
60
-
61
-	# train_curve = curce_data(X_train,y_train,y_pred_train)
62
-	test_curve = curce_data(X_test,y_test,y_pred_test)
63
-	print("推广mse差", test_mse-train_mse)
64
-
65
-	'''
40
+    model = LinearRegression()
41
+    #一调用这个函数,就会不停地找合适的w和b 直到误差最小
42
+    model.fit(X_train, y_train)
43
+    #打印W
44
+    print(model.coef_)
45
+    #打印b
46
+    print(model.intercept_)
47
+    #模型已经训练完毕,用模型看下在训练集的表现
48
+    y_pred_train = model.predict(X_train)
49
+    #sklearn 求解训练集的mse
50
+    # y_train 在训练集上 真实的y值
51
+    # y_pred_train 通过模型预测出来的y值
52
+    #计算  (y_train-y_pred_train)^2/n
53
+    train_mse = metrics.mean_squared_error(y_train, y_pred_train)
54
+    print("训练集MSE:", train_mse)
55
+
56
+    #看下在测试集上的效果
57
+    y_pred_test = model.predict(X_test)
58
+    test_mse = metrics.mean_squared_error(y_test, y_pred_test)
59
+    print("测试集MSE:",test_mse)
60
+
61
+    # train_curve = curce_data(X_train,y_train,y_pred_train)
62
+    test_curve = curce_data(X_test,y_test,y_pred_test)
63
+    print("推广mse差", test_mse-train_mse)
64
+
65
+    '''
66 66
 	with open("train_curve.csv","w") as f :
67 67
 		f.writelines("\n".join(train_curve))
68 68
 	保存数据
69 69
 	with open("test_curve.csv","w") as f :
70 70
 		f.writelines("\n".join(test_curve))
71 71
 	'''
72
-	for x in test_curve:
73
-		print(x)
72
+    for x in test_curve:
73
+        print(x)
74 74
 
75 75
 	return X_train,y_train, model.coef_, model.intercept_
76 76