Browse Source

大盘指数

yufeng 4 years ago
parent
commit
b964238c19
1 changed files with 21 additions and 22 deletions
  1. 21 22
      stock/dnn_predict_by_day.py

+ 21 - 22
stock/dnn_predict_by_day.py

@@ -18,13 +18,13 @@ def read_data(path):
18 18
     return day_lines
19 19
 
20 20
 
21
-def predict(file_path='', model_path='15min_dnn_seq.h5'):
21
+def predict(file_path='', model_path='15min_dnn_seq'):
22 22
     day_lines = read_data(file_path)
23 23
     print('数据读取完毕')
24 24
 
25 25
     models = []
26 26
     for x in range(0, 12):
27
-        models.append(load_model('10_18d_dnn_seq_' + str(x) + '.h5'))
27
+        models.append(load_model(model_path + '_' + str(x) + '.h5'))
28 28
     estimator = joblib.load('km_dmi_18.pkl')
29 29
     print('模型加载完毕')
30 30
 
@@ -49,27 +49,25 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
49 49
             train_x = np.array([line[:-1]])
50 50
             result = models[r[0]].predict(train_x)
51 51
 
52
-            # if r[0] in [2,4,7,11]:
53
-            #     if result[0][0] > 0.5 or result[0][1] > 0.5:
54
-            #         up_num = up_num + 1
55
-            #     elif result[0][2] > 0.5:
56
-            #         up_num = up_num + 0.12
57
-            # elif r[0] in [1,6]:
58
-            #     if result[0][3] > 0.5 or result[0][4] > 0.5:
59
-            #         down_num = down_num + 1
60
-            #     elif result[0][2] > 0.5:
61
-            #         down_num = down_num + 0.12
62
-            # elif r[0] in [10]: #
63
-            #     if result[0][0] > 0.5 or result[0][1] > 0.5:
52
+            if result[0][3] > 0.5 or result[0][4] > 0.5:
53
+                down_num = down_num + 1
54
+            elif result[0][1] > 0.5 or result[0][2] > 0.5:
55
+                up_num = up_num + 0.5 # 悲观调大 乐观调小
56
+
57
+            # if result[0][0] > 0.5 or result[0][1] > 0.5:
58
+            #     if r[0] in [0,2,3,4,5,9,10,11]:
64 59
             #         up_num = up_num + 1
65
-            #     elif result[0][3] > 0.5 or result[0][4] > 0.5:
60
+            #     elif r[0] in [8]:
61
+            #         up_num = up_num + 0.6
62
+            #     else:
63
+            #         up_num = up_num + 0.4
64
+            # if result[0][3] > 0.5 or result[0][4] > 0.5:
65
+            #     if r[0] in [4,6,]:
66 66
             #         down_num = down_num + 1
67
-            # else:
68
-                # pass
69
-            if result[0][0] > 0.5 or result[0][1] > 0.5:
70
-                up_num = up_num + 0.5
71
-            elif result[0][3] > 0.5 or result[0][4] > 0.5:
72
-                down_num = down_num + 0.5
67
+            #     elif r[0] in [0,1,3,7,8,]:
68
+            #         down_num = down_num + 0.6
69
+            #     else:
70
+            #         down_num = down_num + 0.4
73 71
 
74 72
         print(key, int(up_num), int(down_num), (down_num*1.2 + 2)/(up_num*1.2 + 2))
75 73
 
@@ -78,5 +76,6 @@ if __name__ == '__main__':
78 76
     # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
79 77
     # predict(file_path='D:\\data\\quantization\\stock9_18_20200220.log', model_path='18d_dnn_seq.h5')
80 78
     # predict(file_path='D:\\data\\quantization\\stock9_18_2.log', model_path='18d_dnn_seq.h5')
81
-    predict(file_path='D:\\data\\quantization\\stock10_18d_20190103_20190604.log', model_path='18d_dnn_seq.h5')
79
+    predict(file_path='D:\\data\\quantization\\stock11_18d_20200221.log', model_path='11_18d_dnn_seq')
80
+    # predict(file_path='D:\\data\\quantization\\stock11_18d_20190103_20190604.log', model_path='11_18d_dnn_seq')
82 81
     # predict(file_path='D:\\data\\quantization\\stock9_18_4.log', model_path='18d_dnn_seq.h5')