基於ALS算法電影推薦(java版)

基於ALS算法的最佳電影推薦(java版)

package spark;

import java.util.Arrays;
import java.util.List;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.storage.StorageLevel;

import scala.Tuple2;

public class SparkALSDemo {

    public static void main(String ... args) throws Exception {
        Logger logger = Logger.getLogger(SparkALSDemo.class);
        // 設置日誌的等級 並關閉jetty容器的日誌
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
        Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
        // 設置運行環境,並創建SparkContext
        SparkConf sparkConf = new SparkConf().setAppName("MovieLensALS");
        sparkConf.setMaster("local[4]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);

        // 裝載樣本評分數據,並按照Timestamp模10的分爲10份
        String movielensHomeDir = "F:/ml-1m";
        JavaRDD<Tuple2<Long, Rating>> ratings = jsc.textFile(movielensHomeDir + "/ratings.dat").map(
                line -> {
                    String[] fields = line.split("::");
                    return new Tuple2<Long, Rating>(Long.parseLong(fields[3]) % 10, new Rating(Integer.parseInt(fields[0]),
                            Integer.parseInt(fields[1]), Double.parseDouble(fields[2])));
                });

        // 裝載用戶評分,該評分由評分器生成(即生成文件personalRatings.txt)
        JavaRDD<String> data = jsc.textFile("F:/ml-1m/personalRatings.txt");
        JavaRDD<Rating> myRatingsRDD = data.map(s -> {
            String[] sarray = s.split("::");
            return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2]));
        });

        // 統計樣本數據中的評分概要
        logger.info("Got " + ratings.count() + " ratings from " + ratings.map(tupe -> tupe._2.user()).distinct().count() + " users " + ratings.map(tupe -> tupe._2.product()).distinct().count() + " movies");
        // 用於訓練是rating中key=[0-5]的數據
        JavaRDD<Rating> training = ratings.filter(x -> x._1 < 6).map(tupe2 -> tupe2._2).union(myRatingsRDD)
                .repartition(4).persist(StorageLevel.MEMORY_ONLY());
        // 用於校驗是rating中key=[6-7]的數據
        JavaRDD<Rating> validation = ratings.filter(x -> x._1 >= 6 && x._1 < 8).map(tupe2 -> tupe2._2).repartition(4)
                .persist(StorageLevel.MEMORY_ONLY());
        // 用於測試的是rating中key=[8-9]的數據
        JavaRDD<Rating> test = ratings.filter(x -> x._1 >= 8).map(tupe2 -> tupe2._2).persist(StorageLevel.MEMORY_ONLY());
        logger.info("Training: " + training.count() + " validation: " + validation.count() + " test: " + test.count());


        // 定義不同的參數。計算均方根誤差值,找到均方根誤差值最小的模型。即:最優模型
        List<Integer> ranks = (List<Integer>)Arrays.asList(8, 10,  12);
        List<Double> lambdas = (List<Double>)Arrays.asList(0.1, 2.5, 5.0);
        List<Integer> numIters = (List<Integer>)Arrays.asList(10, 15, 20);
        MatrixFactorizationModel bestModel = null;
        double bestValidationRmse = Double.MAX_VALUE;
        int bestRank = 0;
        double bestLambda = -1.0;
        int bestNumIter = -1;
        for (int i = 0; i < ranks.size(); i++) {
            MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(training), ranks.get(i), numIters.get(i), lambdas.get(i));
            double validationRmse = SparkALSDemo.computeRMSEAverage(model, validation, validation.count());
            if (validationRmse < bestValidationRmse) {
                bestModel = model;
                bestValidationRmse = validationRmse;
                bestRank = ranks.get(i);
                bestLambda = lambdas.get(i);
                bestNumIter = numIters.get(i);
            }
        }
        double testRmse = SparkALSDemo.computeRMSEAverage(bestModel, test, test.count());
        logger.info("The best model was trained with rank = " + bestRank + " and lambda = " + bestLambda    + ", and numIter = " + bestNumIter + ", and its RMSE on the test set is " + testRmse + ".");

        // 創建一個基準數據集,該數據集是訓練數據集[training]與校驗數據集[validation]的交集.最優模型就是從這個基礎數據集計算得來的
        JavaRDD<Double> rdd = training.union(validation).map(d -> d.rating());
        double meanRating = rdd.reduce((a, b) -> a + b) / rdd.count();
        double baselineRmse = Math.sqrt(test.map(x -> (meanRating - x.rating()) * (meanRating - x.rating())).reduce((a1, a2) -> a1 + a2)/ test.count());
        double improvement = (baselineRmse - testRmse) / baselineRmse * 100;
        logger.info("The best model improves the baseline by " + String.format("%1.2f", improvement) + "%.");

        // 加載電影數據
        JavaRDD<Tuple2<Integer, String>> movies = jsc.textFile(movielensHomeDir + "/movies.dat").map(line -> {
            String[] fields = line.split("::");
            return new Tuple2<Integer, String>(Integer.parseInt(fields[0]), fields[1]);
        });
        //將用戶已經評過分的數據濾掉
        List<Integer> myRatedMovieIds = myRatingsRDD.map(d -> d.product()).collect();
        JavaRDD<Integer> candidates = movies.map(s -> s._1).filter(m -> !myRatedMovieIds.contains(m));

        //預測用戶100最喜歡的10部電影
        JavaRDD<Rating> rr = bestModel.predict(JavaPairRDD.fromJavaRDD(candidates.map(d -> new Tuple2<Integer, Integer>(100, d)))).sortBy(f->f.rating(), false, 4);
        logger.info("Movies recommended for you:");
        rr.take(10).forEach(a -> logger.info("用戶" + a.user() + "-[ " + a.product() + "]-[" + a.rating() + "]"));
        //jsc.stop();
    }

    /**
     * 根據模型model計算data的平均均方根誤差
     * 
     * @param model
     * @param data
     * @param n
     * @return
     */
    public static double computeRMSEAverage(MatrixFactorizationModel model, JavaRDD<Rating> data, long n) {

        JavaRDD<Rating> jddRat = model.predict(JavaPairRDD.fromJavaRDD(data.map(d -> new Tuple2<Integer, Integer>(d.user(), d
                .product()))));
        JavaPairRDD<String, Double> pre = JavaPairRDD.fromJavaRDD(jddRat.map(f -> new Tuple2<String, Double>(f.user() + "_"
                + f.product(), f.rating())));
        JavaPairRDD<String, Double> rea = JavaPairRDD.fromJavaRDD(data.map(f -> new Tuple2<String, Double>(f.user() + "_"
                + f.product(), f.rating())));
        // 相當於SQl中的內聯
        JavaRDD<Tuple2<Double, Double>> d = pre.join(rea).values();
        return d.map(f -> Math.pow(f._1 - f._2, 2)).reduce((a, b) -> a + b) / n;
    }
}

該文援引的是http://files.grouplens.org/datasets/movielens/ 中 ml-1m.zip的數據。下載加壓到本地 修改代碼中的路徑即可

發佈了32 篇原創文章 · 獲贊 2 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章