yufeng 4 years ago
parent
commit
2e1c47269b
2 changed files with 241 additions and 4 deletions
  1. 232 0
      mix/mix_predict_everyday_200.py
  2. 9 4
      mix/mix_train_190.py

+ 232 - 0
mix/mix_predict_everyday_200.py

@@ -0,0 +1,232 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+holder_stock_list = [
8
+                        # 医疗
9
+                        '000150.SZ', '300300.SZ', '603990.SH', '300759.SZ', '300347.SZ', '300003.SZ', '300253.SZ',
10
+                        # 5G
11
+                        '300698.SZ', '600498.SH', '300310.SZ', '600353.SH', '603912.SH', '603220.SH', '300602.SZ', '600260.SH',
12
+                        # 车联网
13
+                        '002369.SZ', '002920.SZ', '300020.SZ', '002373.SZ', '002869.SZ',
14
+                        # 工业互联网
15
+                        '002184.SZ', '002364.SZ','300310.SZ', '300670.SZ', '300166.SZ', '002169.SZ', '002380.SZ',
16
+                        # 特高压
17
+                        '300341.SZ', '300670.SZ', '300018.SZ', '600268.SH', '002879.SZ',
18
+                        # 基础建设
19
+                        '300041.SZ', '603568.SH', '000967.SZ', '603018.SH',
20
+                        # B
21
+                        '002555.SZ', '002174.SZ',
22
+                        # ROE
23
+    '002976.SZ', '002847.SZ', '002597.SZ', '300686.SZ', '000708.SZ', '603948.SH', '600507.SH', '300401.SZ', '002714.SZ', '600732.SH', '300033.SZ', '300822.SZ', '300821.SZ',
24
+    '002458.SZ', '000708.SZ', '600732.SH', '603719.SH', '300821.SZ', '300800.SZ', '300816.SZ', '300812.SZ', '603195.SH', '300815.SZ', '603053.SH', '603551.SH', '002975.SZ',
25
+    '603949.SH', '002970.SZ', '300809.SZ', '002968.SZ', '300559.SZ', '002512.SZ', '300783.SZ', '300003.SZ', '603489.SH', '300564.SZ', '600802.SH', '002600.SZ',
26
+    '000933.SZ', '601918.SH', '000651.SZ', '002916.SZ', '000568.SZ', '000717.SZ', '600452.SH', '603589.SH', '600690.SH', '603886.SH', '300117.SZ', '000858.SZ', '002102.SZ',
27
+    '300136.SZ', '600801.SH', '600436.SH', '300401.SZ', '002190.SZ', '300122.SZ', '002299.SZ', '603610.SH', '002963.SZ', '600486.SH', '300601.SZ', '300682.SZ', '300771.SZ',
28
+    '000868.SZ', '002607.SZ', '603068.SH', '603508.SH', '603658.SH', '300571.SZ', '603868.SH', '600768.SH', '300760.SZ', '002901.SZ', '603638.SH', '601100.SH', '002032.SZ',
29
+    '600083.SH', '600507.SH', '603288.SH', '002304.SZ', '000963.SZ', '300572.SZ', '000885.SZ', '600995.SH', '300080.SZ', '601888.SH', '000048.SZ', '000333.SZ', '300529.SZ',
30
+    '000537.SZ', '002869.SZ', '600217.SH', '000526.SZ', '600887.SH', '002161.SZ', '600267.SH', '600668.SH', '600052.SH', '002379.SZ', '603369.SH', '601360.SH', '002833.SZ',
31
+    '002035.SZ', '600031.SH', '600678.SH', '600398.SH', '600587.SH', '600763.SH', '002016.SZ', '603816.SH', '000031.SZ', '002555.SZ', '603983.SH', '002746.SZ', '603899.SH',
32
+    '300595.SZ', '300632.SZ', '600809.SH', '002507.SZ', '300198.SZ', '600779.SH', '603568.SH', '300638.SZ', '002011.SZ', '603517.SH', '000661.SZ', '300630.SZ', '000895.SZ',
33
+    '002841.SZ', '300602.SZ', '300418.SZ', '603737.SH', '002755.SZ', '002803.SZ', '002182.SZ', '600132.SH', '300725.SZ', '600346.SH', '300015.SZ', '300014.SZ', '300628.SZ',
34
+    '000789.SZ', '600368.SH', '300776.SZ', '600570.SH', '000509.SZ', '600338.SH', '300770.SZ', '600309.SH', '000596.SZ', '300702.SZ', '002271.SZ', '300782.SZ', '300577.SZ',
35
+    '603505.SH', '603160.SH', '300761.SZ', '603327.SH', '002458.SZ', '300146.SZ', '002463.SZ', '300417.SZ', '600566.SH', '002372.SZ', '600585.SH', '000848.SZ', '600519.SH',
36
+    '000672.SZ', '300357.SZ', '002234.SZ', '603444.SH', '300236.SZ', '603360.SH', '002677.SZ', '300487.SZ', '600319.SH', '002415.SZ', '000403.SZ', '600340.SH', '601318.SH',
37
+
38
+
39
+]
40
+
41
+
42
+def read_data(path):
43
+    lines = []
44
+    with open(path) as f:
45
+        for line in f.readlines()[:]:
46
+            line = eval(line.strip())
47
+            if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
48
+                lines.append(line)
49
+
50
+    size = len(lines[0])
51
+    train_x=[s[:size - 2] for s in lines]
52
+    train_y=[s[size-1] for s in lines]
53
+    return np.array(train_x),np.array(train_y),lines
54
+
55
+
56
+import pymongo
57
+from util.mongodb import get_mongo_table_instance
58
+code_table = get_mongo_table_instance('tushare_code')
59
+k_table = get_mongo_table_instance('stock_day_k')
60
+stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
61
+all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
62
+
63
+
64
+industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
65
+            '汽车配件', '软件服务',
66
+            '互联网', '纺织',
67
+            '塑料', '半导体',]
68
+
69
+A_concept_code_list = [   'TS2', # 5G
70
+                        'TS24', # OLED
71
+                        'TS26', #健康中国
72
+                        'TS43',  #新能源整车
73
+                        'TS59', # 特斯拉
74
+                        'TS65', #汽车整车
75
+                        'TS142', # 物联网
76
+                        'TS153', # 无人驾驶
77
+                        'TS163', # 雄安板块-智慧城市
78
+                        'TS175', # 工业自动化
79
+                        'TS232', # 新能源汽车
80
+                        'TS254', # 人工智能
81
+                        'TS258', # 互联网医疗
82
+                        'TS264', # 工业互联网
83
+                        'TS266', # 半导体
84
+                        'TS269', # 智慧城市
85
+                        'TS271', # 3D玻璃
86
+                        'TS295', # 国产芯片
87
+                        'TS303', # 医疗信息化
88
+                        'TS323', # 充电桩
89
+                        'TS328', # 虹膜识别
90
+                        'TS361', # 病毒
91
+    ]
92
+
93
+
94
+gainian_map = {}
95
+hangye_map = {}
96
+
97
+def predict_today(file, day, model='10_18d', log=True):
98
+    lines = []
99
+    with open(file) as f:
100
+        for line in f.readlines()[:]:
101
+            line = eval(line.strip())
102
+            # if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
103
+            lines.append(line)
104
+
105
+    size = len(lines[0])
106
+
107
+    model=load_model(model)
108
+
109
+    for line in lines:
110
+        train_x = np.array([line[:size - 1]])
111
+        train_x_tmp = train_x[:,:18*20]
112
+        train_x_a = train_x_tmp.reshape(train_x.shape[0], 18, 20, 1)
113
+        # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
114
+        train_x_c = train_x[:,18*20:]
115
+
116
+        result = model.predict([train_x_c, train_x_a, ])
117
+        # print(result, line[-1])
118
+        stock = code_table.find_one({'ts_code':line[-1][0]})
119
+
120
+        if result[0][0] > 0.6:
121
+            if line[-1][0].startswith('688'):
122
+                continue
123
+            # 去掉ST
124
+            if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
125
+                continue
126
+
127
+            if stock['ts_code'] in holder_stock_list:
128
+                print(stock['ts_code'], stock['name'], '维持买入评级')
129
+
130
+            # 跌的
131
+            k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
132
+            # if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
133
+            #     continue
134
+            # if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
135
+            #     continue
136
+            # if k_table_list[-1]['close'] > 80:
137
+            #     continue
138
+
139
+            # 指定某几个行业
140
+            # if stock['industry'] in industry:
141
+            concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
142
+            concept_detail_list = []
143
+
144
+            # 处理行业
145
+            if stock['sw_industry'] in hangye_map:
146
+                i_c = hangye_map[stock['sw_industry']]
147
+                hangye_map[stock['sw_industry']] = i_c + 1
148
+            else:
149
+                hangye_map[stock['sw_industry']] = 1
150
+
151
+            if len(concept_code_list) > 0:
152
+                for concept in concept_code_list:
153
+                    for c in all_concept_code_list:
154
+                        if c['code'] == concept['concept_code']:
155
+                            concept_detail_list.append(c['name'])
156
+
157
+                            if c['name'] in gainian_map:
158
+                                g_c = gainian_map[c['name']]
159
+                                gainian_map[c['name']] = g_c + 1
160
+                            else:
161
+                                gainian_map[c['name']] = 1
162
+
163
+            print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
164
+
165
+            if log is True:
166
+                with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix.txt', mode='a', encoding="utf-8") as f:
167
+                    f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' buy' + '\n')
168
+
169
+        elif result[0][1] > 0.4:
170
+            if stock['ts_code'] in holder_stock_list:
171
+                print(stock['ts_code'], stock['name'], '震荡评级')
172
+
173
+        elif result[0][2] > 0.5:
174
+            if stock['ts_code'] in holder_stock_list:
175
+                print(stock['ts_code'], stock['name'], '赶紧卖出')
176
+        else:
177
+            if stock['ts_code'] in holder_stock_list:
178
+                print(stock['ts_code'], stock['name'], result[0],)
179
+
180
+    # print(gainian_map)
181
+    # print(hangye_map)
182
+
183
+    gainian_list = [(key, gainian_map[key])for key in gainian_map]
184
+    gainian_list = sorted(gainian_list, key=lambda x:x[1], reverse=True)
185
+
186
+    hangye_list = [(key, hangye_map[key])for key in hangye_map]
187
+    hangye_list = sorted(hangye_list, key=lambda x:x[1], reverse=True)
188
+
189
+    print(gainian_list)
190
+    print(hangye_list)
191
+
192
+def _read_pfile_map(path):
193
+    s_list = []
194
+    with open(path, encoding='utf-8') as f:
195
+        for line in f.readlines()[:]:
196
+            s_list.append(line)
197
+    return s_list
198
+
199
+
200
+def join_two_day(a, b):
201
+    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
202
+    b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
203
+    for a in a_list:
204
+        for b in b_list:
205
+            if a[2:11] == b[2:11]:
206
+                print(a)
207
+
208
+
209
+def check_everyday(day, today):
210
+    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
211
+    x = 0
212
+    for a in a_list:
213
+        print(a[:-1])
214
+        k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
215
+        if k_day_list is not None and len(k_day_list) > 0:
216
+            k_day = k_day_list[0]
217
+            k_day_0 = k_day_list[-1]
218
+            k_day_last = k_day_list[1]
219
+            if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
220
+                print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
221
+                x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
222
+
223
+    print(x/len(a_list))
224
+
225
+
226
+if __name__ == '__main__':
227
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
228
+    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
229
+    # multi_predict()
230
+    predict_today("D:\\data\\quantization\\stock215_18d_20200323.log", 20200327, model='215_18d_mix_6D_ma5_s_seq.h5', log=True)
231
+    # join_two_day(20200305, 20200305)
232
+    # check_everyday(20200311, 20200312)

