yufeng 4 years ago
parent
commit
282d3852c8
5 changed files with 303 additions and 47 deletions
  1. 158 11
      stock/dnn_predict_dmi.py
  2. 4 4
      stock/dnn_train.py
  3. 26 23
      stock/dnn_train_dmi.py
  4. 18 9
      stock/kmeans.py
  5. 97 0
      util/mongodb.py

+ 158 - 11
stock/dnn_predict_dmi.py

@@ -1,13 +1,16 @@
1 1
 # -*- encoding:utf-8 -*-
2 2
 import numpy as np
3 3
 from keras.models import load_model
4
+import joblib
4 5
 
5 6
 
6 7
 def read_data(path):
7 8
     lines = []
8 9
     with open(path) as f:
9 10
         for line in f.readlines()[:]:
10
-            lines.append(eval(line.strip()))
11
+            line = eval(line.strip())
12
+            if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
13
+                lines.append(line)
11 14
 
12 15
     size = len(lines[0])
13 16
     train_x=[s[:size - 2] for s in lines]
@@ -23,40 +26,184 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
23 26
     print('DNN', score)
24 27
 
25 28
     up_num = 0
29
+    up_error = 0
26 30
     up_right = 0
31
+    down_num = 0
32
+    down_error = 0
33
+    down_right = 0
27 34
     i = 0
28 35
     result=model.predict(test_x)
29 36
     win_dnn = []
30
-    with open('dnn_predict_dmi_14d.txt', 'a') as f:
37
+    with open('dnn_predict_dmi_18d.txt', 'a') as f:
31 38
         for r in result:
32 39
             fact = test_y[i]
33 40
             if r[0] > 0.5:
34 41
                 f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
35 42
                 win_dnn.append([lines[i][-2], lines[i][-1]])
36 43
                 if fact[0] == 1:
44
+                    up_right = up_right + 1.12
45
+                elif fact[1] == 1:
37 46
                     up_right = up_right + 1.06
47
+                elif fact[2] == 1:
48
+                    up_right = up_right + 1
49
+                elif fact[3] == 1:
50
+                    up_right = up_right + 0.94
51
+                else:
52
+                    up_error = up_error + 1
53
+                    up_right = up_right + 0.88
54
+                up_num = up_num + 1
55
+            elif r[1] > 0.5:
56
+                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
57
+                win_dnn.append([lines[i][-2], lines[i][-1]])
58
+                if fact[0] == 1:
59
+                    up_right = up_right + 1.12
38 60
                 elif fact[1] == 1:
61
+                    up_right = up_right + 1.06
62
+                elif fact[2] == 1:
39 63
                     up_right = up_right + 1
64
+                elif fact[3] == 1:
65
+                    up_right = up_right + 0.94
40 66
                 else:
41
-                    up_right = up_right + 0.9
67
+                    up_error = up_error + 1
68
+                    up_right = up_right + 0.88
42 69
                 up_num = up_num + 1
43 70
 
71
+            if r[3] > 0.6:
72
+                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
73
+                win_dnn.append([lines[i][-2], lines[i][-1]])
74
+                if fact[0] == 1:
75
+                    down_error = down_error + 1
76
+                    down_right = down_right + 1.12
77
+                elif fact[1] == 1:
78
+                    down_right = down_right + 1.06
79
+                elif fact[2] == 1:
80
+                    down_right = down_right + 1
81
+                elif fact[3] == 1:
82
+                    down_right = down_right + 0.94
83
+                else:
84
+                    down_right = down_right + 0.88
85
+                down_num = down_num + 1
86
+            elif r[4] > 0.6:
87
+                f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
88
+                win_dnn.append([lines[i][-2], lines[i][-1]])
89
+                if fact[0] == 1:
90
+                    down_error = down_error + 1
91
+                    down_right = down_right + 1.12
92
+                elif fact[1] == 1:
93
+                    down_right = down_right + 1.06
94
+                elif fact[2] == 1:
95
+                    down_right = down_right + 1
96
+                elif fact[3] == 1:
97
+                    down_right = down_right + 0.94
98
+                else:
99
+                    down_right = down_right + 0.88
100
+                down_num = down_num + 1
101
+
44 102
             i = i + 1
