3.7 KB

  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import joblib
  5. model_path = '160_18d_lstm_5D_ma5_s_seq.h5'
  6. data_dir = 'D:\\data\\quantization\\'
  7. kmeans = 'roc'
  8. def read_data(path):
  9. lines = []
  10. with open(path) as f:
  11. for line in f.readlines()[:]:
  12. line = eval(line.strip())
  13. if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
  14. lines.append(line)
  15. size = len(lines[0])
  16. train_x=[s[:size - 2] for s in lines]
  17. train_y=[s[size-1] for s in lines]
  18. return np.array(train_x),np.array(train_y),lines
  19. def _score(fact, line):
  20. with open('mix_predict_dmi_18d.txt', 'a') as f:
  21. f.write(str([line[-2], line[-1]]) + "\n")
  22. up_right = 0
  23. up_error = 0
  24. if fact[0] == 1:
  25. up_right = up_right + 1.12
  26. elif fact[1] == 1:
  27. up_right = up_right + 1.06
  28. elif fact[2] == 1:
  29. up_right = up_right + 1
  30. up_error = up_error + 0.5
  31. elif fact[3] == 1:
  32. up_right = up_right + 0.94
  33. up_error = up_error + 1
  34. else:
  35. up_error = up_error + 1
  36. up_right = up_right + 0.88
  37. return up_right,up_error
  38. def mul_predict(name="10_18d"):
  39. r = 0
  40. p = 0
  41. for x in range(0, 8):
  42. win_dnn, up_ratio,down_ratio = predict(data_dir + kmeans + '\\stock160_18d_train1_B_' + str(x) + ".log", x) # stock160_18d_trai_0
  43. r = r + up_ratio
  44. p = p + down_ratio
  45. print(r, p)
  46. def predict(file_path='', idx=-1):
  47. test_x,test_y,lines=read_data(file_path)
  48. print(idx, 'Load data success')
  49. test_x_a = test_x[:,:18*24]
  50. test_x_a = test_x_a.reshape(test_x.shape[0], 18, 24)
  51. # test_x_b = test_x[:, 18*16:18*16+10*18]
  52. # test_x_b = test_x_b.reshape(test_x.shape[0], 18, 10, 1)
  53. test_x_c = test_x[:,18*24:]
  54. model=load_model(model_path.split('.')[0] + '_' + str(idx) + '.h5')
  55. score = model.evaluate([test_x_c, test_x_a, ], test_y)
  56. print('LSTM', score)
  57. up_num = 0
  58. up_error = 0
  59. up_right = 0
  60. down_num = 0
  61. down_error = 0
  62. down_right = 0
  63. i = 0
  64. result=model.predict([test_x_c, test_x_a, ])
  65. win_dnn = []
  66. for r in result:
  67. fact = test_y[i]
  68. if idx in [-2]:
  69. if r[0] > 0.5 or r[1] > 0.5:
  70. pass
  71. else:
  72. if r[0] > 0.6 or r[1] > 0.6:
  73. tmp_right,tmp_error = _score(fact, lines[i])
  74. up_right = tmp_right + up_right
  75. up_error = tmp_error + up_error
  76. up_num = up_num + 1
  77. elif r[3] > 0.7 or r[4] > 0.7:
  78. if fact[0] == 1:
  79. down_error = down_error + 1
  80. down_right = down_right + 1.12
  81. elif fact[1] == 1:
  82. down_error = down_error + 1
  83. down_right = down_right + 1.06
  84. elif fact[2] == 1:
  85. down_error = down_error + 0.5
  86. down_right = down_right + 1
  87. elif fact[3] == 1:
  88. down_right = down_right + 0.94
  89. else:
  90. down_right = down_right + 0.88
  91. down_num = down_num + 1
  92. i = i + 1
  93. if up_num == 0:
  94. up_num = 1
  95. if down_num == 0:
  96. down_num = 1
  97. print('LSTM', up_right, up_num, up_right/up_num, up_error/up_num, down_right/down_num, down_error/down_num)
  98. return win_dnn,up_right/up_num,down_right/down_num
  99. if __name__ == '__main__':
  100. # predict(file_path='D:\\data\\quantization\\stock160_18d_10D_test.log', model_path='160_18d_lstm_5D_ma5_s_seq.h5')
  101. # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
  102. mul_predict(name='stock160_18d')
  103. # predict_today(20200229, model='11_18d')