kmeans.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # -*- encoding:utf-8 -*-
  2. from sklearn.cluster import KMeans
  3. import numpy as np
  4. from annoy import AnnoyIndex
  5. import joblib
  6. def read_data(path):
  7. lines = []
  8. with open(path) as f:
  9. for x in range(100000):
  10. line = eval(f.readline().strip())
  11. # if line[-1][0] == 1 or line[-1][1] == 1:
  12. lines.append(line)
  13. return lines
  14. length = 10
  15. def class_fic(file_path=''):
  16. lines = read_data(file_path)
  17. print('读取数据完毕')
  18. size = len(lines[0])
  19. train_x = np.array([s[:length] for s in lines])
  20. train_y = [s[size - 1] for s in lines]
  21. v_x = train_x.reshape(train_x.shape[0], 4*length)
  22. stock_list = [s[size - 2] for s in lines]
  23. estimator = KMeans(n_clusters=16, random_state=9)
  24. estimator.fit(v_x)
  25. label_pred = estimator.labels_ # 获取聚类标签
  26. centroids = estimator.cluster_centers_
  27. joblib.dump(estimator , 'km.pkl')
  28. print(estimator.predict(v_x[:10]))
  29. estimator = joblib.load('km.pkl')
  30. print(estimator.predict(v_x[10:20]))
  31. # annoy_sim(v_x)
  32. # print('save数据完毕')
  33. # return find_annoy(train_y, stock_list)
  34. def annoy_sim(lines):
  35. tree = 30
  36. t = AnnoyIndex(length*4, metric="angular") # 24是向量维度
  37. i = 0
  38. for stock in lines:
  39. t.add_item(i, stock)
  40. i = i + 1
  41. t.build(tree)
  42. t.save('stock_20d.ann')
  43. def find_annoy(lines, stock_list):
  44. t = AnnoyIndex(length*4, metric="angular")
  45. t.load('stock_20d.ann')
  46. num = 0
  47. right = 0
  48. win_dnn = []
  49. for i in range(len(lines)):
  50. index, distance = t.get_nns_by_item(i, 10, include_distances=True)
  51. # print(index, distance)
  52. # 预测
  53. total = 0
  54. g = 0
  55. for j in range(1, len(index)):
  56. if distance[j] < 0.4:
  57. total = total + 1
  58. if lines[j][0] == 1:
  59. g = g + 1
  60. elif lines[j][1] == 1:
  61. g = g + 1
  62. elif lines[j][2] == 1:
  63. g = g + 0.5
  64. if total > 1 and g / total > 0.38:
  65. right = right + 1
  66. if stock_list[i][1] > 20181101:
  67. print(stock_list[i])
  68. win_dnn.append([stock_list[i], lines[i]])
  69. # 计算
  70. # if lines[i][0] == 1:
  71. # g = 0
  72. # total = 0
  73. # for j in range(1,len(index)):
  74. # if distance[j] < 0.4:
  75. # total = total + 1
  76. # if lines[j][0] == 1:
  77. # g = g+1
  78. # elif lines[j][1] == 1:
  79. # g = g+1
  80. # if total > 1 and g/total > 0.21:
  81. # right = right + 1
  82. # if total > 1:
  83. # num = num + 1
  84. print(right, num)
  85. print('find数据完毕')
  86. return win_dnn
  87. if __name__ == '__main__':
  88. class_fic(file_path="D:\\data\\quantization\\stock2_10.log")