# -*- encoding:utf-8 -*- import numpy as np from keras.models import load_model import random from mix.stock_source import * import pymongo from util.mongodb import get_mongo_table_instance code_table = get_mongo_table_instance('tushare_code') k_table = get_mongo_table_instance('stock_day_k') stock_concept_table = get_mongo_table_instance('tushare_concept_detail') all_concept_code_list = list(get_mongo_table_instance('tushare_concept').find({})) gainian_map = {} hangye_map = {} Z_list = [] # 自选 R_list = [] # ROE O_list = [] # 其他 def predict_today(file, day, model='10_18d', log=True): # industry_list = get_hot_industry(day) lines = [] with open(file) as f: for line in f.readlines()[:]: line = eval(line.strip()) lines.append(line) size = len(lines[0]) model=load_model(model) row = 18 col = 9 for line in lines: train_x = np.array([line[:size - 1]]) train_x_a = train_x[:,:row*col] train_x_a = train_x_a.reshape(train_x.shape[0], row, col, 1) train_x_b = train_x[:, row*col:row*col + 11*14] train_x_b = train_x_b.reshape(train_x.shape[0], 11, 14, 1) train_x_c = train_x[:,row*col + 11*14:] result = model.predict([train_x_c, train_x_a, train_x_b]) # print(result, line[-1]) stock = code_table.find_one({'ts_code':line[-1][0]}) if result[0][0] > 0.85: if line[-1][0].startswith('688'): continue # 去掉ST if stock['name'].startswith('ST') or stock['name'].startswith('N') or stock['name'].startswith('*'): continue if stock['ts_code'] in ROE_stock_list or stock['ts_code'] in zeng_stock_list: R_list.append([stock['ts_code'], stock['name']]) print(stock['ts_code'], stock['name'], 'zhang10') concept_code_list = list(stock_concept_table.find({'ts_code':stock['ts_code']})) concept_detail_list = [] if len(concept_code_list) > 0: for concept in concept_code_list: for c in all_concept_code_list: if c['code'] == concept['concept_code']: concept_detail_list.append(c['name']) if log is True: with open('D:\\data\\quantization\\predict\\' + str(day) + '_week_119.txt', mode='a', encoding="utf-8") as f: f.write(str(line[-1]) + ' ' + stock['name'] + ' ' + stock['sw_industry'] + ' ' + str(concept_detail_list) + ' ' + str(result[0][0]) + '\n') elif result[0][1] > 0.7: print(stock['ts_code'], stock['name'], 'zhang5') elif result[0][2] > 0.5: pass elif result[0][3] > 0.5: pass else: pass # print(gainian_map) # print(hangye_map) random.shuffle(O_list) print(O_list[:3]) random.shuffle(R_list) print('----ROE----') print(R_list[:3]) if __name__ == '__main__': # 策略B # predict_today("D:\\data\\quantization\\stock505_28d_20200416.log", 20200416, model='505_28d_mix_5D_ma5_s_seq.h5', log=True) predict_today("D:\\data\\quantization\\week119_18d_20200403.log", 20200410, model='119_18d_mix_3W_s_seqA.h5', log=True)