yufeng 4 years ago
parent
commit
f10d4d38a4
3 changed files with 136 additions and 28 deletions
  1. 120 0
      stock/cnn_predict_by_stock.py
  2. 4 14
      stock/cnn_predict_dmi.py
  3. 12 14
      stock/cnn_train_dmi.py

+ 120 - 0
stock/cnn_predict_by_stock.py

@@ -0,0 +1,120 @@
1
+# -*- encoding:utf-8 -*-
2
+import numpy as np
3
+from keras.models import load_model
4
+import joblib
5
+
6
+
7
+def read_data(path):
8
+    stock_lines = {}
9
+    with open(path) as f:
10
+        for line in f.readlines()[:]:
11
+            line = eval(line.strip())
12
+            stock = str(line[-2][0])
13
+
14
+            if stock in stock_lines:
15
+                stock_lines[stock].append(line)
16
+            else:
17
+                stock_lines[stock] = [line]
18
+    # print(len(day_lines['20191230']))
19
+    return stock_lines
20
+
21
+
22
+import pymongo
23
+from util.mongodb import get_mongo_table_instance
24
+code_table = get_mongo_table_instance('tushare_code')
25
+k_table = get_mongo_table_instance('stock_day_k')
26
+
27
+
28
+def predict(file_path='', model_path='15min_dnn_seq'):
29
+    stock_lines = read_data(file_path)
30
+    print('数据读取完毕')
31
+
32
+    models = []
33
+    # for x in range(0, 12):
34
+    models.append(load_model(model_path + '.h5'))
35
+    estimator = joblib.load('km_dmi_18.pkl')
36
+    print('模型加载完毕')
37
+
38
+    total_money = 0
39
+    total_num = 0
40
+    items = sorted(stock_lines.keys())
41
+    for key in items:
42
+        # print(day)
43
+        lines = stock_lines[key]
44
+        init_money = 10000
45
+        last_price = 1
46
+
47
+        if lines[0][-2][0].startswith('6'):
48
+            continue
49
+
50
+        buy = 0 # 0空 1买入 2卖出
51
+        chiyou_0 = 0
52
+        high_price = 0
53
+
54
+        x = 24 # 每条数据项数
55
+        k = 18 # 周期
56
+        for line in lines:
57
+            # v = line[1:x*k + 1]
58
+            # v = np.array(v)
59
+            # v = v.reshape(k, x)
60
+            # v = v[:,6:10]
61
+            # v = v.reshape(1, 4*k)
62
+            # print(v)
63
+
64
+            train_x = np.array([line[:-2]])
65
+            train_x = train_x.reshape(train_x.shape[0], 1,6,77)
66
+            result = models[0].predict(train_x)
67
+
68
+            stock_name = line[-2]
69
+            today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
70
+            today_price = today_price[0]
71
+
72
+            if result[0][0] > 0.6 or result[0][1] > 0.6: #and (r[0] not in [2,6,8,10]):
73
+                chiyou_0 = 0
74
+                if buy == 0:
75
+                    last_price = today_price['open']
76
+                    high_price = last_price
77
+                    print('首次买入', stock_name, today_price['open'])
78
+                    buy = 1
79
+                else:
80
+                    init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
81
+                    last_price = today_price['close']
82
+                    print('买入+买入', stock_name, today_price['close'])
83
+                    buy = 1
84
+                    if last_price > high_price:
85
+                        high_price = last_price
86
+            elif buy == 1:
87
+                chiyou_0 = chiyou_0 + 1
88
+                last_price = today_price['close']
89
+                if chiyou_0 > 2 and today_price['close'] < last_price:
90
+                    print('卖出', stock_name, today_price['close'])
91
+                    init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
92
+                    buy = 0
93
+                    chiyou_0 = 0
94
+
95
+                elif init_money < 9000:
96
+                    print('止损卖出', stock_name, today_price['close'])
97
+                    init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
98
+                    buy = 0
99
+                    chiyou_0 = 0
100
+
101
+        print(key, init_money)
102
+
103
+        with open('D:\\data\\quantization\\stock_16_18d' + '_' +  'profit.log', 'a') as f:
104
+            if init_money > 10000:
105
+                f.write(str(key) + ' ' + str(init_money) + '\n')
106
+            elif init_money < 10000:
107
+                f.write(str(key) + ' ' + str(init_money) + '\n')
108
+
109
+        if init_money != 10000:
110
+            total_money = total_money + init_money
111
+            total_num = total_num + 1
112
+
113
+    print(total_money, total_num, total_money/total_num/10000)
114
+
115
+
116
+if __name__ == '__main__':
117
+    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
118
+    # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq')
119
+    predict(file_path='D:\\data\\quantization\\stock16_18d_test.log', model_path='16_18d_cnn_seq')
120
+    # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')

+ 4 - 14
stock/cnn_predict_dmi.py

@@ -33,6 +33,7 @@ def _score(fact, line):
33 33
         up_right = up_right + 1
