yufeng 4 years ago
parent
commit
84510e9adf

+ 0 - 0
integr/stock_train.py


+ 1 - 1
mix/mix_predict_300.py

@@ -92,7 +92,7 @@ def predict(file_path='', model_path='15min_dnn_seq.h5', idx=-1, row=18, col=20)
92 92
 if __name__ == '__main__':
93 93
     # predict(file_path='D:\\data\\quantization\\stock181_18d_test.log', model_path='181_18d_mix_6D_ma5_s_seq.h5')
94 94
     # predict(file_path='D:\\data\\quantization\\stock217_18d_train1.log', model_path='218_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
95
-    predict(file_path='D:\\data\\quantization\\stock321_28d_test.log', model_path='321_28d_mix_5D_ma5_s_seq.h5', row=28, col=20)
95
+    predict(file_path='D:\\data\\quantization\\stock324_28d_20191211.log', model_path='324_28d_mix_5D_ma5_s_seq.h5', row=28, col=18)
96 96
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
97 97
     # multi_predict(model='15_18d')
98 98
     # predict_today(20200229, model='11_18d')

+ 1 - 1
mix/mix_predict_400.py

@@ -90,7 +90,7 @@ if __name__ == '__main__':
90 90
     # predict(file_path='D:\\data\\quantization\\stock181_18d_test.log', model_path='181_18d_mix_6D_ma5_s_seq.h5')
91 91
     # predict(file_path='D:\\data\\quantization\\stock217_18d_train1.log', model_path='218_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
92 92
     # predict(file_path='D:\\data\\quantization\\stock400_18d_train1.log', model_path='400_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
93
-    predict(file_path='D:\\data\\quantization\\stock403_30d_train1.log', model_path='403_30d_mix_5D_ma5_s_seq_2.h5', row=30, col=20)
93
+    predict(file_path='D:\\data\\quantization\\stock417_30d_train1.log', model_path='417_30d_mix_5D_ma5_s_seq.h5', row=30, col=19)
94 94
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
95 95
     # multi_predict(model='15_18d')
96 96
     # predict_today(20200229, model='11_18d')

+ 1 - 1
mix/mix_predict_500.py

@@ -94,7 +94,7 @@ if __name__ == '__main__':
94 94
     # predict(file_path='D:\\data\\quantization\\stock181_18d_test.log', model_path='181_18d_mix_6D_ma5_s_seq.h5')
95 95
     # predict(file_path='D:\\data\\quantization\\stock217_18d_train1.log', model_path='218_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
96 96
     # predict(file_path='D:\\data\\quantization\\stock400_18d_train1.log', model_path='400_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
97
-    predict(file_path='D:\\data\\quantization\\stock501_28d_train1.log', model_path='501_28d_mix_5D_ma5_s_seq.h5', row=28, col=19)
97
+    predict(file_path='D:\\data\\quantization\\stock507_28d_train1.log', model_path='507_28d_mix_5D_ma5_s_seq.h5', row=28, col=19)
98 98
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
99 99
     # multi_predict(model='15_18d')
100 100
     # predict_today(20200229, model='11_18d')

+ 96 - 0
mix/mix_predict_600.py

@@ -0,0 +1,96 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+def read_data(path):
8
+    lines = []
9
+    with open(path) as f:
10
+        for line in f.readlines()[:]:
11
+            line = eval(line.strip())
12
+            lines.append(line)
13
+
14
+    size = len(lines[0])
15
+    train_x=[s[:size - 2] for s in lines]
16
+    train_y=[s[size-1] for s in lines]
17
+    return np.array(train_x, dtype=np.float32),np.array(train_y, dtype=np.float32),lines
18
+
19
+
20
+def _score(fact, line):
21
+    up_right = 0
22
+    up_error = 0
23
+
24
+    if fact[0] == 1:
25
+        up_right = up_right + 1.06
26
+    elif fact[1] == 1:
27
+        up_error = up_error + 0.3
28
+        up_right = up_right + 0.98
29
+    else:
30
+        up_error = up_error + 1
31
+        up_right = up_right + 0.94
32
+    return up_right,up_error
33
+
34
+
35
+def predict(file_path='', model_path='15min_dnn_seq.h5', idx=-1, row=18, col=20):
36
+    test_x,test_y,lines=read_data(file_path)
37
+
38
+    test_x_a = test_x[:,:row*col]
39
+    test_x_a = test_x_a.reshape(test_x.shape[0], row, col, 1)
40
+    # test_x_b = test_x[:, row*col:row*col+18*2]
41
+    # test_x_b = test_x_b.reshape(test_x.shape[0], 18, 2, 1)
42
+    test_x_c = test_x[:,row*col:]
43
+
44
+    model=load_model(model_path)
45
+    score = model.evaluate([test_x_c, test_x_a, ], test_y)
46
+    print('MIX', score)
47
+
48
+    up_num = 0
49
+    up_error = 0
50
+    up_right = 0
51
+    down_num = 0
52
+    down_error = 0
53
+    down_right = 0
54
+    i = 0
55
+    result = model.predict([test_x_c, test_x_a, ])
56
+    win_dnn = []
57
+    for r in result:
58
+        fact = test_y[i]
59
+
60
+        if idx in [-2]:
61
+            if r[0] > 0.5 or r[1] > 0.5:
62
+                pass
63
+        else:
64
+            if r[0] > 0.5:
65
+                tmp_right,tmp_error = _score(fact, lines[i])
66
+                up_right = tmp_right + up_right
67
+                up_error = tmp_error + up_error
68
+                up_num = up_num + 1
69
+            elif r[2] > 0.5:
70
+                if fact[0] == 1:
71
+                    down_error = down_error + 1
72
+                    down_right = down_right + 1.04
73
+                elif fact[1] == 1:
74
+                    down_error = down_error + 0.3
75
+                    down_right = down_right + 0.98
76
+                else:
77
+                    down_right = down_right + 0.92
78
+                down_num = down_num + 1
79
+
80
+        i = i + 1
81
+    if up_num == 0:
82
+        up_num = 1
83
+    if down_num == 0:
84
+        down_num = 1
85
+    print('MIX', up_right, up_num, up_right/up_num, up_error/up_num, down_right/down_num, down_error/down_num)
86
+    return win_dnn,up_right/up_num,down_right/down_num
87
+
88
+
89
+if __name__ == '__main__':
90
+    # predict(file_path='D:\\data\\quantization\\stock181_18d_test.log', model_path='181_18d_mix_6D_ma5_s_seq.h5')
91
+    # predict(file_path='D:\\data\\quantization\\stock217_18d_train1.log', model_path='218_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
92
+    # predict(file_path='D:\\data\\quantization\\stock400_18d_train1.log', model_path='400_18d_mix_5D_ma5_s_seq.h5', row=18, col=18)
93
+    predict(file_path='D:\\data\\quantization\\stock603_30d_train1.log', model_path='603_30d_mix_5D_ma5_s_seq.h5', row=30, col=19)
94
+    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
95
+    # multi_predict(model='15_18d')
96
+    # predict_today(20200229, model='11_18d')

+ 5 - 3
mix/mix_predict_by_day_190.py

@@ -60,9 +60,9 @@ def predict(file_path='', model_path='15min_dnn_seq', rows=18, cols=18):
60 60
             if result[0][0]> 0.5:
61 61
                 up_num = up_num + ratio
62 62
             elif result[0][1] > 0.5:
63
-                up_num = up_num + 0.5*ratio
63
+                up_num = up_num + 0.01*ratio
64 64
             elif result[0][2] > 0.5:
65
-                down_num = down_num + 0.5*ratio
65
+                down_num = down_num + 0.01*ratio
66 66
             else:
67 67
                 down_num = down_num + ratio
68 68
 
@@ -86,5 +86,7 @@ if __name__ == '__main__':
86 86
     # predict(file_path='D:\\data\\quantization\\stock9_18_2.log', model_path='18d_dnn_seq.h5')
87 87
     # predict(file_path='D:\\data\\quantization\\stock16_18d_20200310.log', model_path='16_18d_mix_seq')
88 88
     # predict(file_path='D:\\data\\quantization\\stock196_18d_20200326.log', model_path='196_18d_mix_6D_ma5_s_seq')
