MapReduce實現大規模矩陣乘法

矩陣乘法

在這裏插入圖片描述
在這裏插入圖片描述

第一種做法

兩次mr過程

第一次Map:對矩陣元素mij產生鍵值對(j,(M,i,mij)),對每個元素njk產生鍵值對(j,(N,k,njk))
第一次Reduce: 生成鍵值對key:(i,k),value:mij*njk
第二次Map:do nothing
第二次Reduce:對每個key(i,k),累加求和得((i,k),sum) 得到結果矩陣Dik的值



import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.MultipleInputs;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.ReflectionUtils;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;

class Element implements Writable {

    public String tag;//標記是矩陣M還是矩陣N
    public int index;
    public double value;

    Element() {
        tag = "";
        index = 0;
        value = 0.0;
    }

    Element(String tag, int index, double value) {
        this.tag = tag;
        this.index = index;
        this.value = value;
    }

    @Override
    public void readFields(DataInput input) throws IOException {
        tag = input.readLine();
        index = input.readInt();
        value = input.readDouble();
    }

    @Override
    public void write(DataOutput output) throws IOException {
        output.writeBytes(tag);
        output.writeInt(index);
        output.writeDouble(value);
    }
}

class Pair implements WritableComparable<Pair> {

    int i;
    int j;

    Pair() {
        i = 0;
        j = 0;
    }

    Pair(int i, int j) {
        this.i = i;
        this.j = j;
    }

    @Override
    public void readFields(DataInput input) throws IOException {
        i = input.readInt();
        j = input.readInt();
    }

    @Override
    public void write(DataOutput output) throws IOException {
        output.writeInt(i);
        output.writeInt(j);
    }

    @Override
    public int compareTo(Pair compare) {

        if (i > compare.i) {
            return 1;
        } else if (i < compare.i) {
            return -1;
        } else {
            if (j > compare.j) {
                return 1;
            } else if (j < compare.j) {
                return -1;
            }
        }
        return 0;
    }

    public String toString() {
        return i + " " + j + " ";
    }

}

public class MultiplyMatrix {

    /**
     * 對每個矩陣元素mij 產生鍵值對(j, (M,i,mij))
     */
    public static class matrixMapperM extends Mapper<Object, Text, IntWritable, Element> {


        @Override
        public void map(Object object, Text value, Context context) throws IOException, InterruptedException {

            //i,j,Mij
            String readLine = value.toString();
            String[] stringTokens = readLine.split(",");

            int index = Integer.parseInt(stringTokens[0]);
            double val = Double.parseDouble(stringTokens[2]);
            IntWritable key = new IntWritable(Integer.parseInt(stringTokens[1]));
            Element e = new Element("M", index, val);
            context.write(key, e);
        }
    }


    /**
     * 對每個矩陣元素njk 產生鍵值對(j, (N,k,mjk))
     */
    public static class matrixMapperN extends Mapper<Object, Text, IntWritable, Element> {


        @Override
        public void map(Object object, Text value, Context context) throws IOException, InterruptedException {

            //j,k,Njk
            String readLine = value.toString();
            String[] stringTokens = readLine.split(",");

            int index = Integer.parseInt(stringTokens[1]);
            double val = Double.parseDouble(stringTokens[2]);
            IntWritable key = new IntWritable(Integer.parseInt(stringTokens[0]));
            Element e = new Element("N", index, val);
            context.write(key, e);
        }
    }

    public static class ReducerMXN extends Reducer<IntWritable, Element, Pair, DoubleWritable> {

        /**
         * input: <key,List<values></>></>
         * output: (<i,k>,Mij*Njk) </>)
         *
         * @param key
         * @param values
         * @param context
         * @throws IOException
         */
        @Override
        public void reduce(IntWritable key, Iterable<Element> values, Context context) throws IOException, InterruptedException {

            ArrayList<Element> m = new ArrayList<Element>();
            ArrayList<Element> n = new ArrayList<Element>();

            Configuration conf = context.getConfiguration();

            for (Element element : values) {

                Element tempElement = ReflectionUtils.newInstance(Element.class, conf);
                ReflectionUtils.copy(conf, element, tempElement);

                if (tempElement.tag.equals("M")) {
                    m.add(tempElement);
                } else {
                    n.add(tempElement);
                }
            }

            for (int i = 0; i < m.size(); i++) {
                for (int j = 0; j < n.size(); j++) {

					//生成key:(i,k)
					//value: Mij * Njk 
                    Pair p = new Pair(m.get(i).index, n.get(j).index);
                    DoubleWritable doubleWritable = new DoubleWritable(m.get(i).value * n.get(j).value);
                    context.write(p, doubleWritable);
                }
            }
        }
    }

    /*
    以下的map-reduce目的是彙總結果
     */
    public static class MapperMXN extends Mapper<Object, Text, Pair, DoubleWritable> {

        /*
        do nothing 簡單映射
         */
        @Override
        public void map(Object text, Text value, Context context) throws IOException, InterruptedException {

            String readLine = text.toString();
            String[] pairValue = readLine.split(",");

            Pair p = new Pair(Integer.parseInt(pairValue[0]), Integer.parseInt(pairValue[1]));
            DoubleWritable doubleWritable = new DoubleWritable(Double.parseDouble(pairValue[2]));
            context.write(p, doubleWritable);

        }
    }

    /**
     * Reduce節點彙總所有key 爲<i,k>的數值並作累加</>
     */
    public static class ReduceMXN extends Reducer<Pair, DoubleWritable, Pair, DoubleWritable> {

