Java8で強化されたMapと、書きやすくなったメモ化再帰

Java8のlambda構文の話を書くと、旧来の書き方でいいというコメントがつくのですが、それでも便利になったMapの恩恵を受けることは多いんじゃないかと思います。
※ 2018/5/31 Java9からはメモ化再帰には使えなくなっています
※ 2019/2/15 なんか問題ない?


Mapには、lambda式を使ったメソッドが多く追加されていますが、たとえばgetOrDefaultメソッドのようなlambda式を使わないメソッドも追加されていて、これも便利です。
そして、このようなlambda式を使わないメソッドも、間接的にはlambda構文サポートでの言語拡張のおかげです。
Mapはインタフェースなので、Java7までの構文でメソッドを追加しようとすると、Mapを実装しているすべてのクラスに新しいメソッドの実装を追加する必要がありました。そしてそれは現実的に不可能なので、今までMapなどのインタフェースに手がいれられることはありませんでした。それが、lambda構文サポートの一環として入れられた、インタフェースのデフォルト実装のおかげで、インタフェースを拡張することができるようになったわけです。

forEachとgetOrDefault

では、今回のMap拡張で一番使うことになりそうな、forEachとgetOrDefaultを見てみます。
まず、文字列のリストを集計して、文字列ごとの出現回数をカウントして表示するコードをJava7構文で書いてみます。

List<String> strs = Arrays.asList("blue", "red", "black", "blue", "white", "black", "blue");

Map<String, Integer> counter = new HashMap<>();
for(String s : strs){
    Integer c = counter.get(s);
    if(c == null){
        c = 0;
    }
    counter.put(s, c + 1);
}
for(Map.Entry<String, Integer> me : counter.entrySet()){
    System.out.printf("%s:%d%n", me.getKey(), me.getValue());
}


先に、結果表示部分を見てみると、Map.Entryを取り出してループをまわしています。
実際にこのコードを書くとき、ここの型指定でいちいちMapが格納している型を確認するのが結構めんどくさかったりしました。Map.Entryって書くのも面倒です。
これが、forEachを使うと次のようになります。

counter.forEach((k, v) -> {
  System.out.printf("%s:%d%n", k, v)
});

すっきり。
型推論してくれるので、counterが保持している型を改めて書く必要もないし、Map.Entryのように新しい型をもってくる必要もありません。


次に集計部分を見てみます。
ここでは、次のようにして、Mapから値をとりだして、値がなければ0を使うようにしています。

Integer c = counter.get(s);
if(c == null){
    c = 0;
}


これが、getOrDefaultを使うと次のように書けます。

Integer c = counter.getOrDefault(s, 0);

条件文がなくなりました!


では、for文を使っているところもforEachを使って書き直すと、結局こんなコードになります。

Map<String, Integer> counter = new HashMap<>();
strs.forEach(s -> {
    Integer c = counter.getOrDefault(s, 0);
    counter.put(s, c + 1);
});

いい感じです。


ただ、Java8ではリストを集計するのにそもそもこんなコードは書く必要がなくて、コレクターを使って次のように書けます。

Map<String, Long> counter = strs.stream()
    .collect(Collectors.groupingBy(s -> s, Collectors.counting()));

ここまでの話はなんだったんだろう?って感じですね。

putIfAbsentとcomputeIfAbsent

文字列を、頭文字ごとにリストに入れるというコードをJava7までの構文で書いてみると次のようになります。

List<String> strs = Arrays.asList("blue", "red", "black", "blue", "white", "black", "blue");

Map<String, List<String>> words = new HashMap<>();
for(String str : strs){
    String initial = str.substring(0, 1);
    if(!words.containsKey(initial)){
        words.put(initial, new ArrayList<String>());
    }
    List<String> ls = words.get(initial);
    ls.add(str);
}
for(Map.Entry<String, List<String>> me : words.entrySet()){
    System.out.println(me.getKey());
    for(String s : me.getValue()){
        System.out.println("  " + s);
    }
}


ひとつ、Java8での目立たない改善点として、既存文法での型推論も強化されたということがあります。この部分、本当はArrayListに与える型は推論してほしいのですが、Java7ではメソッドの引数では型推論がきかないので、ダイヤモンド演算子ではなく型を明示する必要があります。

words.put(initial, new ArrayList<String>());

ここが、Java8ではちゃんと型推論が効いて、ダイヤモンド演算子を使うことができるようになっています。

words.put(initial, new ArrayList<>());


