GPTのEmbeddingを利用してブログの投稿に対する近いものを探し出す

OpenAIでGPTを使ったAPIにembeddingというのがあって、これを使うと文章同士の距離がとれるので、近いエントリを取得したり文章から検索したりができるということで、試してみました。
思いのほかちゃんと動きました。おそらく、GPTで一番実用的なんじゃないでしょうか。

embeddingとは

なんか、文章の特徴を表す多次元のベクトルに変換してくれるらしい。
ようわからん。
OpenAIでは1500次元くらいのベクトルに変換します。 そして、このベクトルの距離が近ければ文章の内容も近いやろということで、似たエントリの抽出などができます。 しかし、テキストが要素数1500のdouble配列になるので、95KBくらい。要するに、95KBあってうまくやればどんな文章でも特徴を表せるよねって感じですね。ということで、だいたいのWeb文章よりはデータが大きくなります。
Azure OpenAI APIの解説がわかりやすいようなようわからんような。
Azure OpenAI Service の埋め込み - Azure OpenAI - embeddings and cosine similarity | Microsoft Learn

はてなブログからエントリのデータをとってくる

はてなブログではMovableType形式でブログをエクスポートできるので、ダウンロードします。
めちゃわかりにくいところにあって、設定 > 詳細の「読者になるボタン」のちょっと上にリンクがあります。

こんな感じになっています。これを読み込んでOpenAIに投げてインデックスを作っていく。

AUTHOR: nowokay
TITLE: ChatGPTは真にプログラミング知識なしでのコンピュータ操作を実現している
BASENAME: 2023/02/27/174524
STATUS: Publish
ALLOW COMMENTS: 1
CONVERT BREAKS: 0
DATE: 02/27/2023 17:45:24
CATEGORY: ChatGPT
CATEGORY: AI
IMAGE: https://cdn-ak.f.st-hatena.com/images/fotolife/n/nowokay/20230227/20230227171005.png
-----
BODY:
<p>ChatGPTで文章を要約したり口調を変えたりゲームのルールを教えてゲームを遊んだり、みんな いろいろな使い方や楽しみ方をしていると思います。<br/>
中にはプログラミングにあまり縁のない人も多くいます。<br/>

インデックスを作る

nowokay.hatenablog.com.export.txtというファイル名になってるので、それを読み込んで解析して、今回はMongoDBにblog_dbというDBでentriesというcollectionにデータをつっこんでいく作戦。

