1台のマシンで並列実行するためのMapReduceフレームワークを作ってみた

ブックマークのコメントに

1台でもGbyte単位のデータを処理可能なプログラムを簡単に書けるのは十分なメリットだと思う。

とあったので、ついカッとなって作ってみた。


サンプルは前と同じく、クラスがJavaソース中でimportされている回数を数えるもの。
mapreduceメソッドを適当に呼び出せばおっけーです。
こんな感じの結果ファイルが生成されました。

java.io.FileOutputStream	1
java.io.FileReader	1
java.awt.Graphics	1
java.io.Reader	1
java.awt.event.*	3
java.io.BufferedReader	3


書いてみたら動いたというレベルなので、実際に使うにはちゃんと例外処理とかをやってください。
ソースはこんな感じ

import java.io.*;
import java.lang.management.ManagementFactory;
import java.util.*;
import java.util.concurrent.*;

public class MapReduce {
    private static final String INPUT_PATH =
            "C:/Users/naoki/java/raytrace/src/ml";
    private static final String OUTPUT_PATH =
            "C:/Users/naoki/java/raytrace/build/temp";
    private static final int reducerCount = 20;

    public static void main(String[] args) throws Exception{
        mapreduce(new Mapper() {
            public void map(String line, 
                    List<Map.Entry<String, String>> output)
            {
                if(line.startsWith("import")){
                    //import文のとき、クラス名をキーに値を1にして出力
                    String word = line.substring(7, line.length() - 1);
                    output.add(new AbstractMap.SimpleEntry<String, String>(
                            word, "1"));
                }
            }
        },new Reducer() {
            public void reduce(String key, List<String> values, 
                    List<String> output)
            {
                int sum = 0;
                for(String value : values){
                    sum += Integer.parseInt(value);
                }
                output.add(String.format("%s\t%d", key, sum));
            }
        });
    }

    /** Mapper */
    interface Mapper{
        void map(String line, List<Map.Entry<String, String>> output);
    }
    /** Reducer */
    interface Reducer{
        void reduce(String key, List<String> values, List<String> output);
    }

    /** MapReduce処理 */
    static void mapreduce(final Mapper mapper, final Reducer reducer)
            throws IOException, InterruptedException
    {
        int procs = ManagementFactory.getOperatingSystemMXBean().
                getAvailableProcessors();
        ExecutorService exec = Executors.newFixedThreadPool(procs);

        //結果フォルダの準備
        File out = new File(OUTPUT_PATH);
        if(!out.exists()) out.mkdirs();
        for(File f : out.listFiles()) f.delete();

        //入力フォルダの確認
        File in = new File(INPUT_PATH);
        if(!in.exists()){
            System.out.println(in.getCanonicalPath() + "がありません。");
            return;
        }

        //一時ファイルの準備
        File[] tempFiles = new File[reducerCount];
        for(int i = 0; i < tempFiles.length; ++i){
            tempFiles[i] = new File(out,
                    String.format("%02d.txt", i));
        }

        //一時ファイル書き込み準備
        final FileWriter[] fws = new FileWriter[tempFiles.length];
        for(int i = 0; i < fws.length; ++i){
            fws[i] = new FileWriter(tempFiles[i], true);
        }

        //ファイルを読み込む
        System.out.println("map");
        LinkedList<File> dirs = new LinkedList<File>();
        dirs.add(in);
        while(!dirs.isEmpty()){
            for(File f : dirs.pop().listFiles()){
                //各ファイルの処理
                if(f.isDirectory()){
                    dirs.add(f);
                    continue;
                }
                if(!f.getName().endsWith(".java")) continue;
                FileReader fr = new FileReader(f);
                BufferedReader buf = new BufferedReader(fr);
                //一行ずつ読み込んでMapperに渡す
                for(String line; (line = buf.readLine()) != null;){
                    final List<Map.Entry<String, String>> outputs =
                            new ArrayList<Map.Entry<String, String>>();
                    final String paramline = line;
                    exec.submit(new Runnable() {
                        public void run() {
                            mapper.map(paramline, outputs);
                            //mapperの結果を一時ファイルに振り分ける
                            try {
                                for(Map.Entry<String, String> me : outputs){
                                    int hash = me.getKey().hashCode();
                                    if(hash < 0) hash = -hash;
                                    int num = (hash / 100) % reducerCount;
                                    //ファイルに書き込む
                                    synchronized(fws[num]){
                                        fws[num].append(String.format(
                                                "%s\t%s%n",
                                                me.getKey(), me.getValue()));
                                    }
                                }
                            } catch (IOException ex) {
                                ex.printStackTrace();
                            }
                        }
                    });
                }
            }
        }

        //Map待ち
        exec.shutdown();
        exec.awaitTermination(3, TimeUnit.MINUTES);
        for(FileWriter fw : fws){
            fw.close();
        }
        //Reduce
        System.out.println("reduce");
        exec = Executors.newFixedThreadPool(procs);

        //結果ファイルの準備
        File resultFile = new File(out, "result.txt");
        final FileWriter fw = new FileWriter(resultFile);
        //中間ファイルをひとつづつ処理
        for(final File f : tempFiles){
            exec.submit(new Runnable() {
                public void run() {
                    try {
                        //中間ファイルのデータをキーごとにまとめる
                        Map<String, List<String>> mapresult =
                                new HashMap<String, List<String>>();
                        List<String> result = new ArrayList<String>();
                        FileReader fr = new FileReader(f);
                        BufferedReader buf = new BufferedReader(fr);
                        for (String line; (line = buf.readLine()) != null;) {
                            String[] strs = line.split("\t", 2);
                            String key = strs[0];
                            String value = strs[1];
                            if (!mapresult.containsKey(key)) {
                                mapresult.put(key, new ArrayList<String>());
                            }
                            mapresult.get(key).add(value);
                        }
                        //キーごとにreduce処理
                        for (Map.Entry<String, List<String>> me : mapresult.entrySet()) {
                            reducer.reduce(me.getKey(), me.getValue(), result);
                        }
                        //reduce結果を書き出し
                        synchronized (fw) {
                            for (String line : result) {
                                fw.append(line + "\n");
                            }
                        }
                        buf.close();
                        fr.close();
                    } catch (IOException ex) {
                        ex.printStackTrace();
                    } 
                }
            });
        }

        //Reduce待ち
        exec.shutdown();
        exec.awaitTermination(3, TimeUnit.MINUTES);

        //終了
        fw.close();
        System.out.println("finish");
    }

}