Spark2.0機器學習系列之7:多類分類問題(方法歸總和分類結果評估)

一對多(One-vs-Rest classifier)

將只能用於二分問題的分類(如Logistic迴歸、SVM)方法擴展到多類。

參考:http://www.cnblogs.com/CheeseZH/p/5265959.html

“一對多”方法

    訓練時依次把某個類別的樣本歸爲一類,其他剩餘的樣本歸爲另一類,這樣k個類別的樣
    本就構造出了k個binary分類器。分類時將未知樣本分類爲具有最大分類函數值的那
    類。
    假如我有四類要劃分(也就是4個Label),他們是A、B、C、D。
      於是我在抽取訓練集的時候,分別抽取
      (1)A所對應的向量作爲正集,B,C,D所對應的向量作爲負集;
      (2)B所對應的向量作爲正集,A,C,D所對應的向量作爲負集;
      (3)C所對應的向量作爲正集,A,B,D所對應的向量作爲負集;
      (4)D所對應的向量作爲正集,A,B,C所對應的向量作爲負集;
      使用這四個訓練集分別進行訓練,然後的得到四個訓練結果文件。
      在測試的時候,把對應的測試向量分別利用這四個訓練結果文件進行測試。
      最後每個測試都有一個結果f1(x),f2(x),f3(x),f4(x)。
      於是最終的結果便是這四個值中最大的一個作爲分類結果。

  這種方法有種缺陷,因爲訓練集是1:M ,這種情況下存在biased(即正負樣本數可能很不均衡)。

另外還有“一對一”方法,Spark2.0中還沒有實現。
其做法是在任意兩類樣本之間設計一個分類器,因此k個類別的樣本就需要設計k(k-1)/2個SVM。
當對一個未知樣本進行分類時,最後得票最多的類別即爲該未知樣本的類別。
Libsvm中的多類分類就是根據這個方法實現的。
主要缺點:當類別很多的時候,model的個數是n*(n-1)/2,代價還是相當大的。(是不是不適合大數據集?)

Spark “一對多”代碼

//定義一個binary分類器,如:LogisticRegression 
LogisticRegression lr=new LogisticRegression()
                .setMaxIter(10)
                .setRegParam(0.3)
                .setElasticNetParam(0.2)                
                .setThreshold(0.5);
//建立一對多多分類器model                
OneVsRestModel model=new OneVsRest()
                .setClassifier(lr)//將binary分類器用這種辦法加入
                .fit(training);
//利用多分類器model預測
Dataset<Row>predictions=model.transform(test);  

Spark中那些方法可以用於多類分類

多類分類結果評估

(MulticlassClassificationEvaluator類)
在前面一篇文章裏面介紹的關於二分問題的評估方法,部分評估方法可以延伸到多類分類爲問題。這些概念可以參考
下面的文章:
http://blog.csdn.net/qq_34531825/article/details/52313553
Spark中多分類分類基於JavaRDD的評估方法如下:
Precision,Recall,F-measure都有按照不同label分別評價,或者加權總體評價。
這裏寫圖片描述
但是目前基於DataFrame的好像還沒有這麼多,沒有介紹文檔
通過explainParams函數打印出來就四種:

System.out.println(evaluator.explainParams());
metricName: metric name in evaluation (f1|weightedPrecision|weightedRecall|accuracy) 

使用方法如下:

MultilayerPerceptronClassificationModel model=
                multilayerPerceptronClassifier.fit(training);   


Dataset<Row> predictions=model.transform(test);     
MulticlassClassificationEvaluator evaluator=
        new MulticlassClassificationEvaluator()
        .setLabelCol("label")               
        .setPredictionCol("prediction");


//System.out.println(evaluator.explainParams());
double accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
double weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
double weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
double f1=evaluator.setMetricName("f1").evaluate(predictions);      
發佈了42 篇原創文章 · 獲贊 57 · 訪問量 30萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章