        @Override
        public void reduce(Pair p, Iterable<DoubleWritable> values, Context context) throws IOException, InterruptedException {
            double ans = 0;
            for (DoubleWritable value : values) {
                ans += value.get();
            }
            context.write(p, new DoubleWritable(ans));
        }

    }

    public static void main(String[] args) throws Exception {
        Configuration conf = new Configuration();
        Job job = Job.getInstance();
        job.setJobName("MapIntermediate");
        job.setJarByClass(MultiplyMatrix.class);

        MultipleInputs.addInputPath(job, new Path(args[0]), TextInputFormat.class, matrixMapperM.class);
        MultipleInputs.addInputPath(job, new Path(args[1]), TextInputFormat.class, matrixMapperN.class);
        job.setReducerClass(ReducerMXN.class);

        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(Element.class);

        job.setOutputKeyClass(Pair.class);
        job.setOutputValueClass(DoubleWritable.class);

        job.setOutputFormatClass(TextOutputFormat.class);

        FileOutputFormat.setOutputPath(job, new Path(args[2]));

        job.waitForCompletion(true);

        Job job2 = Job.getInstance();
        job2.setJobName("MapFinalOutput");
        job2.setJarByClass(MultiplyMatrix.class);

        job2.setMapperClass(MapperMXN.class);
        job2.setReducerClass(ReduceMXN.class);

        job2.setMapOutputKeyClass(Pair.class);
        job2.setMapOutputValueClass(DoubleWritable.class);

        job2.setOutputKeyClass(Pair.class);
        job2.setOutputValueClass(DoubleWritable.class);

        job2.setInputFormatClass(TextInputFormat.class);
        job2.setOutputFormatClass(TextOutputFormat.class);

        FileInputFormat.setInputPaths(job2, new Path(args[2]));
        FileOutputFormat.setOutputPath(job2, new Path(args[3]));

        job2.waitForCompletion(true);
    }
}

第二種做法:

直接一個map,一個reduce
第一個map直接生成
key:(i,k)
value:(M/N, j, value)

然後reduce直接針對每一個(i,k)對value做累加

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;


import java.io.IOException;


public class matrixMul {

    public static class MatrixMapper extends Mapper<Object, Text, Text, Text> {

        private Text map_key = new Text();
        private Text map_value = new Text();

        int row = 300;
        int col = 500;

        String Target;//Matrix M or N?
        String i, j, k, ij, jk;

        public void map(Object key, Text value, Context context) throws IOException, InterruptedException {

            String eachterm[] = value.toString().split("#");
            Target = eachterm[0];
            if (Target.equals("M")) {

                i = eachterm[1];//m row
                j = eachterm[2];//m col
                ij = eachterm[3];//m[i][j]

                //  Map函數:對於矩陣M的每個元素M[i,j],
                //  產生一系列的鍵值對(i,k)->(M,j, M[i,j]),其中k=1,2…,
                //  直到矩陣N的列數。
                for (int c = 1; c <= col; c++) {
                    map_key.set(i + "#" + String.valueOf(c));
                    map_value.set("M" + "#" + j + "#" + ij);
                    context.write(map_key, map_value);
                }
            }

            //同樣,對於矩陣N的每個元素N[j,k],
            // 產生一系列的鍵值對(i,k)->(N,j,N[j,k]),
            // 其中i=1,2…,直到矩陣M的行數。
            else if (Target.equals("N")) {
                j = eachterm[1];
                k = eachterm[2];
                jk = eachterm[3];
                for (int r = 1; r <= row; r++) {
                    map_key.set(String.valueOf(r) + "#" + k);
                    map_value.set("N" + "#" + j + "#" + jk);
                    context.write(map_key, map_value);
                }
            }
        }
    }

    public static class MatrixReducer extends Reducer<Text, Text, Text, Text> {


        private Text reduce_value = new Text();
        int jNum = 150; //M的col,N的row

        //M的行向量
        int[] M_ij = new int[jNum + 1];
        //N的列向量
        int[] N_jk = new int[jNum + 1];

        int j;//index
        int ij, jk;//M[ij],N[jk]

        int sum = 0;
        String target;

        /**
         * input:
         * @param key
         * @param values
         * @param context
         * @throws IOException
         * @throws InterruptedException
         */
        public void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {

            sum = 0;
            for (Text val : values) {
                String eachterm[] = val.toString().split("#");
                target = eachterm[0];
                j = Integer.parseInt(eachterm[1]);
                if (target.equals("M")) {
                    ij = Integer.parseInt(eachterm[2]);
                    M_ij[j] = ij;
                } else if (target.equals("N")) {
                    jk = Integer.parseInt(eachterm[2]);
                    N_jk[j] = jk;
                }
            }

            for (int i = 1; i <= jNum; i++) {
                sum += M_ij[i] * N_jk[i];
            }
            reduce_value.set(String.valueOf(sum));
            context.write(key, reduce_value);
        }
    }

    public static void main(String[] args) throws Exception {


        Configuration conf = new Configuration();
        String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
        if (otherArgs.length != 2) {
            System.err.println("Usage:<in><out>");
            System.exit(2);
        }
        Job job = new Job(conf, "matrixMul");
        job.setJarByClass(matrixMul.class);
        job.setMapperClass(MatrixMapper.class);
        job.setReducerClass(MatrixReducer.class);
        job.setOutputKeyClass(Text.class);
        job.setOutputValueClass(Text.class);
        FileInputFormat.addInputPath(job, new Path(otherArgs[0]));
        FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));
        System.exit(job.waitForCompletion(true) ? 0 : 1);
    }


}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章