dnn_predict_dmi_everyday.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import joblib
  5. holder_stock_list = [
  6. '000063.SZ',
  7. '002373.SZ',
  8. '300253.SZ',
  9. '300059.SZ',
  10. # b账户
  11. '002373.SZ',
  12. '300422.SZ',
  13. '300468.SZ',
  14. ]
  15. def read_data(path):
  16. lines = []
  17. with open(path) as f:
  18. for line in f.readlines()[:]:
  19. line = eval(line.strip())
  20. if line[-2][0].startswith('0') or line[-2][0].startswith('3'):
  21. lines.append(line)
  22. size = len(lines[0])
  23. train_x=[s[:size - 2] for s in lines]
  24. train_y=[s[size-1] for s in lines]
  25. return np.array(train_x),np.array(train_y),lines
  26. import pymongo
  27. from util.mongodb import get_mongo_table_instance
  28. code_table = get_mongo_table_instance('tushare_code')
  29. k_table = get_mongo_table_instance('stock_day_k')
  30. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  31. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  32. industry = ['家用电器', '元器件', 'IT设备', '汽车服务',
  33. '汽车配件', '软件服务',
  34. '互联网', '纺织',
  35. '塑料', '半导体',]
  36. A_concept_code_list = [ 'TS2', # 5G
  37. 'TS24', # OLED
  38. 'TS26', #健康中国
  39. 'TS43', #新能源整车
  40. 'TS59', # 特斯拉
  41. 'TS65', #汽车整车
  42. 'TS142', # 物联网
  43. 'TS153', # 无人驾驶
  44. 'TS163', # 雄安板块-智慧城市
  45. 'TS175', # 工业自动化
  46. 'TS232', # 新能源汽车
  47. 'TS254', # 人工智能
  48. 'TS258', # 互联网医疗
  49. 'TS264', # 工业互联网
  50. 'TS266', # 半导体
  51. 'TS269', # 智慧城市
  52. 'TS271', # 3D玻璃
  53. 'TS295', # 国产芯片
  54. 'TS303', # 医疗信息化
  55. 'TS323', # 充电桩
  56. 'TS328', # 虹膜识别
  57. 'TS361', # 病毒
  58. ]
  59. gainian_map = {}
  60. hangye_map = {}
  61. def predict_today(day, model='10_18d', log=True):
  62. lines = []
  63. with open('D:\\data\\quantization\\stock' + model + '_' + str(day) +'.log') as f:
  64. for line in f.readlines()[:]:
  65. line = eval(line.strip())
  66. # if line[-1][0].startswith('0') or line[-1][0].startswith('3'):
  67. lines.append(line)
  68. size = len(lines[0])
  69. train_x=[s[:size - 1] for s in lines]
  70. np.array(train_x)
  71. estimator = joblib.load('km_dmi_18.pkl')
  72. models = []
  73. for x in range(0, 12):
  74. models.append(load_model(model + '_dnn_seq_' + str(x) + '.h5'))
  75. x = 24 # 每条数据项数
  76. k = 18 # 周期
  77. shift = 1
  78. for line in lines:
  79. # print(line)
  80. v = line[1:x*k + 1]
  81. v = np.array(v)
  82. v = v.reshape(k, x)
  83. v = v[:,4:8]
  84. v = v.reshape(1, 4*k)
  85. # print(v)
  86. r = estimator.predict(v)
  87. train_x = np.array([line[:size - 1]])
  88. result = models[r[0]].predict(train_x)
  89. # print(result, line[-1])
  90. stock = code_table.find_one({'ts_code':line[-1][0]})
  91. if result[0][0] > 0.5 or result[0][1] > 0.5:
  92. if line[-1][0].startswith('688'):
  93. continue
  94. # 去掉ST
  95. if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
  96. continue
  97. if stock['ts_code'] in holder_stock_list:
  98. print(stock['ts_code'], stock['name'], '维持买入评级')
  99. # 跌的
  100. k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5))
  101. # if k_table_list[0]['close'] > k_table_list[-1]['close']*1.20:
  102. # continue
  103. # if k_table_list[0]['close'] < k_table_list[-1]['close']*0.90:
  104. # continue
  105. # if k_table_list[-1]['close'] > 80:
  106. # continue
  107. # 指定某几个行业
  108. # if stock['industry'] in industry:
  109. concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  110. concept_detail_list = []
  111. # 处理行业
  112. if stock['sw_industry'] in hangye_map:
  113. i_c = hangye_map[stock['sw_industry']]
  114. hangye_map[stock['sw_industry']] = i_c + 1
  115. else:
  116. hangye_map[stock['sw_industry']] = 1
  117. # if len(concept_code_list) > 0:
  118. # for concept in concept_code_list:
  119. # for c in all_concept_code_list:
  120. # if c['code'] == concept['concept_code']:
  121. # concept_detail_list.append(c['name'])
  122. #
  123. # if c['name'] in gainian_map:
  124. # g_c = gainian_map[c['name']]
  125. # gainian_map[c['name']] = g_c + 1
  126. # else:
  127. # gainian_map[c['name']] = 1
  128. print(line[-1], stock['name'], stock['sw_industry'], str(concept_detail_list), 'buy', k_table_list[0]['pct_chg'])
  129. if log is True:
  130. with open('D:\\data\\quantization\\predict\\' + str(day) + '.txt', mode='a', encoding="utf-8") as f:
  131. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' buy' + '\n')
  132. # concept_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  133. # concept_list = [c['concept_code'] for c in concept_list]
  134. elif result[0][2] > 0.5:
  135. if stock['ts_code'] in holder_stock_list:
  136. print(stock['ts_code'], stock['name'], '震荡评级')
  137. elif result[0][3] > 0.5 or result[0][4] > 0.5:
  138. if stock['ts_code'] in holder_stock_list:
  139. print(stock['ts_code'], stock['name'], '赶紧卖出')
  140. else:
  141. if stock['ts_code'] in holder_stock_list:
  142. print(stock['ts_code'], stock['name'], result[0], r[0])
  143. print(gainian_map)
  144. print(hangye_map)
  145. def _read_pfile_map(path):
  146. s_list = []
  147. with open(path, encoding='utf-8') as f:
  148. for line in f.readlines()[:]:
  149. s_list.append(line)
  150. return s_list
  151. def join_two_day(a, b):
  152. a_list = _read_pfile_map('D:\\data\\quantization\\predict\\' + str(a) + '.txt')
  153. b_list = _read_pfile_map('D:\\data\\quantization\\predict\\dmi_' + str(b) + '.txt')
  154. for a in a_list:
  155. for b in b_list:
  156. if a[2:11] == b[2:11]:
  157. print(a)
  158. if __name__ == '__main__':
  159. # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
  160. # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5')
  161. # multi_predict()
  162. # predict_today(20200305, model='11_18d', log=True)
  163. join_two_day(20200311, 20200311)