123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # -*- encoding:utf-8 -*-
- import numpy as np
- from keras.models import load_model
- import joblib
- def read_data(path):
- stock_lines = {}
- with open(path) as f:
- for line in f.readlines()[:]:
- line = eval(line.strip())
- stock = str(line[-2][0])
- if stock in stock_lines:
- stock_lines[stock].append(line)
- else:
- stock_lines[stock] = [line]
- # print(len(day_lines['20191230']))
- return stock_lines
- import pymongo
- from util.mongodb import get_mongo_table_instance
- code_table = get_mongo_table_instance('tushare_code')
- k_table = get_mongo_table_instance('stock_day_k')
- def predict(file_path='', model_path='15min_dnn_seq'):
- stock_lines = read_data(file_path)
- print('数据读取完毕')
- models = []
- # for x in range(0, 12):
- models.append(load_model(model_path + '.h5'))
- estimator = joblib.load('km_dmi_18.pkl')
- print('模型加载完毕')
- total_money = 0
- total_num = 0
- items = sorted(stock_lines.keys())
- for key in items:
- # print(day)
- lines = stock_lines[key]
- init_money = 10000
- last_price = 1
- if lines[0][-2][0].startswith('6'):
- continue
- buy = 0 # 0空 1买入 2卖出
- chiyou_0 = 0
- high_price = 0
- x = 24 # 每条数据项数
- k = 18 # 周期
- for line in lines:
- # v = line[1:x*k + 1]
- # v = np.array(v)
- # v = v.reshape(k, x)
- # v = v[:,6:10]
- # v = v.reshape(1, 4*k)
- # print(v)
- train_x = np.array([line[:-2]])
- train_x = train_x.reshape(train_x.shape[0], 1,6,77)
- result = models[0].predict(train_x)
- stock_name = line[-2]
- today_price = list(k_table.find({'code':line[-2][0], 'tradeDate':{'$gt':int(line[-2][1])}}).sort('tradeDate',pymongo.ASCENDING).limit(1))
- today_price = today_price[0]
- if result[0][0] > 0.6 or result[0][1] > 0.6: #and (r[0] not in [2,6,8,10]):
- chiyou_0 = 0
- if buy == 0:
- last_price = today_price['open']
- high_price = last_price
- print('首次买入', stock_name, today_price['open'])
- buy = 1
- else:
- init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
- last_price = today_price['close']
- print('买入+买入', stock_name, today_price['close'])
- buy = 1
- if last_price > high_price:
- high_price = last_price
- elif buy == 1:
- chiyou_0 = chiyou_0 + 1
- last_price = today_price['close']
- if chiyou_0 > 2 and today_price['close'] < last_price:
- print('卖出', stock_name, today_price['close'])
- init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
- buy = 0
- chiyou_0 = 0
- elif init_money < 9000:
- print('止损卖出', stock_name, today_price['close'])
- init_money = init_money * (today_price['close'] - last_price)/last_price + init_money
- buy = 0
- chiyou_0 = 0
- print(key, init_money)
- with open('D:\\data\\quantization\\stock_16_18d' + '_' + 'profit.log', 'a') as f:
- if init_money > 10000:
- f.write(str(key) + ' ' + str(init_money) + '\n')
- elif init_money < 10000:
- f.write(str(key) + ' ' + str(init_money) + '\n')
- if init_money != 10000:
- total_money = total_money + init_money
- total_num = total_num + 1
- print(total_money, total_num, total_money/total_num/10000)
- if __name__ == '__main__':
- # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5')
- # predict(file_path='D:\\data\\quantization\\stock12_18d_test.log', model_path='12_18d_dnn_seq')
- predict(file_path='D:\\data\\quantization\\stock16_18d_test.log', model_path='16_18d_cnn_seq')
- # predict(file_path='D:\\data\\quantization\\stock12_18d_20190103_20190604.log', model_path='13_18d_dnn_seq')
|