Browse Source

数据更新

yufeng0528 4 years ago
parent
commit
f95c729bb5
2 changed files with 16 additions and 9 deletions
  1. 9 0
      bbztx/get_data.py
  2. 7 9
      linear/train.py

+ 9 - 0
bbztx/get_data.py

@@ -115,6 +115,7 @@ def to_a1(attr_list):
115 115
         new_attr.append(sell_count)
116 116
 
117 117
         print(new_attr)
118
+        new_attr_list.append(new_attr)
118 119
 
119 120
     return new_attr_list
120 121
 
@@ -145,6 +146,14 @@ def to_list_attr(item, a_list):
145 146
     return c_list
146 147
 
147 148
 
149
+def to_file(data_list):
150
+    with open("train_data", "w") as f:
151
+        for line in data_list:
152
+            line = [line[:-2], [line[-1]]]
153
+            f.write(str(line) + "\n")
154
+
155
+
148 156
 if __name__ == '__main__':
149 157
     attr_list = get_articles()
150 158
     new_attr_list = to_a1(attr_list)
159
+    to_file(new_attr_list)

+ 7 - 9
linear/train.py

@@ -4,8 +4,6 @@
4 4
 最简单的mse
5 5
 '''
6 6
 import sys
7
-reload(sys)
8
-sys.setdefaultencoding('utf-8')
9 7
 
10 8
 import numpy as np
11 9
 from sklearn.linear_model import LinearRegression
@@ -32,7 +30,7 @@ def read_data(path):
32 30
 	return X,y
33 31
 
34 32
 
35
-def test():
33
+def demo():
36 34
 	X_train,y_train=read_data("train_data")
37 35
 	X_test,y_test=read_data("test_data")
38 36
 
@@ -41,9 +39,9 @@ def test():
41 39
 	#一调用这个函数,就会不停地找合适的w和b 直到误差最小
42 40
 	model.fit(X_train, y_train)
43 41
 	#打印W
44
-	print model.coef_
42
+	print(model.coef_)
45 43
 	#打印b
46
-	print model.intercept_
44
+	print(model.intercept_)
47 45
 	#模型已经训练完毕,用模型看下在训练集的表现
48 46
 	y_pred_train = model.predict(X_train)
49 47
 	#sklearn 求解训练集的mse
@@ -51,16 +49,16 @@ def test():
51 49
 	# y_pred_train 通过模型预测出来的y值
52 50
 	#计算  (y_train-y_pred_train)^2/n
53 51
 	train_mse = metrics.mean_squared_error(y_train, y_pred_train)
54
-	print "训练集MSE:".decode('utf-8'), train_mse
52
+	print("训练集MSE:", train_mse)
55 53
 
56 54
 	#看下在测试集上的效果
57 55
 	y_pred_test = model.predict(X_test)
58 56
 	test_mse = metrics.mean_squared_error(y_test, y_pred_test)
59
-	print "测试集MSE:".decode('utf-8'),test_mse
57
+	print("测试集MSE:",test_mse)
60 58
 
61 59
 	# train_curve = curce_data(X_train,y_train,y_pred_train)
62 60
 	# test_curve = curce_data(X_test,y_test,y_pred_test)
63
-	print "推广mse差".decode('utf-8'), test_mse-train_mse
61
+	print("推广mse差", test_mse-train_mse)
64 62
 
65 63
 	'''
66 64
 	with open("train_curve.csv","w") as f :
@@ -80,7 +78,7 @@ def draw_line():
80 78
 
81 79
 if __name__ == '__main__':
82 80
 	# draw_line()
83
-	p, q, w,b = test()
81
+	p, q, w,b = demo()
84 82
 	p = [i[0] for i in p.tolist()]
85 83
 	q = [i[0] for i in q.tolist()]
86 84
 	w = w[0]