45 103
     if up_num == 0:
46 104
         up_num = 1
47
-    print('DNN', up_right, up_num, up_right/up_num)
48
-    return win_dnn,up_right/up_num
105
+    print('DNN', up_right, up_num, up_right/up_num, up_error/up_num, down_right/down_num, down_error/down_num)
106
+    return win_dnn,up_right/up_num,down_right/down_num
49 107
 
50 108
 
51 109
 def multi_predict():
52
-    r = 0
53
-    for x in [0, 2, 4, 5, 7]:
54
-        win_dnn, ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock7_14_' + str(x) + '_test.log', model_path='14d_dnn_seq_' + str(x) + '.h5')
55
-        r = r + ratio
56
-    print(r)
110
+    r = 0;
111
+    p = 0
112
+    # for x in range(0, 12): # 0,2,3,4,6,8,9,10,11
113
+    # for x in [5,6,11]:
114
+    for x in [2,4,7,10]: # 2表现最好 优秀的
115
+        print(x)
116
+    # for x in [0,2,5,6,7]: # 5表现最好
117
+        win_dnn, up_ratio,down_ratio = predict(file_path='D:\\data\\quantization\\kmeans\\stock9_18_test_' + str(x) + '.log', model_path='18d_dnn_seq_' + str(x) + '.h5')
118
+        r = r + up_ratio
119
+        p = p + down_ratio
120
+    print(r, p)
121
+
122
+import pymongo
123
+from util.mongodb import get_mongo_table_instance
124
+code_table = get_mongo_table_instance('tushare_code')
125
+k_table = get_mongo_table_instance('stock_day_k')
126
+
127
+industry = ['全国地产', '区域地产', '酒店餐饮',
128
+            '家用电器', '文教休闲', '元器件', 'IT设备', '汽车服务',
129
+            '汽车配件', '港口', '机场', '商贸代理', '软件服务', '证券',
130
+            '供气供热', '多元金融', '百货','食品', '水务',
131
+            '互联网', '纺织', '保险', '航空',  '超市连锁', '软饮料',
132
+            '塑料', '电器连锁', '半导体', '乳制品',]
133
+
134
+
135
+def predict_today(day):
136
+    lines = []
137
+    with open('D:\\data\\quantization\\stock9_18_' +  str(day) +'.log') as f:
138
+        for line in f.readlines()[:]:
139
+            line = eval(line.strip())
140
+            if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
141
+                lines.append(line)
142
+
143
+    size = len(lines[0])
144
+    train_x=[s[:size - 1] for s in lines]
145
+    np.array(train_x)
146
+
147
+    estimator = joblib.load('km_dmi_18.pkl')
148
+
149
+    models = []
150
+    for x in range(0, 12):
151
+        models.append(load_model('18d_dnn_seq_' + str(x) + '.h5'))
152
+
153
+    x = 21 # 每条数据项数
154
+    k = 18 # 周期
155
+    for line in lines:
156
+        v = line[1:x*k + 1]
157
+        v = np.array(v)
158
+        v = v.reshape(k, x)
159
+        v = v[:,4:8]
160
+        v = v.reshape(1, 4*k)
161
+        # print(v)
162
+        r = estimator.predict(v)
163
+
164
+        if r[0] in [5,6,11]:
165
+            train_x = np.array([line[:size - 1]])
166
+
167
+            result = models[r[0]].predict(train_x)
168
+            if result[0][3] > 0.5 or result[0][4] > 0.5:
169
+                stock = code_table.find_one({'ts_code':line[-1][0]})
170
+                if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
171
+                    continue
172
+                if line[0] > 80:
173
+                    continue
174
+                if stock['industry'] in industry:
175
+                    pass
176
+                    # print(line[-1], stock['name'], stock['industry'], 'sell')
177
+
178
+        if r[0] in [2,4,7,10]:
179
+            train_x = np.array([line[:size - 1]])
180
+
181
+            result = models[r[0]].predict(train_x)
182
+            # print(result, line[-1])
183
+            if result[0][0] > 0.5 or result[0][1] > 0.5:
184
+                if line[-1][0].startswith('688'):
185
+                    continue
186
+                # 去掉ST
187
+                stock = code_table.find_one({'ts_code':line[-1][0]})
188
+                if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
189
+                    continue
190
+
191
+                # 跌的
192
+                k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':20200214}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
193
+                if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
194
+                    continue
195
+                if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
196
+                    continue
197
+                if k_table_list[-1]['close'] > 80:
198
+                    continue
199
+
200
+                # 指定某几个行业
201
+                # if stock['industry'] in industry:
202
+                print(line[-1], stock['name'], stock['industry'], 'buy')
57 203
 
