基於Spark on Yarn的apriori算法java實現

一  前言

處理一個大數據集,找出其中的強關聯規則,本文基於spark使用java語言實現了apriori算法,算法已經通過測試,後邊附帶一個測試實驗及運行結果。

二  apriori算法描述

apriori是一種經典的數據挖掘算法,可以挖掘出數據庫中哪些物品經常一起出現,滿足最小支持度和最小置信度的的規則爲強關聯規則。因此,算法需要找出所有的強關聯規則,從而爲實際提供決策或者預測未來的結果。apriori算法使用逐層搜索的迭代思想,第k頻繁項集用於找出第(k+1)頻繁項集,依次類推,直到找出所有的頻繁項集。最後從已經找出的這些頻繁項集中進一步找出所有強關聯規則。本文實現的算法主要完成找出所有的頻繁項集,這一步也是apriori算法最重要的。

三  算法實現

package org.min.apriori;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;


import scala.Tuple2;
import scala.collection.mutable.ArrayBuffer;


/**
 * 
 * @author ShiMin
 * @date   2015/10/13
 * @description APriori algorithm runs on spark in java. 
 *
 */
public class FrequentItemset
{
public static int SUPPORT_DEGREE = 4;//the support of APriori algorithm
public static int TRANSACTION_NUM = 25;//the number of transaction
public static String SEPARATOR = " ";//line separator
public static int NOFITEMS = 4;//the number of items in itemSet


@SuppressWarnings("serial")
public static void main(String[] args)
{
Logger.getLogger("org.apache.spark").setLevel(Level.OFF);
args = new String[]{"hdfs://master:9000/data/input/wordcounts.txt", "hdfs://master:9000/data/output"};

if(args.length != 2)
{
System.err.println("USage:<Datapath> <Output>");
System.exit(1);
}

SparkConf sparkConf = new SparkConf().setAppName("apriori algorithm").setMaster("local[4]");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);


JavaRDD<String> lines = ctx.textFile(args[0], 1); //textFile(path: String, minPartitions: Int)

//remove the ID of transaction.
JavaPairRDD<String,Integer> ones = lines.map(new Function<String, String>()
{
public String call(String v1) throws Exception
{
return v1.substring(v1.indexOf(" ") + 1).trim();
}
})
//convert lines to <key,value>(or <line,1>) pairs.
.mapToPair(new PairFunction<String, String, Integer>()
{
public Tuple2<String, Integer> call(String t) throws Exception
{
return new Tuple2<String, Integer>(t, 1);
}
})
//combine the same translations.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2 ;
}
});

//convert <line,count> pairs to <List<String>,count> pairs form.
JavaPairRDD<List<String>, Integer> transactions = ones.mapToPair(new PairFunction<Tuple2<String,Integer>, List<String>, Integer>()
{
public Tuple2<List<String>, Integer> call(Tuple2<String, Integer> t)
throws Exception
{
String[] items = t._1.split(SEPARATOR);
List<String> itemlist = Arrays.asList(items);;
return new Tuple2<List<String>, Integer>(itemlist, t._2);
}
})
//cache the transaction in memory.
.cache();

//count the 1 frequent itemSet which satisfies the minimum support_degree.
JavaPairRDD<String, Integer> onefi = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>,Integer>, String, Integer>()
{
public Iterable<Tuple2<String, Integer>> call(Tuple2<List<String>, Integer> t)
throws Exception
{
List<Tuple2<String, Integer>> t2list = new ArrayList<Tuple2<String, Integer>>();
for(String item : t._1)
{
t2list.add(new Tuple2<String, Integer>(item, t._2));
}
return t2list;
}
})
//combine the same item.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2 ;
}
})
//filter out the satisfactory item.
.filter(new Function<Tuple2<String,Integer>, Boolean>()
{
public Boolean call(Tuple2<String, Integer> v1) throws Exception
{
return v1._2 >= SUPPORT_DEGREE;
}
})
//cache the 1 frequent itemSet in memory.
.cache();

//compute the support rate of each item.
onefi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1 + ":" + (double)v1._2 / TRANSACTION_NUM;
}
})
//save the 1 frequent itemSet to the result_1.txt.
.saveAsTextFile(args[1] + "/result_1");

