industry_predict_everyday_100.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import random
  5. from mix.stock_source import *
  6. import pymongo
  7. from util.mongodb import get_mongo_table_instance
  8. code_table = get_mongo_table_instance('tushare_code')
  9. k_table = get_mongo_table_instance('stock_day_k')
  10. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  11. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  12. gainian_map = {}
  13. hangye_map = {}
  14. Z_list = [] # 自选
  15. R_list = [] # ROE
  16. O_list = [] # 其他
  17. def predict_today(file, day, model='10_18d', log=True):
  18. industry_list = get_hot_industry(day)
  19. lines = []
  20. with open(file) as f:
  21. for line in f.readlines()[:]:
  22. line = eval(line.strip())
  23. lines.append(line)
  24. size = len(lines[0])
  25. model=load_model(model)
  26. for line in lines:
  27. train_x = np.array([line[:size - 1]])
  28. train_x_tmp = train_x[:,:11*11]
  29. train_x_a = train_x_tmp.reshape(train_x.shape[0], 11, 11, 1)
  30. # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
  31. train_x_c = train_x[:,11*11:]
  32. result = model.predict([train_x_c, train_x_a, ])
  33. # print(result, line[-1])
  34. stock = code_table.find_one({'ts_code':line[-1][0]})
  35. with open('D:\\data\\quantization\\predict\\' + str(day) + '_industry100.txt', mode='a', encoding="utf-8") as f:
  36. if result[0][0] > 0.5:
  37. print(line[-1], '大涨')
  38. O_list.append(line[-1])
  39. f.write(str(line[-1]) + ',大涨\n')
  40. elif result[0][1] > 0.5:
  41. print(line[-1], '涨')
  42. O_list.append(line[-1])
  43. f.write(str(line[-1]) + ',涨\n')
  44. elif result[0][2] > 0.5:
  45. print(line[-1], '跌')
  46. f.write(str(line[-1]) + ',跌\n')
  47. elif result[0][3] > 0.5:
  48. print(line[-1], '大跌')
  49. f.write(str(line[-1]) + ',大跌\n')
  50. random.shuffle(O_list)
  51. print(O_list[:3])
  52. if __name__ == '__main__':
  53. predict_today("D:\\data\\quantization\\industry\\stock15_10d_3D_20200417.log", 20200417, model='111_10d_mix_3D_s_seq.h5', log=True)
  54. # join_two_day(20200305, 20200305)
  55. # check_everyday(20200311, 20200312)