#!/usr/bin/python # -*- encoding:utf-8 -*- from sklearn import svm import numpy as np from sklearn.model_selection import GridSearchCV import matplotlib.pyplot as plt def read_data(path): with open(path) as f : lines=f.readlines() lines=[eval(line.strip()) for line in lines] X,y=zip(*lines) X=np.array(X) y=np.array(y) return X,y def drawScatterAndLine(p, q, support): x1 = [] x2 = [] y1 = [] y2 = [] x3 = [] y3 = [] x4 = [] y4 = [] for idx,i in enumerate(q): item = (p[idx][0], p[idx][1]) if item in support: is_supppot = True print(i, "support:", item) else: is_supppot = False if i == 0: if is_supppot: x3.append(p[idx][0]) y3.append(p[idx][1]) else: x1.append(p[idx][0]) y1.append(p[idx][1]) else: if is_supppot: x4.append(p[idx][0]) y4.append(p[idx][1]) else: x2.append(p[idx][0]) y2.append(p[idx][1]) plt.scatter(x1, y1) plt.scatter(x2, y2) plt.scatter(x3, y3, c="g") plt.scatter(x4, y4, c='black') plt.xlabel('p') plt.ylabel('q') plt.title('SVM') plt.show() X_train,y_train=read_data("train_data") X_test,y_test=read_data("test_data") # C对样本错误的容忍 # 默认核函数是rbf, 导致这个数据支撑向量不在图的边界点上 # 如果把核函数调成linear或者poly,支撑向量就在图的边界点上 model= svm.SVC(kernel='rbf') model.fit(X_train, y_train) print(model.support_vectors_) print(model.support_) #各类的支持向量在训练样本中的索引 print(len(model.support_)) drawScatterAndLine(X_train, y_train, model.support_vectors_) score = model.score(X_test,y_test) print(score) #网格搜索 # search_space = {'C': np.logspace(-3, 3, 7)} # print(search_space['C']) # model= svm.SVC() # gridsearch = GridSearchCV(model, param_grid=search_space) # gridsearch.fit(X_train,y_train) # cv_performance = gridsearch.best_score_ # test_performance = gridsearch.score(X_test, y_test) # print("C",gridsearch.best_params_['C'])