PYthon 教你怎麼選擇SVM的核函數kernel及案例分析

關注微信公共號:小程在線

關注CSDN博客:程志偉的博客

4種核函數的適用場景

接上文可以選在非線性核函數,可以將數據明顯的區別開

clf = SVC(kernel = "rbf").fit(X,y)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plot_svc_decision_function(clf)
H:\Anaconda3\lib\site-packages\sklearn\svm\base.py:193: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.
  "avoid this warning.", FutureWarning)

 

#################探索核函數在不同數據集上的表現################
1. 導入所需要的庫和模塊
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import svm
from sklearn.datasets import make_circles, make_moons, make_blobs,make_classification

 

2. 創建數據集,定義核函數的選擇
n_samples = 100
datasets = [
        make_moons(n_samples=n_samples, noise=0.2, random_state=0),
        make_circles(n_samples=n_samples, noise=0.2, factor=0.5, random_state=1),
        make_blobs(n_samples=n_samples, centers=2, random_state=5),
        make_classification(n_samples=n_samples,n_features =
                            2,n_informative=2,n_redundant=0, random_state=5)
        ]

Kernel = ["linear","poly","rbf","sigmoid"]

for X,Y in datasets:
    plt.figure(figsize=(5,4))
    plt.scatter(X[:,0],X[:,1],c=Y,s=50,cmap="rainbow")

以上4張圖分別是月牙形,環形,雜亂性,對稱形

 

 

我們總共有四個數據集,四種核函數,我們希望觀察每種數據集下每個核函數的表現。以核函數爲列,以圖像分佈
爲行,我們總共需要16個子圖來展示分類結果。而同時,我們還希望觀察圖像本身的狀況,所以我們總共需要20
個子圖,其中第一列是原始圖像分佈,後面四列分別是這種分佈下不同核函數的表現

3. 構建子圖

nrows=len(datasets)
ncols=len(Kernel) + 1
fig, axes = plt.subplots(nrows, ncols,figsize=(20,16))

 

4. 開始進行子圖循環
#第一層循環:在不同的數據集中循環
for ds_cnt, (X,Y) in enumerate(datasets):
    #在圖像中的第一列,第一個,放置原數據的分佈
    #zorder=10表示畫布的層級,edgecolors表示邊緣額顏色
    ax = axes[ds_cnt, 0]
    if ds_cnt == 0:
        ax.set_title("Input data")
    ax.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired,edgecolors='k')
    ax.set_xticks(())
    ax.set_yticks(())
    #第二層循環:在不同的核函數中循環
    #從圖像的第二列開始,一個個填充分類結果
    for est_idx, kernel in enumerate(Kernel):
        #定義子圖位置,從第一列,第二個開始
        ax = axes[ds_cnt, est_idx + 1]
        #建模
        clf = svm.SVC(kernel=kernel, gamma=2).fit(X, Y)
        score = clf.score(X, Y)
        
        #繪製圖像本身分佈的散點圖
        ax.scatter(X[:, 0], X[:, 1], c=Y
                   ,zorder=10
                   ,cmap=plt.cm.Paired,edgecolors='k')
        
        #繪製支持向量
        ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=50,
                   facecolors='none', zorder=10, edgecolors='k')
        
        #繪製決策邊界
        x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
        y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
        
        #np.mgrid,合併了我們之前使用的np.linspace和np.meshgrid的用法
        #一次性使用最大值和最小值來生成網格
        #表示爲[起始值:結束值:步長]
        #如果步長是複數,則其整數部分就是起始值和結束值之間創建的點的數量,並且結束值被包含在內
        XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
        
        #np.c_,類似於np.vstack的功能
        Z = clf.decision_function(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
        
        #填充等高線不同區域的顏色
        ax.pcolormesh(XX, YY, Z > 0, cmap=plt.cm.Paired)
        
        #繪製等高線
        ax.contour(XX, YY, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'],
                   levels=[-1, 0, 1])
        
        #設定座標軸爲不顯示
        ax.set_xticks(())
        ax.set_yticks(())
        
        #將標題放在第一行的頂上
        if ds_cnt == 0:
            ax.set_title(kernel)
        
        #爲每張圖添加分類的分數
        ax.text(0.95, 0.06, ('%.2f' % score).lstrip('0')
                , size=15
                , bbox=dict(boxstyle='round', alpha=0.8, facecolor='white')
                #爲分數添加一個白色的格子作爲底色
                , transform=ax.transAxes #確定文字所對應的座標軸,就是ax子圖的座標軸本身
                , horizontalalignment='right' #位於座標軸的什麼方向
                )

plt.tight_layout()
plt.show()
__main__:53: UserWarning: No contour levels were found within the data range.

可以觀察到,線性核函數和多項式核函數在非線性數據上表現會浮動,如果數據相對線性可分,則表現不錯,如果
是像環形數據那樣徹底不可分的,則表現糟糕。在線性數據集上,線性核函數和多項式核函數即便有擾動項也可以
表現不錯,可見多項式核函數是雖然也可以處理非線性情況,但更偏向於線性的功能。
Sigmoid核函數就比較尷尬了,它在非線性數據上強於兩個線性核函數,但效果明顯不如rbf,它在線性數據上完全
比不上線性的核函數們,對擾動項的抵抗也比較弱,所以它功能比較弱小,很少被用到。rbf,高斯徑向基核函數基本在任何數據集上都表現不錯,屬於比較萬能的核函數。

 

#########################探索核函數的優勢和缺陷###########################

通過繪製SVC在不同核函數下的決策邊界並計算SVC在不同核函數下分類準確率來觀察覈函數的效用
1. 導入所需要的庫和模塊
from sklearn.datasets import load_breast_cancer
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from time import time
import datetime

data = load_breast_cancer()
X = data.data
y = data.target

X.shape
Out[3]: (569, 30)

np.unique(y)
Out[4]: array([0, 1])

plt.scatter(X[:,0],X[:,1],c=y)
plt.show()

from sklearn.decomposition import PCA
X_dr = PCA(2).fit_transform(X)
X_dr.shape
Out[6]: (569, 2)

plt.scatter(X_dr[:,0],X_dr[:,1],c=y)
plt.show()

Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=420)

