Browse Source

参数调整下

yufeng0528 4 years ago
parent
commit
4296d9f09f
1 changed files with 3 additions and 3 deletions
  1. 3 3
      stock/dnn_train.py

+ 3 - 3
stock/dnn_train.py

@@ -12,7 +12,7 @@ from imblearn.over_sampling import RandomOverSampler
12 12
 def read_data(path):
13 13
     lines = []
14 14
     with open(path) as f:
15
-        for x in range(100000):
15
+        for x in range(60000):
16 16
             lines.append(eval(f.readline().strip()))
17 17
 
18 18
     random.shuffle(lines)
@@ -42,9 +42,9 @@ def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stoc
42 42
     model = Sequential()
43 43
     model.add(Dense(units=120+input_dim, input_dim=input_dim,  activation='relu'))
44 44
     # model.add(Dense(units=60+int(input_dim/2), activation='relu'))
45
-    model.add(Dense(units=60+input_dim, activation='relu',kernel_regularizer=regularizers.l2(0.01)))
45
+    model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l2(0.001)))
46 46
     model.add(Dropout(0.2))
47
-    model.add(Dense(units=60+input_dim, activation='relu',kernel_regularizer=regularizers.l2(0.01)))
47
+    model.add(Dense(units=60+input_dim, activation='relu'))
48 48
     model.add(Dropout(0.2))
49 49
     model.add(Dense(units=60+input_dim, activation='selu'))
50 50
     # model.add(Dropout(0.2))