123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- # -*- encoding:utf-8 -*-
- from sklearn.cluster import KMeans
- import numpy as np
- from annoy import AnnoyIndex
- def read_data(path):
- lines = []
- with open(path) as f:
- for x in range(160000):
- line = eval(f.readline().strip())
- # if line[-1][0] == 1 or line[-1][1] == 1:
- lines.append(line)
- return lines
- length = 20
- def class_fic(file_path=''):
- lines = read_data(file_path)
- print('读取数据完毕')
- size = len(lines[0])
- train_x = np.array([s[:length] for s in lines])
- train_y = [s[size - 1] for s in lines]
- v_x = train_x.reshape(train_x.shape[0], 4*length)
- stock_list = [s[size - 2] for s in lines]
- # annoy_sim(v_x)
- print('save数据完毕')
- return find_annoy(train_y, stock_list)
- def annoy_sim(lines):
- tree = 30
- t = AnnoyIndex(length*4, metric="angular") # 24是向量维度
- i = 0
- for stock in lines:
- t.add_item(i, stock)
- i = i + 1
- t.build(tree)
- t.save('stock_20d.ann')
- def find_annoy(lines, stock_list):
- t = AnnoyIndex(length*4, metric="angular")
- t.load('stock_20d.ann')
- num = 0
- right = 0
- win_dnn = []
- for i in range(len(lines)):
- index, distance = t.get_nns_by_item(i, 10, include_distances=True)
- # print(index, distance)
- # 预测
- total = 0
- g = 0
- for j in range(1, len(index)):
- if distance[j] < 0.4:
- total = total + 1
- if lines[j][0] == 1:
- g = g + 1
- elif lines[j][1] == 1:
- g = g + 1
- elif lines[j][2] == 1:
- g = g + 0.5
- if total > 1 and g / total > 0.38:
- right = right + 1
- if stock_list[i][1] > 20181101:
- print(stock_list[i])
- win_dnn.append([stock_list[i], lines[i]])
- # 计算
- # if lines[i][0] == 1:
- # g = 0
- # total = 0
- # for j in range(1,len(index)):
- # if distance[j] < 0.4:
- # total = total + 1
- # if lines[j][0] == 1:
- # g = g+1
- # elif lines[j][1] == 1:
- # g = g+1
- # if total > 1 and g/total > 0.21:
- # right = right + 1
- # if total > 1:
- # num = num + 1
- print(right, num)
- print('find数据完毕')
- return win_dnn
- if __name__ == '__main__':
- class_fic(file_path="D:\\data\\quantization\\stock2_20.log")
|