パーセプトロンってなんだろう?

ということで、パーセプトロンがどんな感じで判定してるかを見てみました。


判定の境界が直線になっていることがわかります。パーセプトロンは、こういう具合に判定結果を分離する直線(平面)を求めるもので、これが線形分離ということです。そのため、線形分離不可能な場合には、誤判定が発生してます。

import java.util.*;

public class Perceptron implements LearningMachine{
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();
    double b;
    double[] p;
    int dim;
    public Perceptron(int dim) {
        this.dim = dim;
    }

    public static void main(String[] args) {
        new Graph("パーセプトロン評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new Perceptron(2);
            }
        };
    }

    public void learn(int cls, double[] data) {
        int yi = cls == 1 ? 1 : -1;
        patterns.add(new AbstractMap.SimpleEntry(yi, data));
        
        final double k = .01;
        b = 0;
        p = new double[dim];
        
        for(int j = 0; j < 100; ++j){
            //学習を繰り返す
            boolean fin = true;
            for(Map.Entry<Integer, double[]> entry : patterns){
                double[] pattern = entry.getValue();
                int pcls = entry.getKey();
                double in = 0;
                for(int i = 0; i < pattern.length; ++i){
                    in += pattern[i] * p[i];
                }
                in += b;
                if(in * pcls <= 0){
                    //誤判定で再学習
                    fin = false;
                    for(int i = 0; i < p.length; ++i){
                        p[i] += pattern[i] * k * pcls;
                    }
                    b += k * pcls;
                }
            }
            if(fin) break;//パーセプトロンの収束
        }
        
    }
    public int trial(double[] data) {
        double in = 0;
        for(int i = 0; i < data.length; ++i){
            in += data[i] * p[i];
        }
        in += b;
        return (in > 0) ? 1 : -1;
    }
}

3-NN法

NN法だと識別の境界ががたがたすぎるので、3-NN法っていうのを使ってみます。
3-NN法は、近いほうから3つの学習パターンをとってきて、そのうちの多数決で識別する方法です。そういう意味では、NN法ってのは1-NN法になります。
そうすると、識別の境界がちょっと滑らかになります。


けども、この結果をみると、学習パターンでも誤判別されているものがあります。これは、今回単純にデータの多数決を取ってるためで、データの距離によって重みをつけたりして改善させる必要があります。

import java.util.*;

public class NN3 implements LearningMachine{

    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();

    public static void main(String[] args) {
        new Graph("3-NN法評価"){
            @Override
            public LearningMachine createLearningMachine() {
                return new NN3();
            }
        };
    }

    public void learn(int cls, double[] data) {
        patterns.add(new AbstractMap.SimpleEntry(cls, data));
    }
    public int trial(double[] data) {
        //パターンを近い順に求める
        Map<Double, Map.Entry<Integer, double[]>> sorting =
                new TreeMap<Double, Map.Entry<Integer, double[]>>();
        for (Map.Entry<Integer, double[]> entry : patterns) {
            double[] ss = entry.getValue();
            if (ss.length != data.length) {
                System.out.println("へんなデータ");
                continue;
            }
            //データ間の距離を求める
            double dist = 0;
            for (int i = 0; i < ss.length; ++i) {
                dist += (ss[i] - data[i]) * (ss[i] - data[i]);
            }
            sorting.put(dist, entry);
        }
        //近い順から3つ
        List<Integer> clss = new ArrayList<Integer>();
        for(Map.Entry<Integer, double[]> entry : sorting.values()){
            clss.add(entry.getKey());
            if(clss.size() == 3) break;
        }
        //多数決
        int cls = clss.get(0);
        if(clss.size() < 3) return cls;
        if(clss.get(1) == clss.get(2)){
            return clss.get(1);
        }
        return cls;
    }
}

NN法っていいよね

NN法で学習した結果が、どんな感じの判定になるのか表示してみます。
2パラメータの学習データで、どういう風に判定されるのかを表示しています。


NN法では、一番近い学習データによって判定しているので、学習データに関しては正しい判定結果になります。けれども、判定の境界ががたがたになっていて、学習データ以外では外れた結果になることも多そうです。


ここまでのコメントで、「線形分離可能」という言葉が出てますが、これは、直線で判定結果を分けることができるかどうかということです。
NN法では、判定の境界が直線ではないので、「非線形分離」になっています。これが、ある程度の学習性能につながっていて、自分で書いた字を学習して自分で書いた字を判定するには十分かも、という結果につながっています。ただ、NN法は、すべての学習データを利用して判定するため、学習データが多ければデータ量も計算量も多くなってしまいます。


