predict_everyweek_100.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import random
  5. from mix.stock_source import *
  6. import pymongo
  7. from util.mongodb import get_mongo_table_instance
  8. code_table = get_mongo_table_instance('tushare_code')
  9. k_table = get_mongo_table_instance('stock_day_k')
  10. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  11. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  12. gainian_map = {}
  13. hangye_map = {}
  14. Z_list = [] # 自选
  15. R_list = [] # ROE
  16. O_list = [] # 其他
  17. def predict_today(file, day, model='10_18d', log=True):
  18. # industry_list = get_hot_industry(day)
  19. lines = []
  20. with open(file) as f:
  21. for line in f.readlines()[:]:
  22. line = eval(line.strip())
  23. lines.append(line)
  24. size = len(lines[0])
  25. model=load_model(model)
  26. row = 18
  27. col = 11
  28. for line in lines:
  29. train_x = np.array([line[:size - 1]])
  30. train_x_a = train_x[:,:row*col]
  31. train_x_a = train_x_a.reshape(train_x.shape[0], row, col, 1)
  32. train_x_b = train_x[:, row*col:row*col + 11*16]
  33. train_x_b = train_x_b.reshape(train_x.shape[0], 11, 16, 1)
  34. train_x_c = train_x[:,row*col + 11*16:]
  35. result = model.predict([train_x_c, train_x_a, train_x_b])
  36. # print(result, line[-1])
  37. stock = code_table.find_one({'ts_code':line[-1][0]})
  38. if result[0][0] > 0.5:
  39. if line[-1][0].startswith('688'):
  40. continue
  41. # 去掉ST
  42. if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
  43. continue
  44. if stock['ts_code'] in ROE_stock_list or stock['ts_code'] in zeng_stock_list:
  45. R_list.append([stock['ts_code'], stock['name']])
  46. concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  47. concept_detail_list = []
  48. if len(concept_code_list) > 0:
  49. for concept in concept_code_list:
  50. for c in all_concept_code_list:
  51. if c['code'] == concept['concept_code']:
  52. concept_detail_list.append(c['name'])
  53. if log is True:
  54. with open('D:\\data\\quantization\\predict\\' + str(day) + '_week_103A.txt', mode='a', encoding="utf-8") as f:
  55. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' A ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n')
  56. elif result[0][1] > 0.5:
  57. with open('D:\\data\\quantization\\predict\\' + str(day) + '_week_103A.txt', mode='a', encoding="utf-8") as f:
  58. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' B ' + str(result[0][1]) + '\n')
  59. elif result[0][2] > 0.5:
  60. if stock['ts_code'] in holder_stock_list:
  61. print(stock['ts_code'], stock['name'], '警告危险')
  62. elif result[0][3] > 0.5:
  63. if stock['ts_code'] in holder_stock_list or stock['ts_code'] in zixuan_stock_list:
  64. print(stock['ts_code'], stock['name'], '赶紧卖出')
  65. else:
  66. pass
  67. # print(gainian_map)
  68. # print(hangye_map)
  69. random.shuffle(O_list)
  70. print(O_list[:3])
  71. random.shuffle(R_list)
  72. print('----ROE----')
  73. print(R_list[:3])
  74. if __name__ == '__main__':
  75. # 策略B
  76. # predict_today("D:\\data\\quantization\\stock505_28d_20200416.log", 20200416, model='505_28d_mix_5D_ma5_s_seq.h5', log=True)
  77. predict_today("D:\\data\\quantization\\week101_18d_20191213.log", 20191213, model='103_18d_mix_3W_s_seqA.h5', log=True)