kmeans.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # -*- encoding:utf-8 -*-
  2. from sklearn.cluster import KMeans
  3. import numpy as np
  4. from annoy import AnnoyIndex
  5. def read_data(path):
  6. lines = []
  7. with open(path) as f:
  8. for x in range(160000):
  9. line = eval(f.readline().strip())
  10. # if line[-1][0] == 1 or line[-1][1] == 1:
  11. lines.append(line)
  12. return lines
  13. length = 20
  14. def class_fic(file_path=''):
  15. lines = read_data(file_path)
  16. print('读取数据完毕')
  17. size = len(lines[0])
  18. train_x = np.array([s[:length] for s in lines])
  19. train_y = [s[size - 1] for s in lines]
  20. v_x = train_x.reshape(train_x.shape[0], 4*length)
  21. stock_list = [s[size - 2] for s in lines]
  22. # annoy_sim(v_x)
  23. print('save数据完毕')
  24. return find_annoy(train_y, stock_list)
  25. def annoy_sim(lines):
  26. tree = 30
  27. t = AnnoyIndex(length*4, metric="angular") # 24是向量维度
  28. i = 0
  29. for stock in lines:
  30. t.add_item(i, stock)
  31. i = i + 1
  32. t.build(tree)
  33. t.save('stock_20d.ann')
  34. def find_annoy(lines, stock_list):
  35. t = AnnoyIndex(length*4, metric="angular")
  36. t.load('stock_20d.ann')
  37. num = 0
  38. right = 0
  39. win_dnn = []
  40. for i in range(len(lines)):
  41. index, distance = t.get_nns_by_item(i, 10, include_distances=True)
  42. # print(index, distance)
  43. # 预测
  44. total = 0
  45. g = 0
  46. for j in range(1, len(index)):
  47. if distance[j] < 0.4:
  48. total = total + 1
  49. if lines[j][0] == 1:
  50. g = g + 1
  51. elif lines[j][1] == 1:
  52. g = g + 1
  53. elif lines[j][2] == 1:
  54. g = g + 0.5
  55. if total > 1 and g / total > 0.38:
  56. right = right + 1
  57. if stock_list[i][1] > 20181101:
  58. print(stock_list[i])
  59. win_dnn.append([stock_list[i], lines[i]])
  60. # 计算
  61. # if lines[i][0] == 1:
  62. # g = 0
  63. # total = 0
  64. # for j in range(1,len(index)):
  65. # if distance[j] < 0.4:
  66. # total = total + 1
  67. # if lines[j][0] == 1:
  68. # g = g+1
  69. # elif lines[j][1] == 1:
  70. # g = g+1
  71. # if total > 1 and g/total > 0.21:
  72. # right = right + 1
  73. # if total > 1:
  74. # num = num + 1
  75. print(right, num)
  76. print('find数据完毕')
  77. return win_dnn
  78. if __name__ == '__main__':
  79. class_fic(file_path="D:\\data\\quantization\\stock2_20.log")