cnn_predict.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import keras
  2. # -*- encoding:utf-8 -*-
  3. import numpy as np
  4. from keras.models import Sequential
  5. from keras.layers import Dense,Dropout
  6. import random
  7. from keras.models import load_model
  8. lines = []
  9. def read_data(path):
  10. with open(path) as f:
  11. for line in f.readlines()[:]:
  12. lines.append(eval(line.strip()))
  13. train_x=[s[:-2] for s in lines]
  14. train_y=[s[-1] for s in lines]
  15. return np.array(train_x),np.array(train_y)
  16. def predict():
  17. test_x,test_y=read_data("D:\\data\\quantization\\stock6_test.log")
  18. test_x = test_x.reshape(test_x.shape[0], 1,80,5)
  19. path="15min_cnn_seq.h5"
  20. model=load_model(path)
  21. score = model.evaluate(test_x, test_y)
  22. print('CNN', score)
  23. result=model.predict(test_x)
  24. # print(result)
  25. up_num = 0
  26. up_right = 0
  27. i = 0
  28. for r in result:
  29. fact = test_y[i]
  30. if r[0] > 0.5:
  31. if fact[0] == 1:
  32. up_right = up_right + 1
  33. elif fact[1] == 1:
  34. up_right = up_right + 0.2
  35. up_num = up_num + 1
  36. i = i + 1
  37. print('CNN', up_right, up_num, up_right/up_num) # 预测涨的正确率
  38. i = 0
  39. win_dnn = []
  40. for r in result:
  41. if r[0] > 0.5:
  42. # print(lines[i][-2])
  43. win_dnn.append([lines[i][-2], lines[i][-1]])
  44. i = i + 1
  45. return win_dnn
  46. if __name__ == '__main__':
  47. predict()