ところでサポートベクターマシンって何なの?

銀河高原ビールの店

最近、機械学習とか、そのアルゴリズムのひとつであるサポートベクターマシンとかやってるわけですが、そもそも機械学習ってなんなんでしょか?
機械学習ってのは、なんとなく与えられた点の分類から、新たに与えられた点の分類を推測するのですが、ようするに、点が与えられたときにそこから分類の領域を推測しておいて、新たな点がきたときにはどの領域に入るかを判別するのです。


ニューラルネットワークは、名前にニューロンとかついてて、とてもステキな響きがするのですが、あれは関数のあてはめを行っているのです。そうやって関数をあてはめることで、領域の境界面を求めます。
NN法は、学習とかせず、一番近いデータが同じ分類になるはずという戦略でやってます。
サポートベクターマシンも考え方としてはNN法と同じで、新しい点がやってくると、学習したそれぞれの点までの近さを計算して、一番ちかい分類を求めます。そのため、学習データが増えると、学習に時間がかかるようになります。ただ、NN法と違うのは、学習した点のうち、境界面を形作るものだけをサポートベクターとして残して、判別にはそのサポートベクターだけを使うところです。そうすると、学習結果はサポートベクターとそれに対応する係数だけを覚えておけばいいことになります。そのため、実際の判別では、限られた数のサポートベクターとの比較を行えばよくなり、メモリも食わず時間もかからないということになります。


どんな点がサポートベクターになるかを見てみると、こんな感じになりました。大きい点がサポートベクターです。いい感じに、境界付近の点だけがサポートベクターになっています。サポートベクターではない点は、データ判別の段階では捨ててしまってもかまいません。


ところで、SVMではカーネルを使って非線形分離を行うのですが、そこで大切なのがカーネル固有のパラメタです。今回はガウシアンカーネルを使うのですが、ここでもひとつパラメータを与える必要があります。Math.expの引数の分母になってる数ですね。σの二乗を指定することになってます。
このパラメータがなにかを見てみるために、いろいろいじってみました。上の画像のときには1500を指定してるので、σはだいたい40程度ということになります。
これをσ^2=100、つまりσ=10とするとこんな感じになります。


境界面がガタガタになっています。ほとんどの点がサポートベクターになっています。この境界面は、学習データに適合しすぎて、学習データが変われば判別面は大きく変わるし、ちょっとデータがずれると判別結果が変わることになってしまいます。
機械学習で大切なことは、既知の学習データから未知のデータを分類することですが、これでは既知の学習データはうまく分類できるけど、未知のデータがうまく分類できるかどうかは学習データに強く依存することになってしまします。未知のデータをうまく分類できるかどうかのことを汎化能力といいます。
σ=5くらいにしてMath.exp(-n/30)とすると、もっとひどいことに。


学習データはうまく判別できてるのですが、おそらく未知データはまったく判別できない境界面になっています。数個の点を除いて、すべてがサポートベクターになっているという感じです。そうすると、近い学習点があるかどうかで判別をすることになってしまいます。
で、わかったことは、ガウシアンカーネルのσは、学習データの影響範囲で、判別面はその影響範囲をつなぎ合わせたものになるということです。今回のサンプルでは、σは学習データの影響するドット数と言えます。
ガウシアンカーネルのパラメータσは少なすぎるよりは大きすぎる方がよさそうです。


サポートベクターマシンの本はこの本が参考になります。

サポートベクターマシン入門

サポートベクターマシン入門


ということで、今回のプログラム。単独で動きます。

import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import java.util.*;
import javax.swing.JComponent;

public class SMOFrame extends javax.swing.JFrame {

    /** Creates new form SMOFrame */
    public SMOFrame() {
        initComponents();
        
        Random r = new Random();
        for(int i = 0; i < 350; ++i){
            int x = r.nextInt(400);
            int y = r.nextInt(300);
            patterns.add(new AbstractMap.SimpleEntry(
                    ((x * 4 / 400) % 2 + (y * 3 / 300) % 2 + 1) % 2 * 2 - 1, new Point(x, y)));
        }
        
        paint();
        pnlCanvas.add(canvas);
        canvas.addMouseListener(new MouseAdapter() {
            @Override
            public void mousePressed(MouseEvent e) {
                patterns.add(new AbstractMap.SimpleEntry<Integer, Point>(
                        rbBlue.isSelected() ? 1 : -1,
                        e.getPoint()));
                //SMOFrame.this.paint();
                canvas.repaint();
            }
        });
    }
    BufferedImage img = new BufferedImage(400, 300, BufferedImage.TYPE_INT_RGB);
    JComponent canvas = new JComponent(){
        @Override
        public void paintComponent(Graphics g){
            g.drawImage(img, 0, 0, this);
            for(int i = 0; i < patterns.size(); ++i){
                    Map.Entry<Integer, Point> e = patterns.get(i);
                if(e.getKey() < 0){
                    g.setColor(Color.RED);
                }else{
                    g.setColor(Color.BLUE);
                }
                    int r = (w != null && w.length > i && w[i] != 0) ? 8 : 4;
                g.fillOval(e.getValue().x - r / 2, e.getValue().y - r / 2, r, r);
            }
        }
    };
    boolean learned = false;
    