var path = Path.of("nowokay.hatenablog.com.export.txt");
try (var bur = Files.newBufferedReader(path);
     var client = MongoClients.create("mongodb://localhost:27017"))
{
    var db = client.getDatabase("blog_db");
    var coll = db.getCollection("entries", BlogEntry.class);

    var service = new OpenAiService(getToken(), Duration.ZERO);

データはこんな感じのrecordに。

public record BlogEntry (
        String title, String baseName, String image, String date, boolean published,
        String body, String stripedBody, List<Double> vector) { }

パースの状態遷移はこんな感じ。

enum Part {HEADER, CONTENT, BODY, COMMENT}

recordはイミュータブルでパースしつつ値をつっこむということができないので、ヘッダーを一時的に格納するクラス。

static class Header{ String baseName; String image; String title; 
        String date; boolean published;}

こんな感じでヘッダーをつくっていきます。

switch (p) {
    case HEADER -> {
        if (line.startsWith("BASENAME")) {
            System.out.println("bn:" + (h.baseName = line.substring("BASENAME: ".length())));
        } else if(line.startsWith("IMAGE")) {
            System.out.println("img:" + (h.image = line.substring("IMAGE: ".length())));
        } else if(line.startsWith("TITLE")) {
            System.out.println("title:" + (h.title = line.substring("TITLE: ".length())));
        } else if(line.startsWith("DATE")) {
            System.out.println("date:" + (h.date = line.substring("DATE: ".length())));
        } else if(line.equals("STATUS: Publish")) {
            h.published = true;
        } else if (line.equals("-----")) {
            p = Part.CONTENT;
        }
    }
    case CONTENT -> {

さて1エントリ読み込めました、ってなったらOpenAIを呼び出して文章に対応するベクトルをとってきます。 リクエストはこんな感じ。モデルにtext-embedding-ada-002を指定しています。adaは一番軽く安いモデルという位置づけだけど、だいたいどんな用途でもこれで十分って書いてあったのでadaを使います。

var req = EmbeddingRequest.builder()
        .user("dummy")
        .model("text-embedding-ada-002")
        .input(List.of(text.substring(0, Math.min(text.length(), 4000)))).build();

で、呼び出すんだけど、頻繁に「The server is currently overloaded with other requests」といってコケるので、リトライを仕込んでおきます。失敗したら1分待つ。

EmbeddingResult res = null;
for (int i = 0; i < 5; ++i) {
    try {
        res = service.createEmbeddings(req);
    } catch (OpenAiHttpException ex) {
        System.out.println(ex.getMessage());
        Thread.sleep(Duration.ofMinutes(1));
        continue;
    }
    break;
}
if (res == null) {
    System.out.println("retry 5 times but could not access");
    return;
}

無事にベクトルが取れたらMongoDBにつっこみます。

BlogEntry ent = new BlogEntry(
        h.title, h.baseName, h.image, h.date, h.published,
        body.toString(), text, res.getData().get(0).getEmbedding());
coll.insertOne(ent);

あと、無課金勢は1分に20リクエストまでという制限があるので、3秒待ちます。ここでは100ミリ秒ほど余裕もたせてます。

Thread.sleep(Duration.ofSeconds(3).plusMillis(100)); // 20 request per min for the rate limit

コード全体は本文最後に貼っておくけど、gistはこれ。
https://gist.githubusercontent.com/kishida/0ac9f96cbf9f4d4f91906f74205472c8/raw/ea63107a22444764e624cf6849111d25b9193d5b/HatenaReader.java

実行して一晩寝ておくと こんな感じのデータができます。ちなみにこれはBudibaseというローコードツール。

Budibaseでデータメンテできるようにしようかと思ったのだけど、MongoDBだとページングがめんどいので、データ確認だけにしてます。Budibaseを実際使うときはPostgreSQLにしたほうがよさそう。

ところで気になるEmbedding APIの課金ですが、2500エントリを処理して$0.5でした。登録時にもらえる$18のクーポン使いきる気配がありません。

検索サーバーをつくる

まず↑のプログラムで作ったデータをMongoDBから全件とってきてます。

try (var client = MongoClients.create("mongodb://localhost:27017")) {
    var db = client.getDatabase("blog_db");
    var entColl = db.getCollection("entries", BlogEntry.class);
    var keyColl = db.getCollection("keywords", Keyword.class);
    
    List<BlogEntry> entries = StreamSupport.stream(entColl.find().spliterator(), false).toList();
    var service = new OpenAiService(System.getenv("OPENAI_TOKEN"), Duration.ZERO);

あと、今回はサーバーサイドなので、フレームワーク使うとか考えたのだけど、フレームワークようわからんのでソケットでWebサーバー作っておきます。

ServerSocket serverSoc = new ServerSocket(8989);
for (;;) {
    try (Socket s = serverSoc.accept();
         InputStream is = s.getInputStream();
         BufferedReader bur = new BufferedReader(new InputStreamReader(is));
         OutputStream os = s.getOutputStream();
         PrintWriter pw = new PrintWriter(os))
    {
        String firstLine = bur.readLine();
        String query = firstLine == null ? "" : firstLine.split(" ")[1].substring(1);
        bur.lines().takeWhile(Predicate.not(String::isEmpty)).count();
        pw.println("HTTP/1.0 200 OK");
        pw.println("Content-Type: text/html; charset=utf-8");
        pw.println();

近いエントリを見つけるのは、今回はデータが2500件程度なので、素朴に全件のベクトルの距離を計算して近い順に3件とってきてます。

private static void printRelated(PrintWriter pw, List<BlogEntry> entries, List<Double> vector) {
    record Score(double score, BlogEntry entry) {}
    TreeSet<Score> ts = new TreeSet<>(
            (s1, s2) -> Double.compare(s1.score(), s2.score()));
    for (var e : entries) {
        if (!e.published() || e.stripedBody().length() < 50) {
            continue;
        }
        double score = 0;
        for (int i = 0; i < vector.size(); ++i) {
            score += vector.get(i) * e.vector().get(i);
        }
        ts.add(new Score(-score, e));
        while (ts.size() > 4) ts.remove(ts.last());
    }
    ts.stream().skip(1).forEach(sc -> { // 最初の一件は同じエントリ
        printEntry(pw, sc.entry());
        pw.println("score: %f".formatted(sc.score()));
    });
}

といいつつ、上記のコードの距離の計算部分、こんな感じにしています。

double score = 0;
for (int i = 0; i < vector.size(); ++i) {
    score += vector.get(i) * e.vector().get(i);
}

embeddingで得られたベクトルは単位ベクトルのはずなので、距離が近いものと角度が小さいものが一致します。そうすると距離ではなくて内積が使えるので、それぞれの要素を互いに掛けたものを足せば、角度のcosがとれるという作戦。引き算が2回減るだけだから速度的にはそんなに変わらないと思うけど、コードがすっきりします。

次のようなコードを書いても、展開すると上記のscoreに2かけて2から引いた値になるはずなので、そういう方向でも上記式が求めれるはず。

score += (vector.get(i) - e.vector().get(i)) *
         (vector.get(i) * e.vector().get(i));

あと、検索ワードが指定されていたら、そこからembeddingでベクトルをとってきて比較をしています。

var q = URLDecoder.decode(query.substring(query.indexOf('?') + 3), "utf-8");

pw.println(header.formatted("", q));
var word = keyColl.find(Filters.eq("word", q)).first();
List<Double> vec;
if (word == null) {
    var req = EmbeddingRequest.builder()
            .user("dummy")
            .model("text-embedding-ada-002")
            .input(List.of(limitText(q, 4000))).build();
    EmbeddingResult res = service.createEmbeddings(req);
    vec = res.getData().get(0).getEmbedding();
    keyColl.insertOne(new Keyword(q, vec));
} else {
    vec = word.vector();
}
printRelated(pw, entries, vec);

今回は動作テストで同じキーワードを指定することが多いので、MongoDBにキャッシュするようにしています。
Azure OpenAI APIのほうだと本文用と検索キーワード用にモデルがわかれているけど、OpenAI本家ではtext-embedding-ada-002で本文と検索キーワード両方に対応できるみたいですね。

※追記 2023/3/10 Maximum Inner Product Search(MIPS)というらしい。

ソース全体はここ
https://gist.githubusercontent.com/kishida/0ac9f96cbf9f4d4f91906f74205472c8/raw/ea63107a22444764e624cf6849111d25b9193d5b/RelatedBlog.java

で、動かしたらこう。

「近いエントリを探す」をクリックすると、近いエントリが3件表示されます。ChatGPTに関するものが得れてます。

あまりに自然にそれっぽいものが取れてるので「普通の結果やん」って思ってしまうけど、ベクトル計算しただけで似たエントリが取れるというのは非常に面白いです。

あと、「ビールを飲みたい」で検索するとビールの話をしてるエントリがひっかかっています。

ということで、いい感じに検索やレコメンドができました。
やろうと思えばブログの分類などもできますね。

全体ソース

gistはこちら。
https://gist.github.com/kishida/0ac9f96cbf9f4d4f91906f74205472c8