yufeng 4 years ago
parent
commit
44b9a04f81
2 changed files with 92 additions and 93 deletions
  1. 76 79
      stock/dnn_predict_dmi.py
  2. 16 14
      stock/dnn_train_dmi.py

+ 76 - 79
stock/dnn_predict_dmi.py

@@ -17,8 +17,28 @@ def read_data(path):
17 17
     train_y=[s[size-1] for s in lines]
18 18
     return np.array(train_x),np.array(train_y),lines
19 19
 
20
+def _score(fact, line):
21
+    with open('dnn_predict_dmi_18d.txt', 'a') as f:
22
+        f.write(str([line[-2], line[-1]]) + "\n")
20 23
 
21
-def predict(file_path='', model_path='15min_dnn_seq.h5'):
24
+    up_right = 0
25
+    up_error = 0
26
+
27
+    if fact[0] == 1:
28
+        up_right = up_right + 1.12
29
+    elif fact[1] == 1:
30
+        up_right = up_right + 1.06
31
+    elif fact[2] == 1:
32
+        up_right = up_right + 1
33
+    elif fact[3] == 1:
34
+        up_right = up_right + 0.94
35
+    else:
36
+        up_error = up_error + 1
37
+        up_right = up_right + 0.88
38
+    return up_right,up_error
39
+
40
+
41
+def predict(file_path='', model_path='15min_dnn_seq.h5', idx=-1):
22 42
     test_x,test_y,lines=read_data(file_path)
23 43
 
24 44
     model=load_model(model_path)
@@ -34,58 +54,31 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
34 54
     i = 0
35 55
     result=model.predict(test_x)
36 56
     win_dnn = []
37
-    with open('dnn_predict_dmi_18d.txt', 'a') as f:
38
-        for r in result:
39
-            fact = test_y[i]
40
-            if r[0] > 0.5:
41
-                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
42
-                win_dnn.append([lines[i][-2], lines[i][-1]])
43
-                if fact[0] == 1:
44
-                    up_right = up_right + 1.12
45
-                elif fact[1] == 1:
46
-                    up_right = up_right + 1.06
47
-                elif fact[2] == 1:
48
-                    up_right = up_right + 1
49
-                elif fact[3] == 1:
50
-                    up_right = up_right + 0.94
51
-                else:
52
-                    up_error = up_error + 1
53
-                    up_right = up_right + 0.88
57
+    for r in result:
58
+        fact = test_y[i]
59
+
60
+        if idx in [0]:
61
+            if r[0] > 0.5 or r[1] > 0.5:
62
+                pass
63
+                # if fact[0] == 1:
64
+                #     up_right = up_right + 1.12
65
+                # elif fact[1] == 1:
66
+                #     up_right = up_right + 1.06
67
+                # elif fact[2] == 1:
68
+                #     up_right = up_right + 1
69
+                # elif fact[3] == 1:
70
+                #     up_right = up_right + 0.94
71
+                # else:
72
+                #     up_error = up_error + 1
73
+                #     up_right = up_right + 0.88
74
+                # up_num = up_num + 1
75
+        else:
76
+            if r[0] > 0.5 or r[1] > 0.5:
77
+                tmp_right,tmp_error = _score(fact, lines[i])
78
+                up_right = tmp_right + up_right
79
+                up_error = tmp_error + up_error
54 80
                 up_num = up_num + 1
55
-            elif r[1] > 0.5:
56
-                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
57
-                win_dnn.append([lines[i][-2], lines[i][-1]])
58
-                if fact[0] == 1:
59
-                    up_right = up_right + 1.12
60
-                elif fact[1] == 1:
61
-                    up_right = up_right + 1.06
62
-                elif fact[2] == 1:
63
-                    up_right = up_right + 1
64
-                elif fact[3] == 1:
65
-                    up_right = up_right + 0.94
66
-                else:
67
-                    up_error = up_error + 1
68
-                    up_right = up_right + 0.88
69
-                up_num = up_num + 1
70
-
71
-            if r[3] > 0.6:
72
-                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
73
-                win_dnn.append([lines[i][-2], lines[i][-1]])
74
-                if fact[0] == 1:
75
-                    down_error = down_error + 1
76
-                    down_right = down_right + 1.12
77
-                elif fact[1] == 1:
78
-                    down_right = down_right + 1.06
79
-                elif fact[2] == 1:
80
-                    down_right = down_right + 1
81
-                elif fact[3] == 1:
82
-                    down_right = down_right + 0.94
83
-                else:
84
-                    down_right = down_right + 0.88
85
-                down_num = down_num + 1
86
-            elif r[4] > 0.6:
87
-                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
88
-                win_dnn.append([lines[i][-2], lines[i][-1]])
81
+            elif r[3] > 0.5 or r[4] > 0.5:
89 82
                 if fact[0] == 1:
90 83
                     down_error = down_error + 1
91 84
                     down_right = down_right + 1.12
@@ -99,9 +92,11 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
99 92
                     down_right = down_right + 0.88
100 93
                 down_num = down_num + 1
101 94
 
102
-            i = i + 1
95
+        i = i + 1
103 96
     if up_num == 0:
104 97
         up_num = 1
98
+    if down_num == 0:
99
+        down_num = 1
105 100
     print('DNN', up_right, up_num, up_right/up_num, up_error/up_num, down_right/down_num, down_error/down_num)
106 101
     return win_dnn,up_right/up_num,down_right/down_num
107 102
 
@@ -110,11 +105,13 @@ def multi_predict():
110 105
     r = 0;
111 106
     p = 0
112 107
     # for x in range(0, 12): # 0,2,3,4,6,8,9,10,11
113
-    # for x in [5,6,11]:
114
-    for x in [2,4,7,10]: # 2表现最好 优秀的
108
+    # for x in [2,3,4,5,6,7,8,9,11]: 10_18,0没数据需要重新计算
109
+    for x in [0,1,10]:
110
+    # for x in [2,4,7,10]: # 2表现最好 优秀的 0,8正确的反向指标,(9错误的反向指标 样本量太少)
115 111
         print(x)
116 112
     # for x in [0,2,5,6,7]: # 5表现最好
117
-        win_dnn, up_ratio,down_ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock9_18_test_' + str(x) + '.log', model_path='18d_dnn_seq_' + str(x) + '.h5')
113
+        win_dnn, up_ratio,down_ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock10_18_test_' + str(x) + '.log',
114
+                                               model_path='18d_dnn_seq_' + str(x) + '.h5', idx=x)
118 115
         r = r + up_ratio
119 116
         p = p + down_ratio
120 117
     print(r, p)
@@ -132,9 +129,9 @@ industry = ['全国地产', '区域地产', '酒店餐饮',
132 129
             '塑料', '电器连锁', '半导体', '乳制品',]
133 130
 
134 131
 
135
-def predict_today(day):
132
+def predict_today(day, model='10_18d'):
136 133
     lines = []
137
-    with open('D:\\data\\quantization\\stock9_18_' +  str(day) +'.log') as f:
134
+    with open('D:\\data\\quantization\\stock' + model + '_' +  str(day) +'.log') as f:
138 135
         for line in f.readlines()[:]:
139 136
             line = eval(line.strip())
140 137
             if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
@@ -148,9 +145,9 @@ def predict_today(day):
148 145
 
149 146
     models = []
150 147
     for x in range(0, 12):
151
-        models.append(load_model('18d_dnn_seq_' + str(x) + '.h5'))
148
+        models.append(load_model(model + '_dnn_seq_' + str(x) + '.h5'))
152 149
 
153
-    x = 21 # 每条数据项数
150
+    x = 24 # 每条数据项数
154 151
     k = 18 # 周期
155 152
     for line in lines:
156 153
         v = line[1:x*k + 1]
@@ -161,21 +158,21 @@ def predict_today(day):
161 158
         # print(v)
162 159
         r = estimator.predict(v)
163 160
 
164
-        if r[0] in [5,6,11]:
165
-            train_x = np.array([line[:size - 1]])
166
-
167
-            result = models[r[0]].predict(train_x)
168
-            if result[0][3] > 0.5 or result[0][4] > 0.5:
169
-                stock = code_table.find_one({'ts_code':line[-1][0]})
170
-                if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
171
-                    continue
172
-                if line[0] > 80:
173
-                    continue
174
-                if stock['industry'] in industry:
175
-                    pass
176
-                    # print(line[-1], stock['name'], stock['industry'], 'sell')
177
-
178
-        if r[0] in [2,4,7,10]:
161
+        # if r[0] in [1,6,10]:
162
+        #     train_x = np.array([line[:size - 1]])
163
+        #
164
+        #     result = models[r[0]].predict(train_x)
165
+        #     if result[0][3] > 0.5 or result[0][4] > 0.5:
166
+        #         stock = code_table.find_one({'ts_code':line[-1][0]})
167
+        #         if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
168
+        #             continue
169
+        #         if line[0] > 80:
170
+        #             continue
171
+        #         if stock['industry'] in industry:
172
+        #             pass
173
+        #             # print(line[-1], stock['name'], stock['industry'], 'sell')
174
+
175
+        if r[0] in [2,3,4,5,6,7,8,9,11]:
179 176
             train_x = np.array([line[:size - 1]])