58 204
 
59 205
 if __name__ == '__main__':
60 206
     # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
61 207
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
62
-    multi_predict()
208
+    multi_predict()
209
+    # predict_today(20200219)

+ 4 - 4
stock/dnn_train.py

@@ -43,7 +43,7 @@ def read_data(path):
43 43
 def resample(path):
44 44
     lines = []
45 45
     with open(path) as f:
46
-        for x in range(160000):
46
+        for x in range(330000):
47 47
             lines.append(eval(f.readline().strip()))
48 48
     estimator = joblib.load('km.pkl')
49 49
 
@@ -56,7 +56,7 @@ def resample(path):
56 56
         v = v.reshape(1, 40)
57 57
         # print(v)
58 58
         r = estimator.predict(v)
59
-        with open('D:\\data\\quantization\\kmeans\\stock2_10_' + str(r[0]) + '.log', 'a') as f:
59
+        with open('D:\\data\\quantization\\kmeans\\stock8_14_train_' + str(r[0]) + '.log', 'a') as f:
60 60
             f.write(str(line) + '\n')
61 61
 
62 62
 
@@ -106,5 +106,5 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
106 106
 if __name__ == '__main__':
107 107
     # train(input_dim=176, result_class=5, file_path="D:\\data\\quantization\\stock6_5.log", model_name='5d_dnn_seq.h5')
108 108
     # train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stock6.log", model_name='15m_dnn_seq.h5')
109
-    # resample('D:\\data\\quantization\\stock6_5.log')
110
-    mul_train()
109
+    resample('D:\\data\\quantization\\stock8_14.log')
110
+    # mul_train()

+ 26 - 23
stock/dnn_train_dmi.py

@@ -1,14 +1,11 @@
1
-import keras
2 1
 # -*- encoding:utf-8 -*-
3 2
 import numpy as np
4 3
 from keras.models import Sequential
5 4
 from keras.layers import Dense,Dropout
6 5
 import random
7 6
 from keras import regularizers
8
-from keras.models import load_model
9 7
 from imblearn.over_sampling import RandomOverSampler
10 8
 import joblib
11
-import tensorflow
12 9
 
13 10
 
14 11
 def read_data(path):
@@ -43,12 +40,19 @@ def read_data(path):
43 40
 def resample(path):
44 41
     lines = []
45 42
     with open(path) as f:
46
-        for x in range(76000):
43
+        i = 0
44
+        for x in range(110000):
45
+            # print(i)
47 46
             lines.append(eval(f.readline().strip()))
48
-    estimator = joblib.load('km_dmi.pkl')
47
+            i = i + 1
48
+    estimator = joblib.load('km_dmi_18.pkl')
49
+
50
+    file_list = []
51
+    for x in range(0, 12):
52
+        file_list.append(open('D:\\data\\quantization\\kmeans\\stock9_18_train_' + str(x) + '.log', 'a'))
49 53
 