#下面的代碼運行不出來
Kernel = ["linear","poly","rbf","sigmoid"]
for kernel in Kernel:
    time0 = time()
    clf= SVC(kernel = kernel
             , gamma="auto"
             # , degree = 1
             , cache_size=5000  #內存
             ).fit(Xtrain,Ytrain)
    print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
    print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
模型一直停留在線性核函數之後,就沒有再打印結果了。這證明,多項式核函數此時此刻要消耗大量的時間,運算非常的緩慢
 

#時間戳

time()
Out[9]: 1585731238.5509906

now = time()

datetime.datetime.fromtimestamp(now).strftime("%Y-%m-%d,%H:%M:%S:%f")
Out[10]: '2020-04-01,16:56:35:263156'

 

在循環中去掉多項式核函數

Kernel = ["linear","rbf","sigmoid"]
for kernel in Kernel:
    time0 = time()
    clf= SVC(kernel = kernel
             , gamma="auto"
             # , degree = 1
             , cache_size=5000  #內存
             ).fit(Xtrain,Ytrain)
    print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
    print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
The accuracy under kernel linear is 0.929825
00:00:926657
The accuracy under kernel rbf is 0.596491
00:00:084060
The accuracy under kernel sigmoid is 0.596491
00:00:010509

 

乳腺癌數據集是一個線性數據集,線性核函數跑出來的效果很好。rbf和sigmoid兩個擅長非線性的數據從效果上來看完全不可用。其次,線性核函數的運行速度遠遠不如非線性的兩個核函數。如果數據是線性的,那如果我們把degree參數調整爲1,多項式核函數應該也可以得到不錯的結果。

Kernel = ["linear","poly","rbf","sigmoid"]
for kernel in Kernel:
    time0 = time()
    clf= SVC(kernel = kernel
             , gamma="auto"
             , degree = 1
             , cache_size=5000
             ).fit(Xtrain,Ytrain)
    print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
    print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
The accuracy under kernel linear is 0.929825
00:00:823586
The accuracy under kernel poly is 0.923977
00:00:157116
The accuracy under kernel rbf is 0.596491
00:00:078048
The accuracy under kernel sigmoid is 0.596491
00:00:010008

 

多項式核函數的運行速度立刻加快了,並且精度也提升到了接近線性核函數的水平,rbf在線性數據上也可以表現得非常好,那在這裏,爲什麼跑出來的結果如此糟糕呢?其實,這裏真正的問題是數據的量綱問題。回憶一下我們如何求解決策邊界,如何判斷點是否在決策邊界的一邊?是靠計算”距離“,雖然我們不能說SVM是完全的距離類模型,但是它嚴重受到數據量綱的影響。讓我們來探索一下乳腺癌數據集的量綱

import pandas as pd
data = pd.DataFrame(X)
data.describe([0.01,0.05,0.1,0.25,0.5,0.75,0.9,0.99]).T
Out[13]: 
    count        mean         std  ...          90%          99%         max
