パーセプトロンで非線形分離するには

パーセプトロンでは線形分離できない場合に対応できないということでしたが、これを可能にする方法を[id:nowokay:20080318]でしましまさんに教えてもらっていました。
実装してみるとこんな感じで、非線形分離できるようになりました。


x^2、xy、y^2を新たにデータとして追加してるわけですが、要するにこれは2次元データを5次元データに拡張してることになります。2次元じゃ線形分離は無理だけど5次元だと線形分離できますよ、という話。

//NonlinearPerceptron.java
import java.util.*;

public class NonlinearPerceptron implements LearningMachine{
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();
    double b;
    double[] p;
    int dim;
    public NonlinearPerceptron (int dim) {
        this.dim = dim;
    }

    public static void main(String[] args) {
        new Graph("非線形パーセプトロン評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new Perceptron(2);
            }
        };
    }

    public void learn(int cls, double[] data) {
        int yi = cls == 1 ? 1 : -1;

        patterns.add(new AbstractMap.SimpleEntry(yi, kernel(data)));
        
        final double k = .01;
        b = 0;
        p = new double[dim + 3];
        
        for(int j = 0; j < 1000; ++j){
            //学習を繰り返す
            boolean fin = true;
            for(Map.Entry<Integer, double[]> entry : patterns){
                double[] pattern = entry.getValue();
                int pcls = entry.getKey();
                double in = 0;
                for(int i = 0; i < pattern.length; ++i){
                    in += pattern[i] * p[i];
                }
                in += b;
                if(in * pcls <= 0){
                    //誤判定で再学習
                    fin = false;
                    for(int i = 0; i < p.length; ++i){
                        p[i] += pattern[i] * k * pcls;
                    }
                    b += k * pcls;
                }
            }
            if(fin) break;
        }
        
    }
    public int trial(double[] data) {
        double in = 0;
        double[] d = kernel(data);
        for(int i = 0; i < d.length; ++i){
            in += d[i] * p[i];
        }
        in += b;
        return (in > 0) ? 1 : -1;
    }

    private double[] kernel(double[] data){
        double[] d = new double[data.length + 3];
        for(int i = 0; i < data.length; ++i){
            d[i] = data[i];
        }       
        d[d.length - 3] = d[1] * d[1];
        d[d.length - 2] = d[0] * d[0];
        d[d.length - 1] = d[1] * d[0];
        return d;
    }
}