89
-    predict(file_path='D:\\data\\quantization\\stock321_28d_5D_20200403.log', model_path='321_28d_mix_5D_ma5_s_seq', rows=28, cols=20)
89
+    predict(file_path='D:\\data\\quantization\\stock321_28d_5D_20200414.log', model_path='321_28d_mix_5D_ma5_s_seq_2', rows=28, cols=20)
90 90
     # predict(file_path='D:\\data\\quantization\\stock9_18_4.log', model_path='18d_dnn_seq.h5')
91
+    # predict(file_path='D:\\data\\quantization\\stock324_28d_3D_20200414_A.log', model_path='324_28d_mix_5D_ma5_s_seq', rows=28, cols=18)
92
+    # predict(file_path='D:\\data\\quantization\\stock324_28d_3D_20200414_A.log', model_path='603_30d_mix_5D_ma5_s_seq', rows=30, cols=19)

+ 86 - 0
mix/mix_predict_by_day_324.py

@@ -0,0 +1,86 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+def read_data(path):
8
+    day_lines = {}
9
+    with open(path) as f:
10
+        for line in f.readlines()[:]:
11
+            line = eval(line.strip())
12
+            date = str(line[-1][-1])
13
+            if date in day_lines:
14
+                day_lines[date].append(line)
15
+            else:
16
+                day_lines[date] = [line]
17
+    # print(len(day_lines['20191230']))
18
+    return day_lines
19
+
20
+
21
+def predict(file_path='', model_path='15min_dnn_seq', rows=18, cols=18):
22
+    day_lines = read_data(file_path)
23
+    print('数据读取完毕')
24
+
25
+    model=load_model(model_path + '.h5')
26
+    print('模型加载完毕')
27
+
28
+    items = sorted(day_lines.keys())
29
+    for key in items:
30
+        # print(day)
31
+        lines = day_lines[key]
32
+
33
+        up_num = 0
34
+        down_num = 0
35
+        size = len(lines[0])
36
+        x0 = 0
37
+        x1 = 0
38
+        x2 = 0
39
+        x3 = 0
40
+        x4 = 0
41
+
42
+        for line in lines:
43
+            train_x = np.array([line[:size - 1]])
44
+            train_x_a = train_x[:,:rows*cols]
45
+            train_x_a = train_x_a.reshape(train_x.shape[0], rows, cols, 1)
46
+            # train_x_b = train_x[:, 18*18:18*18+2*18]
47
+            # train_x_b = train_x_b.reshape(train_x.shape[0], 18, 2, 1)
48
+            train_x_c = train_x[:,rows*cols:]
49
+
50
+            result = model.predict([train_x_c, train_x_a])
51
+
52
+            ratio = 1
53
+            if train_x_c[0][-1] == 1:
54
+                ratio = 2
55
+            elif train_x_c[0][-2] == 1:
56
+                ratio = 1.6
57
+            elif train_x_c[0][-3] == 1:
58
+                ratio = 1.3
59
+
60
+            if result[0][0]> 0.5:
61
+                up_num = up_num + ratio
62
+            elif result[0][1] > 0.5:
63
+                up_num = up_num + 0.01*ratio
64
+            elif result[0][2] > 0.5:
65
+                down_num = down_num + 0.01*ratio
66
+            else:
67
+                down_num = down_num + ratio
68
+
69
+            maxx = max(result[0])
70
+            if maxx - result[0][0] == 0:
71
+                x0 = x0 + 1
72
+            if maxx - result[0][1] == 0:
73
+                x1 = x1 + 1
74
+            if maxx - result[0][2] == 0:
75
+                x2 = x2 + 1
76
+            if maxx - result[0][3] == 0:
77
+                x3 = x3 + 1
78
+
79
+        # print(key, int(up_num), int(down_num), (down_num*1.2 + 2)/(up_num*1.2 + 2), )
80
+        print(key, x0, x1, x2,x3, (down_num*1.2 + 2)/(up_num*1.2 + 2))
81
+
82
+
83
+if __name__ == '__main__':
84
+    # predict(file_path='D:\\data\\quantization\\stock324_28d_3D_20200415.log', model_path='324_28d_mix_5D_ma5_s_seq', rows=28, cols=18)
85
+
86
+    predict(file_path='D:\\data\\quantization\\stock324_28d_3D_20191221.log', model_path='324_28d_mix_5D_ma5_s_seq', rows=28, cols=18)

+ 0 - 247
mix/mix_predict_everyday_400.py

