日本語CLIPを使って画像検索を作ったら素晴らしすぎた

LINEヤフーから日本語CLIPが出ていたので、どうやって使うんだろうと試してたら、なんかめちゃくちゃ便利な画像検索ができてしまいました。

clip-japanese-basic

LINEヤフーの日本語CLIP、clip-japanese-baseはこちらで紹介されています。
高性能な日本語マルチモーダル基盤モデル「clip-japanese-base」を公開しました

HuggingFaceのモデルはこちら。
https://huggingface.co/line-corporation/clip-japanese-base

CLIPとは?

ところでCLIPとは、となりますけど、OpenAIが公開してる、言語と画像を扱える機械学習モデルです。Contrastive Language-Image Pre-Trainingの略らしい。
https://openai.com/index/clip/

言葉と画像に対してそれぞれベクトルを返してくれて、内容が近ければ同じような向きになっているという仕組みです。

なので、一つの画像のベクトルをとってきて、複数の言葉のベクトルと比べてみれば、どの言葉が一番近いかというのがわかります。

HuggingFaceのサンプルでは「犬」「猫」「象」と比べてどれに近いかという分類をしてますね。

text = tokenizer(["犬", "猫", "象"])

そして犬の画像を与えてるので、最初の値が1になってる。

[[1., 0., 0.]]

逆に、一つの言葉のベクトルをとってきて、複数の画像のうち近い向きのものを探してくれば画像の検索になるというわけです。

日本語CLIPは、rinna、Stable AI、Recruitからも出ています。
rinna社、日本語に特化した言語画像モデルCLIPを公開|rinna株式会社
最高性能の、日本語画像言語特徴抽出モデル「Japanese Stable CLIP」をリリースしました — Stability AI Japan
Recruit Data Blog | 日本語CLIP 学習済みモデルと評価用データセットの公開

CLIPサーバーをつくる

ということでCLIPを使ってなにかを作りたいのですけど、HuggingFaceのモデルなので基本的にはPythonで動かします。
けれど処理はJavaで書きたいので、Web APIをつくります。
FastAPIというのを使ってWeb APIをつくりました。

画像からベクトルを得るimage_embedというエンドポイントを用意。あとでこれを使って、画像のインデックスを作ります。

@app.post("/image_embed")
def embed_text(request: TextRequest):
  image = Image.open(request.text)
  image_t = processor(image, return_tensors="pt").to(device)
  with torch.no_grad():
    image_features = model.get_image_features(**image_t)
  embedding = image_features.cpu().numpy().tolist()
  return {"embedding": embedding[0]}

それと、言語からベクトルを得るtext_embedというエンドポイントを用意。検索時にこれで検索語句からベクトルをとってきて、画像インデックスの中から近いベクトルをもってるものを探していきます。

@app.post("/text_embed")
def embed_text(request: TextRequest):
  text_t = tokenizer(request.text).to(device)
  with torch.no_grad():
    text_features = model.get_text_features(**text_t)
  embedding = text_features.cpu().numpy().tolist()
  return {"embedding": embedding[0]}

ソースはこれ。
https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c#file-clip_server-py

Javaからの呼び出し

PythonのWeb APIJavaから呼び出すコードを書きます。HttpClientを使うのだけど、FastAPIがHTTP2 Upgradeに対応してないようで、HttpClientがUpgrade: h2cというヘッダーをつけないように、HTTP/1.1を指定しておきます。

private static final HttpClient client = HttpClient.newBuilder()
        .version(HttpClient.Version.HTTP_1_1) 
        .build();

というのを前回のブログにまとめてます。
PythonのFastAPIにJavaのHttpClientから接続しようとするとupgradeできないというエラーになるのでHTTP 1.1を指定する - きしだのHatena

あとは、呼び出すのみ。

String json = mapper.writeValueAsString(new TextRequest(text));
var req = HttpRequest.newBuilder()
        .uri(URI.create(BASE_URL + endPoint))
        .header("Content-Type", "application/json")
        .POST(HttpRequest.BodyPublishers.ofString(json))
        .build();
var res = client.send(req, HttpResponse.BodyHandlers.ofString());
var body = mapper.readValue(res.body(), EmbedResponse.class);
return body.embedding();

コードはこれ。
https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c#file-clipclient-java

画像からインデックスを作る