50 54
     x = 21 # 每条数据项数
51
-    k = 14 # 周期
55
+    k = 18 # 周期
52 56
     for line in lines:
53 57
         v = line[1:x*k + 1]
54 58
         v = np.array(v)
@@ -57,18 +61,18 @@ def resample(path):
57 61
         v = v.reshape(1, 4*k)
58 62
         # print(v)
59 63
         r = estimator.predict(v)
60
-        with open('D:\\data\\quantization\\kmeans\\stock7_14_' + str(r[0]) + '_test.log', 'a') as f:
61
-            f.write(str(line) + '\n')
64
+        file_list[r[0]].write(str(line) + '\n')
62 65
 
63 66
 
64 67
 def mul_train():
65
-    # for x in range(0, 8):
66
-    for x in [0, 2, 4, 5, 7]:
67
-        score = train(input_dim=300, result_class=3, file_path="D:\\data\\quantization\\kmeans\\stock7_14_" + str(x) + ".log",
68
-              model_name='14d_dnn_seq_' + str(x) + '.h5')
68
+    # for x in range(0, 12):
69
+    for x in [11,0,1,3,8,9]:
70
+    # for x in [2,4,7,10]:
71
+        score = train(input_dim=384, result_class=5, file_path="D:\\data\\quantization\\kmeans\\stock9_18_train_" + str(x) + ".log",
72
+              model_name='18d_dnn_seq_' + str(x) + '.h5')
69 73
 
70
-        with open('D:\\data\\quantization\\kmeans\\stock7_14_' + str(x) + '_dmi.log', 'a') as f:
71
-            f.write(str(score[1]) + '\n')
74
+        with open('D:\\data\\quantization\\kmeans\\stock9_18_dmi.log', 'a') as f:
75
+            f.write(str(x) + ':' + str(score[1]) + '\n')
72 76
 
73 77
 
74 78
 def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stock6.log", model_name=''):
@@ -76,29 +80,30 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
76 80
 
77 81
     model = Sequential()
78 82
     model.add(Dense(units=120+input_dim, input_dim=input_dim,  activation='relu'))
79
-    model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l1(0.001)))
83
+    model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l1(0.002)))
80 84
     model.add(Dropout(0.2))
81 85
     model.add(Dense(units=120+input_dim, activation='relu'))
86
+    model.add(Dense(units=120+input_dim, activation='relu'))
87
+    model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l1(0.002)))
82 88
     model.add(Dropout(0.2))
83 89
     model.add(Dense(units=120 + input_dim, activation='relu'))
84 90
     model.add(Dropout(0.2))
85 91
     model.add(Dense(units=120+input_dim, activation='selu'))
86
-    model.add(Dropout(0.1))
87
-    # model.add(Dense(units=60+input_dim, activation='selu'))
88
-    # model.add(Dropout(0.2))
92
+    model.add(Dropout(0.2))
93
+    model.add(Dense(units=120+input_dim, activation='selu'))
89 94
     model.add(Dense(units=512, activation='relu'))
90 95
 
91 96
     model.add(Dense(units=result_class, activation='softmax'))
92 97
     model.compile(loss='categorical_crossentropy', optimizer="adam",metrics=['accuracy'])
93 98
 
94 99
     print("Starting training ")
95
-    # model.fit(train_x, train_y, batch_size=1024, epochs=400 + 4*int(len(train_x)/1000), shuffle=True)
96
-    model.fit(train_x, train_y, batch_size=2048, epochs=500 + 6*int(len(train_x)/700), shuffle=True)
100
+    model.fit(train_x, train_y, batch_size=4096, epochs=900 + 6*int(len(train_x)/600), shuffle=True)
97 101
     score = model.evaluate(test_x, test_y)
98 102
     print(score)
99 103
     print('Test score:', score[0])
100 104
     print('Test accuracy:', score[1])
101 105
 
106
+
102 107
     model.save(model_name)
