# -*- encoding:utf-8 -*- import numpy as np from keras.models import load_model import joblib def read_data(path): stock_lines = {} with open(path) as f: for line in f.readlines()[:]: line = eval(line.strip()) stock = str(line[-2][0]) if stock in stock_lines: stock_lines[stock].append(line) else: stock_lines[stock] = [line] # print(len(day_lines['20191230'])) return stock_lines import pymongo from util.mongodb import get_mongo_table_instance code_table = get_mongo_table_instance('tushare_code') k_table = get_mongo_table_instance('stock_day_k') def predict(file_path='', model_path='15min_dnn_seq'): stock_lines = read_data(file_path) print('数据读取完毕') models = [] for x in range(0, 12): models.append(load_model(model_path + '_' + str(x) + '.h5')) estimator = joblib.load('km_dmi_18.pkl') print('模型加载完毕') total_money = 0 total_num = 0 items = sorted(stock_lines.keys()) for key in items: # print(day) lines = stock_lines[key] init_money = 10000 last_price = 1 if lines[0][-2][0].startswith('6'): continue buy = 0 # 0空 1买入 2卖出 chiyou_0 = 0 high_price = 0 x = 24 # 每条数据项数 k = 18 # 周期 for line in lines: v = line[1:x*k + 1] v = np.array(v) v = v.reshape(k, x) v = v[:,6:10] v = v.reshape(1, 4*k) # print(v) r = estimator.predict(v) train_x = np.array([line[:-2]]) result = models[r[0]].predict(train_x) stock_name = line[-2] today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1)) today_price = today_price[0] if result[0][0] > 0.5 or result[0][1] > 0.5: #and (r[0] not in [2,6,8,10]): chiyou_0 = 0 print(r[0]) if buy == 0: last_price = today_price['open'] high_price = last_price print('首次买入', stock_name, today_price['open']) buy = 1 else: init_money = init_money * (today_price['close'] - last_price)/last_price + init_money last_price = today_price['close'] print('买入+买入', stock_name, today_price['open']) buy = 1 if last_price > high_price: high_price = last_price elif result[0][3] > 0.5 or result[0][4] > 0.5:#and (r[0] not in [5,8]): if buy == 1: if chiyou_0 > 2 or init_money < 9000: init_money = init_money * (today_price['open'] - last_price)/last_price + init_money print('卖出', stock_name, today_price['open']) buy = 0 chiyou_0 = 0 # elif init_money > 15000 and 100*(today_price['close'] - high_price)/high_price < -15: # init_money = init_money * (today_price['open'] - last_price)/last_price + init_money # print('最高点回撤卖出', stock_name, today_price['open']) # buy = 0 # chiyou_0 = 0 else: init_money = init_money * (today_price['close'] - last_price)/last_price + init_money print('继续持有,不卖出', stock_name, today_price['close']) buy = 1 chiyou_0 = chiyou_0 + 1 if today_price['close'] > high_price: high_price = today_price['close'] else: if buy == 1: init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money if init_money < 8500: print('止损卖出', stock_name, today_price['close']) buy = 0 else: chiyou_0 = chiyou_0 + 1 if init_money < 10500 and chiyou_0 > 1 and today_price['close'] < last_price: print('连续持有次数太多-- 卖出', stock_name, today_price['close']) buy = 0 chiyou_0 = 0 elif chiyou_0 > 2 and today_price['close'] < last_price: print('连续持有次数太多++ 卖出', stock_name, today_price['close']) buy = 0 chiyou_0 = 0 else: buy = 1 print('持有', stock_name, today_price['close']) if today_price['close'] > high_price: high_price = today_price['close'] last_price = today_price['close'] else: # print('忽略') pass # 具有后验知识的存在, # if result[0][1] > 0.5 or result[0][2] > 0.5: # chiyou_0 = 0 # if r[0] in [2,5,9,10,11]: # if buy == 0: # last_price = today_price['open'] # print('首次买入', line[-2], today_price['open']) # buy = 1 # elif buy == 1: # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money # last_price = today_price['close'] # print('买入+买入', line[-2], today_price['open']) # buy = 1 # else: # last_price = today_price['close'] # print('卖出后买入', line[-2], today_price['open']) # buy = 1 # else: # if buy == 1: # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money # last_price = today_price['close'] # print('买入+买入', line[-2], today_price['open']) # buy = 1 # elif result[0][1] > 0.5 or result[0][2] > 0.5: # # if r[0] in [0,1,3,4,6,7] and buy in [0,1]: # buy = 0 # init_money = init_money * (today_price['close'] - last_price)/last_price + init_money # print('卖出', line[-2], today_price['close']) # else: # if buy == 1: # init_money = (init_money * (today_price['close'] - last_price)/last_price) + init_money # if init_money < 9000: # print('止损卖出', line[-2], today_price['close']) # buy = 0 # else: # chiyou_0 = chiyou_0 + 1 # if chiyou_0 > 5 and today_price['close'] < last_price: # print('连续持有次数太多 卖出', line[-2], today_price['close']) # buy = 0 # else: # buy = 1 # print('持有', line[-2], today_price['close']) # last_price = today_price['close'] print(key, init_money) with open('D:\\data\\quantization\\stock_15_18d' + '_' + 'profit.log', 'a') as f: if init_money > 10000: f.write(str(key) + ' ' + str(init_money) + '\n') elif init_money < 10000: f.write(str(key) + ' ' + str(init_money) + '\n') if init_money != 10000: total_money = total_money + init_money total_num = total_num + 1 print(total_money, total_num, total_money/total_num/10000) if __name__ == '__main__': # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5') # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq') predict(file_path='D:\\data\\quantization\\stock15_18d_test.log', model_path='15_18d_dnn_seq') # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')