@@ -1,247 +0,0 @@
1
-# -*- encoding:utf-8 -*-
2
-import numpy as np
3
-from keras.models import load_model
4
-import random
5
-
6
-
7
-holder_stock_list = [
8
-        # 医疗
9
-        '000150.SZ', '300300.SZ', '603990.SH', '300759.SZ', '300347.SZ', '300003.SZ', '300253.SZ','002421.SZ','300168.SZ','002432.SZ',
10
-        # 5G
11
-        '300698.SZ', '600498.SH', '300310.SZ', '600353.SH', '603912.SH', '603220.SH', '300602.SZ', '600260.SH', '002463.SZ','300738.SZ',
12
-        # 车联网
13
-        '002369.SZ', '002920.SZ', '300020.SZ', '002373.SZ', '002869.SZ','300098.SZ','300048.SZ','002401.SZ',
14
-        # 工业互联网
15
-        '002184.SZ', '002364.SZ','300310.SZ', '300670.SZ', '300166.SZ', '002169.SZ', '002380.SZ','002421.SZ',
16
-        # 特高压
17
-        '300341.SZ', '300670.SZ', '300018.SZ', '600268.SH', '002879.SZ','002028.SZ',
18
-        # 基础建设
19
-        '300041.SZ', '603568.SH', '000967.SZ', '603018.SH','002062.SZ',
20
-        # 华为
21
-        '300687.SZ','002316.SZ','300339.SZ','300378.SZ','300020.SZ','300634.SZ','002570.SZ',
22
-        '600801.SH', '300113.SZ','002555.SZ', '002174.SZ','600585.SH','600276.SH','002415.SZ','000651.SZ',
23
-        '300074.SZ'
24
-]
25
-
26
-ROE_stock_list =   [                      # ROE
27
-'002976.SZ', '002847.SZ', '002597.SZ', '300686.SZ', '000708.SZ', '603948.SH', '600507.SH', '300401.SZ', '002714.SZ', '600732.SH', '300033.SZ', '300822.SZ', '300821.SZ',
28
-'002458.SZ', '000708.SZ', '600732.SH', '603719.SH', '300821.SZ', '300800.SZ', '300816.SZ', '300812.SZ', '603195.SH', '300815.SZ', '603053.SH', '603551.SH', '002975.SZ',
29
-'603949.SH', '002970.SZ', '300809.SZ', '002968.SZ', '300559.SZ', '002512.SZ', '300783.SZ', '300003.SZ', '603489.SH', '300564.SZ', '600802.SH', '002600.SZ',
30
-'000933.SZ', '601918.SH', '000651.SZ', '002916.SZ', '000568.SZ', '000717.SZ', '600452.SH', '603589.SH', '600690.SH', '603886.SH', '300117.SZ', '000858.SZ', '002102.SZ',
31
-'300136.SZ', '600801.SH', '600436.SH', '300401.SZ', '002190.SZ', '300122.SZ', '002299.SZ', '603610.SH', '002963.SZ', '600486.SH', '300601.SZ', '300682.SZ', '300771.SZ',
32
-'000868.SZ', '002607.SZ', '603068.SH', '603508.SH', '603658.SH', '300571.SZ', '603868.SH', '600768.SH', '300760.SZ', '002901.SZ', '603638.SH', '601100.SH', '002032.SZ',
33
-'600083.SH', '600507.SH', '603288.SH', '002304.SZ', '000963.SZ', '300572.SZ', '000885.SZ', '600995.SH', '300080.SZ', '601888.SH', '000048.SZ', '000333.SZ', '300529.SZ',
34
-'000537.SZ', '002869.SZ', '600217.SH', '000526.SZ', '600887.SH', '002161.SZ', '600267.SH', '600668.SH', '600052.SH', '002379.SZ', '603369.SH', '601360.SH', '002833.SZ',
35
-'002035.SZ', '600031.SH', '600678.SH', '600398.SH', '600587.SH', '600763.SH', '002016.SZ', '603816.SH', '000031.SZ', '002555.SZ', '603983.SH', '002746.SZ', '603899.SH',
36
-'300595.SZ', '300632.SZ', '600809.SH', '002507.SZ', '300198.SZ', '600779.SH', '603568.SH', '300638.SZ', '002011.SZ', '603517.SH', '000661.SZ', '300630.SZ', '000895.SZ',
37
-'002841.SZ', '300602.SZ', '300418.SZ', '603737.SH', '002755.SZ', '002803.SZ', '002182.SZ', '600132.SH', '300725.SZ', '600346.SH', '300015.SZ', '300014.SZ', '300628.SZ',
38
-'000789.SZ', '600368.SH', '300776.SZ', '600570.SH', '000509.SZ', '600338.SH', '300770.SZ', '600309.SH', '000596.SZ', '300702.SZ', '002271.SZ', '300782.SZ', '300577.SZ',
39
-'603505.SH', '603160.SH', '300761.SZ', '603327.SH', '002458.SZ', '300146.SZ', '002463.SZ', '300417.SZ', '600566.SH', '002372.SZ', '600585.SH', '000848.SZ', '600519.SH',
40
-'000672.SZ', '300357.SZ', '002234.SZ', '603444.SH', '300236.SZ', '603360.SH', '002677.SZ', '300487.SZ', '600319.SH', '002415.SZ', '000403.SZ', '600340.SH', '601318.SH',
41
-
42
-]
43
-
44
-
45
-import pymongo
46
-from util.mongodb import get_mongo_table_instance
47
-code_table = get_mongo_table_instance('tushare_code')
48
-k_table = get_mongo_table_instance('stock_day_k')
49
-stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
50
-all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
51
-
52
-
53
-industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
54
-            '汽车配件', '软件服务',
55
-            '互联网', '纺织',
56
-            '塑料', '半导体',]
57
-
58
-A_concept_code_list = [   'TS2', # 5G
59
-                        'TS24', # OLED
60
-                        'TS26', #健康中国
61
-                        'TS43',  #新能源整车
62
-                        'TS59', # 特斯拉
63
-                        'TS65', #汽车整车
64
-                        'TS142', # 物联网
65
-                        'TS153', # 无人驾驶
66
-                        'TS163', # 雄安板块-智慧城市
67
-                        'TS175', # 工业自动化
68
-                        'TS232', # 新能源汽车
69
-                        'TS254', # 人工智能
70
-                        'TS258', # 互联网医疗
71
-                        'TS264', # 工业互联网
72
-                        'TS266', # 半导体
73
-                        'TS269', # 智慧城市
74
-                        'TS271', # 3D玻璃
75
-                        'TS295', # 国产芯片
76
-                        'TS303', # 医疗信息化
77
-                        'TS323', # 充电桩
78
-                        'TS328', # 虹膜识别
79
-                        'TS361', # 病毒
80
-    ]
81
-
82
-
83
-gainian_map = {}
84
-hangye_map = {}
85
-
86
-Z_list = []  # 自选
87
-R_list = []  #  ROE
88
-O_list = []  # 其他
89
-
90
-def predict_today(file, day, model='10_18d', log=True):
91
-    lines = []
92
-    with open(file) as f:
93
-        for line in f.readlines()[:]:
94
-            line = eval(line.strip())
95
-            lines.append(line)
96
-
97
-    size = len(lines[0])
98
-
99
-    model=load_model(model)
100
-
101
-    for line in lines:
102
-        train_x = np.array([line[:size - 1]])
103
-        train_x_tmp = train_x[:,:30*19]
104
-        train_x_a = train_x_tmp.reshape(train_x.shape[0], 30, 19, 1)
105
-        # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
106
-        train_x_c = train_x[:,30*19:]
107
-
108
-        result = model.predict([train_x_c, train_x_a, ])
109
-        # print(result, line[-1])
110
-        stock = code_table.find_one({'ts_code':line[-1][0]})
111
-
112
-        if result[0][0] > 0.6:
113
-            if line[-1][0].startswith('688'):
114
-                continue
115
-            # 去掉ST
116
-            if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
117
-                continue
118
-
119
-            k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
120
-            # if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
121
-            #     continue
122
-            # if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
123
-            #     continue
124
-            # if k_table_list[-1]['close'] > 80:
125
-            #     continue
126
-
127
-            # 指定某几个行业
128
-            # if stock['industry'] in industry:
129
-            concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
130
-            concept_detail_list = []
131
-
132
-            # 处理行业
133
-            if stock['sw_industry'] in hangye_map:
134
-                i_c = hangye_map[stock['sw_industry']]
135
-                hangye_map[stock['sw_industry']] = i_c + 1
136
-            else:
137
-                hangye_map[stock['sw_industry']] = 1
138
-
139
-            if len(concept_code_list) > 0:
140
-                for concept in concept_code_list:
141
-                    for c in all_concept_code_list:
142
-                        if c['code'] == concept['concept_code']:
143
-                            concept_detail_list.append(c['name'])
144
-
145
-                            if c['name'] in gainian_map:
146
-                                g_c = gainian_map[c['name']]
147
-                                gainian_map[c['name']] = g_c + 1
148
-                            else:
149
-                                gainian_map[c['name']] = 1
150
-
151
-            if stock['ts_code'] in holder_stock_list:
152
-                print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
153
-                print(stock['ts_code'], stock['name'], '买入评级')
154
-                Z_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
155
-            elif stock['ts_code'] in ROE_stock_list:
156
-                print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
157
-                print(stock['ts_code'], stock['name'], '买入评级')
158
-                R_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
159
-            else:
160
-                O_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
161
-
162
-            if log is True:
163
-                with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix.txt', mode='a', encoding="utf-8") as f:
164
-                    f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' buy' + '\n')
165
-
166
-        # elif result[0][1] > 0.5:
167
-        #     if stock['ts_code'] in holder_stock_list:
168
-        #         print(stock['ts_code'], stock['name'], '震荡评级')
169
-        # elif result[0][2] > 0.4:
170
-        #     if stock['ts_code'] in holder_stock_list:
171
-        #         print(stock['ts_code'], stock['name'], '赶紧卖出')
172
-        # else:
173
-        #     if stock['ts_code'] in holder_stock_list or stock['ts_code'] in ROE_stock_list:
174
-        #         print(stock['ts_code'], stock['name'], result[0],)
175
-
176
-    # print(gainian_map)
177
-    # print(hangye_map)
178
-
179
-    gainian_list = [(key, gainian_map[key])for key in gainian_map]
180
-    gainian_list = sorted(gainian_list, key=lambda x:x[1], reverse=True)
181
-
182
-    hangye_list = [(key, hangye_map[key])for key in hangye_map]
183
-    hangye_list = sorted(hangye_list, key=lambda x:x[1], reverse=True)
184
-
185
-    print(gainian_list)
186
-    print(hangye_list)
187
-
188
-    print('-----买入列表---------')
189
-    print(Z_list)
190
-    print(R_list)
191
-    print(O_list)
192
-
193
-    print('------随机结果--------')
194
-    random.shuffle(Z_list)
195
-    print('自选')
196
-    print(Z_list[:3])
197
-
198
-    random.shuffle(R_list)
199
-    print('ROE')
200
-    print(R_list[:3])
201
-
202
-    random.shuffle(O_list)
203
-    print('其他')
204
-    print(O_list[:3])
205
-
206
-
207
-def _read_pfile_map(path):
208
-    s_list = []
209
-    with open(path, encoding='utf-8') as f:
210
-        for line in f.readlines()[:]:
211
-            s_list.append(line)
212
-    return s_list
213
-
214
-
215
-def join_two_day(a, b):
216
-    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
217
-    b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
218
-    for a in a_list:
219
-        for b in b_list:
220
-            if a[2:11] == b[2:11]:
221
-                print(a)
222
-
223
-
224
-def check_everyday(day, today):
225
-    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
226
-    x = 0
227
-    for a in a_list:
228
-        print(a[:-1])
229
-        k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
230
-        if k_day_list is not None and len(k_day_list) > 0:
231
-            k_day = k_day_list[0]
232
-            k_day_0 = k_day_list[-1]
233
-            k_day_last = k_day_list[1]
234
-            if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
235
-                print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
236
-                x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
237
-
238
-    print(x/len(a_list))
239
-
240
-
241
-if __name__ == '__main__':
242
-    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
243
-    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
244
-    # multi_predict()
245
-    predict_today("D:\\data\\quantization\\stock405_30d_20200403.log", 20200403, model='405_30d_mix_5D_ma5_s_seq.h5', log=True)
246
-    # join_two_day(20200305, 20200305)
247
-    # check_everyday(20200311, 20200312)

