3-NN法

NN法だと識別の境界ががたがたすぎるので、3-NN法っていうのを使ってみます。
3-NN法は、近いほうから3つの学習パターンをとってきて、そのうちの多数決で識別する方法です。そういう意味では、NN法ってのは1-NN法になります。
そうすると、識別の境界がちょっと滑らかになります。


けども、この結果をみると、学習パターンでも誤判別されているものがあります。これは、今回単純にデータの多数決を取ってるためで、データの距離によって重みをつけたりして改善させる必要があります。

import java.util.*;

public class NN3 implements LearningMachine{

    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();

    public static void main(String[] args) {
        new Graph("3-NN法評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new NN3();
            }
        };
    }

    public void learn(int cls, double[] data) {
        patterns.add(new AbstractMap.SimpleEntry(cls, data));
    }
    public int trial(double[] data) {
        //パターンを近い順に求める
        Map<Double, Map.Entry<Integer, double[]>> sorting =
                new TreeMap<Double, Map.Entry<Integer, double[]>>();
        for (Map.Entry<Integer, double[]> entry : patterns) {
            double[] ss = entry.getValue();
            if (ss.length != data.length) {
                System.out.println("へんなデータ");
                continue;
            }
            //データ間の距離を求める
            double dist = 0;
            for (int i = 0; i < ss.length; ++i) {
                dist += (ss[i] - data[i]) * (ss[i] - data[i]);
            }
            sorting.put(dist, entry);
        }
        //近い順から3つ
        List<Integer> clss = new ArrayList<Integer>();
        for(Map.Entry<Integer, double[]> entry : sorting.values()){
            clss.add(entry.getKey());
            if(clss.size() == 3) break;
        }
        //多数決
        int cls = clss.get(0);
        if(clss.size() < 3) return cls;
        if(clss.get(1) == clss.get(2)){
            return clss.get(1);
        }
        return cls;
    }
}