Browse Source

300基本费了

yufeng 4 years ago
parent
commit
18706f72f9
2 changed files with 25 additions and 16 deletions
  1. 1 1
      mix/mix_predict_200.py
  2. 24 15
      mix/mix_train_300.py

+ 1 - 1
mix/mix_predict_200.py

@@ -91,7 +91,7 @@ def predict(file_path='', model_path='15min_dnn_seq.h5', idx=-1, row=18, col=20)
91 91
 if __name__ == '__main__':
92 92
     # predict(file_path='D:\\data\\quantization\\stock181_18d_test.log', model_path='181_18d_mix_6D_ma5_s_seq.h5')
93 93
     # predict(file_path='D:\\data\\quantization\\stock201_18d_train1.log', model_path='213_18d_mix_6D_ma5_s_seq.h5', row=18, col=20)
94
-    predict(file_path='D:\\data\\quantization\\stock301_18d_train1.log', model_path='301_18d_mix_6D_ma5_s_seq.h5', row=30, col=20)
94
+    predict(file_path='D:\\data\\quantization\\stock314_24d_train1.log', model_path='314_24d_mix_6D_ma5_s_seq.h5', row=24, col=18)
95 95
     # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
96 96
     # multi_predict(model='15_18d')
97 97
     # predict_today(20200229, model='11_18d')

+ 24 - 15
mix/mix_train_300.py

@@ -20,15 +20,24 @@ early_stopping = EarlyStopping(monitor='accuracy', patience=5, verbose=2)
20 20
 
21 21
 epochs= 88
22 22
 size = 400000 #18W 60W
23
-file_path = 'D:\\data\\quantization\\stock302_18d_train2.log'
24
-model_path = '302_18d_mix_6D_ma5_s_seq.h5'
25
-file_path1='D:\\data\\quantization\\stock302_18d_test.log'
26
-col = 20
23
+file_path = 'D:\\data\\quantization\\stock314_24d_train2.log'
24
+model_path = '314_24d_mix_6D_ma5_s_seq.h5'
25
+file_path1='D:\\data\\quantization\\stock314_24d_test.log'
26
+col = 18
27
+row = 24
27 28
 '''
28
-ROC     30*18           38,100,17
29
-DMI     30*20           39,101,13
30
-MACD    30*19           
31
-RSI     30*17
29
+30d+ma5
30
+0 ROC     30*18           38,100,17
31
+1 DMI     30*20           39,101,13
32
+2 MACD    30*19           
33
+3 RSI     30*17           
34
+30d+close
35
+4 ROC     30*18           
36
+5 DMI     30*20           
37
+6 MACD    30*19           32,96,44
38
+7 RSI     30*17           31,96,42
39
+24d+close
40
+14 ROC    24*18           31,95,52
32 41
 
33 42
 '''
34 43
 
@@ -67,11 +76,11 @@ def read_data(path, path1=file_path1):
67 76
 
68 77
 train_x,train_y,test_x,test_y=read_data(file_path)
69 78
 
70
-train_x_a = train_x[:,:30*col]
71
-train_x_a = train_x_a.reshape(train_x.shape[0], 30, col, 1)
79
+train_x_a = train_x[:,:row*col]
80
+train_x_a = train_x_a.reshape(train_x.shape[0], row, col, 1)
72 81
 # train_x_b = train_x[:, 9*26:18*26]
73 82
 # train_x_b = train_x_b.reshape(train_x.shape[0], 9, 26, 1)
74
-train_x_c = train_x[:,30*col:]
83
+train_x_c = train_x[:,row*col:]
75 84
 
76 85
 
77 86
 def create_mlp(dim, regress=False):
@@ -135,7 +144,7 @@ def create_cnn(width, height, depth, size=48, kernel_size=(5, 6), regress=False,
135 144
 # create the MLP and CNN models
136 145
 mlp = create_mlp(train_x_c.shape[1], regress=False)
137 146
 # cnn_0 = create_cnn(18, 20, 1, kernel_size=(3, 3), size=90, regress=False, output=96)       # 31 97 46
138
-cnn_0 = create_cnn(30, col, 1, kernel_size=(6, col), size=96, regress=False, output=96)         # 29 98 47
147
+cnn_0 = create_cnn(row, col, 1, kernel_size=(6, col), size=96, regress=False, output=96)         # 29 98 47
139 148
 # cnn_0 = create_cnn(18, 20, 1, kernel_size=(9, 9), size=90, regress=False, output=96)         # 28 97 53
140 149
 # cnn_0 = create_cnn(18, 20, 1, kernel_size=(3, 20), size=90, regress=False, output=96)
141 150
 # cnn_1 = create_cnn(18, 20, 1, kernel_size=(18, 10), size=80, regress=False, output=96)
@@ -182,11 +191,11 @@ model.fit(
182 191
 
183 192
 model.save(model_path)
184 193
 
185
-test_x_a = test_x[:,:30*col]
186
-test_x_a = test_x_a.reshape(test_x.shape[0], 30, col, 1)
194
+test_x_a = test_x[:,:row*col]
195
+test_x_a = test_x_a.reshape(test_x.shape[0], row, col, 1)
187 196
 # test_x_b = test_x[:, 9*26:9*26+9*26]
188 197
 # test_x_b = test_x_b.reshape(test_x.shape[0], 9, 26, 1)
189
-test_x_c = test_x[:,30*col:]
198
+test_x_c = test_x[:,row*col:]
190 199
 
191 200
 # make predictions on the testing data
192 201
 print("[INFO] predicting house prices...")