Rinna社がチャットにも対応した日本語言語モデルをリリースしてました。
Rinnaの新しい3Bモデルを試してみる - きしだのHatena
そうするとちゃんとチャットとしてやりとりしたいですね。
ところで、Stable DiffusionのStability AIが言語モデルStableLMをリリースしています。
Stability AI 言語モデル「StableLM Suite」の第一弾をリリース - (英語Stability AI
で、チャットモデルもあって、Gradioを使ったWeb UIを公開しています。
Stablelm Tuned Alpha Chat - a Hugging Face Space by stabilityai
スクリプトはここのapp.pyです。
https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat/tree/main
ローカルで動かすとこんな感じ。賢い。VRAMは14.5GBくらい使います。
Stability AIのチャット、割と賢いな。そして反応も速い。ただし、チャットを6-7往復くらい続けると16GB VRAMではOut Of Memoryになる。 pic.twitter.com/2RtRLZOQsd
— きしだൠ(K1S) (@kis) 2023年5月22日
このスクリプトで読み込んでるモデルをRinnaモデルにします。
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft" m = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16).cuda() tok = AutoTokenizer.from_pretrained(model_name, use_fast = False)
あと、メッセージ生成部分をRinnaモデルにあわせて変更します。
curr_system_message = "" messages = curr_system_message + \ "<NL>".join(["<NL>".join(["ユーザー: "+item[0], "システム: "+item[1]]) for item in history])
しかし、こんな感じ。会話にならぬ・・・
返答は速いです。
Stability AIのチャットを改造してRinnaモデルとチャットできるようにしたけど、会話にならぬ・・・ pic.twitter.com/SryTENSVWm
— きしだൠ(K1S) (@kis) 2023年5月22日
※ 追記 14:54
どうやら入力が渡っていなかった模様。
tokenizerのencodeを呼び出すようにして。
model_inputs = tok.encode(messages, return_tensors="pt", add_special_tokens=False).to("cuda")
dictでは渡さず
generate_kwargs = dict( # model_inputs,
Thread呼び出しでgenerateに渡すようにすると会話できました。
t = Thread(target=m.generate, args=[model_inputs], kwargs=generate_kwargs)
会話のコンテキストも引き継いでいますね。
Rinnaチャットとお話できるようになった!
— きしだൠ(K1S) (@kis) 2023年5月22日
割とちゃんとしている! pic.twitter.com/yRv8tTyXqX
コードはこちら。
https://gist.github.com/kishida/bca7a8c55ec64ee43d9caf12e4a1cf08