+ 38 - 89
mix/mix_predict_everyday_500.py

@@ -2,71 +2,28 @@
2 2
 import numpy as np
3 3
 from keras.models import load_model
4 4
 import random
5
-
6
-
7
-zixuan_stock_list = [
8
-                        # 医疗
9
-                        '000150.SZ', '300300.SZ', '603990.SH', '300759.SZ', '300347.SZ', '300003.SZ', '300253.SZ','002421.SZ','300168.SZ','002432.SZ',
10
-                        # 5G
11
-                        '300698.SZ', '600498.SH', '300310.SZ', '600353.SH', '603912.SH', '603220.SH', '300602.SZ', '600260.SH', '002463.SZ','300738.SZ',
12
-                        # 车联网
13
-                        '002369.SZ', '002920.SZ', '300020.SZ', '002373.SZ', '002869.SZ','300098.SZ','300048.SZ','002401.SZ',
14
-                        # 工业互联网
15
-                        '002184.SZ', '002364.SZ','300310.SZ', '300670.SZ', '300166.SZ', '002169.SZ', '002380.SZ','002421.SZ',
16
-                        # 特高压
17
-                        '300341.SZ', '300670.SZ', '300018.SZ', '600268.SH', '002879.SZ','002028.SZ',
18
-                        # 基础建设
19
-                        '300041.SZ', '603568.SH', '000967.SZ', '603018.SH','002062.SZ',
20
-                        # 华为
21
-                        '300687.SZ','002316.SZ','300339.SZ','300378.SZ','300020.SZ','300634.SZ','002570.SZ',
22
-                        '600801.SH', '300113.SZ','002555.SZ', '002174.SZ','600585.SH','600276.SH','002415.SZ','000651.SZ',
23
-                        '300074.SZ'
24
-]
25
-
26
-ROE_stock_list =   [                      # ROE
27
-'002976.SZ', '002847.SZ', '002597.SZ', '300686.SZ', '000708.SZ', '603948.SH', '600507.SH', '300401.SZ', '002714.SZ', '600732.SH', '300033.SZ', '300822.SZ', '300821.SZ',
28
-'002458.SZ', '000708.SZ', '600732.SH', '603719.SH', '300821.SZ', '300800.SZ', '300816.SZ', '300812.SZ', '603195.SH', '300815.SZ', '603053.SH', '603551.SH', '002975.SZ',
29
-'603949.SH', '002970.SZ', '300809.SZ', '002968.SZ', '300559.SZ', '002512.SZ', '300783.SZ', '300003.SZ', '603489.SH', '300564.SZ', '600802.SH', '002600.SZ',
30
-'000933.SZ', '601918.SH', '000651.SZ', '002916.SZ', '000568.SZ', '000717.SZ', '600452.SH', '603589.SH', '600690.SH', '603886.SH', '300117.SZ', '000858.SZ', '002102.SZ',
31
-'300136.SZ', '600801.SH', '600436.SH', '300401.SZ', '002190.SZ', '300122.SZ', '002299.SZ', '603610.SH', '002963.SZ', '600486.SH', '300601.SZ', '300682.SZ', '300771.SZ',
32
-'000868.SZ', '002607.SZ', '603068.SH', '603508.SH', '603658.SH', '300571.SZ', '603868.SH', '600768.SH', '300760.SZ', '002901.SZ', '603638.SH', '601100.SH', '002032.SZ',
33
-'600083.SH', '600507.SH', '603288.SH', '002304.SZ', '000963.SZ', '300572.SZ', '000885.SZ', '600995.SH', '300080.SZ', '601888.SH', '000048.SZ', '000333.SZ', '300529.SZ',
34
-'000537.SZ', '002869.SZ', '600217.SH', '000526.SZ', '600887.SH', '002161.SZ', '600267.SH', '600668.SH', '600052.SH', '002379.SZ', '603369.SH', '601360.SH', '002833.SZ',
35
-'002035.SZ', '600031.SH', '600678.SH', '600398.SH', '600587.SH', '600763.SH', '002016.SZ', '603816.SH', '000031.SZ', '002555.SZ', '603983.SH', '002746.SZ', '603899.SH',
36
-'300595.SZ', '300632.SZ', '600809.SH', '002507.SZ', '300198.SZ', '600779.SH', '603568.SH', '300638.SZ', '002011.SZ', '603517.SH', '000661.SZ', '300630.SZ', '000895.SZ',
37
-'002841.SZ', '300602.SZ', '300418.SZ', '603737.SH', '002755.SZ', '002803.SZ', '002182.SZ', '600132.SH', '300725.SZ', '600346.SH', '300015.SZ', '300014.SZ', '300628.SZ',
38
-'000789.SZ', '600368.SH', '300776.SZ', '600570.SH', '000509.SZ', '600338.SH', '300770.SZ', '600309.SH', '000596.SZ', '300702.SZ', '002271.SZ', '300782.SZ', '300577.SZ',
39
-'603505.SH', '603160.SH', '300761.SZ', '603327.SH', '002458.SZ', '300146.SZ', '002463.SZ', '300417.SZ', '600566.SH', '002372.SZ', '600585.SH', '000848.SZ', '600519.SH',
40
-'000672.SZ', '300357.SZ', '002234.SZ', '603444.SH', '300236.SZ', '603360.SH', '002677.SZ', '300487.SZ', '600319.SH', '002415.SZ', '000403.SZ', '600340.SH', '601318.SH',
41
-
42
-]
43
-
44
-holder_stock_list = [
45
-'600498.SH', '002223.SZ', '300136.SZ', '300559.SZ',
46
-    '600496.SH', '300682.SZ'
47
-]
48
-
5
+from mix.stock_source import *
49 6
 import pymongo
50 7
 from util.mongodb import get_mongo_table_instance
8
+
51 9
 code_table = get_mongo_table_instance('tushare_code')
52 10
 k_table = get_mongo_table_instance('stock_day_k')
53 11
 stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
54 12
 all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
55 13
 
56 14
 
57
-industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
58
-            '汽车配件', '软件服务',
59
-            '互联网', '纺织',
60
-            '塑料', '半导体',]
61
-
62 15
 gainian_map = {}
63 16
 hangye_map = {}
64 17
 
18
+
65 19
 Z_list = []  # 自选
66 20
 R_list = []  #  ROE
67 21
 O_list = []  # 其他
68 22
 
23
+
69 24
 def predict_today(file, day, model='10_18d', log=True):
25
+    industry_list = get_hot_industry(day)
26
+
70 27
     lines = []
71 28
     with open(file) as f:
72 29
         for line in f.readlines()[:]:
@@ -88,12 +45,35 @@ def predict_today(file, day, model='10_18d', log=True):
88 45
         # print(result, line[-1])
89 46
         stock = code_table.find_one({'ts_code':line[-1][0]})
90 47
 
91
-        if result[0][0] > 0.6:
92
-            pass
93
-
94
-        elif result[0][1] > 0.5:
48
+        if result[0][0] > 0.5 and stock['sw_industry'] in industry_list:
49
+            if line[-1][0].startswith('688'):
50
+                continue
51
+            # 去掉ST
52
+            if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
53
+                continue
54
+            k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
55
+
56
+            # 指定某几个行业
57
+            # if stock['industry'] in industry:
58
+            concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
59
+            concept_detail_list = []
60
+
61
+            if len(concept_code_list) > 0:
62
+                for concept in concept_code_list:
63
+                    for c in all_concept_code_list:
64
+                        if c['code'] == concept['concept_code']:
65
+                            concept_detail_list.append(c['name'])
66
+            # if stock['ts_code'] in ROE_stock_list:
67
+            print(stock['ts_code'], stock['name'], '买入')
68
+            O_list.append([stock['ts_code'], stock['name']])
69
+
70
+            if log is True:
71
+                with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix500.txt', mode='a', encoding="utf-8") as f:
72
+                    f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n')
73
+
74
+        elif result[0][1] > 0.5 or result[0][2] > 0.5 :
95 75
             pass
96
-        elif result[0][2] > 0.5 or result[0][3] > 0.5:
76
+        elif result[0][3] > 0.5:
97 77
             if stock['ts_code'] in holder_stock_list or stock['ts_code'] in zixuan_stock_list:
98 78
                 print(stock['ts_code'], stock['name'], '赶紧卖出')
