基本のNearestNeighbors法(NN法)でパターン認識

とりあえず、パターン認識の一番の基本になるNearestNeighbors法(NN法)でやってみます。
NN法は、判定するデータが、学習に使ったデータのうちの一番近いものに分類する方法です。
自分の字で学習させて自分の字を判定させるなら、これでいいかも、とか思う。
前処理として画像処理してやると、結構いい感じになりそうです。けど、今回は画像処理はあとまわしです。


とりあえず、学習器のインタフェースを用意します。learnメソッドで学習させて、trialメソッドで判定

//LearningMachine.java
public interface LearningMachine {
    //学習
    void learn(int cls, double[] data);
    //評価
    int trial(double[] data);
}


で、NN法を実装させるとこんな感じ。

//NearestNeighbor.java
import java.util.*;

class NearestNeighbor implements LearningMachine {
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();

    public static void main(String[] args) {
        new MachineLearning("NN法", new NearestNeighbor());
    }

    public void learn(int cls, double[] data) {
        patterns.add(new AbstractMap.SimpleEntry(cls, data));
    }

    public int trial(double[] data) {
        int cls = 0;
        //一番近いパターンを求める
        double mindist = Double.POSITIVE_INFINITY;
        for (Map.Entry<Integer, double[]> entry : patterns) {
            double[] ss = entry.getValue();
            if (ss.length != data.length) {
                System.out.println("へんなデータ");
                continue;
            }
            //データ間の距離を求める
            double dist = 0;
            for (int i = 0; i < ss.length; ++i) {
                dist += (ss[i] - data[i]) * (ss[i] - data[i]);
            }
            if (mindist > dist) {
                mindist = dist;
                cls = entry.getKey();
            }
        }
        return cls;
    }
}


で、画面系

//MachineLearning.java
import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import javax.swing.*;

public class MachineLearning extends JComponent 
        implements ActionListener, MouseMotionListener{
    Image img = new BufferedImage(300, 300, BufferedImage.TYPE_INT_RGB);
    Graphics bg = img.getGraphics();
    Point pt;
    LearningMachine learningMachine;
    int mesh = 10;

    public MachineLearning(String title, LearningMachine learningMachine) {
        this.learningMachine = learningMachine;
       JFrame f = new JFrame(title);
       f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
       f.setSize(320, 420);
       clear();
       f.add(this);
       
       JPanel p = new JPanel();
       p.setLayout(new GridLayout(2, 5));
       for(int i = 0; i < 10; ++i){
           JButton b = new JButton(i +"");
           p.add(b);
           b.addActionListener(this);
       }
       addMouseMotionListener(this);
       
       JButton b = new JButton("判定");
       b.addActionListener(this);
       f.add(b, BorderLayout.NORTH);
       
       f.add(p, BorderLayout.SOUTH);
       f.setVisible(true);
    }
    
    /** ボタンが押されたときの処理 */
    public void actionPerformed(ActionEvent e) {
        BufferedImage bi = new BufferedImage(mesh, mesh, BufferedImage.TYPE_INT_RGB);
        Graphics2D g2 = (Graphics2D) bi.getGraphics();
        g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, 
                RenderingHints.VALUE_INTERPOLATION_BICUBIC);
        g2.drawImage(img, 0, 0, mesh, mesh, this);
        double[] data = new double[mesh * mesh];
        for(int x = 0; x < mesh; ++x){
            for(int y = 0; y < mesh; ++y){
                data[y * mesh + x] = (255 - bi.getRGB(x, y) & 255) / 255.;
            }
        }
        
        if("判定".equals(e.getActionCommand())){
            //判定
            int ans = learningMachine.trial(data);
            System.out.println("答えは" + ans);
        }else{
            //学習処理
            int idx = Integer.parseInt(e.getActionCommand());
            learningMachine.learn(idx, data);
        }
        clear();
    }

    /** マウスで描画 */
    public void mouseDragged(MouseEvent e) {
        Point old = pt;
        pt = e.getPoint();
        if(old != null){
            bg.setColor(Color.BLACK);
            ((Graphics2D)bg).setStroke(new BasicStroke(25));
            bg.drawLine(old.x, old.y, pt.x, pt.y);
            repaint();
        }
    }

    public void mouseMoved(MouseEvent e) {
        pt = null;
    }
    
    //描画
    @Override
    protected void paintComponent(Graphics g) {
        g.drawImage(img, 0, 0, this);
    }
    
    private void clear(){
        bg.setColor(Color.WHITE);
        bg.fillRect(0, 0, 300, 300);
        bg.setColor(Color.LIGHT_GRAY);
        ((Graphics2D)bg).setStroke(new BasicStroke(1));
        bg.drawRect(20, 20, 260, 260);
        bg.drawLine(0, 150, 300, 150);
        bg.drawLine(150, 0, 150, 300);
        repaint();
    }
}