yufeng 4 years ago
parent
commit
3def82b241
1 changed files with 203 additions and 0 deletions
  1. 203 0
      stock/dnn_predict_dmi_everyday.py

+ 203 - 0
stock/dnn_predict_dmi_everyday.py

@@ -0,0 +1,203 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+holder_stock_list = [
8
+                     '000063.SZ'
9
+                     '300538.SZ',
10
+                     '002261.SZ',
11
+                     '002475.SZ',
12
+                     '300037.SZ'
13
+                     '300059.SZ',
14
+                     '300244.SZ',
15
+                     '300803.SZ',
16
+                     '300102.SZ',
17
+                     '002815.SZ']
18
+
19
+
20
+def read_data(path):
21
+    lines = []
22
+    with open(path) as f:
23
+        for line in f.readlines()[:]:
24
+            line = eval(line.strip())
25
+            if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
26
+                lines.append(line)
27
+
28
+    size = len(lines[0])
29
+    train_x=[s[:size - 2] for s in lines]
30
+    train_y=[s[size-1] for s in lines]
31
+    return np.array(train_x),np.array(train_y),lines
32
+
33
+
34
+import pymongo
35
+from util.mongodb import get_mongo_table_instance
36
+code_table = get_mongo_table_instance('tushare_code')
37
+k_table = get_mongo_table_instance('stock_day_k')
38
+stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
39
+all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
40
+
41
+
42
+industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
43
+            '汽车配件', '软件服务',
44
+            '互联网', '纺织',
45
+            '塑料', '半导体',]
46
+
47
+A_concept_code_list = [   'TS2', # 5G
48
+                        'TS24', # OLED
49
+                        'TS26', #健康中国
50
+                        'TS43',  #新能源整车
51
+                        'TS59', # 特斯拉
52
+                        'TS65', #汽车整车
53
+                        'TS142', # 物联网
54
+                        'TS153', # 无人驾驶
55
+                        'TS163', # 雄安板块-智慧城市
56
+                        'TS175', # 工业自动化
57
+                        'TS232', # 新能源汽车
58
+                        'TS254', # 人工智能
59
+                        'TS258', # 互联网医疗
60
+                        'TS264', # 工业互联网
61
+                        'TS266', # 半导体
62
+                        'TS269', # 智慧城市
63
+                        'TS271', # 3D玻璃
64
+                        'TS295', # 国产芯片
65
+                        'TS303', # 医疗信息化
66
+                        'TS323', # 充电桩
67
+                        'TS328', # 虹膜识别
68
+                        'TS361', # 病毒
69
+    ]
70
+
71
+
72
+gainian_map = {}
73
+hangye_map = {}
74
+
75
+def predict_today(day, model='10_18d', log=True):
76
+    lines = []
77
+    with open('D:\\data\\quantization\\stock' + model + '_' +  str(day) +'.log') as f:
78
+        for line in f.readlines()[:]:
79
+            line = eval(line.strip())
80
+            # if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
81
+            lines.append(line)
82
+
83
+    size = len(lines[0])
84
+    train_x=[s[:size - 1] for s in lines]
85
+    np.array(train_x)
86
+
87
+    estimator = joblib.load('km_dmi_18.pkl')
88
+
89
+    models = []
90
+    for x in range(0, 12):
91
+        models.append(load_model(model + '_dnn_seq_' + str(x) + '.h5'))
92
+
93
+    x = 24 # 每条数据项数
94
+    k = 18 # 周期
95
+    for line in lines:
96
+        # print(line)
97
+        v = line[1:x*k + 1]
98
+        v = np.array(v)
99
+        v = v.reshape(k, x)
100
+        v = v[:,4:8]
101
+        v = v.reshape(1, 4*k)
102
+        # print(v)
103
+        r = estimator.predict(v)
104
+
105
+        train_x = np.array([line[:size - 1]])
106
+
107
+        result = models[r[0]].predict(train_x)
108
+        # print(result, line[-1])
109
+        stock = code_table.find_one({'ts_code':line[-1][0]})
110
+
111
+        if result[0][0] > 0.6 or result[0][1] > 0.6:
112
+            if line[-1][0].startswith('688'):
113
+                continue
114
+            # 去掉ST
115
+            if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
116
+                continue
117
+
118
+            if stock['ts_code'] in holder_stock_list:
119
+                print(stock['ts_code'], stock['name'], '维持买入评级')
120
+
121
+            # 跌的
122
+            k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
123
+            if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
124
+                continue
125
+            if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
126
+                continue
127
+            if k_table_list[-1]['close'] > 80:
128
+                continue
129
+
130
+            # 指定某几个行业
131
+            # if stock['industry'] in industry:
132
+            concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
133
+            concept_detail_list = []
134
+
135
+            # 处理行业
136
+            if stock['industry'] in hangye_map:
137
+                i_c = hangye_map[stock['industry']]
138
+                hangye_map[stock['industry']] = i_c + 1
139
+            else:
140
+                hangye_map[stock['industry']] = 1
141
+
142
+            if len(concept_code_list) > 0:
143
+                for concept in concept_code_list:
144
+                    for c in all_concept_code_list:
145
+                        if c['code'] == concept['concept_code']:
146
+                            concept_detail_list.append(c['name'])
147
+
148
+                            if c['name'] in gainian_map:
149
+                                g_c = gainian_map[c['name']]
150
+                                gainian_map[c['name']] = g_c + 1
151
+                            else:
152
+                                gainian_map[c['name']] = 1
153
+
154
+            print(line[-1], stock['name'], stock['industry'], str(concept_detail_list), 'buy')
155
+
156
+            if log is True:
157
+                with open('D:\\data\\quantization\\predict\\' + str(day) + '.txt', mode='a', encoding="utf-8") as f:
158
+                    f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['industry'] + ' ' + str(concept_detail_list) + ' buy' + '\n')
159
+
160
+
161
+
162
+
163
+            # concept_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
164
+            # concept_list = [c['concept_code'] for c in concept_list]
165
+
166
+        elif result[0][2] > 0.5:
167
+            if stock['ts_code'] in holder_stock_list:
168
+                print(stock['ts_code'], stock['name'], '震荡评级')
169
+
170
+        elif result[0][3] > 0.5 and result[0][4] > 0.5:
171
+            if stock['ts_code'] in holder_stock_list:
172
+                print(stock['ts_code'], stock['name'], '赶紧卖出')
173
+        else:
174
+            if stock['ts_code'] in holder_stock_list:
175
+                print(stock['ts_code'], stock['name'], result[0], r[0])
176
+
177
+    print(gainian_map)
178
+    print(hangye_map)
179
+
180
+
181
+def _read_pfile_map(path):
182
+    s_list = []
183
+    with open(path, encoding='utf-8') as f:
184
+        for line in f.readlines()[:]:
185
+            s_list.append(line)
186
+    return s_list
187
+
188
+
189
+def join_two_day(a, b):
190
+    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '".txt')
191
+    b_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(b) + '".txt')
192
+    for a in a_list:
193
+        for b in b_list:
194
+            if a[2:11] == b[2:11]:
195
+                print(a)
196
+
197
+
198
+if __name__ == '__main__':
199
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
200
+    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
201
+    # multi_predict()
202
+    predict_today(20200227, model='11_18d', log=True)
203
+    join_two_day(20200226, 20200225)