Browse Source

个股收益率计算

yufeng 4 years ago
parent
commit
8bd39d29c2
1 changed files with 129 additions and 0 deletions
  1. 129 0
      stock/dnn_predict_by_stock.py

+ 129 - 0
stock/dnn_predict_by_stock.py

@@ -0,0 +1,129 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+def read_data(path):
8
+    stock_lines = {}
9
+    with open(path) as f:
10
+        for line in f.readlines()[:]:
11
+            line = eval(line.strip())
12
+            stock = str(line[-2][0])
13
+            if stock in stock_lines:
14
+                stock_lines[stock].append(line)
15
+            else:
16
+                stock_lines[stock] = [line]
17
+    # print(len(day_lines['20191230']))
18
+    return stock_lines
19
+
20
+
21
+import pymongo
22
+from util.mongodb import get_mongo_table_instance
23
+code_table = get_mongo_table_instance('tushare_code')
24
+k_table = get_mongo_table_instance('stock_day_k')
25
+
26
+
27
+def predict(file_path='', model_path='15min_dnn_seq'):
28
+    stock_lines = read_data(file_path)
29
+    print('数据读取完毕')
30
+
31
+    models = []
32
+    for x in range(0, 12):
33
+        models.append(load_model(model_path + '_' + str(x) + '.h5'))
34
+    estimator = joblib.load('km_dmi_18.pkl')
35
+    print('模型加载完毕')
36
+
37
+    total_money = 0
38
+    total_num = 0
39
+    items = sorted(stock_lines.keys())
40
+    for key in items:
41
+        # print(day)
42
+        lines = stock_lines[key]
43
+        init_money = 10000
44
+        last_price = 1
45
+
46
+        if lines[0][-2][0].startswith('6'):
47
+            continue
48
+
49
+        buy = 0 # 0空 1买入 2卖出
50
+        chiyou_0 = 0
51
+
52
+        x = 24 # 每条数据项数
53
+        k = 18 # 周期
54
+        for line in lines:
55
+            v = line[1:x*k + 1]
56
+            v = np.array(v)
57
+            v = v.reshape(k, x)
58
+            v = v[:,4:8]
59
+            v = v.reshape(1, 4*k)
60
+            # print(v)
61
+            r = estimator.predict(v)
62
+
63
+            train_x = np.array([line[:-2]])
64
+            result = models[r[0]].predict(train_x)
65
+
66
+            today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
67
+            today_price = today_price[0]
68
+
69
+            if result[0][1] > 0.5 or result[0][2] > 0.5:
70
+                chiyou_0 = 0
71
+                if r[0] in [2,5,9,10,11]:
72
+                    if buy == 0:
73
+                        last_price = today_price['open']
74
+                        print('首次买入', line[-2], today_price['open'])
75
+                        buy = 1
76
+                    elif buy == 1:
77
+                        init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
78
+                        last_price = today_price['close']
79
+                        print('买入+买入', line[-2], today_price['open'])
80
+                        buy = 1
81
+                    else:
82
+                        last_price = today_price['close']
83
+                        print('卖出后买入', line[-2], today_price['open'])
84
+                        buy = 1
85
+                else:
86
+                    if buy == 1:
87
+                        init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
88
+                        last_price = today_price['close']
89
+                        print('买入+买入', line[-2], today_price['open'])
90
+                        buy = 1
91
+            elif result[0][1] > 0.5 or result[0][2] > 0.5:
92
+                # if r[0] in [0,1,3,4,6,7] and buy in [0,1]:
93
+                buy = 0
94
+                init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
95
+                print('卖出', line[-2], today_price['close'])
96
+            else:
97
+                if buy == 1:
98
+                    init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money
99
+                    if init_money < 9000:
100
+                        print('止损卖出', line[-2], today_price['close'])
101
+                        buy = 0
102
+                    else:
103
+                        chiyou_0 = chiyou_0 + 1
104
+                        if chiyou_0 > 5 and today_price['close'] < last_price:
105
+                            print('连续持有次数太多 卖出', line[-2], today_price['close'])
106
+                            buy = 0
107
+                        else:
108
+                            buy = 1
109
+                            print('持有', line[-2], today_price['close'])
110
+                    last_price = today_price['close']
111
+
112
+        print(key, init_money)
113
+
114
+        with open('D:\\data\\quantization\\stock_11_18d' + '_' +  'profit.log', 'a') as f:
115
+            if init_money > 10000:
116
+                f.write(str(key) + ' ' + str(init_money) + '\n')
117
+            elif init_money < 10000:
118
+                f.write(str(key) + ' ' + str(init_money) + '\n')
119
+
120
+        if init_money != 10000:
121
+            total_money = total_money + init_money
122
+            total_num = total_num + 1
123
+
124
+    print(total_money, total_num, total_money/total_num/10000)
125
+
126
+
127
+if __name__ == '__main__':
128
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
129
+    predict(file_path='D:\\data\\quantization\\stock11_18d_stock.log', model_path='11_18d_dnn_seq')