34 34
     elif fact[3] == 1:
35 35
         up_right = up_right + 0.94
36
+        up_error = up_error + 1
36 37
     else:
37 38
         up_error = up_error + 1
38 39
         up_right = up_right + 0.88
@@ -62,29 +63,18 @@ def predict(file_path='', model_path='15min_dnn_seq.h5', idx=-1):
62 63
         if idx in [-2]:
63 64
             if r[0] > 0.5 or r[1] > 0.5:
64 65
                 pass
65
-                # if fact[0] == 1:
66
-                #     up_right = up_right + 1.12
67
-                # elif fact[1] == 1:
68
-                #     up_right = up_right + 1.06
69
-                # elif fact[2] == 1:
70
-                #     up_right = up_right + 1
71
-                # elif fact[3] == 1:
72
-                #     up_right = up_right + 0.94
73
-                # else:
74
-                #     up_error = up_error + 1
75
-                #     up_right = up_right + 0.88
76
-                # up_num = up_num + 1
77 66
         else:
78
-            if r[0] > 0.6 or r[1] > 0.6:
67
+            if r[0] > 0.7 or r[1] > 0.7:
79 68
                 tmp_right,tmp_error = _score(fact, lines[i])
80 69
                 up_right = tmp_right + up_right
81 70
                 up_error = tmp_error + up_error
82 71
                 up_num = up_num + 1
83
-            elif r[3] > 0.5 or r[4] > 0.5:
72
+            elif r[3] > 0.6 or r[4] > 0.6:
84 73
                 if fact[0] == 1:
85 74
                     down_error = down_error + 1
86 75
                     down_right = down_right + 1.12
87 76
                 elif fact[1] == 1:
77
+                    down_error = down_error + 1
88 78
                     down_right = down_right + 1.06
89 79
                 elif fact[2] == 1:
90 80
                     down_right = down_right + 1

+ 12 - 14
stock/cnn_train_dmi.py

@@ -18,7 +18,7 @@ from keras import regularizers
18 18
 def read_data(path):
19 19
     lines = []
20 20
     with open(path) as f:
21
-        for x in range(200000):
21
+        for x in range(20000):
22 22
             lines.append(eval(f.readline().strip()))
23 23
 
24 24
     random.shuffle(lines)
@@ -42,8 +42,8 @@ def read_data(path):
42 42
 
43 43
 
44 44
 train_x,train_y,test_x,test_y=read_data("D:\\data\\quantization\\stock16_18d_train.log")
45
-train_x = train_x.reshape(train_x.shape[0], 1,6,77)
46
-test_x = test_x.reshape(test_x.shape[0], 1,6, 77)
45
+train_x = train_x.reshape(train_x.shape[0], 1,77,6)
46
+test_x = test_x.reshape(test_x.shape[0], 1,77, 6)
47 47
 
48 48
 
49 49
 
@@ -51,22 +51,20 @@ model = Sequential()
51 51
 
52 52
 # 模型卷积层设计
53 53
 model.add(Conv2D(
54
-    nb_filter=32,  # 第一层设置32个滤波器
55
-    nb_row=10,
56
-    nb_col=6,  # 设置滤波器的大小为5*5
54
+    kernel_size=(5, 6), filters=64,
57 55
     padding='same',  # 选择滤波器的扫描方式,即是否考虑边缘
58
-    input_shape=(1,6,77),  # 设置输入的形状
56
+    input_shape=(1,77,6),  # 设置输入的形状
59 57
     # batch_input_shape=(64, 1, 28, 28),
60 58
 ))
61 59
 # 选择激活函数
62 60
 model.add(Activation('relu'))
63 61
 
64
-# 设置下采样(池化层)
65
-model.add(MaxPool2D(
66
-    pool_size=(4,1),  # 下采样格为2*2
67
-    strides=(2,2),  # 向右向下的步长
68
-    padding='same', # padding mode is 'same'
69
-))
62
+# # 设置下采样(池化层)
63
+# model.add(MaxPool2D(
64
+#     pool_size=(4,1),  # 下采样格为2*2
65
+#     strides=(2,2),  # 向右向下的步长
66
+#     padding='same', # padding mode is 'same'
67
+# ))
70 68
 
71 69
 # 使用Flatten函数,将输入数据扁平化(因为输入数据是一个多维的形式,需要将其扁平化)
72 70
 model.add(Flatten())  # 将多维的输入一维化
@@ -89,7 +87,7 @@ model.compile(optimizer=adam,
89 87
     metrics=['accuracy'])
90 88
 
91 89
 print("Starting training ")
92
-h=model.fit(train_x, train_y, batch_size=4096*2, epochs=150, shuffle=True)
90
+h=model.fit(train_x, train_y, batch_size=4096*2, epochs=50, shuffle=True)
93 91
 score = model.evaluate(test_x, test_y)
94 92
 print(score)
95 93
 print('Test score:', score[0])