103 108
 
104 109
     return score
@@ -110,7 +115,5 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
110 115
 
111 116
 
112 117
 if __name__ == '__main__':
113
-    # train(input_dim=176, result_class=5, file_path="D:\\data\\quantization\\stock6_5.log", model_name='5d_dnn_seq.h5')
114
-    # train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stock6.log", model_name='15m_dnn_seq.h5')
115
-    # resample('D:\\data\\quantization\\stock7_14_test.log')
118
+    # resample('D:\\data\\quantization\\stock9_18_1.log')
116 119
     mul_train()

+ 18 - 9
stock/kmeans.py

@@ -7,7 +7,7 @@ import joblib
7 7
 def read_data(path):
8 8
     lines = []
9 9
     with open(path) as f:
10
-        for x in range(100000):
10
+        for x in range(20000):
11 11
             line = eval(f.readline().strip())
12 12
             # if line[-1][0] == 1 or line[-1][1] == 1:
13 13
             lines.append(line)
@@ -15,27 +15,36 @@ def read_data(path):
15 15
     return lines
16 16
 
17 17
 
18
-length = 14  # 周期是多少
19
-j = 21
18
+length = 18  # 周期是多少
19
+j = 20
20 20
 def class_fic(file_path=''):
21 21
     lines = read_data(file_path)
22 22
     print('读取数据完毕')
23 23
     size = len(lines[0])
24
-    train_x = np.array([s[5:9] + s[j+5:j+9] + s[j*2+5:j*2+9] + s[j*3+5:j*3+9] + s[j*4+5:j*4+9] + s[j*5+5:j*5+9] + s[j*6+5:j*6+9]
25
-                        + s[j*7+5:j*7+9] + s[j*8+5:j*8+9] + s[j*9+5:j*9+9] + s[j*10+5:j*10+9] + s[j*11+5:j*11+9] + s[j*12+5:j*12+9] + s[j*13+5:j*13+9]for s in lines])
24
+    x_list = []
25
+    for s in lines:
26
+        tmp_list = []
27
+        for x in range(0, length):
28
+            tmp_list = tmp_list + s[x*j+5:x*j+9]
29
+        x_list.append(tmp_list)
30
+    train_x = np.array(x_list)
31
+    # train_x = np.array([s[5:9] + s[j+5:j+9] + s[j*2+5:j*2+9] + s[j*3+5:j*3+9] + s[j*4+5:j*4+9] + s[j*5+5:j*5+9] + s[j*6+5:j*6+9]
32
+    #                     + s[j*7+5:j*7+9] + s[j*8+5:j*8+9] + s[j*9+5:j*9+9] + s[j*10+5:j*10+9] + s[j*11+5:j*11+9] + s[j*12+5:j*12+9] + s[j*13+5:j*13+9]
33
+    #                     + s[j*14+5:j*14+9] + s[j*15+5:j*15+9] + s[j*9+5:j*9+9] + s[j*10+5:j*10+9] + s[j*11+5:j*11+9] + s[j*12+5:j*12+9] + s[j*13+5:j*13+9]
34
+    #                     for s in lines])
26 35
     # train_y = [s[size - 1] for s in lines]
27 36
     v_x = train_x.reshape(train_x.shape[0], 4*length)
28 37
     stock_list = [s[size - 2] for s in lines]
29 38
 
30
-    estimator = KMeans(n_clusters=8, random_state=19)
39
+    estimator = KMeans(n_clusters=12, random_state=129)
31 40
     estimator.fit(v_x)
32 41
     label_pred = estimator.labels_  # 获取聚类标签
33 42
     centroids = estimator.cluster_centers_
34
-    joblib.dump(estimator , 'km_dmi.pkl')
43
+    joblib.dump(estimator , 'km_dmi_18.pkl')
35 44
 
36 45
     print(estimator.predict(v_x[:10]))
37 46
 
