Browse Source

优化按股票计算收益

yufeng 4 years ago
parent
commit
1a45393e41
1 changed files with 50 additions and 23 deletions
  1. 50 23
      stock/dnn_predict_by_stock.py

+ 50 - 23
stock/dnn_predict_by_stock.py

@@ -10,6 +10,7 @@ def read_data(path):
10 10
         for line in f.readlines()[:]:
11 11
             line = eval(line.strip())
12 12
             stock = str(line[-2][0])
13
+
13 14
             if stock in stock_lines:
14 15
                 stock_lines[stock].append(line)
15 16
             else:
@@ -29,7 +30,7 @@ def predict(file_path='', model_path='15min_dnn_seq'):
29 30
     print('数据读取完毕')
30 31
 
31 32
     models = []
32
-    for x in range(0, 20):
33
+    for x in range(0, 12):
33 34
         models.append(load_model(model_path + '_' + str(x) + '.h5'))
34 35
     estimator = joblib.load('km_dmi_18.pkl')
35 36
     print('模型加载完毕')
@@ -48,6 +49,7 @@ def predict(file_path='', model_path='15min_dnn_seq'):
48 49
 
49 50
         buy = 0 # 0空 1买入 2卖出
50 51
         chiyou_0 = 0
52
+        high_price = 0
51 53
 
52 54
         x = 24 # 每条数据项数
53 55
         k = 18 # 周期
@@ -55,7 +57,7 @@ def predict(file_path='', model_path='15min_dnn_seq'):
55 57
             v = line[1:x*k + 1]
56 58
             v = np.array(v)
57 59
             v = v.reshape(k, x)
58
-            v = v[:,4:8]
60
+            v = v[:,6:10]
59 61
             v = v.reshape(1, 4*k)
60 62
             # print(v)
61 63
             r = estimator.predict(v)
@@ -63,49 +65,73 @@ def predict(file_path='', model_path='15min_dnn_seq'):
63 65
             train_x = np.array([line[:-2]])
64 66
             result = models[r[0]].predict(train_x)
65 67
 
68
+            stock_name = line[-2]
66 69
             today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
67 70
             today_price = today_price[0]
68 71
 
69
-            if result[0][1] > 0.5 or result[0][2] > 0.5:
72
+            if result[0][0] > 0.5 or result[0][1] > 0.5: #and (r[0] not in [2,6,8,10]):
70 73
                 chiyou_0 = 0
71 74
                 print(r[0])
72 75
                 if buy == 0:
73 76
                     last_price = today_price['open']
74
-                    print('首次买入', line[-2], today_price['open'])
75
-                    buy = 1
76
-                elif buy == 1:
77
-                    init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
78
-                    last_price = today_price['close']
79
-                    print('买入+买入', line[-2], today_price['open'])
77
+                    high_price = last_price
78
+                    print('首次买入', stock_name, today_price['open'])
80 79
                     buy = 1
81 80
                 else:
81
+                    init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
82 82
                     last_price = today_price['close']
83
-                    print('卖出后买入', line[-2], today_price['open'])
83
+                    print('买入+买入', stock_name, today_price['open'])
84 84
                     buy = 1
85
-            elif result[0][1] > 0.5 or result[0][2] > 0.5:
86
-                buy = 0
87
-                init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
88
-                print('卖出', line[-2], today_price['close'])
89
-            else:
85
+                    if last_price > high_price:
86
+                        high_price = last_price
87
+            elif result[0][3] > 0.5 or result[0][4] > 0.5:#and (r[0] not in [5,8]):
88
+                if buy == 1:
89
+                    if chiyou_0 > 2 or init_money < 9000:
90
+                        init_money = init_money * (today_price['open'] - last_price)/last_price + init_money
91
+                        print('卖出', stock_name, today_price['open'])
92
+                        buy = 0
93
+                        chiyou_0 = 0
94
+                    # elif init_money > 15000 and 100*(today_price['close'] - high_price)/high_price < -15:
95
+                    #     init_money = init_money * (today_price['open'] - last_price)/last_price + init_money
96
+                    #     print('最高点回撤卖出', stock_name, today_price['open'])
97
+                    #     buy = 0
98
+                    #     chiyou_0 = 0
99
+                    else:
100
+                        init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
101
+                        print('继续持有,不卖出', stock_name, today_price['close'])
102
+                        buy = 1
103
+                        chiyou_0 = chiyou_0 + 1
104
+
105
+                        if today_price['close'] > high_price:
106
+                            high_price = today_price['close']
107
+
108
+        else:
90 109
                 if buy == 1:
91 110
                     init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money
92 111
                     if init_money < 8500:
93
-                        print('止损卖出', line[-2], today_price['close'])
112
+                        print('止损卖出', stock_name, today_price['close'])
94 113
                         buy = 0
95 114
                     else:
96 115
                         chiyou_0 = chiyou_0 + 1
97
-                        if init_money < 10000 and chiyou_0 > 3 and today_price['close'] < last_price:
98
-                            print('连续持有次数太多-- 卖出', line[-2], today_price['close'])
116
+                        if init_money < 10500 and chiyou_0 > 1 and today_price['close'] < last_price:
117
+                            print('连续持有次数太多-- 卖出', stock_name, today_price['close'])
99 118
                             buy = 0
100 119
                             chiyou_0 = 0
101
-                        elif chiyou_0 > 5 and today_price['close'] < last_price:
102
-                            print('连续持有次数太多++ 卖出', line[-2], today_price['close'])
120
+                        elif chiyou_0 > 2 and today_price['close'] < last_price:
121
+                            print('连续持有次数太多++ 卖出', stock_name, today_price['close'])
103 122
                             buy = 0
104 123
                             chiyou_0 = 0
105 124
                         else:
106 125
                             buy = 1
107
-                            print('持有', line[-2], today_price['close'])
126
+                            print('持有', stock_name, today_price['close'])
127
+
128
+                            if today_price['close'] > high_price:
129
+                                high_price = today_price['close']
130
+
108 131
                     last_price = today_price['close']
132
+                else:
133
+                    # print('忽略')
134
+                    pass
109 135
 
110 136
             # 具有后验知识的存在,
111 137
             # if result[0][1] > 0.5 or result[0][2] > 0.5:
@@ -153,7 +179,7 @@ def predict(file_path='', model_path='15min_dnn_seq'):
153 179
 
154 180
         print(key, init_money)
155 181
 
156
-        with open('D:\\data\\quantization\\stock_12_18d' + '_' +  'profit.log', 'a') as f:
182
+        with open('D:\\data\\quantization\\stock_15_18d' + '_' +  'profit.log', 'a') as f:
157 183
             if init_money > 10000:
158 184
                 f.write(str(key) + ' ' + str(init_money) + '\n')
159 185
             elif init_money < 10000:
@@ -169,4 +195,5 @@ def predict(file_path='', model_path='15min_dnn_seq'):
169 195
 if __name__ == '__main__':
170 196
     # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
171 197
     # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq')
172
-    predict(file_path='D:\\data\\quantization\\stock11_18d_test.log', model_path='12_18d_dnn_seq')
198
+    predict(file_path='D:\\data\\quantization\\stock15_18d_test.log', model_path='15_18d_dnn_seq')
199
+    # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')