hadoop-樸素貝葉斯算法的簡單實現

文章轉自:https://blog.csdn.net/Angelababy_huan/article/details/53046151

  貝葉斯分類器的分類原理是通過某對象的先驗概率,利用貝葉斯公式計算出其後驗概率,即該對象屬於某一類的概率,選擇具有最大後驗概率的類作爲該對象所屬的類。    

    以下爲一個簡單的例子:

    數據:天氣情況和每天是否踢足球的記錄表

日期踢足球天氣溫度溼度風速
1號否(0)晴天(0)熱(0)高(0)低(0)
2號否(0)晴天(0)熱(0)高(0)高(1)
3號是(1)多雲(1)熱(0)高(0)低(0)
4號是(1)下雨(2)舒適(1)高(0)低(0)
5號是(1)下雨(2)涼爽(2)正常(1)低(0)
6號否(0)下雨(2)涼爽(2)正常(1)高(1)
7號是(1)多雲(1)涼爽(2)正常(1)高(1)
8號否(0)晴天(0)舒適(1)高(0)低(0)
9號是(1)晴天(0)涼爽(2)正常(1)低(0)
10號是(1)下雨(2)舒適(1)正常(1)低(0)
11號是(1)晴天(0)舒適(1)正常(1)高(1)
12號是(1)多雲(1)舒適(1)高(0)高(1)
13號是(1)多雲(1)熱(0)正常(1)低(0)
14號否(0)下雨(2)舒適(1)高(0)高(1)
15號晴天(0)涼爽(2)高(0)高(1)
    需要預測15號,在這種天氣情況下是否踢球。

    假設15號去踢球,踢球的概率計算過程如下:

    P(踢球的概率) = 9/14

    P(晴天|踢) = 踢球天數中晴天踢球的次數/踢球次數 = 2/9

    P(涼爽|踢) = 踢球天數中涼爽踢球的次數/踢球次數 = 3/9

    P(溼度高|踢) = 踢球天數中溼度高踢球的次數/踢球次數 = 3/9

    P(風速高|踢) = 踢球天數中風速高踢球的次數/踢球次數 = 3/9

    則15號踢球的概率P = 9/14 * 2/9 * 3/9 * 3/9 * 3/9 = 0.00529

    按照上述步驟還可計算出15號不去踢球的概率P = 5/14 * 3/5 * 1/5 * 4/5 * 3/5 = 0.02057

    可以看出,15號不去踢球的概率大於去踢球的概率,則可預測說,15號不去踢球。

    理解樸素貝葉斯的流程之後,開始設計MR程序。在Mapper中,對訓練數據進行拆分,也就是將這條訓練數據拆分爲類別和訓練數據,將訓練數據以自定義值類型來保存,然後傳遞給Reducer。

                

Mapper:

import java.io.IOException;   
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;  
import org.apache.hadoop.mapreduce.Mapper;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BayesMapper extends Mapper<Object, Text, IntWritable, MyWritable> {  
    Logger log = LoggerFactory.getLogger(BayesMapper.class);  
    private IntWritable myKey = new IntWritable();  
    private MyWritable myValue = new MyWritable();
    @Override  
    protected void map(Object key, Text value, Context context)  
            throws IOException, InterruptedException {  
        log.info("***"+value.toString());  
        int[] values = getIntData(value);  
        int label = values[0];  //存放類別  
        int[] result = new int[values.length-1]; //存放數據  
        for(int i =1;i<values.length;i++){  
            result[i-1] = values[i];
        }  
        myKey.set(label);  
        myValue.setValue(result);  
        context.write(myKey, myValue);  
    }  
    private int[] getIntData(Text value) {  
        String[] values = value.toString().split(",");  
        int[] data = new int[values.length];
        for(int i=0; i < values.length;i++){
        	if(!values[i].equals(""))
        		if(values[i].matches("^[0-9]+$"))
        			data[i] = Integer.parseInt(values[i]);  
        }  
        return data;  
    }  
}  

MyWritable:

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;

public class MyWritable implements Writable{  
    private int[] value;  
    public MyWritable() {  
        
    }  
    public MyWritable(int[] value){  
        this.setValue(value);  
    } 
    public void write(DataOutput out) throws IOException {  
        out.writeInt(value.length);  
        for(int i=0; i<value.length;i++){  
            out.writeInt(value[i]);  
        }  
    }   
    public void readFields(DataInput in) throws IOException {  
        int vLength = in.readInt();  
        value = new int[vLength];  
        for(int i=0; i<vLength;i++){  
            value[i] = in.readInt();  
        }  
    }  
    public int[] getValue() {  
        return value;  
    }  
    public void setValue(int[] value) {  
        this.value = value;  
    }  
}  

