svm_train.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/python
  2. # -*- encoding:utf-8 -*-
  3. from sklearn import svm
  4. import numpy as np
  5. from sklearn.model_selection import GridSearchCV
  6. import matplotlib.pyplot as plt
  7. def read_data(path):
  8. with open(path) as f :
  9. lines=f.readlines()
  10. lines=[eval(line.strip()) for line in lines]
  11. X,y=zip(*lines)
  12. X=np.array(X)
  13. y=np.array(y)
  14. return X,y
  15. def drawScatterAndLine(p, q, support):
  16. x1 = []
  17. x2 = []
  18. y1 = []
  19. y2 = []
  20. x3 = []
  21. y3 = []
  22. x4 = []
  23. y4 = []
  24. for idx,i in enumerate(q):
  25. item = (p[idx][0], p[idx][1])
  26. if item in support:
  27. is_supppot = True
  28. print(i, "support:", item)
  29. else:
  30. is_supppot = False
  31. if i == 0:
  32. if is_supppot:
  33. x3.append(p[idx][0])
  34. y3.append(p[idx][1])
  35. else:
  36. x1.append(p[idx][0])
  37. y1.append(p[idx][1])
  38. else:
  39. if is_supppot:
  40. x4.append(p[idx][0])
  41. y4.append(p[idx][1])
  42. else:
  43. x2.append(p[idx][0])
  44. y2.append(p[idx][1])
  45. plt.scatter(x1, y1)
  46. plt.scatter(x2, y2)
  47. plt.scatter(x3, y3, c="g")
  48. plt.scatter(x4, y4, c='black')
  49. plt.xlabel('p')
  50. plt.ylabel('q')
  51. plt.title('SVM')
  52. plt.show()
  53. X_train,y_train=read_data("train_data")
  54. X_test,y_test=read_data("test_data")
  55. # C对样本错误的容忍
  56. # 默认核函数是rbf, 导致这个数据支撑向量不在图的边界点上
  57. # 如果把核函数调成linear或者poly,支撑向量就在图的边界点上
  58. model= svm.SVC(kernel='rbf')
  59. model.fit(X_train, y_train)
  60. print(model.support_vectors_)
  61. print(model.support_) #各类的支持向量在训练样本中的索引
  62. print(len(model.support_))
  63. drawScatterAndLine(X_train, y_train, model.support_vectors_)
  64. score = model.score(X_test,y_test)
  65. print(score)
  66. #网格搜索
  67. # search_space = {'C': np.logspace(-3, 3, 7)}
  68. # print(search_space['C'])
  69. # model= svm.SVC()
  70. # gridsearch = GridSearchCV(model, param_grid=search_space)
  71. # gridsearch.fit(X_train,y_train)
  72. # cv_performance = gridsearch.best_score_
  73. # test_performance = gridsearch.score(X_test, y_test)
  74. # print("C",gridsearch.best_params_['C'])