+ 9 - 4
mix/mix_train_190.py

@@ -21,14 +21,19 @@ early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
21 21
 epochs= 88
22 22
 size = 400000 #18W 60W
23 23
 file_path = 'D:\\data\\quantization\\stock196_18d_train2.log'
24
-model_path = '196_18d_mix_6D_ma5_s_seq.h5'
24
+model_path = '196A_18d_mix_6D_ma5_s_seq.h5'
25 25
 file_path1='D:\\data\\quantization\\stock196_18d_test.log'
26 26
 
27 27
 '''
28 28
 大盘预测
29 29
 结果均用使用ma
30
-6 ROC               37,99,28
31
-5 after用5日         
30
+6 ROC   cnn18*18                                             37,99,28
31
+7 ROC + 窗口6*18+ cnn18*18                                        
32
+8 after用5日                                        
33
+9 after5 + roc in before                            
34
+9A after5 + roc in before + beta                    
35
+9B after5 + roc in before + beta + 其他信息         
36
+        
32 37
 '''
33 38
 
34 39
 def read_data(path, path1=file_path1):
@@ -136,7 +141,7 @@ mlp = create_mlp(train_x_c.shape[1], regress=False)
136 141
 # cnn_0 = create_cnn(18, 21, 1, kernel_size=(3, 3), size=64, regress=False, output=128)       # 31 97 46
137 142
 # cnn_0 = create_cnn(18, 21, 1, kernel_size=(6, 6), size=64, regress=False, output=128)         # 29 98 47
138 143
 # cnn_0 = create_cnn(18, 21, 1, kernel_size=(9, 9), size=64, regress=False, output=128)         # 28 97 53
139
-cnn_0 = create_cnn(18, 18, 1, kernel_size=(3, 18), size=96, regress=False, output=96)       #A23 99 33 A' 26 99 36 #B 34 98 43
144
+cnn_0 = create_cnn(18, 18, 1, kernel_size=(6, 18), size=96, regress=False, output=128)       #A23 99 33 A' 26 99 36 #B 34 98 43
140 145
 # cnn_1 = create_cnn(18, 21, 1, kernel_size=(18, 11), size=96, regress=False, output=96)
141 146
 # cnn_1 = create_cnn(9, 26, 1, kernel_size=(2, 14), size=36, regress=False, output=64)
142 147