非線形サポートベクターマシン

インドの青鬼

とりあえず最適化の問題は置いておいて、ここを参考に非線形分離できるようにしてみました。
http://www.neuro.sfc.keio.ac.jp/~masato/study/SVM/SVM_3_1.htm

さぁ,これで君も非線形SVMのコーディングができちゃうのだ.素晴らしき哉.

ほんとにできた。
うん、ぼくにも非線形SVMのコーディングができちゃいましたよ!
すばらしきかな


SVMも基本は線形分離なので、非線形分離に対応するにはパーセプトロンでやったようにデータの次元を増やしてそこで線形分離します。
で、SVMがすごいのはそこでの計算をごにょごにょして、データの次元を実際には増やさずに高次元で計算したことにしてしまうのです。
SVMでは全データ同士の内積の計算をしていたのですが、その代わりにカーネル関数と呼ばれる関数を使います。カーネル関数には、高次元で内積を計算したことになるようなものを選びます。そうすると、データを高次元に飛ばすことなく、高次元に持っていったのと同じ結果がえられるのです。これをカーネルトリックといいます。
カーネルトリック


まず、多項式カーネルというカーネル関数を使ってやってみました。こんなカーネルです。
k(x_1, x_2)=(1+x_1^tx_2)^p


次にガウシアンカーネル
k(x_1, x_2)=exp(\frac{-||x_1-x_2||^2}{2\sigma^2})


なんか、もう、まさにサポートベクターマシンという分離面ができました。
2σ^2=1.2だとこんな感じに。

2σ^2=2だとこう。


ソースはこれで。
実行にはid:nowokay:20080327のGraph.javaid:nowokay:20080326のLerningMachine.javaが必要です。

//Nonlinear.java
import java.util.AbstractMap.SimpleEntry;
import java.util.*;

public class NonlinearSVM implements LearningMachine{
    public static void main(String[] args) {
        new Graph("非線形SVM評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new NonlinearSVM();
            }
        };
    }

    double kernel(double[] x1, double[] x2){
        /*
        //多項式カーネル
        double k = 1;
        for(int i = 0; i < x1.length; ++i){
            k += x1[i] * x2[i];
        }
        double ret = k * k;
         */

        //ガウシアンカーネル
        double n = 0;
        for (int i = 0; i < x1.length; i++) {
            n += (x1[i] - x2[i]) * (x1[i] - x2[i]);
        }
        return Math.exp(-n / 2);//2は 2σ^2
    }
    
    double[] w;
    double b;
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();
    public void learn(int cls, double[] data) {
        int yi = cls == 1 ? 1 : -1;
        patterns.add(new SimpleEntry<Integer, double[]>(yi, data));
        if(patterns.size() < 10) return;//途中をとばす。
        
        w = new double[patterns.size()];
        double k = .01;//学習係数
        double[] lambda = new double[patterns.size()];
        for(int i = 0; i < lambda.length; ++i){
            lambda[i] = Math.random();
        }
        //未定乗数を求める
        double[][] cache = new double[patterns.size()][patterns.size()];
        for(int lp = 0; lp < 5000000; ++lp){
            //正規化
            double plus = 0;
            double minus = 0;
            for(int i = 0; i < lambda.length; ++i){
                if(lambda[i] < 0) continue;
                if(patterns.get(i).getKey() < 0){
                    minus += lambda[i];
                }else{
                    plus += lambda[i];
                }
            }
            if(minus != 0 && plus != 0){
                double max = (plus + minus) / 2;//Math.max(plus, minus);
                for(int i = 0; i < lambda.length; ++i){
                    if(patterns.get(i).getKey() < 0){
                        lambda[i] *= max / minus;
                    }else{
                        lambda[i] *= max / plus;
                    }
                }
            }
            
            //最急降下法
            for(int i = 0; i < lambda.length; ++i){
                double delta = 1;
                double[] xi = patterns.get(i).getValue();
                for(int j = 0; j < lambda.length; ++j){
                    double[] xj = patterns.get(j).getValue();
                    double r = kernel(xi, xj);
                    
                    delta -= lambda[j] * patterns.get(i).getKey() * patterns.get(j).getKey() * r;
                }
                lambda[i] = lambda[i] + k * delta;//途中経過だけど更新しちゃったほうが収束が早いかな?
            }
            
            //実際には収束判定する
        }
        
        //wの値を求める
        for(int i = 0; i < w.length; ++i){
            w[i] = lambda[i] * patterns.get(i).getKey();
        }
        //bを求める
        //ラムダが0ではない要素
        b = 0;
        for(int i = 0; i < lambda.length; ++i){
            if(lambda[i] > 0){
                b = patterns.get(i).getKey();
                for(int j = 0; j < w.length; ++j){
                    b -= w[j] * cache[i][j];
                }
                break;
            }
        }

        for(int i = 0; i < lambda.length; ++i){
            System.out.printf("%.4f ", lambda[i]);
        }
        System.out.println();
    }

    public int trial(double[] data) {
        double s = 0;
        for(int i = 0; i < w.length; ++i){
            Map.Entry<Integer, double[]> p = patterns.get(i);
            s += w[i] * kernel(data, p.getValue());
        }
        return (s + b) > 0 ? 1 : -1;
    }
}