ナイーブベイズ分類器であいさつbot作ってみた

スパムフィルタでよく使われてる、ベイジアンフィルタのことです。これを使って「おはよう」判定してみました。
1000speakersのときに「ベイズやらないんですか?」って言われたり、tockriの人が以前「やっぱベイズよくできとるわ」とか書いてたり、テキスト分類ならベイズでいいん違うかと思ってみたので、やってみました。


結果から言うと、なんか不安定というか、1文字の違いで「おはよう」判定してしまうことがあります。ちゃんと性能測ってみないとわかんないですけど。
考えられるのは、データが少ないことですが(「おはよう」が750件中39件)、それよりも、ついったの発言では「おはよう」発言の中に「おはよう」と関係ない言葉が入ることが多くて、それが悪影響あたえてるんじゃないかと思います。ナイーブベイズアルゴリズムを考えると、データ中のノイズに弱い気がします。
「。」や「!」で文章を区切って学習・判定するといいかもしれません。これだとSVMの場合でも性能でそう。


比べてみて思ったのは、ベイズは実装が簡単だし大量データに強いということです。学習はデータ数をカウントするだけだし、判定は割り算して全部掛けるだけ。
決定論的に計算するので、ゆらぎがないし、ノイズがなければ強そう。
一方で、やっぱSVMは大量データには学習時間かかりまくるのだけど、データが少なくてもうまいこと特徴を抜き出してくれるな、と思いました。


ということで、ナイーブベイズ版のソース
こんな感じで「おはよう」の先頭に「m,」をつけて区別したデータを使います。このファイルをtwit.txtという名前でプログラムと同じフォルダに保存しておきます。

バスで駅に向かっている。
ただいま。
m,おはよ-
早く来い飯。
とりあえずシャワ-浴びるか。
だいじなメ-ルおくった。
m,うぐう、起きた。10時半から会社。
m,おはようございます。本を読んでたら寝てしまってました。


あとは、ソースはこんな感じ。GUIまわりは、NetBeansで生成したソースからコメントを抜いただけです。

import java.io.*;
import java.util.*;

public class BayesianMorningBot extends javax.swing.JFrame {

    public BayesianMorningBot() {
        initComponents();
        try{
            InputStream is = BayesianMorningBot.class.getResourceAsStream(
                    "twit.txt");
            InputStreamReader fr = new InputStreamReader(is);
            BufferedReader bur = new BufferedReader(fr);
            for(String line; (line = bur.readLine()) != null; ){
                int cls = -1;
                if(line.startsWith("m,")){
                    line = line.substring(2);
                    cls = 1;
                }
                learn(line, cls == 1);
            }
            bur.close();
            fr.close();
            is.close();
        }catch(IOException e){
        }
        System.out.printf("%d - %d%n", morningcount, totalcount - morningcount);
    }

    int morningcount = 0;
    int totalcount = 0;
    Map<String, int[]> bicount = new HashMap<String, int[]>();
    
    String normal(String str){
        str = "。" + str.trim();
        if(!str.endsWith("。")){
            str = str + "。";
        }
        return str;
    }
    
    void learn(String str, boolean morning){
        str = normal(str);
        
        Set<String> appear = new HashSet<String>();
        int length = str.length() - 1;
        for(int i = 0; i < length; ++i){
            String bi = str.substring(i, i + 2);
            if(appear.contains(bi)) continue;
            appear.add(bi);
            if(!bicount.containsKey(bi)){
                bicount.put(bi, new int[2]);
            }
            bicount.get(bi)[morning ? 0 : 1]++;
        }
        if(morning) morningcount++;
        totalcount++;
    }
    
    double bias = 0.2;
    public int trial(String d) {
        String str = normal(d);
        
        double morningProb = 1;
        double normalProb = 1;
        Set<String> appear = new HashSet<String>();
        int length = str.length() - 1;
        for(int i = 0; i < length; ++i){
            String bi = str.substring(i, i + 2);
            if(appear.contains(bi)) continue;
            appear.add(bi);
            double p1 = 0;
            double p2 = 0;
            if(bicount.containsKey(bi)){
                p1 = bicount.get(bi)[0];
                p2 = bicount.get(bi)[1];
            }
            p1 = (p1 + 0.5 * bias) / (morningcount + bias);
            p2 = (p2 + 0.5 * bias) / ((totalcount - morningcount) + bias);
            morningProb *= p1;
            normalProb *= p2;
        }        
        morningProb *= morningcount;
        normalProb *= (totalcount - morningcount);
        return morningProb > normalProb ? 1 : -1;
    }
    
    // <editor-fold defaultstate="collapsed" desc="Generated Code">
    private void initComponents() {

        javax.swing.JScrollPane jScrollPane1 = new javax.swing.JScrollPane();
        taOutput = new javax.swing.JTextArea();
        javax.swing.JPanel panel = new javax.swing.JPanel();
        txtMessage = new javax.swing.JTextField();
        btnSpeak = new javax.swing.JButton();

        setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE);

        taOutput.setRows(10);
        jScrollPane1.setViewportView(taOutput);

        getContentPane().add(jScrollPane1, java.awt.BorderLayout.CENTER);

        txtMessage.setColumns(15);
        panel.add(txtMessage);

        btnSpeak.setText("発言");
        btnSpeak.addActionListener(new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                btnSpeakActionPerformed(evt);
            }
        });
        panel.add(btnSpeak);

        getContentPane().add(panel, java.awt.BorderLayout.PAGE_START);

        pack();
    }// </editor-fold>

    private void btnSpeakActionPerformed(java.awt.event.ActionEvent evt){
        taOutput.append(txtMessage.getText() + "\n");
        int cl = trial(txtMessage.getText());
        taOutput.append("> " + (cl > 0 ? "おはよ〜" : "ふむふむ" )+ "\n");
    }

    public static void main(String args[]) {
        java.awt.EventQueue.invokeLater(new Runnable() {
            public void run() {
                new BayesianMorningBot().setVisible(true);
            }
        });
    }

    private javax.swing.JButton btnSpeak;
    private javax.swing.JTextArea taOutput;
    private javax.swing.JTextField txtMessage;

}