    void paint(){
        Graphics g = img.getGraphics();
        if(!learned){
            g.setColor(Color.WHITE);
            g.fillRect(0, 0, img.getWidth(), img.getHeight());
        }else{
            for(int x = 0; x < img.getWidth(); ++x){
                for(int y = 0; y < img.getHeight(); ++y){
                    int cls = trial(new Point(x, y));
                    g.setColor(cls == 1 ? new Color(192, 192, 255) : new Color(255, 192, 192));
                    g.fillRect(x, y, 1, 1);
                }
            }
        }
    }
    
    @SuppressWarnings("unchecked")
    private void initComponents() {

        buttonGroup1 = new javax.swing.ButtonGroup();
        pnlCanvas = new javax.swing.JPanel();
        jPanel1 = new javax.swing.JPanel();
        rbRed = new javax.swing.JRadioButton();
        rbBlue = new javax.swing.JRadioButton();
        btnClear = new javax.swing.JButton();
        btnLearn = new javax.swing.JButton();

        setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE);

        pnlCanvas.setMinimumSize(new java.awt.Dimension(400, 300));
        pnlCanvas.setPreferredSize(new java.awt.Dimension(400, 300));
        pnlCanvas.setLayout(new java.awt.BorderLayout());
        getContentPane().add(pnlCanvas, java.awt.BorderLayout.CENTER);

        buttonGroup1.add(rbRed);
        rbRed.setSelected(true);
        rbRed.setText("赤");
        jPanel1.add(rbRed);

        buttonGroup1.add(rbBlue);
        rbBlue.setText("青");
        jPanel1.add(rbBlue);

        btnClear.setText("クリア");
        btnClear.addActionListener(new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                btnClearActionPerformed(evt);
            }
        });
        jPanel1.add(btnClear);

        btnLearn.setText("学習");
        btnLearn.addActionListener(new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                btnLearnActionPerformed(evt);
            }
        });
        jPanel1.add(btnLearn);

        getContentPane().add(jPanel1, java.awt.BorderLayout.PAGE_START);

        pack();
    }

    
private void btnClearActionPerformed(java.awt.event.ActionEvent evt) {
    patterns.clear();
    learned = false;
    paint();
    canvas.repaint();
}

