predict_everyweek_113.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import random
  5. from mix.stock_source import *
  6. import pymongo
  7. from util.mongodb import get_mongo_table_instance
  8. code_table = get_mongo_table_instance('tushare_code')
  9. k_table = get_mongo_table_instance('stock_day_k')
  10. stock_concept_table = get_mongo_table_instance('tushare_concept_detail')
  11. all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({}))
  12. gainian_map = {}
  13. hangye_map = {}
  14. Z_list = [] # 自选
  15. R_list = [] # ROE
  16. O_list = [] # 其他
  17. def predict_today(file, day, model='10_18d', log=True):
  18. # industry_list = get_hot_industry(day)
  19. lines = []
  20. with open(file) as f:
  21. for line in f.readlines()[:]:
  22. line = eval(line.strip())
  23. lines.append(line)
  24. size = len(lines[0])
  25. model=load_model(model)
  26. row = 18
  27. col = 9
  28. for line in lines:
  29. train_x = np.array([line[:size - 1]])
  30. train_x_a = train_x[:,:row*col]
  31. train_x_a = train_x_a.reshape(train_x.shape[0], row, col, 1)
  32. train_x_b = train_x[:, row*col:row*col + 11*14]
  33. train_x_b = train_x_b.reshape(train_x.shape[0], 11, 14, 1)
  34. train_x_c = train_x[:,row*col + 11*14:]
  35. result = model.predict([train_x_c, train_x_a, train_x_b])
  36. # print(result, line[-1])
  37. stock = code_table.find_one({'ts_code':line[-1][0]})
  38. if result[0][0] > 0.85:
  39. if line[-1][0].startswith('688'):
  40. continue
  41. # 去掉ST
  42. if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'):
  43. continue
  44. if stock['ts_code'] in ROE_stock_list or stock['ts_code'] in zeng_stock_list:
  45. R_list.append([stock['ts_code'], stock['name']])
  46. print(stock['ts_code'], stock['name'], 'zhang10')
  47. concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']}))
  48. concept_detail_list = []
  49. if len(concept_code_list) > 0:
  50. for concept in concept_code_list:
  51. for c in all_concept_code_list:
  52. if c['code'] == concept['concept_code']:
  53. concept_detail_list.append(c['name'])
  54. if log is True:
  55. with open('D:\\data\\quantization\\predict\\' + str(day) + '_week_119.txt', mode='a', encoding="utf-8") as f:
  56. f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n')
  57. elif result[0][1] > 0.7:
  58. print(stock['ts_code'], stock['name'], 'zhang5')
  59. elif result[0][2] > 0.5:
  60. pass
  61. elif result[0][3] > 0.5:
  62. pass
  63. else:
  64. pass
  65. # print(gainian_map)
  66. # print(hangye_map)
  67. random.shuffle(O_list)
  68. print(O_list[:3])
  69. random.shuffle(R_list)
  70. print('----ROE----')
  71. print(R_list[:3])
  72. if __name__ == '__main__':
  73. # 策略B
  74. # predict_today("D:\\data\\quantization\\stock505_28d_20200416.log", 20200416, model='505_28d_mix_5D_ma5_s_seq.h5', log=True)
  75. predict_today("D:\\data\\quantization\\week119_18d_20200403.log", 20200410, model='119_18d_mix_3W_s_seqA.h5', log=True)