さて、JavaからWeb APIを経由してCLIPを呼び出せるようになったので、まずは画像からインデックスを作ります。
MongoDBとかに保存しようかと思ったけど、作ったインデックスを変更することはないので、ざっくりJSONで保存してます。

var images = Files.list(Path.of(PATH))
        .filter(p -> !Files.isDirectory(p))
        .map(p -> new ImageData(p, ClipClient.imageEmbedding(p.toString())))
        .toArray(ImageData[]::new);
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(Files.newOutputStream(Path.of("index.json")), images);

ちゃんとやるときはFiles.walkなどを使いましょう。
あと、CLIPの返すベクトルが単位ベクトルではないようなので、ちゃんと長さ1になるよう正規化したほうがいいです。

ソースこれ。
https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c#file-createindex-java

検索ワードと比較する

インデックスができたら、検索です。
検索語句のベクトルと角度が近いベクトルの画像を探すことになります。

大量の画像に対してまじめにやるなら近似最近傍探索(ANN)を使うほうがいいのだけど、個人の画像で数万件程度だと素朴な処理で十分です。

まずは角度をとるので内積。これVector APISIMD化すると速そうだけど、要素数が少ないので、あまり効きませんでした。

    static float prod(float[] a, float[] b) {
        float score = 0;
        for(int i = 0; i < a.length; ++i) {
            score += a[i] * b[i];
        }
    }

内積ではコサインがとれるので、大きい方が角度が浅く2つのベクトルが近いということになります。あとは大きいものを5つくらいとってくるようにすればいいだけ。

    Result[] search(String text) {
        var emb = ClipClient.textEmbedding(text);

        var top5 = new Result[topCount + 1];
        for (var img : images) {
            var score = prod(emb, img.embedding());
            for (int i = top5.length - 2; i >= 0; --i) {
                if (top5[i] == null) {
                    if (i == 0 || top5[i - 1] != null) {
                        top5[i] = new Result(img, score);
                    }
                    continue;
                }
                if (top5[i].score() < score) {
                    top5[i + 1] = top5[i];
                    top5[i] = new Result(img, score);
                }
            }
        }
        return Arrays.stream(top5).limit(top5.length - 1).toArray(Result[]::new);
    }

件数が少ないので、マルチスレッドで並列化しても効果なかった。

ソースこれ。
https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c#file-sercher-java

UIをつくる

じゃあ処理が全部できんたのでUIつくりましょう。ふつうのSwingのアプリです。 結果の表示をJTextPaneにHTMLで表示とか試してたのだけど、結局自分で描画しました。
ただ、そのときImageIO.readJPEGのOrientationを見てくれないので、metadata-extractorでOrientaionをみて回転するコードが必要です。

        int ori = 1;
        try {
            var metadata = ImageMetadataReader.readMetadata(Files.newInputStream(path));
            var dir = metadata.getFirstDirectoryOfType(ExifIFD0Directory.class);
            ori = dir.getInt(ExifIFD0Directory.TAG_ORIENTATION);
        } catch(Exception ex) {}
        var trans = new AffineTransform();
        switch (ori) {
            case 6 -> {//右
                trans.translate(scaledH, 0);
                trans.rotate(Math.toRadians(90));
            }
            case 3 -> {//逆
                trans.translate(scaledW, scaledH);
                trans.rotate(Math.toRadians(180));
            }
            case 8 -> {//左
                trans.translate(0, scaledW);
                trans.rotate(Math.toRadians(270));
            }
        }
        var rotated = new BufferedImage(scaledW, scaledH, img.getType());
        var op = new AffineTransformOp(trans, AffineTransformOp.TYPE_BICUBIC);
        op.filter(img, rotated);

https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c#file-drawui-java

結果

結構いい感じです。
「戦う猫」でケンカ中の猫の写真がでてることから、ちゃんとシチュエーションの認識もできていますね。
また、「オブジェクト指向の本」でオブジェクト指向の本や「Object Oriented」と書いてあるプレゼンテーションが出ていることから、文字が読めていることもわかります。

あと、「寝てる猫」で出てる写真は結構気に入ってたけど探そうと思ってもなかなかみつからないやつで、こういったなつかしい写真がいろいろ掘り起こされるのが面白いです。

ということで、割と簡単なプログラムでめちゃくちゃ便利な検索ができるので、みんな作りましょう。
ソース、すでに出してますけど、あらためて。
https://gist.github.com/kishida/6c66b3c212f432a19aa176859163e93c