kmeans.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # -*- encoding:utf-8 -*-
  2. from sklearn.cluster import KMeans
  3. import numpy as np
  4. # from annoy import AnnoyIndex
  5. import joblib
  6. def read_data(path):
  7. lines = []
  8. with open(path) as f:
  9. for x in range(20000):
  10. line = eval(f.readline().strip())
  11. # if line[-1][0] == 1 or line[-1][1] == 1:
  12. lines.append(line)
  13. return lines
  14. length = 18 # 周期是多少
  15. j = 20
  16. def class_fic(file_path=''):
  17. lines = read_data(file_path)
  18. print('读取数据完毕')
  19. size = len(lines[0])
  20. x_list = []
  21. for s in lines:
  22. tmp_list = []
  23. for x in range(0, length):
  24. tmp_list = tmp_list + s[x*j+5:x*j+9]
  25. x_list.append(tmp_list)
  26. train_x = np.array(x_list)
  27. # train_x = np.array([s[5:9] + s[j+5:j+9] + s[j*2+5:j*2+9] + s[j*3+5:j*3+9] + s[j*4+5:j*4+9] + s[j*5+5:j*5+9] + s[j*6+5:j*6+9]
  28. # + s[j*7+5:j*7+9] + s[j*8+5:j*8+9] + s[j*9+5:j*9+9] + s[j*10+5:j*10+9] + s[j*11+5:j*11+9] + s[j*12+5:j*12+9] + s[j*13+5:j*13+9]
  29. # + s[j*14+5:j*14+9] + s[j*15+5:j*15+9] + s[j*9+5:j*9+9] + s[j*10+5:j*10+9] + s[j*11+5:j*11+9] + s[j*12+5:j*12+9] + s[j*13+5:j*13+9]
  30. # for s in lines])
  31. # train_y = [s[size - 1] for s in lines]
  32. v_x = train_x.reshape(train_x.shape[0], 4*length)
  33. stock_list = [s[size - 2] for s in lines]
  34. estimator = KMeans(n_clusters=12, random_state=129)
  35. estimator.fit(v_x)
  36. label_pred = estimator.labels_ # 获取聚类标签
  37. centroids = estimator.cluster_centers_
  38. joblib.dump(estimator , 'km_dmi_18.pkl')
  39. print(estimator.predict(v_x[:10]))
  40. estimator = joblib.load('km_dmi_18.pkl')
  41. print(estimator.predict(v_x[10:20]))
  42. # annoy_sim(v_x)
  43. # print('save数据完毕')
  44. # return find_annoy(train_y, stock_list)
  45. def annoy_sim(lines):
  46. tree = 30
  47. t = AnnoyIndex(length*4, metric="angular") # 24是向量维度
  48. i = 0
  49. for stock in lines:
  50. t.add_item(i, stock)
  51. i = i + 1
  52. t.build(tree)
  53. t.save('stock_20d.ann')
  54. def find_annoy(lines, stock_list):
  55. t = AnnoyIndex(length*4, metric="angular")
  56. t.load('stock_20d.ann')
  57. num = 0
  58. right = 0
  59. win_dnn = []
  60. for i in range(len(lines)):
  61. index, distance = t.get_nns_by_item(i, 10, include_distances=True)
  62. # print(index, distance)
  63. # 预测
  64. total = 0
  65. g = 0
  66. for j in range(1, len(index)):
  67. if distance[j] < 0.4:
  68. total = total + 1
  69. if lines[j][0] == 1:
  70. g = g + 1
  71. elif lines[j][1] == 1:
  72. g = g + 1
  73. elif lines[j][2] == 1:
  74. g = g + 0.5
  75. if total > 1 and g / total > 0.38:
  76. right = right + 1
  77. if stock_list[i][1] > 20181101:
  78. print(stock_list[i])
  79. win_dnn.append([stock_list[i], lines[i]])
  80. # 计算
  81. # if lines[i][0] == 1:
  82. # g = 0
  83. # total = 0
  84. # for j in range(1,len(index)):
  85. # if distance[j] < 0.4:
  86. # total = total + 1
  87. # if lines[j][0] == 1:
  88. # g = g+1
  89. # elif lines[j][1] == 1:
  90. # g = g+1
  91. # if total > 1 and g/total > 0.21:
  92. # right = right + 1
  93. # if total > 1:
  94. # num = num + 1
  95. print(right, num)
  96. print('find数据完毕')
  97. return win_dnn
  98. if __name__ == '__main__':
  99. # class_fic(file_path="D:\\data\\quantization\\stock2_10.log")
  100. class_fic(file_path="D:\\data\\quantization\\stock9_18.log")