1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- #!/usr/bin/python
- # -*- encoding:utf-8 -*-
- from sklearn import svm
- import numpy as np
- from sklearn.model_selection import GridSearchCV
- import matplotlib.pyplot as plt
- def read_data(path):
- with open(path) as f :
- lines=f.readlines()
- lines=[eval(line.strip()) for line in lines]
- X,y=zip(*lines)
- X=np.array(X)
- y=np.array(y)
- return X,y
- def drawScatterAndLine(p, q, support):
- x1 = []
- x2 = []
- y1 = []
- y2 = []
- x3 = []
- y3 = []
- x4 = []
- y4 = []
- for idx,i in enumerate(q):
- item = (p[idx][0], p[idx][1])
- if item in support:
- is_supppot = True
- print(i, "support:", item)
- else:
- is_supppot = False
- if i == 0:
- if is_supppot:
- x3.append(p[idx][0])
- y3.append(p[idx][1])
- else:
- x1.append(p[idx][0])
- y1.append(p[idx][1])
- else:
- if is_supppot:
- x4.append(p[idx][0])
- y4.append(p[idx][1])
- else:
- x2.append(p[idx][0])
- y2.append(p[idx][1])
- plt.scatter(x1, y1)
- plt.scatter(x2, y2)
- plt.scatter(x3, y3, c="g")
- plt.scatter(x4, y4, c='black')
- plt.xlabel('p')
- plt.ylabel('q')
- plt.title('SVM')
- plt.show()
- X_train,y_train=read_data("train_data")
- X_test,y_test=read_data("test_data")
- # C对样本错误的容忍
- # 默认核函数是rbf, 导致这个数据支撑向量不在图的边界点上
- # 如果把核函数调成linear或者poly,支撑向量就在图的边界点上
- model= svm.SVC(kernel='rbf')
- model.fit(X_train, y_train)
- print(model.support_vectors_)
- print(model.support_)
- print(len(model.support_))
- drawScatterAndLine(X_train, y_train, model.support_vectors_)
- score = model.score(X_test,y_test)
- print(score)
- #网格搜索
- # search_space = {'C': np.logspace(-3, 3, 7)}
- # print(search_space['C'])
- # model= svm.SVC()
- # gridsearch = GridSearchCV(model, param_grid=search_space)
- # gridsearch.fit(X_train,y_train)
- # cv_performance = gridsearch.best_score_
- # test_performance = gridsearch.score(X_test, y_test)
- # print("C",gridsearch.best_params_['C'])
|