とりあえず、パターン認識の一番の基本になるNearestNeighbors法(NN法)でやってみます。
NN法は、判定するデータが、学習に使ったデータのうちの一番近いものに分類する方法です。
自分の字で学習させて自分の字を判定させるなら、これでいいかも、とか思う。
前処理として画像処理してやると、結構いい感じになりそうです。けど、今回は画像処理はあとまわしです。
とりあえず、学習器のインタフェースを用意します。learnメソッドで学習させて、trialメソッドで判定
//LearningMachine.java public interface LearningMachine { //学習 void learn(int cls, double[] data); //評価 int trial(double[] data); }
で、NN法を実装させるとこんな感じ。
//NearestNeighbor.java import java.util.*; class NearestNeighbor implements LearningMachine { List<Map.Entry<Integer, double[]>> patterns = new ArrayList<Map.Entry<Integer, double[]>>(); public static void main(String[] args) { new MachineLearning("NN法", new NearestNeighbor()); } public void learn(int cls, double[] data) { patterns.add(new AbstractMap.SimpleEntry(cls, data)); } public int trial(double[] data) { int cls = 0; //一番近いパターンを求める double mindist = Double.POSITIVE_INFINITY; for (Map.Entry<Integer, double[]> entry : patterns) { double[] ss = entry.getValue(); if (ss.length != data.length) { System.out.println("へんなデータ"); continue; } //データ間の距離を求める double dist = 0; for (int i = 0; i < ss.length; ++i) { dist += (ss[i] - data[i]) * (ss[i] - data[i]); } if (mindist > dist) { mindist = dist; cls = entry.getKey(); } } return cls; } }
で、画面系
//MachineLearning.java import java.awt.*; import java.awt.event.*; import java.awt.image.BufferedImage; import javax.swing.*; public class MachineLearning extends JComponent implements ActionListener, MouseMotionListener{ Image img = new BufferedImage(300, 300, BufferedImage.TYPE_INT_RGB); Graphics bg = img.getGraphics(); Point pt; LearningMachine learningMachine; int mesh = 10; public MachineLearning(String title, LearningMachine learningMachine) { this.learningMachine = learningMachine; JFrame f = new JFrame(title); f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); f.setSize(320, 420); clear(); f.add(this); JPanel p = new JPanel(); p.setLayout(new GridLayout(2, 5)); for(int i = 0; i < 10; ++i){ JButton b = new JButton(i +""); p.add(b); b.addActionListener(this); } addMouseMotionListener(this); JButton b = new JButton("判定"); b.addActionListener(this); f.add(b, BorderLayout.NORTH); f.add(p, BorderLayout.SOUTH); f.setVisible(true); } /** ボタンが押されたときの処理 */ public void actionPerformed(ActionEvent e) { BufferedImage bi = new BufferedImage(mesh, mesh, BufferedImage.TYPE_INT_RGB); Graphics2D g2 = (Graphics2D) bi.getGraphics(); g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); g2.drawImage(img, 0, 0, mesh, mesh, this); double[] data = new double[mesh * mesh]; for(int x = 0; x < mesh; ++x){ for(int y = 0; y < mesh; ++y){ data[y * mesh + x] = (255 - bi.getRGB(x, y) & 255) / 255.; } } if("判定".equals(e.getActionCommand())){ //判定 int ans = learningMachine.trial(data); System.out.println("答えは" + ans); }else{ //学習処理 int idx = Integer.parseInt(e.getActionCommand()); learningMachine.learn(idx, data); } clear(); } /** マウスで描画 */ public void mouseDragged(MouseEvent e) { Point old = pt; pt = e.getPoint(); if(old != null){ bg.setColor(Color.BLACK); ((Graphics2D)bg).setStroke(new BasicStroke(25)); bg.drawLine(old.x, old.y, pt.x, pt.y); repaint(); } } public void mouseMoved(MouseEvent e) { pt = null; } //描画 @Override protected void paintComponent(Graphics g) { g.drawImage(img, 0, 0, this); } private void clear(){ bg.setColor(Color.WHITE); bg.fillRect(0, 0, 300, 300); bg.setColor(Color.LIGHT_GRAY); ((Graphics2D)bg).setStroke(new BasicStroke(1)); bg.drawRect(20, 20, 260, 260); bg.drawLine(0, 150, 300, 150); bg.drawLine(150, 0, 150, 300); repaint(); } }