Reducer:

import java.io.BufferedReader;
import java.io.IOException;  
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;  
import org.apache.hadoop.conf.Configuration;  
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;  
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BayesReducer extends Reducer<IntWritable, MyWritable, IntWritable, IntWritable>{  
    Logger log = LoggerFactory.getLogger(BayesReducer.class);  
    private String testFilePath;  
    // 測試數據  
    private ArrayList<int[]> testData = new ArrayList<>();  
    // 保存相同k的所有數據  
    private ArrayList<CountAll> allData = new ArrayList<>();  
    @Override  
    protected void setup(Context context)  
            throws IOException, InterruptedException {  
        Configuration conf = context.getConfiguration();  
        testFilePath = conf.get("home/5.txt");  
        Path path = new Path("home/5.txt");  
        FileSystem fs = path.getFileSystem(conf);  
        readTestData(fs,path);  
    }  
    @Override  
    protected void reduce(IntWritable key, Iterable<MyWritable> values,  
            Context context)  
            throws IOException, InterruptedException {  
        Double[] myTest = new Double[testData.get(0).length-1];  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = 1.0;  
        }  
        Long sum = 2L;  
        // 計算每個類別中,每個屬性值爲1的個數  
        for (MyWritable myWritable : values) {  
            int[] myvalue = myWritable.getValue();  
            for(int i=0; i < myvalue.length;i++){  
                myTest[i] += myvalue[i];  
            }  
            sum += 1;  
        }  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = myTest[i]/sum;  
        }  
        allData.add(new CountAll(sum,myTest,key.get()));  
    }  
    private IntWritable myKey = new IntWritable();  
    private IntWritable myValue = new IntWritable();  
      
    protected void cleanup(Context context)  
            throws IOException, InterruptedException {  
        // 保存每個類別的在訓練數據中出現的概率  
        // k,v  0,0.4  
        // k,v  1,0.6  
        HashMap<Integer, Double> labelG = new HashMap<>();  
        Long allSum = getSum(allData); //計算訓練數據的長度  
        for(int i=0; i<allData.size();i++){  
            labelG.put(allData.get(i).getK(),   
                    Double.parseDouble(allData.get(i).getSum().toString())/allSum);  
        }  
        //test的長度 要比訓練數據中的長度大1  
        int sum = 0;  
        int yes = 0;  
        for(int[] test: testData){  
            int value = getClasify(test, labelG);  
            if(test[0] == value){  
                yes += 1;  
            }  
            sum +=1;  
            myKey.set(test[0]);  
            myValue.set(value);  
            context.write(myKey, myValue);  
        }  
        System.out.println("正確率爲:"+(double)yes/sum);  
    }  
    /*** 
     * 求得所有訓練數據的條數 
     * @param allData2 
     * @return 
     */  
    private Long getSum(ArrayList<CountAll> allData2) {  
        Long allSum = 0L;  
        for (CountAll countAll : allData2) {  
            log.info("類別:"+countAll.getK()+"數據:"+myString(countAll.getValue())+"總數:"+countAll.getSum());  
            allSum += countAll.getSum();  
        }  
        return allSum;  
    }  
    /*** 
     * 得到分類的結果 
     * @param test 
     * @param labelG 
     * @return 
     */  
    private int getClasify(int[] test,HashMap<Integer, Double> labelG ) {  
        double[] result = new double[allData.size()]; //以類別的長度作爲數組的長度  
        for(int i = 0; i<allData.size();i++){  
            double count = 0.0;  
            CountAll ca = allData.get(i);  
            Double[] pdata = ca.getValue();  
            for(int j=1;j<test.length;j++){  
                if(test[j] == 1){  
                    // 在該類別中,相同位置上的元素的值出現1的概率  
                    count += Math.log(pdata[j-1]);   
                }else{  
                    count += Math.log(1- pdata[j-1]);   
                }  
                log.info("count: "+count);  
            }  
            count += Math.log(labelG.get(ca.getK()));  
            result[i] = count;  
        }   
        if(result[0] > result[1]){  
            return 0;  
        }else{  
            return 1;  
        }  
    }  
    /*** 
     * 讀取測試數據 
     * @param fs 
     * @param path 
     * @throws NumberFormatException 
     * @throws IOException 
     */  
    private void readTestData(FileSystem fs, Path path) throws NumberFormatException, IOException {  
        FSDataInputStream data = fs.open(path);  
        BufferedReader bf = new BufferedReader(new InputStreamReader(data));  
        String line = "";  
        while ((line = bf.readLine()) != null) {  
            String[] str = line.split(",");  
            int[] myData = new int[str.length];  
            for(int i=0;i<str.length;i++){
            	if(str[i]!=""||!str[i].equals(""))
            		if(str[i].matches("^[0-9]+$"))
                myData[i] = Integer.parseInt(str[i]);  
            }  
            testData.add(myData);  
        }  
        bf.close();  
        data.close();  
          
    }  
    public static String myString(Double[] arr){  
        String num = "";  
        for(int i=0;i<arr.length;i++){  
            if(i==arr.length-1){  
                num += String.valueOf(arr[i]);  
            }else{  
                num += String.valueOf(arr[i])+',';  
            }  
        }  
        return num;  
    }  
}  

