mix_predict_everyday.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import joblib
  5. holder_stock_list = [
  6. # 医疗
  7. '000150.SZ', '300300.SZ', '603990.SH', '300759.SZ', '300347.SZ', '300003.SZ', '300253.SZ',
  8. # 5G
  9. '300698.SZ', '600498.SH', '300310.SZ', '600353.SH', '603912.SH', '603220.SH', '300602.SZ', '600260.SH',
  10. # 车联网
  11. '002369.SZ', '002920.SZ', '300020.SZ', '002373.SZ', '002869.SZ',
  12. # 工业互联网
  13. '002184.SZ', '002364.SZ','300310.SZ', '300670.SZ', '300166.SZ', '002169.SZ', '002380.SZ',
  14. # 特高压
  15. '300341.SZ', '300670.SZ', '300018.SZ', '600268.SH', '002879.SZ',
  16. # 基础建设
  17. '300041.SZ', '603568.SH', '000967.SZ', '603018.SH',
  18. # B
  19. '002555.SZ', '002174.SZ'
  20. ]
  21. def read_data(path):
  22. lines = []
  23. with open(path) as f:
  24. for line in f.readlines()[:]:
  25. line = eval(line.strip())
  26. if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
  27. lines.append(line)
  28. size = len(lines[0])
  29. train_x=[s[:size - 2] for s in lines]
  30. train_y=[s[size-1] for s in lines]
  31. return np.array(train_x),np.array(train_y),lines
  32. import pymongo
  33. from util.mongodb import get_mongo_table_instance
  34. code_table = get_mongo_table_instance('tushare_code')
  35. k_table = get_mongo_table_instance('stock_day_k')
  36. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  37. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  38. industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
  39. '汽车配件', '软件服务',
  40. '互联网', '纺织',
  41. '塑料', '半导体',]
  42. A_concept_code_list = [ 'TS2', # 5G
  43. 'TS24', # OLED
  44. 'TS26', #健康中国
  45. 'TS43', #新能源整车
  46. 'TS59', # 特斯拉
  47. 'TS65', #汽车整车
  48. 'TS142', # 物联网
  49. 'TS153', # 无人驾驶
  50. 'TS163', # 雄安板块-智慧城市
  51. 'TS175', # 工业自动化
  52. 'TS232', # 新能源汽车
  53. 'TS254', # 人工智能
  54. 'TS258', # 互联网医疗
  55. 'TS264', # 工业互联网
  56. 'TS266', # 半导体
  57. 'TS269', # 智慧城市
  58. 'TS271', # 3D玻璃
  59. 'TS295', # 国产芯片
  60. 'TS303', # 医疗信息化
  61. 'TS323', # 充电桩
  62. 'TS328', # 虹膜识别
  63. 'TS361', # 病毒
  64. ]
  65. gainian_map = {}
  66. hangye_map = {}
  67. def predict_today(file, day, model='10_18d', log=True):
  68. lines = []
  69. with open(file) as f:
  70. for line in f.readlines()[:]:
  71. line = eval(line.strip())
  72. # if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
  73. lines.append(line)
  74. size = len(lines[0])
  75. model=load_model(model)
  76. for line in lines:
  77. train_x = np.array([line[:size - 1]])
  78. train_x_tmp = train_x[:,:18*18]
  79. train_x_a = train_x_tmp.reshape(train_x.shape[0], 18, 18, 1)
  80. # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24)
  81. train_x_c = train_x[:,18*18:]
  82. result = model.predict([train_x_c, train_x_a, ])
  83. # print(result, line[-1])
  84. stock = code_table.find_one({'ts_code':line[-1][0]})
  85. if result[0][0] > 0.5:
  86. if line[-1][0].startswith('688'):
  87. continue
  88. # 去掉ST
  89. if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
  90. continue
  91. if stock['ts_code'] in holder_stock_list:
  92. print(stock['ts_code'], stock['name'], '维持买入评级')
  93. # 跌的
  94. k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
  95. # if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
  96. # continue
  97. # if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
  98. # continue
  99. # if k_table_list[-1]['close'] > 80:
  100. # continue
  101. # 指定某几个行业
  102. # if stock['industry'] in industry:
  103. concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  104. concept_detail_list = []
  105. # 处理行业
  106. if stock['sw_industry'] in hangye_map:
  107. i_c = hangye_map[stock['sw_industry']]
  108. hangye_map[stock['sw_industry']] = i_c + 1
  109. else:
  110. hangye_map[stock['sw_industry']] = 1
  111. if len(concept_code_list) > 0:
  112. for concept in concept_code_list:
  113. for c in all_concept_code_list:
  114. if c['code'] == concept['concept_code']:
  115. concept_detail_list.append(c['name'])
  116. if c['name'] in gainian_map:
  117. g_c = gainian_map[c['name']]
  118. gainian_map[c['name']] = g_c + 1
  119. else:
  120. gainian_map[c['name']] = 1
  121. print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
  122. if log is True:
  123. with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix.txt', mode='a', encoding="utf-8") as f:
  124. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' buy' + '\n')
  125. elif result[0][1] > 0.5:
  126. if stock['ts_code'] in holder_stock_list:
  127. print(stock['ts_code'], stock['name'], '震荡评级')
  128. elif result[0][2] > 0.5:
  129. if stock['ts_code'] in holder_stock_list:
  130. print(stock['ts_code'], stock['name'], '赶紧卖出')
  131. else:
  132. if stock['ts_code'] in holder_stock_list:
  133. print(stock['ts_code'], stock['name'], result[0],)
  134. # print(gainian_map)
  135. # print(hangye_map)
  136. gainian_list = [(key, gainian_map[key])for key in gainian_map]
  137. gainian_list = sorted(gainian_list, key=lambda x:x[1], reverse=True)
  138. hangye_list = [(key, hangye_map[key])for key in hangye_map]
  139. hangye_list = sorted(hangye_list, key=lambda x:x[1], reverse=True)
  140. print(gainian_list)
  141. print(hangye_list)
  142. def _read_pfile_map(path):
  143. s_list = []
  144. with open(path, encoding='utf-8') as f:
  145. for line in f.readlines()[:]:
  146. s_list.append(line)
  147. return s_list
  148. def join_two_day(a, b):
  149. a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
  150. b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
  151. for a in a_list:
  152. for b in b_list:
  153. if a[2:11] == b[2:11]:
  154. print(a)
  155. def check_everyday(day, today):
  156. a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(day) + '.txt')
  157. x = 0
  158. for a in a_list:
  159. print(a[:-1])
  160. k_day_list = list(k_table.find({'code':a[2:11], 'tradeDate':{'$lte':int(today)}}).sort('tradeDate', pymongo.DESCENDING).limit(5))
  161. if k_day_list is not None and len(k_day_list) > 0:
  162. k_day = k_day_list[0]
  163. k_day_0 = k_day_list[-1]
  164. k_day_last = k_day_list[1]
  165. if ((k_day_last['close'] - k_day_0['pre_close'])/k_day_0['pre_close']) < 0.2:
  166. print(k_day['open'], k_day['close'], 100*(k_day['close'] - k_day_last['close'])/k_day_last['close'])
  167. x = x + 100*(k_day['close'] - k_day_last['close'])/k_day_last['close']
  168. print(x/len(a_list))
  169. if __name__ == '__main__':
  170. # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
  171. # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
  172. # multi_predict()
  173. predict_today("D:\\data\\quantization\\stock186_18d_20200325.log", 20200325, model='186_18d_mix_6D_ma5_s_seq.h5', log=True)
  174. # join_two_day(20200305, 20200305)
  175. # check_everyday(20200311, 20200312)