さて、putIfAbsentメソッド。これを使うと、Mapが値を保持していないときだけ値を設定するということが可能になります。
そうすると、次のように書けます。

strs.forEach(str -> {
    String initial = str.substring(0, 1);
    words.putIfAbsent(initial, new ArrayList<>());
    List<String> ls = words.get(initial);
    ls.add(str);
});

ついでにforEachを使っています。


ここで、ちょっと問題なのは、wordsがinitialに対応する値をもっている場合でも、ArrayListのオブジェクトが生成されてしまうことです。Javaは遅延評価ではないので、引数の値が実際には使われないとしても、メソッドを呼び出す前に評価されることになって、いちいちArrayListのオブジェクトが生成されるというわけです。


そこで、Mapの初期値を与えるというような場合には、computeIfAbsentメソッドを使うほうが適しています。

strs.forEach(str -> {
    String initial = str.substring(0, 1);
    List<String> ls = words.computeIfAbsent(initial, s -> new ArrayList<>());
    ls.add(str);
});

こうすると、wordsがinitialに対応する値を保持していないときだけlambda式で与えた関数が実行されて、必要なときにだけArrayListのオブジェクトが生成されます。


で、まあこういう風にリストを集計して個別のリストに振り分けるという場合も、コレクターが使えるので、そもそもこんなコードを書く必要はないのですね。

Map<String, List<String>> words = strs.stream()
        .collect(Collectors.groupingBy(str -> str.substring(0, 1)));

computeIfAbsentを使ったメモ化再帰

computeIfAbsentメソッドのドキュメントにも、よくある利用法のひとつとして「memoized result」つまりメモ化があげられてます。
で、まあメモ化というとメモ化再帰、メモ化再帰というとフィボナッチということで、フィボナッチ書いてみます。

public static void main(String... args){
    IntStream.range(1, 101).forEach(i -> {
        System.out.printf("%d:%d%n", i, fib(i))
    });
}
public static long fib(int n){
    if(n == 0){
        return 0;
    }else if(n == 1){
        return 1;
    }else{
        return fib(n - 2) + fib(n - 1);
    }
}


こうすると、40あたりで実行が遅くなって、50まではたどりつかないくらい重くなります。
これは、同じ値を何回も計算してるからなので、一度計算した値はキャッシュすることにして、次のように書き換えます。

private static Map<Integer, Long> memo;

public static void main(String... args){
    memo = new HashMap<>();
    memo.put(0, 0L);
    memo.put(1, 1L);
    IntStream.range(1, 101).forEach(i -> {
        System.out.printf("%d:%d%n", i, fib(i))
    });
}

public static long fib(int n){
    Long result = memo.get(n);
    if(result != null){
        return result;
    }
    result = fib(n - 2) + fib(n - 1);
    memo.put(n, result);
    return result;
}


このfibメソッドを、computeIfAbsentメソッドを使って書き換えると次のようになります。

public static long fib(int n){
    return memo.computeIfAbsent(n, i -> fib(i - 2) + fib(i - 1));
}

いちぎょうだ!
computeIfAbsentメソッドのおかげで、メモ化がやりやすくなったことがわかります。
※ 2018/5/31 Java9からはConccurentModificationExceptionを吐きます
※ 2019/2/15 なんか問題ない?


ところで、このフィボナッチ数列、93あたりで負の値がでてきて、なんだか怪しい感じです。longが桁あふれしてますね。

・・・
90:2880067194370816120
91:4660046610375530309
92:7540113804746346429
93:-6246583658587674878
94:1293530146158671551
・・・


Java8では、桁あふれで例外がでるような演算メソッドがMathクラスに用意されているので、これを使えば桁あふれが検出できるようになっています。

public static long fib(int n){
    return memo.computeIfAbsent(n, i -> Math.addExact(fib(i - 2),fib(i - 1)));
}


こうすると、93を計算するときには、次のようにoverflowの例外が発生します。

Exception in thread "main" java.lang.ArithmeticException: long overflow


ほかにもJava8ではintやlongを符号なし整数として扱うメソッドもIntegerクラスやLongクラスにそれぞれ用意されています。
大きい数を計算するときにはありがたいですね。


メモ化フィボナッチの話は、こちらを参考にしています。
Memoized Fibonacci Numbers with Java 8 | Informatech CR Blog
ここ、「そんなMapとか使わなくても変数2つあれば十分だよ」ってコメントついてて「たしかにそうだけど、これは再帰の効率化の例だからね」って返答してて、よくある光景だなーと思いました。