99 79
         else:
@@ -101,46 +81,15 @@ def predict_today(file, day, model='10_18d', log=True):
101 81
 
102 82
     # print(gainian_map)
103 83
     # print(hangye_map)
104
-
105
-
106
-def _read_pfile_map(path):
107
-    s_list = []
108
-    with open(path, encoding='utf-8') as f:
109
-        for line in f.readlines()[:]:
110
-            s_list.append(line)
111
-    return s_list
112
-
113
-
114
-def join_two_day(a, b):
115
-    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
116
-    b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
117
-    for a in a_list:
118
-        for b in b_list:
119
-            if a[2:11] == b[2:11]:
120
-                print(a)
121
-
122
-
123
-def check_everyday(day, today):
124
-    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
125
-    x = 0
126
-    for a in a_list:
127
-        print(a[:-1])
128
-        k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
129
-        if k_day_list is not None and len(k_day_list) > 0:
130
-            k_day = k_day_list[0]
131
-            k_day_0 = k_day_list[-1]
132
-            k_day_last = k_day_list[1]
133
-            if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
134
-                print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
135
-                x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
136
-
137
-    print(x/len(a_list))
84
+    random.shuffle(O_list)
85
+    print(O_list[:3])
138 86
 
139 87
 
140 88
 if __name__ == '__main__':
141 89
     # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
142 90
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
143 91
     # multi_predict()
144
-    predict_today("D:\\data\\quantization\\stock500_28d_20200403.log", 20200403, model='500_28d_mix_5D_ma5_s_seq.h5', log=True)
92
+    # 策略B
93
+    predict_today("D:\\data\\quantization\\stock505_28d_20200415.log", 20200415, model='505_28d_mix_5D_ma5_s_seq.h5', log=True)
145 94
     # join_two_day(20200305, 20200305)
146 95
     # check_everyday(20200311, 20200312)

+ 166 - 0
mix/mix_predict_everyday_600.py

@@ -0,0 +1,166 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import random
5
+from mix.stock_source import *
6
+import pymongo
7
+from util.mongodb import get_mongo_table_instance
8
+
9
+code_table = get_mongo_table_instance('tushare_code')
10
+k_table = get_mongo_table_instance('stock_day_k')
11
+stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
12
+all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
13
+
14
+
15
+gainian_map = {}
16
+hangye_map = {}
17
+
18
+Z_list = []  # 自选
19
+R_list = []  #  ROE
20
+O_list = []  # 其他
21
+
22
+
23
+def predict_today(file, day, model='10_18d', log=True):
24
+    industry_list = get_hot_industry(day)
25
+
26
+    lines = []
27
+    with open(file) as f:
28
+        for line in f.readlines()[:]:
29
+            line = eval(line.strip())
30
+            lines.append(line)
31
+
32
+    size = len(lines[0])
33
+
34
+    model=load_model(model)
35
+
36
+    for line in lines:
37
+        train_x = np.array([line[:size - 1]])
38
+        train_x_tmp = train_x[:,:30*19]
39
+        train_x_a = train_x_tmp.reshape(train_x.shape[0], 30, 19, 1)
40
+        # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
41
+        train_x_c = train_x[:,30*19:]
42
+
43
+        result = model.predict([train_x_c, train_x_a, ])
44
+        # print(result, line[-1])
45
+        stock = code_table.find_one({'ts_code':line[-1][0]})
46
+
47
+        if result[0][0] > 0.5 and stock['sw_industry'] in industry_list:
48
+            if line[-1][0].startswith('688'):
49
+                continue
50
+            # 去掉ST
51
+            if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
52
+                continue
53
+
54
+            k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
55
+
56
+            # 指定某几个行业
57
+            # if stock['industry'] in industry:
58
+            concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
59
+            concept_detail_list = []
60
+
61
+            if len(concept_code_list) > 0:
62
+                for concept in concept_code_list:
63
+                    for c in all_concept_code_list:
64
+                        if c['code'] == concept['concept_code']:
65
+                            concept_detail_list.append(c['name'])
66
+
67
+            if stock['ts_code'] in zixuan_stock_list:
68
+                # print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
69
+                print(stock['ts_code'], stock['name'], '买入评级', k_table_list[0]['pct_chg'])
70
+                Z_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
71
+            elif stock['ts_code'] in ROE_stock_list:
72
+                print(stock['ts_code'], stock['name'], '买入评级', k_table_list[0]['pct_chg'])
73
+                R_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
74
+            else:
75
+                O_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
76
+
77
+            if log is True:
78
+                with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix.txt', mode='a', encoding="utf-8") as f:
79
+                    f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n')
80
+
81
+        # elif result[0][1] > 0.5:
82
+        #     if stock['ts_code'] in holder_stock_list:
83
+        #         print(stock['ts_code'], stock['name'], '震荡评级')
84
+        # elif result[0][2] > 0.4:
85
+        #     if stock['ts_code'] in holder_stock_list:
86
+        #         print(stock['ts_code'], stock['name'], '赶紧卖出')
87
+        # else:
88
+        #     if stock['ts_code'] in holder_stock_list or stock['ts_code'] in ROE_stock_list:
89
+        #         print(stock['ts_code'], stock['name'], result[0],)
90
+
91
+    # print(gainian_map)
92
+    # print(hangye_map)
93
+
94
+    # gainian_list = [(key, gainian_map[key])for key in gainian_map]
95
+    # gainian_list = sorted(gainian_list, key=lambda x:x[1], reverse=True)
96
+    #
97
+    # hangye_list = [(key, hangye_map[key])for key in hangye_map]
98
+    # hangye_list = sorted(hangye_list, key=lambda x:x[1], reverse=True)
99
+
100
+    # print(gainian_list)
101
+    # print(hangye_list)
102
+
103
+    print('-----买入列表---------')
104
+    print(Z_list)
105
+    print(R_list)
106
+    print(O_list)
107
+
108
+    print('------随机结果--------')
109
+    # random.shuffle(Z_list)
110
+    # print('自选')
111
+    # print(Z_list[:3])
112
+
113
+    random.shuffle(R_list)
114
+    print('ROE')
115
+    print(R_list[:3])
116
+
117
+    O_list.extend(Z_list)
118
+    O_list.extend(Z_list)
119
+    random.shuffle(O_list)
120
+    print('其他')
121
+    print(O_list[:3])
122
+
123
+
124
+def _read_pfile_map(path):
125
+    s_list = []
126
+    with open(path, encoding='utf-8') as f:
127
+        for line in f.readlines()[:]:
128
+            s_list.append(line)
129
+    return s_list
130
+
131
+
132
+def join_two_day(a, b):
133
+    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
134
+    b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
135
+    for a in a_list:
136
+        for b in b_list:
137
+            if a[2:11] == b[2:11]:
138
+                print(a)
139
+
140
+
141
+def check_everyday(day, today):
142
+    a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
143
+    x = 0
144
+    for a in a_list:
145
+        print(a[:-1])
146
+        k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
147
+        if k_day_list is not None and len(k_day_list) > 0:
148
+            k_day = k_day_list[0]
149
+            k_day_0 = k_day_list[-1]
150
+            k_day_last = k_day_list[1]
151
+            if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
152
+                print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
153
+                x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
154
+
155
+    print(x/len(a_list))
156
+
157
+
158
+if __name__ == '__main__':
159
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
160
+    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
161
+    # multi_predict()
162
+    # predict_today("D:\\data\\quantization\\stock405_30d_20200413.log", 20200413, model='405_30d_mix_5D_ma5_s_seq.h5', log=True)
163
+    # 模型A
164
+    predict_today("D:\\data\\quantization\\stock603_30d_20200415.log", 20200415, model='603_30d_mix_5D_ma5_s_seq.h5', log=True)
165
+    # join_two_day(20200305, 20200305)
166
+    # check_everyday(20200311, 20200312)

+ 1 - 1
mix/mix_train_180.py

@@ -18,7 +18,7 @@ from keras.callbacks import EarlyStopping
18 18
 
19 19
 early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
20 20
 
21
-epochs= 68
21
+epochs= 58
22 22
 size = 380000 #18W 60W
23 23
 file_path = 'D:\\data\\quantization\\stock186E_18d_train2.log'
24 24
 model_path = '186E_18d_mix_6D_ma5_s_seq.h5'

+ 19 - 9
mix/mix_train_300.py

