sparkmllib交替最小二乘法

http://spark.apache.org/docs/2.2.0/ml-collaborative-filtering.html

不需要用戶和商品屬性的信息,這類算法通常稱爲協同過濾算法

例子:根據兩個用戶的年齡相同來判斷他們可能有相似的偏好,這不叫協同過濾。相反,根據兩個用戶播放過許多相同歌曲來判斷他們可能都喜歡某首歌,這才叫協同過濾。

SparkMLlib 的ALS算法 要求用戶和產品ID必須是數值型,這意味着大於Integer.MAX_VALUE(2147483647)的值都是非法的。

訓練出的模型可以保存到文件,還可以從文件load模型

package test
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.ml.recommendation.{ALS, ALSModel}

import org.apache.spark.ml.Model
/**
  * Created by othc on 2018-01-19.
  */
object ALS1 {
  case class Rating(userId:Int,artistId:Int,count:Float)
  def main(args: Array[String]): Unit = {
    //session
    val spark = SparkSession.builder().config("spark.sql.warehouse.dir","/usr/local/testdata/spark-warehouse").appName("als").getOrCreate()
    import spark.implicits._
    //用戶id 藝術家id 次數
    val rawUserArtstData: Dataset[String] = spark.read.textFile("/usr/local/mldata/user_artist_data.txt")
    //藝術家id 名字
    val rawArtistData =  spark.read.textFile("/usr/local/mldata/artist_data.txt")
    val artistById = rawArtistData.flatMap(line => {
      val (id, name) = line.span(_ != '\t')
      if (name.isEmpty) {
        None
      } else {
        try {
          Some((id.toInt, name.trim))
        } catch {
          case e: NumberFormatException => None
        }
      }
    })
    //將錯誤的藝術家id或不標準的id 映射成藝術家正規的名字
    val rawArtistAlias = spark.read.textFile("/usr/local/mldata/artist_alias.txt")
    val artistAlias = rawArtistAlias.flatMap(line=>{
      val tokens = line.split("\t")
      if(tokens(0).isEmpty){
        None
      }else{
        Some((tokens(0).toInt,tokens(1).toInt))
      }
    }).rdd.collectAsMap()
    //將map變量廣播
    val bArtistAlias = spark.sparkContext.broadcast(artistAlias)

    val trainData = rawUserArtstData.map(line=>{
      val Array(userId,artistId,count) = line.split(" ").map(_.toInt)
      val finalArtistId= bArtistAlias.value.getOrElse(artistId,artistId)
      Rating(userId,finalArtistId,count.toFloat)
    }).toDF().cache()
    val Array(train,test) = trainData.randomSplit(Array(0.8,0.2))
    val als: ALS = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("artistId").setRatingCol("count")

    val model: ALSModel = als.fit(train)
    //去掉userid或artistId 是NAN的
    model.setColdStartStrategy("drop")
//    保存模型
//    model.save("")
//    //加載模型
//    import org.apache.spark.ml.recommendation.ALS._
//    val load1: ALS = load("")

    val predictions: DataFrame = model.transform(test)
    val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("count").setPredictionCol("prediction")
    val rmse = evaluator.evaluate(predictions)
    println(s"Root-mean-square error = $rmse")

    //每個用戶推薦的前十個電影
    val userRecs: DataFrame = model.recommendForAllUsers(10)
    userRecs.rdd.saveAsTextFile("/usr/local/testdata/")

    //每個電影推薦的十個用戶
    val movieRecs = model.recommendForAllItems(10)
    movieRecs.rdd.saveAsTextFile("/usr/local/testdata/")
    userRecs.show()
    movieRecs.show()

    spark.stop()
  }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章