瀏覽代碼

加强深度

yufeng0528 4 年之前
父節點
當前提交
f4945da332
共有 1 個文件被更改,包括 14 次插入5 次删除
  1. 14 5
      stock/dnn_train.py

+ 14 - 5
stock/dnn_train.py

@@ -11,10 +11,10 @@ def read_data(path):
11 11
     train_y=[]
12 12
     lines = []
13 13
     with open(path) as f:
14
-        for x in range(150000):
14
+        for x in range(100000):
15 15
             lines.append(eval(f.readline().strip()))
16 16
     random.shuffle(lines)
17
-    lines = lines[:20000]
17
+    lines = lines[:80000]
18 18
     d=int(0.95*len(lines))
19 19
 
20 20
     size = len(lines[0])
@@ -24,18 +24,27 @@ def read_data(path):
24 24
     test_y=[s[size-1] for s in lines[d:]]
25 25
     return np.array(train_x),np.array(train_y),np.array(test_x),np.array(test_y)
26 26
 
27
-train_x,train_y,test_x,test_y=read_data("D:\\data\\quantization\\stock5.log")
27
+train_x,train_y,test_x,test_y=read_data("D:\\data\\quantization\\stock6.log")
28 28
 
29 29
 model = Sequential()
30
-model.add(Dense(units=325, input_dim=163,  activation='relu'))
30
+model.add(Dense(units=425, input_dim=166,  activation='relu'))
31
+model.add(Dense(units=325, activation='relu'))
32
+model.add(Dense(units=325, activation='relu'))
31 33
 model.add(Dense(units=225, activation='relu'))
34
+model.add(Dense(units=225, activation='relu'))
35
+model.add(Dense(units=225, activation='relu'))
36
+model.add(Dense(units=225, activation='relu'))
37
+model.add(Dense(units=225, activation='relu'))
38
+model.add(Dense(units=225, activation='relu'))
39
+model.add(Dense(units=125, activation='relu'))
32 40
 model.add(Dense(units=125, activation='relu'))
41
+model.add(Dense(units=166, activation='relu'))
33 42
 # model.add(Dropout(0.2)(Dense(units=225, activation='relu')))
34 43
 model.add(Dense(units=8, activation='softmax'))
35 44
 model.compile(loss='categorical_crossentropy', optimizer="adam",metrics=['accuracy'])
36 45
 
37 46
 print("Starting training ")
38
-h=model.fit(train_x, train_y, batch_size=16, epochs=10, shuffle=True)
47
+h=model.fit(train_x, train_y, batch_size=32, epochs=8, shuffle=True)
39 48
 score = model.evaluate(test_x, test_y)
40 49
 print(score)
41 50
 print('Test score:', score[0])