CountAll:

public class CountAll {  
    private Long sum;  
    private Double[] value;  
    private int k;  
    public CountAll(){}  
    public CountAll(Long sum, Double[] value,int k){  
        this.sum = sum;  
        this.value = value;  
        this.k = k;  
    }  
    public Double[] getValue() {  
        return value;  
    }  
    public void setValue(Double[] value) {  
        this.value = value;  
    }  
    public Long getSum() {  
        return sum;  
    }  
    public void setSum(Long sum) {  
        this.sum = sum;  
    }  
    public int getK() {  
        return k;  
    }  
    public void setK(int k) {  
        this.k = k;  
    }  
}  

MainJob:

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class MainJob {
	public static void main(String[] args) throws Exception {  
        Configuration conf = new Configuration();  
        String[] otherArgs = new GenericOptionsParser(conf, args)  
                .getRemainingArgs();  
        if (otherArgs.length != 2) {  
            System.err.println("Usage: numbersum <in> <out>");  
            System.exit(2);  
        }  
        long startTime = System.currentTimeMillis();// 計算時間  
        Job job = new Job(conf);  
        job.setJarByClass(MainJob.class);  
        job.setMapperClass(BayesMapper.class);  
        job.setReducerClass(BayesReducer.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(MyWritable.class);
        job.setOutputKeyClass(IntWritable.class);  
        job.setOutputValueClass(MyWritable.class);
        FileInputFormat.addInputPath(job, new Path(otherArgs[0]));  
        FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));  
        job.waitForCompletion(true);  
        long endTime = System.currentTimeMillis();  
        System.out.println("time=" + (endTime - startTime));  
        System.exit(0);  
    }  

}

測試數據:

1,0,0,0,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,0,0,0,0,0  
1,0,0,0,1,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,1  
1,1,0,1,1,0,0,0,1,0,1,0,1,1,0,0,0,0,0,0,0,1,1  
1,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1  
1,0,0,1,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,1,0,1  
1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,0,1,1,0,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,1,0,0,1,1  
1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,0,1,1,0,1,1  
1,0,1,1,0,0,1,1,1,0,0,0,1,1,0,0,1,1,1,0,1,1,1  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,0,1,0  
1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,1,1,0,1,0,1,1,0,0,0,1,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0  
1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,0,0,1,1  
1,1,0,0,0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0  
1,1,1,0,0,1,1,1,0,0,1,1,1,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0  
1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0  

驗證數據:

1,1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,0,0,1,1,0,0  
1,1,0,0,1,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,1,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1  
1,0,1,1,1,0,0,1,0,1,0,0,1,1,1,0,1,0,0,0,0,1,0  
1,0,0,1,0,0,0,0,1,0,0,1,0,1,1,0,1,0,0,0,0,0,1  
1,0,0,1,1,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1  
1,1,0,0,1,0,0,1,1,1,1,0,1,1,1,0,1,0,0,0,1,0,1  
1,1,0,0,1,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,1,0,0,1,1,1,0,0,1,1,1,0,0,1,0,1,1,0,1,0,0,0  
1,1,0,0,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0  
1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1  
1,1,0,0,0,1,0,0,0,1,1,0,0,0,1,0,0,0,1,1,0,0,0  
1,1,0,0,1,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,0  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,1,0,0,1,0  
1,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,0,0,0,1,1  
1,1,0,0,0,0,1,1,0,0,1,1,1,0,0,0,0,1,0,0,0,0,1  
1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0  
1,1,1,1,0,1,0,1,1,0,1,0,1,1,0,0,1,0,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,1,0,0  
1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,1,0,1,1,1  
1,0,0,1,1,1,0,0,1,1,1,0,0,1,1,1,1,0,1,0,1,1,0  
1,1,1,0,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,1,0,0  
1,1,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0  
1,1,0,1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0  
1,1,1,1,1,0,1,1,1,0,1,0,0,1,1,1,1,0,0,1,1,0,0 

運行結果:



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章