private void btnLearnActionPerformed(java.awt.event.ActionEvent evt) {
    if(patterns.size() == 0) return;
    learn();
    learned = true;
    paint();
    canvas.repaint();    
}

    double kernel(Point x1, Point x2){
        //ガウシアンカーネル
        double n = (x1.x - x2.x) * (x1.x - x2.x) + (x1.y - x2.y) * (x1.y - x2.y);
        return Math.exp(-n / 1500);//1500は分散の2乗
    }
    
    double[] w;//係数
    double b;//バイアス
    final double c = 10;//許容範囲?無限大にするとハードマージンになるはずだけど
    final double tol = 0.7;//KKT条件の許容範囲(1 - ε)
    double[] lambda;
    double z = 0;
    List<Map.Entry<Integer, Point>> patterns = 
            new ArrayList<Map.Entry<Integer, Point>>();
    public void learn() {
        w = new double[patterns.size()];
        b = 0;
        
        lambda = new double[patterns.size()];
        for(int i = 0; i < lambda.length; ++i){
            lambda[i] = 0;
        }
        
        //未定乗数を求める
        boolean alldata = true;//すべてのデータを処理する場合
        boolean changed = false;//変更があった
        eCache = new double[patterns.size()];
        for(int lp = 0; lp < 500000 && (alldata || changed); ++lp)  {
            changed = false;
            z = 0;
            boolean lastchange = true;
            PROC_LOOP:
            for(int j = 0; j < patterns.size(); ++j){
                //基準点2を選ぶ
                double alpha2 = lambda[j];
                if(!alldata && (alpha2 <= 0 || alpha2 >= c)){// 0 < α < C の点だけ処理する
                    continue;
                }
                if(lastchange){
                    //初回やデータがかわったときキャッシュのクリア
                    for(int i = 0; i < eCache.length; ++i) eCache[i] = Double.NaN;
                }
                lastchange = false;
                
                int t2 = patterns.get(j).getKey();
                double fx2 = calcE(j);
                
                //KKT条件の判定
                double r2 = fx2 * t2;
                if(!((alpha2 < c && r2 < -tol) || (alpha2 > 0 && r2 > tol))){//KKT条件をみたすなら処理しない
                    continue;
                }
                //基準点1を選ぶ
                //選択法1
                int i = 0;
                int offset = (int)(Math.random() * patterns.size());
                
                double max = -1;
                for(int ll = 0; ll < patterns.size(); ++ll){//全データにつき
                    int l = (ll + offset) % patterns.size();
                    //0 < α < C
                    if(0 >= lambda[l] || c <= lambda[l]) continue;
                    double dif = Math.abs(calcE(l) - fx2);
                    if(dif > max){
                        max = dif;
                        i = l;
                    }
                }
                if(max >= 0){
                    if(step(i, j)){
                        //処理をしたら次へ
                        changed = true;
                        lastchange = true;
                        continue PROC_LOOP;
                    }
                }
                //選択法2
                offset = (int)(Math.random() * patterns.size());//ランダムな位置から
                for(int l = 0; l < patterns.size(); ++l){
                    //0 < α < C
                    i = (l + offset) % patterns.size();
                    if(0 >= lambda[i] || c <= lambda[i]) continue;
                    if(step(i, j)){
                        //処理をしたら次へ
                        changed = true;
                        lastchange = true;
                        continue PROC_LOOP;
                    }
                }
                //選択法3
                offset = (int)(Math.random() * patterns.size());//ランダムな位置から
                for(int l = 0; l < patterns.size(); ++l){
                    i = (l + offset) % patterns.size();
                    if(step(i, j)){
                        //処理をしたら次へ
                        changed = true;
                        lastchange = true;
                        continue PROC_LOOP;
                    }
                }
            }
            
            ////すべてのデータを処理しても処理するものがなければ終了になる
            if(z < 0.01) changed = false;
            if(alldata){
                alldata = false;
            }else{
                if(changed) alldata = true;
            }
        }

        //wの値を求める
        for(int i = 0; i < w.length; ++i){
            w[i] = lambda[i] * patterns.get(i).getKey();
        }
        //bを求める
        b = 0;
        int count = 0;
        for(int i = 0; i < lambda.length; ++i){
            if(Math.abs(w[i]) <= 0.05) continue;
            for(int l = 0; l < w.length; ++l){
                b -= w[l] * kernel(
                        patterns.get(i).getValue(), patterns.get(l).getValue());
            }
            ++count;
        }
        b /= count;
    }

    public int trial(Point data) {
        double s = b;
        for(int i = 0; i < w.length; ++i){
            Map.Entry<Integer, Point> p = patterns.get(i);
            s += w[i] * kernel(data, p.getValue());
        }
        return s > 0 ? 1 : -1;
    }
    
    private double[] eCache;
    private double calcE(int i){
        if(!Double.isNaN(eCache[i])) return eCache[i];
        double e = b - patterns.get(i).getKey();
        for(int j = 0; j < lambda.length; ++j){
            e += lambda[j] * patterns.get(j).getKey() * 
                    kernel(patterns.get(j).getValue(), patterns.get(i).getValue());
        }        
        eCache[i] = e;
        return e;
    }
    
    /** 実際の計算処理 */
    private boolean step(int i, int j) {
        if(i == j) return false;
        double fx2 = calcE(j);
        
        int t1 = patterns.get(i).getKey();
        int t2 = patterns.get(j).getKey();

        double fx1 = calcE(i);
        
        //基準点2を計算
        double k11 = kernel(patterns.get(i).getValue(), patterns.get(i).getValue());
        double k22 = kernel(patterns.get(j).getValue(), patterns.get(j).getValue());
        double k12 = kernel(patterns.get(i).getValue(), patterns.get(j).getValue());
        double eta = k11 + k22 - 2 * k12;
        if(eta <= 0) return false;
        double newwj = lambda[j] + t2 * (fx1 - fx2) / eta;
        //クリッピング
        double u;
        double v;
        if(t1 == t2){
            u = Math.max(0, lambda[j] + lambda[i] - c);
            v = Math.min(c, lambda[j] + lambda[i]);
        }else{
            u = Math.max(0, lambda[j] - lambda[i]);
            v = Math.min(c, c + lambda[j] - lambda[i]);
        }
        if(u == v) return false;
        newwj = Math.max(u, newwj);
        newwj = Math.min(v, newwj);

        //基準点2から基準点1を計算
        z += Math.abs(lambda[j] - newwj);
        lambda[i] += t1 * t2 * (lambda[j] - newwj);
        lambda[j] = newwj;
        return true;
    }

    /**
    * @param args the command line arguments
    */
    public static void main(String args[]) {
        java.awt.EventQueue.invokeLater(new Runnable() {
            public void run() {
                new SMOFrame().setVisible(true);
            }
        });
    }

    // Variables declaration - do not modify
    private javax.swing.JButton btnClear;
    private javax.swing.JButton btnLearn;
    private javax.swing.ButtonGroup buttonGroup1;
    private javax.swing.JPanel jPanel1;
    private javax.swing.JPanel pnlCanvas;
    private javax.swing.JRadioButton rbBlue;
    private javax.swing.JRadioButton rbRed;
    // End of variables declaration
}