|
@@ -1,13 +1,16 @@
|
1
|
1
|
# -*- encoding:utf-8 -*-
|
2
|
2
|
import numpy as np
|
3
|
3
|
from keras.models import load_model
|
|
4
|
+import joblib
|
4
|
5
|
|
5
|
6
|
|
6
|
7
|
def read_data(path):
|
7
|
8
|
lines = []
|
8
|
9
|
with open(path) as f:
|
9
|
10
|
for line in f.readlines()[:]:
|
10
|
|
- lines.append(eval(line.strip()))
|
|
11
|
+ line = eval(line.strip())
|
|
12
|
+ if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
|
|
13
|
+ lines.append(line)
|
11
|
14
|
|
12
|
15
|
size = len(lines[0])
|
13
|
16
|
train_x=[s[:size - 2] for s in lines]
|
|
@@ -23,40 +26,184 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
|
23
|
26
|
print('DNN', score)
|
24
|
27
|
|
25
|
28
|
up_num = 0
|
|
29
|
+ up_error = 0
|
26
|
30
|
up_right = 0
|
|
31
|
+ down_num = 0
|
|
32
|
+ down_error = 0
|
|
33
|
+ down_right = 0
|
27
|
34
|
i = 0
|
28
|
35
|
result=model.predict(test_x)
|
29
|
36
|
win_dnn = []
|
30
|
|
- with open('dnn_predict_dmi_14d.txt', 'a') as f:
|
|
37
|
+ with open('dnn_predict_dmi_18d.txt', 'a') as f:
|
31
|
38
|
for r in result:
|
32
|
39
|
fact = test_y[i]
|
33
|
40
|
if r[0] > 0.5:
|
34
|
41
|
f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
|
35
|
42
|
win_dnn.append([lines[i][-2], lines[i][-1]])
|
36
|
43
|
if fact[0] == 1:
|
|
44
|
+ up_right = up_right + 1.12
|
|
45
|
+ elif fact[1] == 1:
|
37
|
46
|
up_right = up_right + 1.06
|
|
47
|
+ elif fact[2] == 1:
|
|
48
|
+ up_right = up_right + 1
|
|
49
|
+ elif fact[3] == 1:
|
|
50
|
+ up_right = up_right + 0.94
|
|
51
|
+ else:
|
|
52
|
+ up_error = up_error + 1
|
|
53
|
+ up_right = up_right + 0.88
|
|
54
|
+ up_num = up_num + 1
|
|
55
|
+ elif r[1] > 0.5:
|
|
56
|
+ f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
|
|
57
|
+ win_dnn.append([lines[i][-2], lines[i][-1]])
|
|
58
|
+ if fact[0] == 1:
|
|
59
|
+ up_right = up_right + 1.12
|
38
|
60
|
elif fact[1] == 1:
|
|
61
|
+ up_right = up_right + 1.06
|
|
62
|
+ elif fact[2] == 1:
|
39
|
63
|
up_right = up_right + 1
|
|
64
|
+ elif fact[3] == 1:
|
|
65
|
+ up_right = up_right + 0.94
|
40
|
66
|
else:
|
41
|
|
- up_right = up_right + 0.9
|
|
67
|
+ up_error = up_error + 1
|
|
68
|
+ up_right = up_right + 0.88
|
42
|
69
|
up_num = up_num + 1
|
43
|
70
|
|
|
71
|
+ if r[3] > 0.6:
|
|
72
|
+ f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
|
|
73
|
+ win_dnn.append([lines[i][-2], lines[i][-1]])
|
|
74
|
+ if fact[0] == 1:
|
|
75
|
+ down_error = down_error + 1
|
|
76
|
+ down_right = down_right + 1.12
|
|
77
|
+ elif fact[1] == 1:
|
|
78
|
+ down_right = down_right + 1.06
|
|
79
|
+ elif fact[2] == 1:
|
|
80
|
+ down_right = down_right + 1
|
|
81
|
+ elif fact[3] == 1:
|
|
82
|
+ down_right = down_right + 0.94
|
|
83
|
+ else:
|
|
84
|
+ down_right = down_right + 0.88
|
|
85
|
+ down_num = down_num + 1
|
|
86
|
+ elif r[4] > 0.6:
|
|
87
|
+ f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
|
|
88
|
+ win_dnn.append([lines[i][-2], lines[i][-1]])
|
|
89
|
+ if fact[0] == 1:
|
|
90
|
+ down_error = down_error + 1
|
|
91
|
+ down_right = down_right + 1.12
|
|
92
|
+ elif fact[1] == 1:
|
|
93
|
+ down_right = down_right + 1.06
|
|
94
|
+ elif fact[2] == 1:
|
|
95
|
+ down_right = down_right + 1
|
|
96
|
+ elif fact[3] == 1:
|
|
97
|
+ down_right = down_right + 0.94
|
|
98
|
+ else:
|
|
99
|
+ down_right = down_right + 0.88
|
|
100
|
+ down_num = down_num + 1
|
|
101
|
+
|
44
|
102
|
i = i + 1
|
45
|
103
|
if up_num == 0:
|
46
|
104
|
up_num = 1
|
47
|
|
- print('DNN', up_right, up_num, up_right/up_num)
|
48
|
|
- return win_dnn,up_right/up_num
|
|
105
|
+ print('DNN', up_right, up_num, up_right/up_num, up_error/up_num, down_right/down_num, down_error/down_num)
|
|
106
|
+ return win_dnn,up_right/up_num,down_right/down_num
|
49
|
107
|
|
50
|
108
|
|
51
|
109
|
def multi_predict():
|
52
|
|
- r = 0
|
53
|
|
- for x in [0, 2, 4, 5, 7]:
|
54
|
|
- win_dnn, ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock7_14_' + str(x) + '_test.log', model_path='14d_dnn_seq_' + str(x) + '.h5')
|
55
|
|
- r = r + ratio
|
56
|
|
- print(r)
|
|
110
|
+ r = 0;
|
|
111
|
+ p = 0
|
|
112
|
+ # for x in range(0, 12): # 0,2,3,4,6,8,9,10,11
|
|
113
|
+ # for x in [5,6,11]:
|
|
114
|
+ for x in [2,4,7,10]: # 2表现最好 优秀的
|
|
115
|
+ print(x)
|
|
116
|
+ # for x in [0,2,5,6,7]: # 5表现最好
|
|
117
|
+ win_dnn, up_ratio,down_ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock9_18_test_' + str(x) + '.log', model_path='18d_dnn_seq_' + str(x) + '.h5')
|
|
118
|
+ r = r + up_ratio
|
|
119
|
+ p = p + down_ratio
|
|
120
|
+ print(r, p)
|
|
121
|
+
|
|
122
|
+import pymongo
|
|
123
|
+from util.mongodb import get_mongo_table_instance
|
|
124
|
+code_table = get_mongo_table_instance('tushare_code')
|
|
125
|
+k_table = get_mongo_table_instance('stock_day_k')
|
|
126
|
+
|
|
127
|
+industry = ['全国地产', '区域地产', '酒店餐饮',
|
|
128
|
+ '家用电器', '文教休闲', '元器件', 'IT设备', '汽车服务',
|
|
129
|
+ '汽车配件', '港口', '机场', '商贸代理', '软件服务', '证券',
|
|
130
|
+ '供气供热', '多元金融', '百货','食品', '水务',
|
|
131
|
+ '互联网', '纺织', '保险', '航空', '超市连锁', '软饮料',
|
|
132
|
+ '塑料', '电器连锁', '半导体', '乳制品',]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+def predict_today(day):
|
|
136
|
+ lines = []
|
|
137
|
+ with open('D:\\data\\quantization\\stock9_18_' + str(day) +'.log') as f:
|
|
138
|
+ for line in f.readlines()[:]:
|
|
139
|
+ line = eval(line.strip())
|
|
140
|
+ if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
|
|
141
|
+ lines.append(line)
|
|
142
|
+
|
|
143
|
+ size = len(lines[0])
|
|
144
|
+ train_x=[s[:size - 1] for s in lines]
|
|
145
|
+ np.array(train_x)
|
|
146
|
+
|
|
147
|
+ estimator = joblib.load('km_dmi_18.pkl')
|
|
148
|
+
|
|
149
|
+ models = []
|
|
150
|
+ for x in range(0, 12):
|
|
151
|
+ models.append(load_model('18d_dnn_seq_' + str(x) + '.h5'))
|
|
152
|
+
|
|
153
|
+ x = 21 # 每条数据项数
|
|
154
|
+ k = 18 # 周期
|
|
155
|
+ for line in lines:
|
|
156
|
+ v = line[1:x*k + 1]
|
|
157
|
+ v = np.array(v)
|
|
158
|
+ v = v.reshape(k, x)
|
|
159
|
+ v = v[:,4:8]
|
|
160
|
+ v = v.reshape(1, 4*k)
|
|
161
|
+ # print(v)
|
|
162
|
+ r = estimator.predict(v)
|
|
163
|
+
|
|
164
|
+ if r[0] in [5,6,11]:
|
|
165
|
+ train_x = np.array([line[:size - 1]])
|
|
166
|
+
|
|
167
|
+ result = models[r[0]].predict(train_x)
|
|
168
|
+ if result[0][3] > 0.5 or result[0][4] > 0.5:
|
|
169
|
+ stock = code_table.find_one({'ts_code':line[-1][0]})
|
|
170
|
+ if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
|
|
171
|
+ continue
|
|
172
|
+ if line[0] > 80:
|
|
173
|
+ continue
|
|
174
|
+ if stock['industry'] in industry:
|
|
175
|
+ pass
|
|
176
|
+ # print(line[-1], stock['name'], stock['industry'], 'sell')
|
|
177
|
+
|
|
178
|
+ if r[0] in [2,4,7,10]:
|
|
179
|
+ train_x = np.array([line[:size - 1]])
|
|
180
|
+
|
|
181
|
+ result = models[r[0]].predict(train_x)
|
|
182
|
+ # print(result, line[-1])
|
|
183
|
+ if result[0][0] > 0.5 or result[0][1] > 0.5:
|
|
184
|
+ if line[-1][0].startswith('688'):
|
|
185
|
+ continue
|
|
186
|
+ # 去掉ST
|
|
187
|
+ stock = code_table.find_one({'ts_code':line[-1][0]})
|
|
188
|
+ if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
|
|
189
|
+ continue
|
|
190
|
+
|
|
191
|
+ # 跌的
|
|
192
|
+ k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':20200214}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
|
|
193
|
+ if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
|
|
194
|
+ continue
|
|
195
|
+ if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
|
|
196
|
+ continue
|
|
197
|
+ if k_table_list[-1]['close'] > 80:
|
|
198
|
+ continue
|
|
199
|
+
|
|
200
|
+ # 指定某几个行业
|
|
201
|
+ # if stock['industry'] in industry:
|
|
202
|
+ print(line[-1], stock['name'], stock['industry'], 'buy')
|
57
|
203
|
|
58
|
204
|
|
59
|
205
|
if __name__ == '__main__':
|
60
|
206
|
# predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
|
61
|
207
|
# predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
|
62
|
|
- multi_predict()
|
|
208
|
+ multi_predict()
|
|
209
|
+ # predict_today(20200219)
|