iteratorや拡張forよりStreamのforEachが速い?

ちょっと気になったので、簡単にベンチマークしてみました。
最初は、ラムダ呼び出しが入る分forEachは遅いんじゃないかと思っていたら、倍の速さに。
もちろん、いろんな条件で変わるんだろうけど、ここまで差が出ることがあるのは驚き。

あと、Collectors.summingIntのような基本型に対するCollectorを使うよりは、intStreamに変換してからsumなど専用メソッドを使うほうが圧倒的に速いことも確認できた。

とりあえず、0から10万件のListを用意。

array = IntStream.range(0, 100_000).boxed().collect(Collectors.toList());

それからベンチマーク用のメソッドを用意。

    public static void bench(String name, Supplier<Integer> proc){
        bench(name, 50_000, proc);
    }
    public static void bench(String name, int count, Supplier<Integer> proc){
        for(int i = 0; i < 100; ++i){
            proc.get();
        }
        long s = System.currentTimeMillis();
        for(int i = 0; i < count; ++i){
            proc.get();
        }
        System.out.printf("%s(%d):%dms%n", name, proc.get(), System.currentTimeMillis() - s);
    }

まずは拡張for

    static Integer forEach(){
        int c = 0;
        for(Integer i : array){
            c += i;
        }
        return c;
    }

これをこんな感じで呼び出す

bench("forEach", LoopBench::forEach);

11124msという結果。

Iteratorを使うループ

    static Integer ite(){
        int c = 0; 
        for(Iterator<Integer> ite = array.iterator(); ite.hasNext();){
            c += ite.next();
        }
        return c;
    }

10610msとなって、拡張forよりちょっと速い。

ついでにインデックスでのアクセス

    static Integer index(){
        int c = 0;
        for(int i = 0; i < array.size(); ++i){
            c += array.get(i);
        }
        return c;
    }

10813ms。拡張forより速く、iteratorより遅い。まあでもだいたい同じ。

Streamを使ってみる。

    static Integer stream(){
        int[] c = {0};
        array.stream().forEach(i -> c[0] += i);
        return c[0];
    }

5715msで、かなり速い。

ただ、これは集計用変数を配列にする必要があったりと、ちょっと美しくない。
こういうのはreduce使う必要があるので、Collectors.summingIntを使ってみる。

    static Integer reduce(){
        return array.stream().collect(Collectors.summingInt(i -> i));
    }

40280msと、想像以上に遅い。parallelStreamにするとコア数分速くなって11305msになるけど、4コア使ってようやく拡張forと同じ処理時間というのは割にあわない。

これを、先にmapToIntでintStreamに変換してからsumすると、5303msとなってものすごく速くなる。

    static Integer intreduce(){
        return array.stream().mapToInt(i -> i).sum();
    }

parallelStreamにすると2011msになって、並列化のうれしさもある。

Streamが入ってから、同じループ処理の書き方の選択肢が増えて、特性もそれぞれ違うので、違いをイメージしつつコードが書けるほうがよさそう。

2023/2/12 グラフ追記

ソースコード全体

public class LoopBench {
    public static void main(String[] args) {
        array = IntStream.range(0, 100_000).boxed().collect(Collectors.toList());
        bench("forEach", LoopBench::forEach);
        bench("ite", LoopBench::ite);
        bench("index", LoopBench::index);
        bench("stream", LoopBench::stream);
        bench("intreduce", LoopBench::intreduce);
        bench("paraintreduce", LoopBench::paraintreduce);
        bench("reduce", LoopBench::reduce);
        bench("para", LoopBench::parareduce);
    }
    
    static List<Integer> array; 
    static Integer forEach(){
        int c = 0;
        for(Integer i : array){
            c += i;
        }
        return c;
    }
    static Integer ite(){
        int c = 0; 
        for(Iterator<Integer> ite = array.iterator(); ite.hasNext();){
            c += ite.next();
        }
        return c;
    }
    static Integer index(){
        int c = 0;
        for(int i = 0; i < array.size(); ++i){
            c += array.get(i);
        }
        return c;
    }
    static Integer stream(){
        int[] c = {0};
        array.stream().forEach(i -> c[0] += i);
        return c[0];
    }
    static Integer reduce(){
        return array.stream().collect(Collectors.summingInt(i -> i));
    }
    static Integer parareduce(){
        return array.parallelStream().collect(Collectors.summingInt(i -> i));
    }
    static Integer intreduce(){
        return array.stream().mapToInt(i -> i).sum();
    }
    static Integer paraintreduce(){
        return array.parallelStream().mapToInt(i -> i).sum();
    }
    public static void bench(String name, Supplier<Integer> proc){
        bench(name, 50_000, proc);
    }
    public static void bench(String name, int count, Supplier<Integer> proc){
        for(int i = 0; i < 100; ++i){
            proc.get();
        }
        long s = System.currentTimeMillis();
        for(int i = 0; i < count; ++i){
            proc.get();
        }
        System.out.printf("%s(%d):%dms%n", name, proc.get(), System.currentTimeMillis() - s);
    }
}
import matplotlib.pyplot as plt
data = [11124, 10610, 10813, 5715, 40280, 11305, 5303, 2011]
label = ["for-each", "iterator", "index", "stream", "collector", "collectorP", "IntStr", "IntStr(p)"]
plt.bar(label, data)