Stability AIのチャットスクリプトを利用してRinnaのチャットモデルとお話する(追記あり)

Rinna社がチャットにも対応した日本語言語モデルをリリースしてました。
Rinnaの新しい3Bモデルを試してみる - きしだのHatena

そうするとちゃんとチャットとしてやりとりしたいですね。

ところで、Stable DiffusionのStability AIが言語モデルStableLMをリリースしています。
Stability AI 言語モデル「StableLM Suite」の第一弾をリリース - (英語Stability AI

で、チャットモデルもあって、Gradioを使ったWeb UIを公開しています。
Stablelm Tuned Alpha Chat - a Hugging Face Space by stabilityai

スクリプトはここのapp.pyです。
https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat/tree/main

ローカルで動かすとこんな感じ。賢い。VRAMは14.5GBくらい使います。

このスクリプトで読み込んでるモデルをRinnaモデルにします。

model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft"
m = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16).cuda()
tok = AutoTokenizer.from_pretrained(model_name, use_fast = False)

あと、メッセージ生成部分をRinnaモデルにあわせて変更します。

    curr_system_message = ""
    messages = curr_system_message + \
        "<NL>".join(["<NL>".join(["ユーザー: "+item[0], "システム: "+item[1]])
                for item in history])

しかし、こんな感じ。会話にならぬ・・・

返答は速いです。

※ 追記 14:54
どうやら入力が渡っていなかった模様。

tokenizerのencodeを呼び出すようにして。

    model_inputs = tok.encode(messages, return_tensors="pt", add_special_tokens=False).to("cuda")

dictでは渡さず

    generate_kwargs = dict(
#        model_inputs,

Thread呼び出しでgenerateに渡すようにすると会話できました。

    t = Thread(target=m.generate, args=[model_inputs], kwargs=generate_kwargs)

会話のコンテキストも引き継いでいますね。

コードはこちら。
https://gist.github.com/kishida/bca7a8c55ec64ee43d9caf12e4a1cf08