38
-    estimator = joblib.load('km_dmi.pkl')
47
+    estimator = joblib.load('km_dmi_18.pkl')
39 48
     print(estimator.predict(v_x[10:20]))
40 49
     # annoy_sim(v_x)
41 50
     # print('save数据完毕')
@@ -107,4 +116,4 @@ def find_annoy(lines, stock_list):
107 116
 
108 117
 if __name__ == '__main__':
109 118
     # class_fic(file_path="D:\\data\\quantization\\stock2_10.log")
110
-    class_fic(file_path="D:\\data\\quantization\\stock7_5.log")
119
+    class_fic(file_path="D:\\data\\quantization\\stock9_18.log")

+ 97 - 0
util/mongodb.py

@@ -0,0 +1,97 @@
1
+#!/usr/bin/python
2
+# -*- coding:utf-8 -*-
3
+
4
+import pymongo
5
+import datetime
6
+import time
7
+import random
8
+import logging
9
+from util.config import config
10
+
11
+
12
+class MongoConfig(object):
13
+    HOST = config.get('mongodb', 'host')
14
+    USER = config.get('mongodb', 'user')
15
+    PORT = int(config.get('mongodb', 'port'))
16
+    PASSWORD = config.get('mongodb', 'password')
17
+    DB = config.get('mongodb', 'db')
18
+
19
+
20
+def _get_default_mongodb_instance():
21
+    mongo_client = pymongo.MongoClient(MongoConfig.HOST, MongoConfig.PORT)
22
+    db = mongo_client[MongoConfig.DB]
23
+    if MongoConfig.PASSWORD is not None and MongoConfig.PASSWORD != '':
24
+        db.authenticate(MongoConfig.USER, MongoConfig.PASSWORD)
25
+    return db
26
+
27
+
28
+def get_mongodb_tablenames():
29
+    db = _get_default_mongodb_instance()
30
+    return db.collection_names(include_system_collections=False)
31
+
32
+
33
+def get_mongo_table_instance(tablename):
34
+    db = _get_default_mongodb_instance()
35
+    return db[tablename]
36
+
37
+
38
+def del_mongodb_table(tablename):
39
+    try:
40
+        db = _get_default_mongodb_instance()
41
+        db.drop_collection(tablename)
42
+        return True
43
+    except:
44
+        return False
45
+
46
+
47
+def rename_mongodb_table(from_tablename, to_tablename):
48
+    try:
49
+        get_mongo_table_instance(from_tablename).rename(to_tablename)
50
+        return True
51
+    except:
52
+        return False
53
+
54
+
55
+def get_mongodb_table_indexes(tablename):
56
+    try:
57
+        ret = []
58
+
59
+        ftable = get_mongo_table_instance(tablename)
60
+        findex = ftable.list_indexes()
61
+        listindex = list(findex)
62
+        if listindex is None or len(listindex) == 0:
63
+            return ret
64
+
65
+        for i in listindex:
66
+            sonobj = i['key']
67
+            indexes = []
68
+
69
+            for k, v in sonobj.iteritems():
70
+                v1 = pymongo.ASCENDING if v > 0 else pymongo.DESCENDING
71
+                temp = (k, v1)
72
+                indexes.append(temp)
73
+
74
+            ret.append(indexes)
75
+
76
+        return ret
77
+    except Exception as e:
78
+        return None
79
+
80
+
81
+def set_mongodb_table_indexes(tablename, indexes):
82
+    try:
83
+        ftable = get_mongo_table_instance(tablename)
84
+
85
+        if indexes is None or len(indexes) == 0:
86
+            return True
87
+        for i in indexes:
88
+            ftable.ensure_index(i)
89
+
90
+        return True
91
+    except Exception as e:
92
+        return False
93
+
94
+
95
+def copy_mongodb_table_indexes(from_tablename, to_tablename):
96
+    indexes = get_mongodb_table_indexes(from_tablename)
97
+    set_mongodb_table_indexes(to_tablename, indexes)