yufeng0528 4 年 前
コミット
ca0d007ec9
共有1 個のファイルを変更した11 個の追加6 個の削除を含む
  1. 11 6
      stock/dnn_predict.py

+ 11 - 6
stock/dnn_predict.py

@@ -6,10 +6,10 @@ from keras.layers import Dense,Dropout
6 6
 import random
7 7
 from keras.models import load_model
8 8
 
9
+lines = []
9 10
 def read_data(path):
10
-    lines = []
11 11
     with open(path) as f:
12
-        for line in f.readlines()[:1000]:
12
+        for line in f.readlines():
13 13
             lines.append(eval(line.strip()))
14 14
 
15 15
     size = len(lines[0])
@@ -17,16 +17,21 @@ def read_data(path):
17 17
     train_y=[s[size-1] for s in lines]
18 18
     return np.array(train_x),np.array(train_y)
19 19
 
20
-test_x,test_y=read_data("D:\\data\\quantization\\stock_test.log")
20
+# test_x,test_y=read_data("D:\\data\\quantization\\stock_test.log")
21
+test_x,test_y=read_data("D:\\data\\quantization\\s\\stock_2020-01-07.log")
21 22
 
22 23
 path="model_seq.h5"
23 24
 model=load_model(path)
24
-score = model.evaluate(test_x, test_y)
25
-print(score)
25
+# score = model.evaluate(test_x, test_y)
26
+# print(score)
26 27
 
27 28
 result=model.predict(test_x)
28 29
 # print(result)
29 30
 i = 0
30
-for x in test_y:
31
+for x in result:
31 32
     # print(str(i) + ":" + str(x))
33
+    if x[0] > 0.8:
34
+        print(lines[i][-2], x, 1)
35
+    # elif x[0] + x[1] > 0.9:
36
+    #     print(lines[i][-2], x, 2)
32 37
     i = i + 1