@@ -18,13 +18,14 @@ from keras.callbacks import EarlyStopping
18 18
 
19 19
 early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
20 20
 
21
-epochs= 108
22
-size = 580000 #共68W
23
-file_path = 'D:\\data\\quantization\\stock321_28d_train2.log'
24
-model_path = '321_28d_mix_5D_ma5_s_seq_2.h5'
25
-file_path1='D:\\data\\quantization\\stock321_28d_test.log'
21
+epochs= 77
22
+size = 440000 #共68W
23
+file_path = 'D:\\data\\quantization\\stock324_28d_train2.log'
24
+model_path = '324_28d_mix_5D_ma5_s_seq.h5'
25
+file_path1='D:\\data\\quantization\\stock324_28d_test.log'
26
+file_path2='D:\\data\\quantization\\stock324_28d_train1.log'
26 27
 row = 28
27
-col = 20
28
+col = 18
28 29
 '''
29 30
 30d+ma5+流通市值>40
30 31
 0 ROC     30*18           38,100,17
@@ -36,8 +37,12 @@ col = 20
36 37
 11 DMI     28*20          37,101,16
37 38
 12 MACD    28*19           
38 39
 28d+ma5+5+流通市值>10
39
-21 DMI     28*20          43,102,9  非常好 
40
- 
40
+21 DMI     28*20          43,102,9  非常好    46,102,8
41
+22 MACD    28*19          46,102,9
42
+1d close
43
+23 DMI     28*20          34,97,36
44
+3d close 去掉ma的两个字段
45
+24 DMI     28*18          41,96,42-13
41 46
        
42 47
 30d+close
43 48
 4 ROC     30*18           
@@ -57,10 +62,15 @@ def read_data(path, path1=file_path1):
57 62
             lines.append(line)
58 63
 
59 64
     with open(path1) as f:
60
-        for x in range(60000): #6w
65
+        for x in range(50000): #6w
61 66
             line = eval(f.readline().strip())
62 67
             lines.append(line)
63 68
 
69
+    # with open(file_path2) as f:
70
+    #     for x in range(60000): #6w
71
+    #         line = eval(f.readline().strip())
72
+    #         lines.append(line)
73
+
64 74
     random.shuffle(lines)
65 75
     print('读取数据完毕')
66 76
 

+ 9 - 7
mix/mix_train_400.py

@@ -5,9 +5,7 @@ from keras.models import Sequential
5 5
 # 优化方法选用Adam(其实可选项有很多,如SGD)
6 6
 from keras.optimizers import Adam
7 7
 import random
8
-from keras.models import load_model
9 8
 from imblearn.over_sampling import RandomOverSampler
10
-from keras.utils import np_utils
11 9
 # 用于模型初始化,Conv2D模型初始化、Activation激活函数,MaxPooling2D是池化层
12 10
 # Flatten作用是将多位输入进行一维化
13 11
 # Dense是全连接层
@@ -20,11 +18,11 @@ early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
20 18
 
21 19
 epochs= 40
22 20
 size = 440000 #共68W
23
-file_path = 'D:\\data\\quantization\\stock403_30d_train2.log'
24
-model_path = '403_30d_mix_5D_ma5_s_seq_2.h5'
25
-file_path1='D:\\data\\quantization\\stock403_30d_test.log'
21
+file_path = 'D:\\data\\quantization\\stock417_30d_train2.log'
22
+model_path = '417_30d_mix_5D_ma5_s_seq.h5'
23
+file_path1='D:\\data\\quantization\\stock417_30d_test.log'
26 24
 row = 30
27
-col = 20
25
+col = 19
28 26
 '''
29 27
 0    roc 涨幅int表示  18*18             59,97,46                                       
30 28
 1    dmi              24*20             59,98,41
@@ -33,10 +31,14 @@ col = 20
33 31
 3B   dmi    9
34 32
 3A   dmi              30*20             53,97,44
35 33
 4    roc              30*18             63,98,40
36
-5    macd             30*19             64,98,39       !
34
+5    macd             30*19             64,98,39-24     !
37 35
 5_1   macd   9                          62,98,41
38 36
 5_2   macd   12                         58,98,43
39 37
 9    rsi              30*17             50,97,43
38
+15  macd+20ma占比     30*19             62,97,44
39
+16  macd+beta1        30*19             62,98,37-27
40
+17  macd+beta1+去掉阶段涨幅             61,98,41-28           
41
+18  macd+资金/ma向量     30*21
40 42
 
41 43
 6    dmi+大盘形态     30*20             52,98,39
42 44
 7    roc+大盘形态     30*18             59,98,40

+ 13 - 8
mix/mix_train_500.py

@@ -18,17 +18,22 @@ from keras.callbacks import EarlyStopping
18 18
 
19 19
 early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
20 20
 
21
-epochs= 42
22
-size = 420000 #共68W
23
-file_path = 'D:\\data\\quantization\\stock501_28d_train2.log'
24
-model_path = '501_28d_mix_5D_ma5_s_seq.h5'
25
-file_path1='D:\\data\\quantization\\stock501_28d_test.log'
21
+epochs= 44
22
+size = 440000 #共68W
23
+file_path = 'D:\\data\\quantization\\stock507_28d_train2.log'
24
+model_path = '507_28d_mix_5D_ma5_s_seq.h5'
25
+file_path1='D:\\data\\quantization\\stock507_28d_test.log'
26 26
 row = 28
27 27
 col = 19
28 28
 '''
29 29
 0    dmi            28*20      38,98,51/5     下跌预判非常准                                    
30
-1    macd           28*19      41,98,53/8       
31
-       
30
+1    macd           28*19      41,98,53/8  
31
+2    dmi-对大盘对比 28*20      35,99,46/17 
32
+3    5d-dmi-对大盘对比 28*20   42,99,39/10
33
+4    3d-dmi-对大盘对比 28*20   40,99,39/07   
34
+5    3d-beta1                  55,99,39/07    ==> 用这个
35
+6    3d-ma20                   40,99,41/07
36
+7    3d-macd   28*19           55,99,40/07
32 37
 '''
33 38
 
34 39
 def read_data(path, path1=file_path1):
@@ -39,7 +44,7 @@ def read_data(path, path1=file_path1):
39 44
             lines.append(line)
40 45
 
41 46
     with open(path1) as f:
42
-        for x in range(32000): #6w
47
+        for x in range(33000): #6w
43 48
             line = eval(f.readline().strip())
44 49
             lines.append(line)
45 50
 

+ 194 - 0
mix/mix_train_600.py

