Browse Source

都是用dnn

yufeng0528 4 years ago
parent
commit
78e63e274a
2 changed files with 10 additions and 7 deletions
  1. 2 1
      stock/compont_predict.py
  2. 8 6
      stock/dnn_predict.py

+ 2 - 1
stock/compont_predict.py

@@ -3,8 +3,9 @@ from stock import dnn_predict
3 3
 
4 4
 
5 5
 def and_predict():
6
-    cnn_result = cnn_predict.predict()
6
+    # cnn_result = cnn_predict.predict()
7 7
     dnn_result = dnn_predict.predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
8
+    cnn_result = dnn_predict.predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
8 9
     print('计算完成')
9 10
 
10 11
     with open('and_predict.txt', 'a') as f:

+ 8 - 6
stock/dnn_predict.py

@@ -6,20 +6,21 @@ from keras.layers import Dense,Dropout
6 6
 import random
7 7
 from keras.models import load_model
8 8
 
9
-lines = []
9
+
10 10
 def read_data(path):
11
+    lines = []
11 12
     with open(path) as f:
12
-        for line in f.readlines():
13
+        for line in f.readlines()[:]:
13 14
             lines.append(eval(line.strip()))
14 15
 
15 16
     size = len(lines[0])
16 17
     train_x=[s[:size - 2] for s in lines]
17 18
     train_y=[s[size-1] for s in lines]
18
-    return np.array(train_x),np.array(train_y)
19
+    return np.array(train_x),np.array(train_y),lines
19 20
 
20 21
 
21 22
 def predict(file_path='', model_path='15min_dnn_seq.h5'):
22
-    test_x,test_y=read_data(file_path)
23
+    test_x,test_y,lines=read_data(file_path)
23 24
 
24 25
     model=load_model(model_path)
25 26
     score = model.evaluate(test_x, test_y)
@@ -35,6 +36,7 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
35 36
             fact = test_y[i]
36 37
             if r[0] > 0.5:
37 38
                 f.write(str([lines[i][-2], lines[i][-1]]) + "\n")
39
+                win_dnn.append([lines[i][-2], lines[i][-1]])
38 40
                 if fact[0] == 1:
39 41
                     up_right = up_right + 1
40 42
                 elif fact[1] == 1:
@@ -47,5 +49,5 @@ def predict(file_path='', model_path='15min_dnn_seq.h5'):
47 49
 
48 50
 
49 51
 if __name__ == '__main__':
50
-    # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
51
-    predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
52
+    predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
53
+    # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')