|
@@ -0,0 +1,120 @@
|
|
1
|
+# -*- encoding:utf-8 -*-
|
|
2
|
+import numpy as np
|
|
3
|
+from keras.models import load_model
|
|
4
|
+import joblib
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+def read_data(path):
|
|
8
|
+ stock_lines = {}
|
|
9
|
+ with open(path) as f:
|
|
10
|
+ for line in f.readlines()[:]:
|
|
11
|
+ line = eval(line.strip())
|
|
12
|
+ stock = str(line[-2][0])
|
|
13
|
+
|
|
14
|
+ if stock in stock_lines:
|
|
15
|
+ stock_lines[stock].append(line)
|
|
16
|
+ else:
|
|
17
|
+ stock_lines[stock] = [line]
|
|
18
|
+ # print(len(day_lines['20191230']))
|
|
19
|
+ return stock_lines
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+import pymongo
|
|
23
|
+from util.mongodb import get_mongo_table_instance
|
|
24
|
+code_table = get_mongo_table_instance('tushare_code')
|
|
25
|
+k_table = get_mongo_table_instance('stock_day_k')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+def predict(file_path='', model_path='15min_dnn_seq'):
|
|
29
|
+ stock_lines = read_data(file_path)
|
|
30
|
+ print('数据读取完毕')
|
|
31
|
+
|
|
32
|
+ models = []
|
|
33
|
+ # for x in range(0, 12):
|
|
34
|
+ models.append(load_model(model_path + '.h5'))
|
|
35
|
+ estimator = joblib.load('km_dmi_18.pkl')
|
|
36
|
+ print('模型加载完毕')
|
|
37
|
+
|
|
38
|
+ total_money = 0
|
|
39
|
+ total_num = 0
|
|
40
|
+ items = sorted(stock_lines.keys())
|
|
41
|
+ for key in items:
|
|
42
|
+ # print(day)
|
|
43
|
+ lines = stock_lines[key]
|
|
44
|
+ init_money = 10000
|
|
45
|
+ last_price = 1
|
|
46
|
+
|
|
47
|
+ if lines[0][-2][0].startswith('6'):
|
|
48
|
+ continue
|
|
49
|
+
|
|
50
|
+ buy = 0 # 0空 1买入 2卖出
|
|
51
|
+ chiyou_0 = 0
|
|
52
|
+ high_price = 0
|
|
53
|
+
|
|
54
|
+ x = 24 # 每条数据项数
|
|
55
|
+ k = 18 # 周期
|
|
56
|
+ for line in lines:
|
|
57
|
+ # v = line[1:x*k + 1]
|
|
58
|
+ # v = np.array(v)
|
|
59
|
+ # v = v.reshape(k, x)
|
|
60
|
+ # v = v[:,6:10]
|
|
61
|
+ # v = v.reshape(1, 4*k)
|
|
62
|
+ # print(v)
|
|
63
|
+
|
|
64
|
+ train_x = np.array([line[:-2]])
|
|
65
|
+ train_x = train_x.reshape(train_x.shape[0], 1,6,77)
|
|
66
|
+ result = models[0].predict(train_x)
|
|
67
|
+
|
|
68
|
+ stock_name = line[-2]
|
|
69
|
+ today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
|
|
70
|
+ today_price = today_price[0]
|
|
71
|
+
|
|
72
|
+ if result[0][0] > 0.6 or result[0][1] > 0.6: #and (r[0] not in [2,6,8,10]):
|
|
73
|
+ chiyou_0 = 0
|
|
74
|
+ if buy == 0:
|
|
75
|
+ last_price = today_price['open']
|
|
76
|
+ high_price = last_price
|
|
77
|
+ print('首次买入', stock_name, today_price['open'])
|
|
78
|
+ buy = 1
|
|
79
|
+ else:
|
|
80
|
+ init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
|
|
81
|
+ last_price = today_price['close']
|
|
82
|
+ print('买入+买入', stock_name, today_price['close'])
|
|
83
|
+ buy = 1
|
|
84
|
+ if last_price > high_price:
|
|
85
|
+ high_price = last_price
|
|
86
|
+ elif buy == 1:
|
|
87
|
+ chiyou_0 = chiyou_0 + 1
|
|
88
|
+ last_price = today_price['close']
|
|
89
|
+ if chiyou_0 > 2 and today_price['close'] < last_price:
|
|
90
|
+ print('卖出', stock_name, today_price['close'])
|
|
91
|
+ init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
|
|
92
|
+ buy = 0
|
|
93
|
+ chiyou_0 = 0
|
|
94
|
+
|
|
95
|
+ elif init_money < 9000:
|
|
96
|
+ print('止损卖出', stock_name, today_price['close'])
|
|
97
|
+ init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
|
|
98
|
+ buy = 0
|
|
99
|
+ chiyou_0 = 0
|
|
100
|
+
|
|
101
|
+ print(key, init_money)
|
|
102
|
+
|
|
103
|
+ with open('D:\\data\\quantization\\stock_16_18d' + '_' + 'profit.log', 'a') as f:
|
|
104
|
+ if init_money > 10000:
|
|
105
|
+ f.write(str(key) + ' ' + str(init_money) + '\n')
|
|
106
|
+ elif init_money < 10000:
|
|
107
|
+ f.write(str(key) + ' ' + str(init_money) + '\n')
|
|
108
|
+
|
|
109
|
+ if init_money != 10000:
|
|
110
|
+ total_money = total_money + init_money
|
|
111
|
+ total_num = total_num + 1
|
|
112
|
+
|
|
113
|
+ print(total_money, total_num, total_money/total_num/10000)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+if __name__ == '__main__':
|
|
117
|
+ # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
|
|
118
|
+ # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq')
|
|
119
|
+ predict(file_path='D:\\data\\quantization\\stock16_18d_test.log', model_path='16_18d_cnn_seq')
|
|
120
|
+ # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')
|