@@ -0,0 +1,194 @@
1
+import keras
2
+# -*- encoding:utf-8 -*-
3
+import numpy as np
4
+from keras.models import Sequential
5
+# 优化方法选用Adam(其实可选项有很多,如SGD)
6
+from keras.optimizers import Adam
7
+import random
8
+from imblearn.over_sampling import RandomOverSampler
9
+# 用于模型初始化,Conv2D模型初始化、Activation激活函数,MaxPooling2D是池化层
10
+# Flatten作用是将多位输入进行一维化
11
+# Dense是全连接层
12
+from keras.layers import Conv2D, Activation, MaxPool2D, Flatten, Dense,Dropout,Input,MaxPooling2D,BatchNormalization,concatenate
13
+from keras import regularizers
14
+from keras.models import Model
15
+from keras.callbacks import EarlyStopping
16
+
17
+early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
18
+
19
+epochs= 44
20
+size = 440000 #共68W
21
+file_path = 'D:\\data\\quantization\\stock603_30d_train2.log'
22
+model_path = '603_30d_mix_5D_ma5_s_seq.h5'
23
+file_path1='D:\\data\\quantization\\stock603_30d_test.log'
24
+row = 30
25
+col = 19
26
+'''
27
+1   macd+beta1 盈利       30*19             64,99,31-31
28
+2   macd+beta1 亏损       30*19
29
+3   macd+减少指数参数     30*19             train1:69,99,32-31 |train3:62,100,23-39
30
+'''
31
+
32
+def read_data(path, path1=file_path1):
33
+    lines = []
34
+    with open(path) as f:
35
+        for x in range(size): #680000
36
+            line = eval(f.readline().strip())
37
+            lines.append(line)
38
+
39
+    with open(path1) as f:
40
+        for x in range(30000): #6w
41
+            line = eval(f.readline().strip())
42
+            lines.append(line)
43
+
44
+    random.shuffle(lines)
45
+    print('读取数据完毕')
46
+
47
+    d=int(0.85*len(lines))
48
+    length = len(lines[0])
49
+
50
+    train_x=[s[:length - 2] for s in lines[0:d]]
51
+    train_y=[s[-1] for s in lines[0:d]]
52
+    test_x=[s[:length - 2] for s in lines[d:]]
53
+    test_y=[s[-1] for s in lines[d:]]
54
+
55
+    print('转换数据完毕')
56
+
57
+    ros = RandomOverSampler(random_state=0)
58
+    X_resampled, y_resampled = ros.fit_sample(np.array(train_x, dtype=np.float32), np.array(train_y, dtype=np.float32))
59
+
60
+    print('数据重采样完毕')
61
+
62
+    return X_resampled,y_resampled,np.array(test_x, dtype=np.float32),np.array(test_y, dtype=np.float32)
63
+
64
+
65
+train_x,train_y,test_x,test_y=read_data(file_path)
66
+
67
+train_x_a = train_x[:,:row*col]
68
+train_x_a = train_x_a.reshape(train_x.shape[0], row, col, 1)
69
+# train_x_b = train_x[:, 9*26:18*26]
70
+# train_x_b = train_x_b.reshape(train_x.shape[0], 9, 26, 1)
71
+train_x_c = train_x[:,row*col:]
72
+
73
+
74
+def create_mlp(dim, regress=False):
75
+    # define our MLP network
76
+    model = Sequential()
77
+    model.add(Dense(256, input_dim=dim, activation="relu"))
78
+    model.add(Dropout(0.2))
79
+    model.add(Dense(256, activation="relu"))
80
+    model.add(Dense(256, activation="relu"))
81
+    model.add(Dense(128, activation="relu"))
82
+
83
+    # check to see if the regression node should be added
84
+    if regress:
85
+        model.add(Dense(1, activation="linear"))
86
+
87
+    # return our model
88
+    return model
89
+
90
+
91
+def create_cnn(width, height, depth, size=48, kernel_size=(5, 6), regress=False, output=24):
92
+    # initialize the input shape and channel dimension, assuming
93
+    # TensorFlow/channels-last ordering
94
+    inputShape = (width, height, 1)
95
+    chanDim = -1
96
+
97
+    # define the model input
98
+    inputs = Input(shape=inputShape)
99
+    # x = inputs
100
+    # CONV => RELU => BN => POOL
101
+    x = Conv2D(size, kernel_size, strides=2, padding="same")(inputs)
102
+    x = Activation("relu")(x)
103
+    x = BatchNormalization(axis=chanDim)(x)
104
+
105
+    # y = Conv2D(24, (2, 8), strides=2, padding="same")(inputs)
106
+    # y = Activation("relu")(y)
107
+    # y = BatchNormalization(axis=chanDim)(y)
108
+
109
+    # flatten the volume, then FC => RELU => BN => DROPOUT
110
+    x = Flatten()(x)
111
+    x = Dense(output)(x)
112
+    x = Activation("relu")(x)
113
+    x = BatchNormalization(axis=chanDim)(x)
114
+    x = Dropout(0.2)(x)
115
+
116
+    # apply another FC layer, this one to match the number of nodes
117
+    # coming out of the MLP
118
+    x = Dense(output)(x)
119
+    x = Activation("relu")(x)
120
+
121
+    # check to see if the regression node should be added
122
+    if regress:
123
+        x = Dense(1, activation="linear")(x)
124
+
125
+    # construct the CNN
126
+    model = Model(inputs, x)
127
+
128
+    # return the CNN
129
+    return model
130
+
131
+
132
+# create the MLP and CNN models
133
+mlp = create_mlp(train_x_c.shape[1], regress=False)
134
+# cnn_0 = create_cnn(18, 20, 1, kernel_size=(3, 3), size=90, regress=False, output=96)       # 31 97 46
135
+cnn_0 = create_cnn(row, col, 1, kernel_size=(6, col), size=96, regress=False, output=96)         # 29 98 47
136
+# cnn_0 = create_cnn(18, 20, 1, kernel_size=(9, 9), size=90, regress=False, output=96)         # 28 97 53
137
+# cnn_0 = create_cnn(18, 20, 1, kernel_size=(3, 20), size=90, regress=False, output=96)
138
+# cnn_1 = create_cnn(18, 20, 1, kernel_size=(18, 10), size=80, regress=False, output=96)
139
+# cnn_1 = create_cnn(9, 26, 1, kernel_size=(2, 14), size=36, regress=False, output=64)
140
+
141
+# create the input to our final set of layers as the *output* of both
142
+# the MLP and CNN
143
+combinedInput = concatenate([mlp.output, cnn_0.output, ])
144
+
145
+# our final FC layer head will have two dense layers, the final one
146
+# being our regression head
147
+x = Dense(1024, activation="relu", kernel_regularizer=regularizers.l1(0.003))(combinedInput)
148
+x = Dropout(0.2)(x)
149
+x = Dense(1024, activation="relu")(x)
150
+x = Dense(1024, activation="relu")(x)
151
+# 在建设一层
152
+x = Dense(3, activation="softmax")(x)
153
+
154
+# our final model will accept categorical/numerical data on the MLP
155
+# input and images on the CNN input, outputting a single value (the
156
+# predicted price of the house)
157
+model = Model(inputs=[mlp.input, cnn_0.input, ], outputs=x)
158
+
159
+
160
+print("Starting training ")
161
+# h = model.fit(train_x, train_y, batch_size=4096*2, epochs=500, shuffle=True)
162
+
163
+# compile the model using mean absolute percentage error as our loss,
164
+# implying that we seek to minimize the absolute percentage difference
165
+# between our price *predictions* and the *actual prices*
166
+opt = Adam(lr=1e-3, decay=1e-3 / 200)
167
+model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=['accuracy'])
168
+
169
+# train the model
170
+print("[INFO] training model...")
171
+model.fit(
172
+    [train_x_c, train_x_a, ], train_y,
173
+    # validation_data=([testAttrX, testImagesX], testY),
174
+    # epochs=int(3*train_x_a.shape[0]/1300),
175
+    epochs=epochs,
176
+    batch_size=2048, shuffle=True,
177
+    callbacks=[early_stopping]
178
+)
179
+
180
+model.save(model_path)
181
+
182
+test_x_a = test_x[:,:row*col]
183
+test_x_a = test_x_a.reshape(test_x.shape[0], row, col, 1)
184
+# test_x_b = test_x[:, 9*26:9*26+9*26]
185
+# test_x_b = test_x_b.reshape(test_x.shape[0], 9, 26, 1)
186
+test_x_c = test_x[:,row*col:]
187
+
188
+# make predictions on the testing data
189
+print("[INFO] predicting house prices...")
190
+score  = model.evaluate([test_x_c, test_x_a,], test_y)
191
+
192
+print(score)
193
+print('Test score:', score[0])
194
+print('Test accuracy:', score[1])

+ 65 - 0
mix/stock_source.py

