一、交叉驗證
在建立分類模型時,交叉驗證(Cross Validation)簡稱爲CV,CV是用來驗證分類器的性能。它的主體思想是將原始數據進行分組,一部分作爲訓練集,一部分作爲驗證集。利用訓練集訓練出模型,利用驗證集來測試模型,以評估分類模型的性能。
二、交叉驗證的作用
- 驗證分類器的性能
- 用於模型的選擇
三、交叉驗證常用的幾種方法
3.1 k折交叉驗證 K-fold Cross Validation(記爲K-CV)
1、將數據集平均分割成K個等份(參數cv值,一般選擇5折10折,即測試集爲20%)
2、使用1份數據作爲測試數據,其餘作爲訓練數據
3、計算測試準確率
4、使用不同的測試集,重複2、3步
5、對測試準確率做平均,作爲對未知數據預測準確率的估計
優點:
因爲每一個樣本數據既可以作爲測試集又可以作爲訓練集,可有效避免欠學習和過學習狀態的發生,得到的結果比較有說服力。
3.2 留一法交叉驗證 Leave-One-Out Cross Validation(記爲LOO-CV)
假設樣本數據集中有N個樣本數據。將每個樣本單獨作爲測試集,其餘N-1個樣本作爲訓練集,這樣得到了N個分類器或模型,用這N個分類器或模型的分類準確率的平均數作爲此分類器的性能指標。
優點:
a. 每一個分類器或模型幾乎所有的樣本都用來作爲訓練模型,因此最接近樣本,實驗評估可靠;
b. 實驗過程沒有隨機因素影響實驗結果,所以實驗結果可複製,因此實驗結果穩定。
缺點:
計算成本高,因爲需要建立的模型數量與樣本數據數量相同,當N很大時,計算相當耗時。
3.3 留p交叉驗證
留p驗證指訓練集上隨機選擇p個樣本作爲測試集,其餘作爲子訓練集。時間複雜度爲CpN,是階乘的複雜度,不可取。
3.4 重複隨機子抽樣驗證 Hold-Out Method
將數據集隨機劃分爲訓練集和測試集。對每一個劃分,用訓練集訓練分類器或模型,用測試集評估預測的精確度。進行多次劃分,用均值來表示效能。
優點:
與K值無關。嚴格意義來說Hold-Out Method不屬於交叉驗證方法,這種方法與k無關。
缺點:
驗證集結果準確率的高低和原始分組有很大關係,可能導致一些數據從未做過訓練或測試數據;而一些數據不止一次選爲訓練或測試數據的情況發生,因此結果不具有說服力。
四、交叉驗證函數
cross_val_score詳情可見官網
train_test_split
#導入
from sklearn.cross_validation import cross_val_score
from sklearn.cross_validation import train_test_split
五、代碼
例子:垃圾郵件分類
input:
from numpy import *
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import GaussianNB as NB
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.linear_model import LogisticRegression as LR
#將詞條合併爲一個列表
def createVocabList(dataSet):
vocabSet = set([]) #創建一個空集
for document in dataSet:
vocabSet = vocabSet | set(document) #創建兩個集合的並集
return list(vocabSet)
#將詞彙轉化爲向量
def bagOfWords2VecMN(vocabList, inputSet):
returnVec = [0]*len(vocabList) #初始化 詞彙等長的0向量
for word in inputSet:
if word in vocabList:
returnVec[vocabList.index(word)] += 1
return returnVec
#預處理 統一小寫,去除長度小於2個的詞彙
def textParse(bigString):
import re
listOfTokens = re.split(r'\W*', bigString)
return [tok.lower() for tok in listOfTokens if len(tok) > 2]
#統計詞頻前10
def calcMostFreq(vocabList,fullText):
import operator
freqDict = {}
for token in vocabList:
freqDict[token]=fullText.count(token)
sortedFreq = sorted(freqDict.items(), key=operator.itemgetter(1), reverse=True)
return sortedFreq[:10]
#讀取數據
def spamTest():
docList=[]; classList = []; fullText =[]
for i in range(1,26):
wordList = textParse(open('email/spam/%d.txt' % i).read())
docList.append(wordList)
fullText.extend(wordList)
classList.append(1)
wordList = textParse(open('email/ham/%d.txt' % i).read())
docList.append(wordList)
fullText.extend(wordList)
classList.append(0)
vocabList = createVocabList(docList) #創建詞列表
top10Words = calcMostFreq(vocabList,fullText) #刪除詞頻前10
for pairW in top10Words:
if pairW[0] in vocabList: vocabList.remove(pairW[0])
trainingSet = list(range(50)) #0-49,,50個數字,50封郵件
train_data = [] #存儲 所有訓練詞彙的向量
train_target = [] #存儲 類別標籤
for docIndex in trainingSet: #得到訓練數據的向量
train_data.append(bagOfWords2VecMN(vocabList, docList[docIndex]))
train_target.append(classList[docIndex])
return train_data,train_target
5.1 cross_val_score
input:
from sklearn.cross_validation import cross_val_score
if __name__ == '__main__':
data = []
target = []
data, target = spamTest()
clf1 = KNN(n_neighbors=8)
clf2 = LR()
clf3 = NB()
#交叉驗證 cv:數據分成的份數,其中一份作爲cv集,其餘n-1作爲訓練集(默認爲3)
for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']):
scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy')
#print(scores)
print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #計算均值及標準差
output:
G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
return _compile(pattern, flags).split(string, maxsplit)
Accuracy:0.64 (+/-0.05)[KNN]
Accuracy:0.94 (+/-0.05)[LR]
Accuracy:0.92 (+/-0.07)[NB]
5.2 train_test_split
input:
from sklearn.cross_validation import train_test_split
if __name__ == '__main__':
data = []
target = []
data, target = spamTest()
clf1 = KNN(n_neighbors=8)
clf2 = LR()
clf3 = NB()
'''
#交叉驗證 cv:數據分成的份數,其中一份作爲cv集,其餘n-1作爲訓練集(默認爲3)
for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']):
scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy')
#print(scores)
print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #計算均值及標準差
'''
#交叉驗證
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2) #交叉驗證 20%選取測試集
clf = clf2.fit(x_train, y_train)
predicted = clf.predict(x_test)
expected = y_test
print(metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))
print('Score:',accuracy_score(expected,predicted))
output:
G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
return _compile(pattern, flags).split(string, maxsplit)
precision recall f1-score support
0 1.00 1.00 1.00 5
1 1.00 1.00 1.00 5
avg / total 1.00 1.00 1.00 10
[[5 0]
[0 5]]
Score: 1.0