dnn_train_dmi.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import Sequential
  4. from keras.layers import Dense,Dropout
  5. import random
  6. from keras import regularizers
  7. from imblearn.over_sampling import RandomOverSampler
  8. import joblib
  9. def read_data(path):
  10. lines = []
  11. with open(path) as f:
  12. # for x in range(30000):
  13. # lines.append(eval(f.readline().strip()))
  14. for line in f.readlines()[:]:
  15. lines.append(eval(line.strip()))
  16. random.shuffle(lines)
  17. print('读取数据完毕')
  18. d=int(0.95*len(lines))
  19. size = len(lines[0])
  20. train_x=[s[:size - 2] for s in lines[0:d]]
  21. train_y=[s[size-1] for s in lines[0:d]]
  22. test_x=[s[:size - 2] for s in lines[d:]]
  23. test_y=[s[size-1] for s in lines[d:]]
  24. print('转换数据完毕')
  25. ros = RandomOverSampler(random_state=0)
  26. X_resampled, y_resampled = ros.fit_sample(np.array(train_x), np.array(train_y))
  27. print('数据重采样完毕')
  28. return X_resampled,y_resampled,np.array(test_x),np.array(test_y)
  29. def resample(path):
  30. lines = []
  31. with open(path) as f:
  32. i = 0
  33. for x in range(110000):
  34. # print(i)
  35. lines.append(eval(f.readline().strip()))
  36. i = i + 1
  37. estimator = joblib.load('km_dmi_18.pkl')
  38. file_list = []
  39. for x in range(0, 12):
  40. file_list.append(open('D:\\data\\quantization\\kmeans\\stock9_18_train_' + str(x) + '.log', 'a'))
  41. x = 21 # 每条数据项数
  42. k = 18 # 周期
  43. for line in lines:
  44. v = line[1:x*k + 1]
  45. v = np.array(v)
  46. v = v.reshape(k, x)
  47. v = v[:,4:8]
  48. v = v.reshape(1, 4*k)
  49. # print(v)
  50. r = estimator.predict(v)
  51. file_list[r[0]].write(str(line) + '\n')
  52. def mul_train():
  53. # for x in range(0, 12):
  54. for x in [11,0,1,3,8,9]:
  55. # for x in [2,4,7,10]:
  56. score = train(input_dim=384, result_class=5, file_path="D:\\data\\quantization\\kmeans\\stock9_18_train_" + str(x) + ".log",
  57. model_name='18d_dnn_seq_' + str(x) + '.h5')
  58. with open('D:\\data\\quantization\\kmeans\\stock9_18_dmi.log', 'a') as f:
  59. f.write(str(x) + ':' + str(score[1]) + '\n')
  60. def train(input_dim=400, result_class=3, file_path="D:\\data\\quantization\\stock6.log", model_name=''):
  61. train_x,train_y,test_x,test_y=read_data(file_path)
  62. model = Sequential()
  63. model.add(Dense(units=120+input_dim, input_dim=input_dim, activation='relu'))
  64. model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l1(0.002)))
  65. model.add(Dropout(0.2))
  66. model.add(Dense(units=120+input_dim, activation='relu'))
  67. model.add(Dense(units=120+input_dim, activation='relu'))
  68. model.add(Dense(units=120+input_dim, activation='relu',kernel_regularizer=regularizers.l1(0.002)))
  69. model.add(Dropout(0.2))
  70. model.add(Dense(units=120 + input_dim, activation='relu'))
  71. model.add(Dropout(0.2))
  72. model.add(Dense(units=120+input_dim, activation='selu'))
  73. model.add(Dropout(0.2))
  74. model.add(Dense(units=120+input_dim, activation='selu'))
  75. model.add(Dense(units=512, activation='relu'))
  76. model.add(Dense(units=result_class, activation='softmax'))
  77. model.compile(loss='categorical_crossentropy', optimizer="adam",metrics=['accuracy'])
  78. print("Starting training ")
  79. model.fit(train_x, train_y, batch_size=4096, epochs=900 + 6*int(len(train_x)/600), shuffle=True)
  80. score = model.evaluate(test_x, test_y)
  81. print(score)
  82. print('Test score:', score[0])
  83. print('Test accuracy:', score[1])
  84. model.save(model_name)
  85. return score
  86. # model=None
  87. # model=load_model(model_name)
  88. # result=model.predict(test_x)
  89. # print(result)
  90. # print(test_y)
  91. if __name__ == '__main__':
  92. # resample('D:\\data\\quantization\\stock9_18_1.log')
  93. mul_train()