0   569.0   14.127292    3.524049  ...    19.530000    24.371600    28.11000
1   569.0   19.289649    4.301036  ...    24.992000    30.652000    39.28000
2   569.0   91.969033   24.298981  ...   129.100000   165.724000   188.50000
3   569.0  654.889104  351.914129  ...  1177.400000  1786.600000  2501.00000
4   569.0    0.096360    0.014064  ...     0.114820     0.132888     0.16340
5   569.0    0.104341    0.052813  ...     0.175460     0.277192     0.34540
6   569.0    0.088799    0.079720  ...     0.203040     0.351688     0.42680
7   569.0    0.048919    0.038803  ...     0.100420     0.164208     0.20120
8   569.0    0.181162    0.027414  ...     0.214940     0.259564     0.30400
9   569.0    0.062798    0.007060  ...     0.072266     0.085438     0.09744
10  569.0    0.405172    0.277313  ...     0.748880     1.291320     2.87300
11  569.0    1.216853    0.551648  ...     1.909400     2.915440     4.88500
12  569.0    2.866059    2.021855  ...     5.123200     9.690040    21.98000
13  569.0   40.337079   45.491006  ...    91.314000   177.684000   542.20000
14  569.0    0.007041    0.003003  ...     0.010410     0.017258     0.03113
15  569.0    0.025478    0.017908  ...     0.047602     0.089872     0.13540
16  569.0    0.031894    0.030186  ...     0.058520     0.122292     0.39600
17  569.0    0.011796    0.006170  ...     0.018688     0.031194     0.05279
18  569.0    0.020542    0.008266  ...     0.030120     0.052208     0.07895
19  569.0    0.003795    0.002646  ...     0.006185     0.012650     0.02984
20  569.0   16.269190    4.833242  ...    23.682000    30.762800    36.04000
21  569.0   25.677223    6.146258  ...    33.646000    41.802400    49.54000
22  569.0  107.261213   33.602542  ...   157.740000   208.304000   251.20000
23  569.0  880.583128  569.356993  ...  1673.000000  2918.160000  4254.00000
24  569.0    0.132369    0.022832  ...     0.161480     0.188908     0.22260
25  569.0    0.254265    0.157336  ...     0.447840     0.778644     1.05800
26  569.0    0.272188    0.208624  ...     0.571320     0.902380     1.25200
27  569.0    0.114606    0.065732  ...     0.208940     0.269216     0.29100
28  569.0    0.290076    0.061867  ...     0.360080     0.486908     0.66380
29  569.0    0.083946    0.018061  ...     0.106320     0.140628     0.20750

[30 rows x 13 columns]

 

數據存在嚴重的量綱不一的問題。我們來使用數據預處理中的標準化的類,對數據進行標準化
from sklearn.preprocessing import StandardScaler
X = StandardScaler().fit_transform(X)
data = pd.DataFrame(X)
data.describe([0.01,0.05,0.1,0.25,0.5,0.75,0.9,0.99]).T
Out[14]: 
    count          mean      std  ...       90%       99%        max
0   569.0 -3.162867e-15  1.00088  ...  1.534446  2.909529   3.971288
1   569.0 -6.530609e-15  1.00088  ...  1.326975  2.644095   4.651889
2   569.0 -7.078891e-16  1.00088  ...  1.529432  3.037982   3.976130
3   569.0 -8.799835e-16  1.00088  ...  1.486075  3.218702   5.250529
4   569.0  6.132177e-15  1.00088  ...  1.313694  2.599511   4.770911
5   569.0 -1.120369e-15  1.00088  ...  1.347811  3.275782   4.568425
6   569.0 -4.421380e-16  1.00088  ...  1.434288  3.300560   4.243589
7   569.0  9.732500e-16  1.00088  ...  1.328412  2.973759   3.927930
8   569.0 -1.971670e-15  1.00088  ...  1.233221  2.862418   4.484751
9   569.0 -1.453631e-15  1.00088  ...  1.342243  3.209454   4.910919
10  569.0 -9.076415e-16  1.00088  ...  1.240514  3.198294   8.906909
11  569.0 -8.853492e-16  1.00088  ...  1.256518  3.081820   6.655279
12  569.0  1.773674e-15  1.00088  ...  1.117354  3.378079   9.461986
13  569.0 -8.291551e-16  1.00088  ...  1.121579  3.021867  11.041842
14  569.0 -7.541809e-16  1.00088  ...  1.123053  3.405812   8.029999
15  569.0 -3.921877e-16  1.00088  ...  1.236492  3.598943   6.143482
16  569.0  7.917900e-16  1.00088  ...  0.882848  2.997338  12.072680
17  569.0 -2.739461e-16  1.00088  ...  1.117927  3.146456   6.649601
18  569.0 -3.108234e-16  1.00088  ...  1.159654  3.834036   7.071917
19  569.0 -3.366766e-16  1.00088  ...  0.904208  3.349301   9.851593
20  569.0 -2.333224e-15  1.00088  ...  1.535063  3.001373   4.094189
21  569.0  1.763674e-15  1.00088  ...  1.297666  2.625885   3.885905
22  569.0 -1.198026e-15  1.00088  ...  1.503553  3.009644   4.287337
23  569.0  5.049661e-16  1.00088  ...  1.393000  3.581882   5.930172
24  569.0 -5.213170e-15  1.00088  ...  1.276124  2.478455   3.955374
25  569.0 -2.174788e-15  1.00088  ...  1.231407  3.335783   5.112877
26  569.0  6.856456e-16  1.00088  ...  1.435090  3.023359   4.700669
27  569.0 -1.412656e-16  1.00088  ...  1.436382  2.354181   2.685877
28  569.0 -2.289567e-15  1.00088  ...  1.132518  3.184317   6.046041
29  569.0  2.575171e-15  1.00088  ...  1.239884  3.141089   6.846856

