mix_predict_by_stock.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import joblib
  5. def read_data(path):
  6. stock_lines = {}
  7. with open(path) as f:
  8. for line in f.readlines()[:]:
  9. line = eval(line.strip())
  10. stock = str(line[-2][0])
  11. if stock in stock_lines:
  12. stock_lines[stock].append(line)
  13. else:
  14. stock_lines[stock] = [line]
  15. # print(len(day_lines['20191230']))
  16. return stock_lines
  17. import pymongo
  18. from util.mongodb import get_mongo_table_instance
  19. code_table = get_mongo_table_instance('tushare_code')
  20. k_table = get_mongo_table_instance('stock_day_k')
  21. def predict(file_path='', model_path='15min_dnn_seq'):
  22. stock_lines = read_data(file_path)
  23. print('数据读取完毕')
  24. models = []
  25. models.append(load_model(model_path + '.h5'))
  26. # estimator = joblib.load('km_dmi_18.pkl')
  27. print('模型加载完毕')
  28. total_money = 0
  29. total_num = 0
  30. items = sorted(stock_lines.keys())
  31. for key in items:
  32. # print(day)
  33. lines = stock_lines[key]
  34. init_money = 10000
  35. last_price = 1
  36. if lines[0][-2][0].startswith('6'):
  37. continue
  38. buy = 0 # 0空 1买入 2卖出
  39. chiyou_0 = 0
  40. high_price = 0
  41. x = 24 # 每条数据项数
  42. k = 18 # 周期
  43. for line in lines:
  44. # v = line[1:x*k + 1]
  45. # v = np.array(v)
  46. # v = v.reshape(k, x)
  47. # v = v[:,6:10]
  48. # v = v.reshape(1, 4*k)
  49. # print(v)
  50. # r = estimator.predict(v)
  51. test_x = np.array([line[:-2]])
  52. test_x_a = test_x[:,:18*24]
  53. test_x_a = test_x_a.reshape(test_x.shape[0], 18, 24, 1)
  54. test_x_b = test_x[:, 18*24:18*24+2*18]
  55. test_x_b = test_x_b.reshape(test_x.shape[0], 18, 2, 1)
  56. test_x_c = test_x[:,18*24+2*18:]
  57. result = models[0].predict([test_x_c, test_x_a, test_x_b])
  58. stock_name = line[-2]
  59. today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
  60. today_price = today_price[0]
  61. if result[0][0] + result[0][1] > 0.7:
  62. chiyou_0 = 0
  63. if buy == 0:
  64. last_price = today_price['open']
  65. high_price = last_price
  66. print('首次买入', stock_name, today_price['open'])
  67. buy = 1
  68. else:
  69. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  70. last_price = today_price['close']
  71. print('买入+买入', stock_name, today_price['close'])
  72. buy = 1
  73. if last_price > high_price:
  74. high_price = last_price
  75. elif buy == 1:
  76. chiyou_0 = chiyou_0 + 1
  77. if chiyou_0 > 2 and ((high_price - today_price['close'])/high_price*100 > 5):
  78. print('卖出', stock_name, today_price['close'])
  79. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  80. buy = 0
  81. chiyou_0 = 0
  82. if init_money < 9000:
  83. print('止损卖出', stock_name, today_price['close'])
  84. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  85. buy = 0
  86. chiyou_0 = 0
  87. else:
  88. print('继续持有', stock_name, today_price['close'])
  89. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  90. buy = 1
  91. last_price = today_price['close']
  92. if last_price > high_price:
  93. high_price = last_price
  94. print(key, init_money)
  95. with open('D:\\data\\quantization\\stock_18_18d' + '_' + 'profit.log', 'a') as f:
  96. if init_money > 10000:
  97. f.write(str(key) + ' ' + str(init_money) + '\n')
  98. elif init_money < 10000:
  99. f.write(str(key) + ' ' + str(init_money) + '\n')
  100. if init_money != 10000:
  101. total_money = total_money + init_money
  102. total_num = total_num + 1
  103. print(total_money, total_num, total_money/total_num/10000)
  104. if __name__ == '__main__':
  105. predict(file_path='D:\\data\\quantization\\stock18_18d_test.log', model_path='18_18d_mix_seq')