表示させるプログラムはこんな感じ。

//Graph.java
import java.awt.*;
import java.awt.image.BufferedImage;
import javax.swing.*;

public class Graph {
    public static void main(String[] args){
        new Graph("NN法評価");
    }
    public LearningMachine createLearningMachine(){
        return new NearestNeighbor();
    }
    
    public Graph(String title) {
        JFrame f = new JFrame(title);
        f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        f.setSize(420, 300);
        f.setLayout(new GridLayout(1, 2));
        
        //線形分離可能
        double[] linear1X = {0.15, 0.3, 0.35, 0.4, 0.55};
        double[] linear1Y = {0.3,  0.6, 0.25, 0.5, 0.4};
        double[] linear2X = {0.4,  0.7, 0.7, 0.85, 0.9};
        double[] linear2Y = {0.85, 0.9, 0.8, 0.7,  0.6};
        f.add(createGraph("線形分離可能", 
                linear1X, linear1Y, linear2X, linear2Y));
        //線形分離不可能
        double[] nonlinear1X = {0.15, 0.45, 0.6, 0.3, 0.75, 0.9};
        double[] nonlinear1Y = {0.5,  0.85, 0.75,  0.75, 0.7, 0.55};
        double[] nonlinear2X = {0.2,  0.55, 0.4,  0.6, 0.8, 0.85};
        double[] nonlinear2Y = {0.3,  0.6,  0.55, 0.4, 0.55, 0.2};
        f.add(createGraph("線形分離不可能",
                nonlinear1X, nonlinear1Y, nonlinear2X, nonlinear2Y));
        
        f.setVisible(true);
    }
    
    JLabel createGraph(String title, double[] linear1X, double[] linear1Y, double[] linear2X, double[] linear2Y) {
        LearningMachine lm = createLearningMachine();
        //学習
        for(int i = 0; i < linear1X.length; ++i){
            lm.learn(-1, new double[]{linear1X[i], linear1Y[i]});
        }
        for(int i = 0; i < linear2X.length; ++i){
            lm.learn( 1, new double[]{linear2X[i], linear2Y[i]});
        }        
        Image img = new BufferedImage(200, 200, BufferedImage.TYPE_INT_RGB);
        Graphics g = img.getGraphics();
        g.setColor(Color.WHITE);
        g.fillRect(0, 0, 200, 200);

        //判定結果
        for (int x = 0; x < 180; x += 2) {
            for (int y = 0; y < 180; y += 2) {
                int cls = lm.trial(new double[]{x / 180., y / 180.});
                g.setColor(cls == 1 ? new Color(192, 192, 255) : new Color(255, 192, 192));
                g.fillRect(x + 10, y + 10, 5, 5);
            }
        }
        //学習パターン
        for (int i = 0; i < linear1X.length; ++i) {
            int x = (int) (linear1X[i] * 180) + 10;
            int y = (int) (linear1Y[i] * 180) + 10;
            g.setColor(Color.RED);
            g.fillOval(x - 3, y - 3, 7, 7);
        }
        for (int i = 0; i < linear2X.length; ++i) {
            int x = (int) (linear2X[i] * 180) + 10;
            int y = (int) (linear2Y[i] * 180) + 10;
            g.setColor(Color.BLUE);
            g.fillOval(x - 3, y - 3, 7, 7);
        }
        //ラベル作成
        JLabel l = new JLabel(title, new ImageIcon(img), JLabel.CENTER);
        l.setVerticalTextPosition(JLabel.BOTTOM);
        l.setHorizontalTextPosition(JLabel.CENTER);
        return l;
    }
}

基本のNearestNeighbors法(NN法)でパターン認識

とりあえず、パターン認識の一番の基本になるNearestNeighbors法(NN法)でやってみます。
NN法は、判定するデータが、学習に使ったデータのうちの一番近いものに分類する方法です。
自分の字で学習させて自分の字を判定させるなら、これでいいかも、とか思う。
前処理として画像処理してやると、結構いい感じになりそうです。けど、今回は画像処理はあとまわしです。


とりあえず、学習器のインタフェースを用意します。learnメソッドで学習させて、trialメソッドで判定

//LearningMachine.java
public interface LearningMachine {
    //学習
    void learn(int cls, double[] data);
    //評価
    int trial(double[] data);
}


で、NN法を実装させるとこんな感じ。

//NearestNeighbor.java
import java.util.*;

class NearestNeighbor implements LearningMachine {
    List<Map.Entry<Integer, double[]>> patterns = 
            new ArrayList<Map.Entry<Integer, double[]>>();