[30 rows x 13 columns]

 

再次運行核函數

Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=420)
Kernel = ["linear","poly","rbf","sigmoid"]
for kernel in Kernel:
    time0 = time()
    clf= SVC(kernel = kernel
             , gamma="auto"
             , degree = 1
             , cache_size=5000
             ).fit(Xtrain,Ytrain)
    print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
    print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
The accuracy under kernel linear is 0.976608
00:00:016012
The accuracy under kernel poly is 0.964912
00:00:006004
The accuracy under kernel rbf is 0.970760
00:00:013005
The accuracy under kernel sigmoid is 0.953216
00:00:007990

量綱統一之後,可以觀察到,所有核函數的運算時間都大大地減少了,尤其是對於線性核來說,而多項式核函數居
然變成了計算最快的。其次,rbf表現出了非常優秀的結果。經過我們的探索,我們可以得到的結論是:
1. 線性核,尤其是多項式核函數在高次項時計算非常緩慢
2. rbf和多項式核函數都不擅長處理量綱不統一的數據集

 

 選取與核函數相關的參數:degree & gamma & coef0
對於高斯徑向基核函數,調整gamma的方式其實比較容易,那就是畫學習曲線。我們來試試看高斯徑向基核函數
rbf的參數gamma在乳腺癌數據集上的表現

score = []
gamma_range = np.logspace(-10, 1, 50) #返回在對數刻度上均勻間隔的數字
for i in gamma_range:
    clf = SVC(kernel="rbf",gamma = i,cache_size=5000).fit(Xtrain,Ytrain)
    score.append(clf.score(Xtest,Ytest))

print(max(score), gamma_range[score.index(max(score))])
plt.plot(gamma_range,score)
plt.show()
0.9766081871345029 0.012067926406393264

通過學習曲線,很容就找出了rbf的最佳gamma值。但我們觀察到,這其實與線性核函數的準確率一模一樣之前的
準確率。我們可以多次調整gamma_range來觀察結果,可以發現97.6608應該是rbf核函數的極限了。


但對於多項式核函數來說,一切就沒有那麼容易了,因爲三個參數共同作用在一個數學公式上影響它的效果,因此
我們往往使用網格搜索來共同調整三個對多項式核函數有影響的參數。依然使用乳腺癌數據集

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV

time0 = time()
gamma_range = np.logspace(-10,1,20)
coef0_range = np.linspace(0,5,10)
param_grid = dict(gamma = gamma_range
                  ,coef0 = coef0_range)

cv = StratifiedShuffleSplit(n_splits=5, test_size=0.3, random_state=420)
grid = GridSearchCV(SVC(kernel = "poly",degree=1,cache_size=5000),
param_grid=param_grid, cv=cv)
grid.fit(X, y)
print("The best parameters are %s with a score of %0.5f" % (grid.best_params_,grid.best_score_))
print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
The best parameters are {'coef0': 0.0, 'gamma': 0.18329807108324375} with a score of 0.96959
00:16:152746

網格搜索爲我們返回了參數coef0=0,gamma=0.18329807108324375,但整體的分數是0.96959,雖然比調參前略有提高,但依然沒有超過線性核函數核rbf的結果。可見,如果最初選擇核函數的時候,就發現多項式的結果不如rbf和線性核函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章