# -*- encoding:utf-8 -*- import numpy as np from keras.models import load_model import random from mix.stock_source import * import pymongo from util.mongodb import get_mongo_table_instance code_table = get_mongo_table_instance('tushare_code') k_table = get_mongo_table_instance('stock_day_k') stock_concept_table = get_mongo_table_instance('tushare_concept_detail') all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({})) gainian_map = {} hangye_map = {} Z_list = [] # 自选 R_list = [] # ROE O_list = [] # 其他 def predict_today(file, day, model='10_18d', log=True, x=29, y=1): industry_list = get_hot_industry(day) lines = [] with open(file) as f: for line in f.readlines()[:]: line = eval(line.strip()) lines.append(line) size = len(lines[0]) model=load_model(model) for line in lines: train_x = np.array([line[:size - 1]]) train_x_tmp = train_x[:,:x*y] train_x_a = train_x_tmp.reshape(train_x.shape[0], x, y, 1) # train_x_b = train_x_tmp.reshape(train_x.shape[0], 18, 24) train_x_c = train_x[:,x*y:] result = model.predict([train_x_c, train_x_a, ]) # print(result, line[-1]) stock = code_table.find_one({'ts_code':line[-1][0]}) if result[0][0] > 0.5: if line[-1][0].startswith('688'): continue # 去掉ST if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'): continue k_table_list = list(k_table.find({'code':line[-1][0], 'tradeDate':{'$lte':day}}).sort("tradeDate", pymongo.DESCENDING).limit(5)) # 指定某几个行业 # if stock['industry'] in industry: concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']})) concept_detail_list = [] if len(concept_code_list) > 0: for concept in concept_code_list: for c in all_concept_code_list: if c['code'] == concept['concept_code']: concept_detail_list.append(c['name']) # if stock['ts_code'] in ROE_stock_list: print(stock['ts_code'], stock['name'], '买入') O_list.append([stock['ts_code'], stock['name']]) if result[0][0] > 0.9: R_list.append([stock['ts_code'], stock['name']]) with open('D:\\data\\quantization\\predict\\' + str(day) + '_mix537.txt', mode='a', encoding="utf-8") as f: f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n') elif result[0][1] > 0.5: pass elif result[0][2] > 0.5: if stock['ts_code'] in holder_stock_list: print(stock['ts_code'], stock['name'], '警告危险') elif result[0][3] > 0.5: if stock['ts_code'] in holder_stock_list or stock['ts_code'] in zixuan_stock_list: print(stock['ts_code'], stock['name'], '赶紧卖出') else: pass # print(gainian_map) # print(hangye_map) # random.shuffle(O_list) # print(O_list[:3]) random.shuffle(R_list) print('----ROE----') print(R_list[:]) import datetime if __name__ == '__main__': today = datetime.datetime.now() today = today today = today.strftime('%Y%m%d') # predict(file_path='D:\\data\\quantization\\stock6_5_test.log', model_path='5d_dnn_seq.h5') # predict(file_path='D:\\data\\quantization\\stock6_test.log', model_path='15m_dnn_seq.h5') # multi_predict() # 策略B # predict_today("D:\\data\\quantization\\stock505_28d_20200416.log", 20200416, model='505_28d_mix_5D_ma5_s_seq.h5', log=True) # predict_today("D:\\data\\quantization\\stock517_28d_" + str(today) + ".log", int(today), model='517_28d_mix_3D_ma5_s_seq.h5', log=True, x=28, y=16) predict_today("D:\\data\\quantization\\stock538_28d_" + str(today) + ".log", int(today), model='539_28d_mix_5D_ma5_s_seq.h5', log=True, x=28, y=17) # join_two_day(20200305, 20200305) # check_everyday(20200311, 20200312)