    public static void main(String[] args) {
        new MachineLearning("NN法", new NearestNeighbor());
    }

    public void learn(int cls, double[] data) {
        patterns.add(new AbstractMap.SimpleEntry(cls, data));
    }

    public int trial(double[] data) {
        int cls = 0;
        //一番近いパターンを求める
        double mindist = Double.POSITIVE_INFINITY;
        for (Map.Entry<Integer, double[]> entry : patterns) {
            double[] ss = entry.getValue();
            if (ss.length != data.length) {
                System.out.println("へんなデータ");
                continue;
            }
            //データ間の距離を求める
            double dist = 0;
            for (int i = 0; i < ss.length; ++i) {
                dist += (ss[i] - data[i]) * (ss[i] - data[i]);
            }
            if (mindist > dist) {
                mindist = dist;
                cls = entry.getKey();
            }
        }
        return cls;
    }
}


で、画面系

//MachineLearning.java
import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import javax.swing.*;

public class MachineLearning extends JComponent 
        implements ActionListener, MouseMotionListener{
    Image img = new BufferedImage(300, 300, BufferedImage.TYPE_INT_RGB);
    Graphics bg = img.getGraphics();
    Point pt;
    LearningMachine learningMachine;
    int mesh = 10;

    public MachineLearning(String title, LearningMachine learningMachine) {
        this.learningMachine = learningMachine;
       JFrame f = new JFrame(title);
       f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
       f.setSize(320, 420);
       clear();
       f.add(this);
       
       JPanel p = new JPanel();
       p.setLayout(new GridLayout(2, 5));
       for(int i = 0; i < 10; ++i){
           JButton b = new JButton(i +"");
           p.add(b);
           b.addActionListener(this);
       }
       addMouseMotionListener(this);
       
       JButton b = new JButton("判定");
       b.addActionListener(this);
       f.add(b, BorderLayout.NORTH);
       
       f.add(p, BorderLayout.SOUTH);
       f.setVisible(true);
    }
    
    /** ボタンが押されたときの処理 */
    public void actionPerformed(ActionEvent e) {
        BufferedImage bi = new BufferedImage(mesh, mesh, BufferedImage.TYPE_INT_RGB);
        Graphics2D g2 = (Graphics2D) bi.getGraphics();
        g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, 
                RenderingHints.VALUE_INTERPOLATION_BICUBIC);
        g2.drawImage(img, 0, 0, mesh, mesh, this);
        double[] data = new double[mesh * mesh];
        for(int x = 0; x < mesh; ++x){
            for(int y = 0; y < mesh; ++y){
                data[y * mesh + x] = (255 - bi.getRGB(x, y) & 255) / 255.;
            }
        }
        
        if("判定".equals(e.getActionCommand())){
            //判定
            int ans = learningMachine.trial(data);
            System.out.println("答えは" + ans);
        }else{
            //学習処理
            int idx = Integer.parseInt(e.getActionCommand());
            learningMachine.learn(idx, data);
        }
        clear();
    }

    /** マウスで描画 */
    public void mouseDragged(MouseEvent e) {
        Point old = pt;
        pt = e.getPoint();
        if(old != null){
            bg.setColor(Color.BLACK);
            ((Graphics2D)bg).setStroke(new BasicStroke(25));
            bg.drawLine(old.x, old.y, pt.x, pt.y);
            repaint();
        }
    }

    public void mouseMoved(MouseEvent e) {
        pt = null;
    }
    
    //描画
    @Override
    protected void paintComponent(Graphics g) {
        g.drawImage(img, 0, 0, this);
    }
    
    private void clear(){
        bg.setColor(Color.WHITE);
        bg.fillRect(0, 0, 300, 300);
        bg.setColor(Color.LIGHT_GRAY);
        ((Graphics2D)bg).setStroke(new BasicStroke(1));
        bg.drawRect(20, 20, 260, 260);
        bg.drawLine(0, 150, 300, 150);
        bg.drawLine(150, 0, 150, 300);
        repaint();
    }
}

バックプロパゲーション

単層パーセプトロンだと線形分離可能なものしか学習できないということで、3層パーセプトロンにしてバックプロパゲーションなるもので学習しようと思って、とりあえずコードを書いたのだけど、まだちゃんと学習できてない><
機械学習の入門書は、これが読みやすい感じ。

フリーソフトでつくる音声認識システム パターン認識・機械学習の初歩から対話システムまで

フリーソフトでつくる音声認識システム パターン認識・機械学習の初歩から対話システムまで

音声認識って書いてあるけど、半分はパターン認識。で、難しい言葉では書いてない。もちろん数式満載ではあるけど。