線形サポートベクターマシン失敗

コエド

とりあえずの目標はサポートベクターマシンてことで、次のサイト参考に組んでみました。*1
http://www.neuro.sfc.keio.ac.jp/~masato/study/SVM/index.htm


なんか、なぜかちゃんと角度的にはいい感じになってるけど、ぜんぜんマージン最大化していません。見てみたら、λは0以下にならないはずなのに、-4000とかになってるし。


で、よく考えたら、どうにか無理やり
\sum\limit_i \lambda_iy_i
の条件にあてはめてやってたけどλ≧0っていう条件を考えていないし、といろいろ試行錯誤して、どうも単純な最急降下法でやるのがまずいんじゃないかなと思ったのです。
上記のサイトでは「それは考えてみてよ.」と軽く書いてあったけど、これって、研究分野になるじゃないの?
結局、そういうのをどうにかするために、SMO法とかSVMLightとかいう分割法があって、それを使わないといけないということまでわかりました。
SMO法はサポートベクターマシン入門に載ってて & いまのところ実装方法が載ってる唯一の日本語の本なのだけど、福岡においてきてしまった*2。なので福岡に戻るまでおあづけかなぁ。もう一冊買っちゃおうかな。


ただ、上記のサイトは2003年と古くて、そのころはサポートベクターマシンに関する本はまったくでてないし、恐らくWebでも情報がなかったと思うので、しかたないと思われます。
bの求め方も上記サイトの式ではうまくいかなかったので、次のようにしてます。
b=\frac{1}{2}(w^tx_1+w^tx_{-1})


とりあえず、ちゃんと動いてないソース。
実行にはid:nowokay:20080327のGraph.javaid:nowokay:20080326のLerningMachine.javaが必要です。

import java.util.AbstractMap.SimpleEntry;
import java.util.*;

public class LinerSVM implements LearningMachine{
    public static void main(String[] args) {
        new Graph("線形SVM評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new LinerSVM();
            }
        };
    }
    
    double[] w;
    double b;
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();
    public void learn(int cls, double[] data) {
        int yi = cls == 1 ? 1 : -1;
        w = new double[data.length];
        b = 0;
        patterns.add(new SimpleEntry<Integer, double[]>(yi, data));
        if(patterns.size() < 10) return;
        
        double k = .01;//学習定数
        double[] lambda = new double[patterns.size()];
        for(int i = 0; i < lambda.length; ++i){
            lambda[i] = Math.random();
        }
        //未定乗数を求める
        for(int lp = 0; lp < 50000; ++lp){
            //正規化
            double plus = 0;
            double minus = 0;
            for(int i = 0; i < lambda.length; ++i){
                //if(lambda[i] < 0) lambda[i] = 0;
                if(patterns.get(i).getKey() < 0){
                    minus += lambda[i];
                }else{
                    plus += lambda[i];
                }
            }
            if(minus != 0 && plus != 0){
                double max = (plus + minus) / 2;//Math.max(plus, minus);
                for(int i = 0; i < lambda.length; ++i){
                    if(patterns.get(i).getKey() < 0){
                        lambda[i] *= max / minus;
                    }else{
                        lambda[i] *= max / plus;
                    }
                }
            }
            
            //最急降下法
            for(int i = 0; i < lambda.length; ++i){
                double delta = 1;
                double[] xi = patterns.get(i).getValue();
                for(int j = 0; j < lambda.length; ++j){
                    double r = 0;
                    double[] xj = patterns.get(j).getValue();
                    for(int n = 0; n < xi.length; ++n){
                        r += xj[n] * xi[n];
                    }
                    delta -= lambda[j] * patterns.get(i).getKey() * patterns.get(j).getKey() * r;
                }
                lambda[i] = lambda[i] + k * delta;//途中経過だけど更新しちゃったほうが収束が早いかな?
            }
            
            //実際には収束判定する
        }
        
        //wの値を求める
        for(int d = 0; d < w.length; ++d){
            w[d] = 0;
            for(int i = 0; i < lambda.length; ++i){
                w[d] += lambda[i] * patterns.get(i).getKey() * patterns.get(i).getValue()[d];
            }
        }
        //bを求める
        //ラムダが0ではない要素
        b = 0;
        for(int i = 0; i < lambda.length; ++i){
            if(lambda[i] > 0 && patterns.get(i).getKey() < 0){
                //b = patterns.get(i).getKey();
                for(int j = 0; j < w.length; ++j){
                    b -= w[j] * patterns.get(i).getValue()[j];
                }
                break;
            }
        }
        for(int i = 0; i < lambda.length; ++i){
            if(lambda[i] > 0 && patterns.get(i).getKey() > 0){
                //b = patterns.get(i).getKey();
                for(int j = 0; j < w.length; ++j){
                    b -= w[j] * patterns.get(i).getValue()[j];
                }
                break;
            }
        }
        b /= 2;
        for(int i = 0; i < lambda.length; ++i){
            System.out.printf("%.4f ", lambda[i]);
        }
        System.out.println();
    }

    public int trial(double[] data) {
        double s = 0;
        for(int i = 0; i < data.length; ++i){
            s += data[i] * w[i];
        }
        return (s + b) > 0 ? 1 : -1;
    }

}

*1:サポートベクターマシンについては、はてなキーワード作って書いておきました

*2:今、東京にいる