180 177
 
181 178
             result = models[r[0]].predict(train_x)
@@ -198,12 +195,12 @@ def predict_today(day):
198 195
                     continue
199 196
 
200 197
                 # 指定某几个行业
201
-                # if stock['industry'] in industry:
202
-                print(line[-1], stock['name'], stock['industry'], 'buy')
198
+                if stock['industry'] in industry:
199
+                    print(line[-1], stock['name'], stock['industry'], 'buy')
203 200
 
204 201
 
205 202
 if __name__ == '__main__':
206 203
     # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
207 204
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
208
-    multi_predict()
209
-    # predict_today(20200219)
205
+    # multi_predict()
206
+    predict_today(20200221)

+ 16 - 14
stock/dnn_train_dmi.py

@@ -36,12 +36,14 @@ def read_data(path):
36 36
 
37 37
     return X_resampled,y_resampled,np.array(test_x),np.array(test_y)
38 38
 
39
+data_dir = 'D:\\data\\quantization\\kmeans\\'
39 40
 
40
-def resample(path):
41
+
42
+def resample(path, suffix='test'):
41 43
     lines = []
42 44
     with open(path) as f:
43 45
         i = 0
44
-        for x in range(110000):
46
+        for x in range(67000): # 42万 10万 6.7万
45 47
             # print(i)
46 48
             lines.append(eval(f.readline().strip()))
47 49
             i = i + 1
@@ -49,9 +51,9 @@ def resample(path):
49 51
 
50 52
     file_list = []
51 53
     for x in range(0, 12):
52
-        file_list.append(open('D:\\data\\quantization\\kmeans\\stock9_18_train_' + str(x) + '.log', 'a'))
54
+        file_list.append(open(data_dir + 'stock11_18d_' + suffix + '_' + str(x) + '.log', 'a'))
53 55
 
54
-    x = 21 # 每条数据项数
56
+    x = 24 # 每条数据项数
55 57
     k = 18 # 周期
56 58
     for line in lines:
57 59
         v = line[1:x*k + 1]
@@ -64,14 +66,14 @@ def resample(path):
64 66
         file_list[r[0]].write(str(line) + '\n')
65 67
 
66 68
 
67
-def mul_train():
68
-    # for x in range(0, 12):
69
-    for x in [11,0,1,3,8,9]:
70
-    # for x in [2,4,7,10]:
71
-        score = train(input_dim=384, result_class=5, file_path="D:\\data\\quantization\\kmeans\\stock9_18_train_" + str(x) + ".log",
69
+def mul_train(name="10_18"):
70
+    for x in range(0, 12):
71
+    # for x in [11,0,1,3,8,9]:
72
+    # for x in [11,0,1,3,5,6,8,9]:
73
+        score = train(input_dim=440, result_class=5, file_path=data_dir + "stock"+ name + "_train_" + str(x) + ".log",
72 74
               model_name='18d_dnn_seq_' + str(x) + '.h5')
73 75
 
74
-        with open('D:\\data\\quantization\\kmeans\\stock9_18_dmi.log', 'a') as f:
76
+        with open(data_dir + 'stock' + name + '_dmi.log', 'a') as f:
75 77
             f.write(str(x) + ':' + str(score[1]) + '\n')
76 78
 
77 79
 
@@ -89,7 +91,7 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
89 91
     model.add(Dense(units=120 + input_dim, activation='relu'))
90 92
     model.add(Dropout(0.2))
91 93
     model.add(Dense(units=120+input_dim, activation='selu'))
92
-    model.add(Dropout(0.2))
94
+    model.add(Dropout(0.1))
93 95
     model.add(Dense(units=120+input_dim, activation='selu'))
94 96
     model.add(Dense(units=512, activation='relu'))
95 97
 
@@ -97,7 +99,7 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
97 99
     model.compile(loss='categorical_crossentropy', optimizer="adam",metrics=['accuracy'])
98 100
 
99 101
     print("Starting training ")
100
-    model.fit(train_x, train_y, batch_size=4096, epochs=900 + 6*int(len(train_x)/600), shuffle=True)
102
+    model.fit(train_x, train_y, batch_size=4096, epochs=555 + 5*int(len(train_x)/888), shuffle=True)
101 103
     score = model.evaluate(test_x, test_y)
102 104
     print(score)
103 105
     print('Test score:', score[0])
@@ -115,5 +117,5 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
115 117
 
116 118
 
117 119
 if __name__ == '__main__':
118
-    # resample('D:\\data\\quantization\\stock9_18_1.log')
119
-    mul_train()
120
+    # resample('D:\\data\\quantization\\stock11_18d_test.log', suffix='test')
121
+    mul_train('11_18d')