mix_predict_everyday_600.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import random
  5. from mix.stock_source import *
  6. import pymongo
  7. from util.mongodb import get_mongo_table_instance
  8. code_table = get_mongo_table_instance('tushare_code')
  9. k_table = get_mongo_table_instance('stock_day_k')
  10. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  11. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  12. gainian_map = {}
  13. hangye_map = {}
  14. Z_list = [] # 自选
  15. R_list = [] # ROE
  16. O_list = [] # 其他
  17. def predict_today(file, day, model='10_18d', log=True):
  18. industry_list = get_hot_industry(day)
  19. lines = []
  20. with open(file) as f:
  21. for line in f.readlines()[:]:
  22. line = eval(line.strip())
  23. lines.append(line)
  24. size = len(lines[0])
  25. model=load_model(model)
  26. for line in lines:
  27. train_x = np.array([line[:size - 1]])
  28. train_x_tmp = train_x[:,:30*19]
  29. train_x_a = train_x_tmp.reshape(train_x.shape[0], 30, 19, 1)
  30. # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
  31. train_x_c = train_x[:,30*19:]
  32. result = model.predict([train_x_c, train_x_a, ])
  33. # print(result, line[-1])
  34. stock = code_table.find_one({'ts_code':line[-1][0]})
  35. if result[0][0] > 0.5 and stock['sw_industry'] in industry_list:
  36. if line[-1][0].startswith('688'):
  37. continue
  38. # 去掉ST
  39. if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
  40. continue
  41. k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
  42. # 指定某几个行业
  43. # if stock['industry'] in industry:
  44. concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  45. concept_detail_list = []
  46. if len(concept_code_list) > 0:
  47. for concept in concept_code_list:
  48. for c in all_concept_code_list:
  49. if c['code'] == concept['concept_code']:
  50. concept_detail_list.append(c['name'])
  51. if stock['ts_code'] in zixuan_stock_list:
  52. # print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
  53. print(stock['ts_code'], stock['name'], '买入评级', k_table_list[0]['pct_chg'])
  54. Z_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
  55. elif stock['ts_code'] in ROE_stock_list:
  56. print(stock['ts_code'], stock['name'], '买入评级', k_table_list[0]['pct_chg'])
  57. R_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
  58. else:
  59. O_list.append([stock['name'], stock['sw_industry'], k_table_list[0]['pct_chg']])
  60. if log is True:
  61. with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix.txt', mode='a', encoding="utf-8") as f:
  62. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n')
  63. # elif result[0][1] > 0.5:
  64. # if stock['ts_code'] in holder_stock_list:
  65. # print(stock['ts_code'], stock['name'], '震荡评级')
  66. # elif result[0][2] > 0.4:
  67. # if stock['ts_code'] in holder_stock_list:
  68. # print(stock['ts_code'], stock['name'], '赶紧卖出')
  69. # else:
  70. # if stock['ts_code'] in holder_stock_list or stock['ts_code'] in ROE_stock_list:
  71. # print(stock['ts_code'], stock['name'], result[0],)
  72. # print(gainian_map)
  73. # print(hangye_map)
  74. # gainian_list = [(key, gainian_map[key])for key in gainian_map]
  75. # gainian_list = sorted(gainian_list, key=lambda x:x[1], reverse=True)
  76. #
  77. # hangye_list = [(key, hangye_map[key])for key in hangye_map]
  78. # hangye_list = sorted(hangye_list, key=lambda x:x[1], reverse=True)
  79. # print(gainian_list)
  80. # print(hangye_list)
  81. print('-----买入列表---------')
  82. print(Z_list)
  83. print(R_list)
  84. print(O_list)
  85. print('------随机结果--------')
  86. # random.shuffle(Z_list)
  87. # print('自选')
  88. # print(Z_list[:3])
  89. random.shuffle(R_list)
  90. print('ROE')
  91. print(R_list[:3])
  92. O_list.extend(Z_list)
  93. O_list.extend(Z_list)
  94. random.shuffle(O_list)
  95. print('其他')
  96. print(O_list[:3])
  97. def _read_pfile_map(path):
  98. s_list = []
  99. with open(path, encoding='utf-8') as f:
  100. for line in f.readlines()[:]:
  101. s_list.append(line)
  102. return s_list
  103. def join_two_day(a, b):
  104. a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
  105. b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
  106. for a in a_list:
  107. for b in b_list:
  108. if a[2:11] == b[2:11]:
  109. print(a)
  110. def check_everyday(day, today):
  111. a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
  112. x = 0
  113. for a in a_list:
  114. print(a[:-1])
  115. k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
  116. if k_day_list is not None and len(k_day_list) > 0:
  117. k_day = k_day_list[0]
  118. k_day_0 = k_day_list[-1]
  119. k_day_last = k_day_list[1]
  120. if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
  121. print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
  122. x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
  123. print(x/len(a_list))
  124. if __name__ == '__main__':
  125. # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
  126. # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
  127. # multi_predict()
  128. # predict_today("D:\\data\\quantization\\stock405_30d_20200413.log", 20200413, model='405_30d_mix_5D_ma5_s_seq.h5', log=True)
  129. # 模型A
  130. predict_today("D:\\data\\quantization\\stock603_30d_20200415.log", 20200415, model='603_30d_mix_5D_ma5_s_seq.h5', log=True)
  131. # join_two_day(20200305, 20200305)
  132. # check_everyday(20200311, 20200312)