Browse Source

指数预测

yufeng 4 years ago
parent
commit
7a50338500
1 changed files with 77 additions and 0 deletions
  1. 77 0
      stock/dnn_predict_by_day.py

+ 77 - 0
stock/dnn_predict_by_day.py

@@ -0,0 +1,77 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+def read_data(path):
8
+    day_lines = {}
9
+    with open(path) as f:
10
+        for line in f.readlines()[:]:
11
+            line = eval(line.strip())
12
+            date = str(line[-1][-1])
13
+            if date in day_lines:
14
+                day_lines[date].append(line)
15
+            else:
16
+                day_lines[date] = [line]
17
+    # print(len(day_lines['20191230']))
18
+    return day_lines
19
+
20
+
21
+def predict(file_path='', model_path='15min_dnn_seq.h5'):
22
+    day_lines = read_data(file_path)
23
+    print('数据读取完毕')
24
+
25
+    models = []
26
+    for x in range(0, 12):
27
+        models.append(load_model('18d_dnn_seq_' + str(x) + '.h5'))
28
+    estimator = joblib.load('km_dmi_18.pkl')
29
+    print('模型加载完毕')
30
+
31
+    items = sorted(day_lines.keys())
32
+    for key in items:
33
+        # print(day)
34
+        lines = day_lines[key]
35
+
36
+        up_num = 0
37
+        down_num = 0
38
+        x = 21 # 每条数据项数
39
+        k = 18 # 周期
40
+        for line in lines:
41
+            v = line[1:x*k + 1]
42
+            v = np.array(v)
43
+            v = v.reshape(k, x)
44
+            v = v[:,4:8]
45
+            v = v.reshape(1, 4*k)
46
+            # print(v)
47
+            r = estimator.predict(v)
48
+
49
+            if r[0] in [2,4,7,10]:
50
+                train_x = np.array([line[:-1]])
51
+
52
+                result = models[r[0]].predict(train_x)
53
+                if result[0][0] > 0.5 or result[0][1] > 0.5:
54
+                    up_num = up_num + 1
55
+
56
+            elif r[0] in [5,6,11]:
57
+                train_x = np.array([line[:-1]])
58
+
59
+                result = models[r[0]].predict(train_x)
60
+                if result[0][3] > 0.5 or result[0][4] > 0.5:
61
+                    down_num = down_num + 1
62
+
63
+            else:
64
+                train_x = np.array([line[:-1]])
65
+
66
+                result = models[r[0]].predict(train_x)
67
+                if result[0][0] > 0.5 or result[0][1] > 0.5:
68
+                    up_num = up_num + 1
69
+                elif result[0][3] > 0.5 or result[0][4] > 0.5:
70
+                    down_num = down_num + 1
71
+
72
+        print(key, up_num, down_num, (down_num*1.5 + 1)/(up_num*1.5 + 1))
73
+
74
+
75
+if __name__ == '__main__':
76
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
77
+    predict(file_path='D:\\data\\quantization\\stock9_18_20200219.log', model_path='18d_dnn_seq.h5')