@@ -0,0 +1,65 @@
1
+from util.mysqlutil import Mysql
2
+import pymongo
3
+from util.mongodb import get_mongo_table_instance
4
+mysql_handler = Mysql()
5
+
6
+zixuan_stock_list = [
7
+    # 医疗
8
+    '603990.SH', '300759.SZ', '300347.SZ','002421.SZ','300168.SZ','002432.SZ','300074.SZ','300677.SZ',
9
+    # 5G
10
+    '300003.SZ', '600498.SH', '300310.SZ', '603912.SH', '603220.SH', '300602.SZ', '600260.SH', '002463.SZ','300738.SZ','002402.SZ',
11
+    # 车联网
12
+    '002369.SZ', '002920.SZ', '300020.SZ', '002869.SZ','300098.SZ','300048.SZ','000851.SZ','300682.SZ',
13
+    # 工业互联网
14
+    '002184.SZ', '002364.SZ','300310.SZ', '300670.SZ', '300166.SZ', '002169.SZ', '002380.SZ','002421.SZ','603083.SH',
15
+    # 特高压
16
+    '300341.SZ', '300670.SZ', '300018.SZ', '600268.SH', '002879.SZ','002028.SZ','300477.SZ',
17
+    # 基础建设
18
+    '603568.SH', '000967.SZ', '603018.SH','002062.SZ',
19
+    # 华为
20
+    '300687.SZ','002316.SZ','300339.SZ','300378.SZ','300020.SZ','300634.SZ','002570.SZ', '300766.SZ',
21
+
22
+    '002555.SZ','600585.SH','600276.SH','002415.SZ','000651.SZ',
23
+
24
+]
25
+
26
+ROE_stock_list =   [                      # ROE
27
+    '002976.SZ', '002847.SZ', '002597.SZ', '300686.SZ', '000708.SZ', '603948.SH', '600507.SH', '300401.SZ', '002714.SZ', '600732.SH', '300033.SZ', '300822.SZ', '300821.SZ',
28
+    '002458.SZ', '000708.SZ', '600732.SH', '603719.SH', '300821.SZ', '300800.SZ', '300816.SZ', '300812.SZ', '603195.SH', '300815.SZ', '603053.SH', '603551.SH', '002975.SZ',
29
+    '603949.SH', '002970.SZ', '300809.SZ', '002968.SZ', '300559.SZ', '002512.SZ', '300783.SZ', '300003.SZ', '603489.SH', '300564.SZ', '600802.SH', '002600.SZ',
30
+    '000933.SZ', '601918.SH', '000651.SZ', '002916.SZ', '000568.SZ', '000717.SZ', '600452.SH', '603589.SH', '600690.SH', '603886.SH', '300117.SZ', '000858.SZ', '002102.SZ',
31
+    '300136.SZ', '600801.SH', '600436.SH', '300401.SZ', '002190.SZ', '300122.SZ', '002299.SZ', '603610.SH', '002963.SZ', '600486.SH', '300601.SZ', '300682.SZ', '300771.SZ',
32
+    '000868.SZ', '002607.SZ', '603068.SH', '603508.SH', '603658.SH', '300571.SZ', '603868.SH', '600768.SH', '300760.SZ', '002901.SZ', '603638.SH', '601100.SH', '002032.SZ',
33
+    '600083.SH', '600507.SH', '603288.SH', '002304.SZ', '000963.SZ', '300572.SZ', '000885.SZ', '600995.SH', '300080.SZ', '601888.SH', '000048.SZ', '000333.SZ', '300529.SZ',
34
+    '000537.SZ', '002869.SZ', '600217.SH', '000526.SZ', '600887.SH', '002161.SZ', '600267.SH', '600668.SH', '600052.SH', '002379.SZ', '603369.SH', '601360.SH', '002833.SZ',
35
+    '002035.SZ', '600031.SH', '600678.SH', '600398.SH', '600587.SH', '600763.SH', '002016.SZ', '603816.SH', '000031.SZ', '002555.SZ', '603983.SH', '002746.SZ', '603899.SH',
36
+    '300595.SZ', '300632.SZ', '600809.SH', '002507.SZ', '300198.SZ', '600779.SH', '603568.SH', '300638.SZ', '002011.SZ', '603517.SH', '000661.SZ', '300630.SZ', '000895.SZ',
37
+    '002841.SZ', '300602.SZ', '300418.SZ', '603737.SH', '002755.SZ', '002803.SZ', '002182.SZ', '600132.SH', '300725.SZ', '600346.SH', '300015.SZ', '300014.SZ', '300628.SZ',
38
+    '000789.SZ', '600368.SH', '300776.SZ', '600570.SH', '000509.SZ', '600338.SH', '300770.SZ', '600309.SH', '000596.SZ', '300702.SZ', '002271.SZ', '300782.SZ', '300577.SZ',
39
+    '603505.SH', '603160.SH', '300761.SZ', '603327.SH', '002458.SZ', '300146.SZ', '002463.SZ', '300417.SZ', '600566.SH', '002372.SZ', '600585.SH', '000848.SZ', '600519.SH',
40
+    '000672.SZ', '300357.SZ', '002234.SZ', '603444.SH', '300236.SZ', '603360.SH', '002677.SZ', '300487.SZ', '600319.SH', '002415.SZ', '000403.SZ', '600340.SH', '601318.SH',
41
+]
42
+
43
+
44
+holder_stock_list = [
45
+    '600498.SH', '002223.SZ',
46
+    '600496.SH', '300682.SZ','601162.SH','002401.SZ','601111.SH',
47
+    '000851.SZ','300639.SZ','603990.SH','603003.SH','603628.SH','601186.SH',
48
+    '600196.SH', '300003.SZ','300748.SZ','603638.SZ',
49
+    '601211.SH'
50
+]
51
+
52
+
53
+def get_hot_industry(trade_date):
54
+    industry_list = []
55
+    trade_date = str(trade_date)
56
+    # trade_day_list = get_mongo_table_instance('tradeDayTableTuShare').find({'is_open':1, 'cal_date':{'$lte':int(trade_date.replace('-', ''))}}) \
57
+    #     .sort('cal_date', direction=pymongo.DESCENDING).skip(3).limit(1)
58
+    # trade_date = str(trade_day_list[0]['cal_date'])
59
+    trade_day_str = trade_date[:4] + '-' + trade_date[4:6] + '-' + trade_date[6:]
60
+    rows = mysql_handler.select_list("SELECT * FROM index_industry_day WHERE "
61
+                                     " trade_date='%s' AND num_zhangfu1>=2 AND num_zhangfu<65 "
62
+                                     % (trade_day_str))
63
+    for row in rows:
64
+        industry_list.append(row['name'])
65
+    return industry_list

+ 91 - 0
util/mysqlutil.py

@@ -0,0 +1,91 @@
1
+#!/usr/bin/python
2
+# -*- coding:utf-8 -*-
3
+
4
+import sys
5
+import os
6
+
7
+sys.path.append(os.path.abspath('..'))
8
+from util.config import config
9
+import pymysql
10
+from sqlalchemy import create_engine
11
+
12
+
13
+class MysqlConfig(object):
14
+    HOST = config.get('mysql', 'host')
15
+    USER = config.get('mysql', 'user')
16
+    PASSWORD = config.get('mysql', 'password')
17
+    DB = config.get('mysql', 'db')
18
+
19
+
20
+engine = create_engine('mysql+pymysql://{}:{}@{}/{}?charset=utf8'.format(
21
+    MysqlConfig.USER,
22
+    MysqlConfig.PASSWORD,
23
+    MysqlConfig.HOST,
24
+    MysqlConfig.DB,
25
+))
26
+
27
+
28
+class Mysql(object):
29
+
30
+    def __init__(self):
31
+        # 数据库构造函数,从连接池中取出连接,并生成操作游标
32
+        # self._conn = Mysql.__GetConnect()
33
+        self.host = MysqlConfig.HOST
34
+        self.user = MysqlConfig.USER
35
+        self.pwd = MysqlConfig.PASSWORD
36
+        self.db = MysqlConfig.DB
37
+
38
+    __pool = None
39
+
40
+    # @staticmethod
41
+    def __GetConnect(self):
42
+        """
43
+        @summary: 静态方法,从连接池中取出连接
44
+        @return MySQLdb.connection
45
+        """
46
+        if not self.db:
47
+            raise (NameError, "没有设置数据库信息")
48
+        self.conn = pymysql.connect(host=self.host, user=self.user, password=self.pwd, database=self.db)
49
+        db= self.conn
50
+        if not db:
51
+            raise (NameError, "连接数据库失败")
52
+        else:
53
+            return db
54
+
55
+    def insert_batch(self, sql, data):
56
+        db = self.__GetConnect()
57
+        cur = db.cursor()
58
+        try:
59
+            cur.executemany(sql, data)
60
+            db.commit()
61
+        except Exception as e:
62
+            db.rollback()
63
+            print(e)
64
+        finally:
65
+            cur.close()
66
+            db.close()
67
+
68
+    def select_list(self, sql):
69
+        db = self.__GetConnect()
70
+        cur = db.cursor(cursor=pymysql.cursors.DictCursor)
71
+        results = []
72
+        try:
73
+            cur.execute(sql)
74
+            results = cur.fetchall()
75
+            db.commit()
76
+            # print(results)
77
+        except Exception as e:
78
+            print(e)
79
+        finally:
80
+            cur.close()
81
+            db.close()
82
+
83
+        return results
84
+
85
+
86
+if __name__ == '__main__':
87
+    sql = 'select name from `ai-callcenter-dev`.callcenter_ai_terminology limit 10'
88
+    mysql = Mysql()
89
+    result = mysql.select_list(sql)
90
+    for item in result:
91
+        print(item['name'])