dnn_predict_by_stock.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. from keras.models import load_model
  4. import joblib
  5. def read_data(path):
  6. stock_lines = {}
  7. with open(path) as f:
  8. for line in f.readlines()[:]:
  9. line = eval(line.strip())
  10. stock = str(line[-2][0])
  11. if stock in stock_lines:
  12. stock_lines[stock].append(line)
  13. else:
  14. stock_lines[stock] = [line]
  15. # print(len(day_lines['20191230']))
  16. return stock_lines
  17. import pymongo
  18. from util.mongodb import get_mongo_table_instance
  19. code_table = get_mongo_table_instance('tushare_code')
  20. k_table = get_mongo_table_instance('stock_day_k')
  21. def predict(file_path='', model_path='15min_dnn_seq'):
  22. stock_lines = read_data(file_path)
  23. print('数据读取完毕')
  24. models = []
  25. for x in range(0, 12):
  26. models.append(load_model(model_path + '_' + str(x) + '.h5'))
  27. estimator = joblib.load('km_dmi_18.pkl')
  28. print('模型加载完毕')
  29. total_money = 0
  30. total_num = 0
  31. items = sorted(stock_lines.keys())
  32. for key in items:
  33. # print(day)
  34. lines = stock_lines[key]
  35. init_money = 10000
  36. last_price = 1
  37. if lines[0][-2][0].startswith('6'):
  38. continue
  39. buy = 0 # 0空 1买入 2卖出
  40. chiyou_0 = 0
  41. high_price = 0
  42. x = 24 # 每条数据项数
  43. k = 18 # 周期
  44. for line in lines:
  45. v = line[1:x*k + 1]
  46. v = np.array(v)
  47. v = v.reshape(k, x)
  48. v = v[:,6:10]
  49. v = v.reshape(1, 4*k)
  50. # print(v)
  51. r = estimator.predict(v)
  52. train_x = np.array([line[:-2]])
  53. result = models[r[0]].predict(train_x)
  54. stock_name = line[-2]
  55. today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
  56. today_price = today_price[0]
  57. if result[0][0] > 0.5 or result[0][1] > 0.5: #and (r[0] not in [2,6,8,10]):
  58. chiyou_0 = 0
  59. print(r[0])
  60. if buy == 0:
  61. last_price = today_price['open']
  62. high_price = last_price
  63. print('首次买入', stock_name, today_price['open'])
  64. buy = 1
  65. else:
  66. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  67. last_price = today_price['close']
  68. print('买入+买入', stock_name, today_price['open'])
  69. buy = 1
  70. if last_price > high_price:
  71. high_price = last_price
  72. elif result[0][3] > 0.5 or result[0][4] > 0.5:#and (r[0] not in [5,8]):
  73. if buy == 1:
  74. if chiyou_0 > 2 or init_money < 9000:
  75. init_money = init_money * (today_price['open'] - last_price)/last_price + init_money
  76. print('卖出', stock_name, today_price['open'])
  77. buy = 0
  78. chiyou_0 = 0
  79. # elif init_money > 15000 and 100*(today_price['close'] - high_price)/high_price < -15:
  80. # init_money = init_money * (today_price['open'] - last_price)/last_price + init_money
  81. # print('最高点回撤卖出', stock_name, today_price['open'])
  82. # buy = 0
  83. # chiyou_0 = 0
  84. else:
  85. init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  86. print('继续持有,不卖出', stock_name, today_price['close'])
  87. buy = 1
  88. chiyou_0 = chiyou_0 + 1
  89. if today_price['close'] > high_price:
  90. high_price = today_price['close']
  91. else:
  92. if buy == 1:
  93. init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money
  94. if init_money < 8500:
  95. print('止损卖出', stock_name, today_price['close'])
  96. buy = 0
  97. else:
  98. chiyou_0 = chiyou_0 + 1
  99. if init_money < 10500 and chiyou_0 > 1 and today_price['close'] < last_price:
  100. print('连续持有次数太多-- 卖出', stock_name, today_price['close'])
  101. buy = 0
  102. chiyou_0 = 0
  103. elif chiyou_0 > 2 and today_price['close'] < last_price:
  104. print('连续持有次数太多++ 卖出', stock_name, today_price['close'])
  105. buy = 0
  106. chiyou_0 = 0
  107. else:
  108. buy = 1
  109. print('持有', stock_name, today_price['close'])
  110. if today_price['close'] > high_price:
  111. high_price = today_price['close']
  112. last_price = today_price['close']
  113. else:
  114. # print('忽略')
  115. pass
  116. # 具有后验知识的存在,
  117. # if result[0][1] > 0.5 or result[0][2] > 0.5:
  118. # chiyou_0 = 0
  119. # if r[0] in [2,5,9,10,11]:
  120. # if buy == 0:
  121. # last_price = today_price['open']
  122. # print('首次买入', line[-2], today_price['open'])
  123. # buy = 1
  124. # elif buy == 1:
  125. # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  126. # last_price = today_price['close']
  127. # print('买入+买入', line[-2], today_price['open'])
  128. # buy = 1
  129. # else:
  130. # last_price = today_price['close']
  131. # print('卖出后买入', line[-2], today_price['open'])
  132. # buy = 1
  133. # else:
  134. # if buy == 1:
  135. # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  136. # last_price = today_price['close']
  137. # print('买入+买入', line[-2], today_price['open'])
  138. # buy = 1
  139. # elif result[0][1] > 0.5 or result[0][2] > 0.5:
  140. # # if r[0] in [0,1,3,4,6,7] and buy in [0,1]:
  141. # buy = 0
  142. # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
  143. # print('卖出', line[-2], today_price['close'])
  144. # else:
  145. # if buy == 1:
  146. # init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money
  147. # if init_money < 9000:
  148. # print('止损卖出', line[-2], today_price['close'])
  149. # buy = 0
  150. # else:
  151. # chiyou_0 = chiyou_0 + 1
  152. # if chiyou_0 > 5 and today_price['close'] < last_price:
  153. # print('连续持有次数太多 卖出', line[-2], today_price['close'])
  154. # buy = 0
  155. # else:
  156. # buy = 1
  157. # print('持有', line[-2], today_price['close'])
  158. # last_price = today_price['close']
  159. print(key, init_money)
  160. with open('D:\\data\\quantization\\stock_15_18d' + '_' + 'profit.log', 'a') as f:
  161. if init_money > 10000:
  162. f.write(str(key) + ' ' + str(init_money) + '\n')
  163. elif init_money < 10000:
  164. f.write(str(key) + ' ' + str(init_money) + '\n')
  165. if init_money != 10000:
  166. total_money = total_money + init_money
  167. total_num = total_num + 1
  168. print(total_money, total_num, total_money/total_num/10000)
  169. if __name__ == '__main__':
  170. # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
  171. # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq')
  172. predict(file_path='D:\\data\\quantization\\stock15_18d_test.log', model_path='15_18d_dnn_seq')
  173. # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')