//count the k frequent itemSet which satisfies the minimun support_degree.
JavaPairRDD<String, Integer> kfi = onefi;
for(int k = 2; k <= NOFITEMS; ++k)
{
List<String> candiatefi = getCandiateKFI(kfi, k);
System.out.println(k + " = " + candiatefi);
JavaRDD<String> firdd = ctx.parallelize(candiatefi);
final Broadcast<JavaRDD<String>> bccFI = ctx.broadcast(firdd);

kfi = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>,Integer>, String, Integer>()
{
private static final long serialVersionUID = 3107941823066446782L;


public Iterable<Tuple2<String, Integer>> call(
final Tuple2<List<String>, Integer> line) throws Exception
{
List<Tuple2<String, Integer>> t2list = bccFI.value().flatMapToPair(new PairFlatMapFunction<String, String, Integer>()
{


public Iterable<Tuple2<String, Integer>> call(String t)
throws Exception
{
List<Tuple2<String, Integer>> lineitemlist = new ArrayList<Tuple2<String, Integer>>();
String[] items = t.split(",");
if(line._1.containsAll(Arrays.asList(items)))
{
lineitemlist.add(new Tuple2<String, Integer>(t, line._2));
System.out.println("line=" + line + " items=" + Arrays.asList(items) + " flag=" + (line._1.containsAll(Arrays.asList(items))));
}
return lineitemlist;
}
}).collect();

System.out.println("t2list" + "=" + t2list + " line=" + line);
return t2list;
}
})
//combine the same item.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2;
}
})
//filter out the satisfactory item.
.filter(new Function<Tuple2<String,Integer>, Boolean>()
{
public Boolean call(Tuple2<String, Integer> v1) throws Exception
{
System.out.println(v1._1 + ":" + v1._2);

return v1._2 >= SUPPORT_DEGREE;
}
})
//cache the k frequent itemSet in memory.
.cache();

//compute the support rate of each item.
kfi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1 + ":" + (double)v1._2 / TRANSACTION_NUM;
}
})
//save the k frequent itemSet to the result_k.txt.
.saveAsTextFile(args[1] + "/result_" + k);
}

// onefi.foreach(new VoidFunction<Tuple2<String, Integer>>()
// {
// public void call(Tuple2<String, Integer> t) throws Exception
// {
// System.out.println(t);
// }
// });

}

public static List<String> getCandiateKFI(JavaPairRDD<String, Integer> kfi, int k)
{
List<String> candiateItemSet = new ArrayList<String>();

//extract the items,save them in list.
List<String> itemlist = kfi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1;
}
}).collect();

for(int i = 0; i < itemlist.size() - 1; i++)
{
for(int j = i + 1; j < itemlist.size(); j++)
{
String tmpItem = "";

if(2 == k)
{
tmpItem = itemlist.get(i) + "," + itemlist.get(j);
tmpItem = sortItems(tmpItem);
}
else
{
String s1 = itemlist.get(i);
String s2 = itemlist.get(j);
if(s1.substring(0, s1.lastIndexOf(',')).equals(s2.substring(0, s2.lastIndexOf(','))))
{
tmpItem = s1 + s2.substring(s2.lastIndexOf(','));
tmpItem = sortItems(tmpItem);
}
}

//filter the item which has infrequent subItem.
boolean hasInfrequentSubItem = false;
if(!"".equals(tmpItem))
{
String[] items = tmpItem.split(",");
for(int m = 0; m < items.length; m++)
{
String subItem = "";
for(int n = 0; n < items.length; n++)
{
if(m != n)
{
subItem += items[n] + ",";
}
}
subItem = subItem.substring(0, subItem.lastIndexOf(','));

if(!itemlist.contains(subItem))
{
hasInfrequentSubItem = true;
break;
}
}
}
else
{
hasInfrequentSubItem = true;
}

if(!hasInfrequentSubItem)
{
candiateItemSet.add(tmpItem);
}
}
}
return candiateItemSet;
}

public static String sortItems(String itemStr)
{
String result = "";
String[] items = itemStr.split(",");
Arrays.sort(items);
for(String item : items)
{
result += item + ",";
}
return result.substring(0, result.lastIndexOf(','));
}



}


四  實驗結果

程序中只設定挖掘出最高4項集,生成的結果保存在4個文件中,結果如下:


每個文件內容如下:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章