ニューラルネットというのは、入力があって、複数の階層を経て出力を得るようなグラフ構造のことです。通常は、入力層・中間層・出力層のように層構造になっているようなものを差します。中でも、中間層が1層の、3層構造になっているものが多くとりあげられます。バックプロパゲーションは、誤差逆伝播法とも言って、ニューラルネットワークのパラメータを学習するための手法です。
ニューラルネットについてのサイトや本では、中間層を多層に対応した一般的な表現で説明されることが多いのですが、なかなか式を読み解くのが難しかったりするので、今回は3層で入力が2パラメータ、出力は1つ、中間層のニューロンは2つという、単純なものを取り上げます。
では、3層ニューラルネットワークでの判定時のデータの流れを見てみます。
3層ということになっていますが、実際の処理は2層になっています。実装するときには2層だと考えたほうがわかりやすいです。
それでは入力層から中間層g1までの流れを見てみます。まず、それぞれの入力x1、x2に係数w11、w21を掛けて足します。また、バイアスとしてw01という値も足します。
こうして集計した結果に対して、正なら1、負なら0という関数fを適用したものがg1の値になります。
式として書くと、次のようになります。
ここで処理上x0=1という入力値を仮においておくと、実装しやすくなります。
gについても一般化すると、次のようになります。
中間層から出力層までも同様に、g0=1と置いておくと次のようになります。
ここで、関数fに、次のようなしきい値関数を適用すると、fが微分できなくなるので、あとのバックプロパゲーション処理で不都合です。
そこで、tanhや、次のようなシグモイド関数を使います。
このグラフは次のようになります。
シグモイド関数の微分は次のようになるので、都合がよいです。
ニューラルネットワークの学習では、出力uと実際の値bとの誤差から、wやhの値を変更して行きます。
このとき、出力uと学習データbとの誤差から中間層→出力層の係数hを修正します。そして、hを修正した量によって入力層→中間層の係数wを修正します。このように、誤差を逆に伝播させることからバックプロパゲーション、誤差逆伝播法といいます。
具体的には、中間層→出力層の係数hの場合、次のように学習係数k、修正量eとすると
この修正量e1は、出力u、学習データbとして
のようになります。(u-b)が誤差、u(1-u)というのはシグモイド関数の微分になっています。
入力層→出力層の係数wの修正量cとすると
のようになります。前の層の修正量に伝播係数を掛けた分を修正しています。
学習係数kは、修正量をどの程度反映させるかという係数です。
ということで、実装してみたバックプロパゲーションによる学習の状況をみてみると次のようになりました。線形分離可能な場合にもうまく識別できています。*1
また、id:nowokay:20080330でやったようなパーセプトロンでの分離に比べて、両データの中間点に識別面ができています。これは単純なしきい値関数ではなく、シグモイド関数を使ったためともいえます。
実行にはid:nowokay:20080327のGraph.java、id:nowokay:20080326のLerningMachine.javaが必要です。
//BackPropergation.java import java.util.*; public class BackPropergation implements LearningMachine { List<Map.Entry<Integer, double[]>> patterns = new ArrayList<Map.Entry<Integer, double[]>>(); double[][] w;//入力→中間層の係数 double[] hidden;//中間層→出力の係数 int dim;//入力パラメータ数 int hiddendim;//中間層の数+1 public BackPropergation(int dim, int hiddendim) { this.dim = dim; this.hiddendim = hiddendim + 1; } public static void main(String[] args) { new Graph("バックプロパゲーション評価") { @Override public LearningMachine createLearningMachine() { return new BackPropergation(2, 2); } }; } public void learn(int cls, double[] data) { int yi = cls == 1 ? 1 : 0; patterns.add(new AbstractMap.SimpleEntry(yi, data)); final double k = .3;//学習係数 w = new double[hiddendim - 1][dim + 1]; for(int i = 0; i < w.length; ++i){ for(int j = 0; j < w[i].length; ++j){ w[i][j] = Math.random() * 2 - 1; } } hidden = new double[hiddendim]; for(int i = 0; i < hiddendim; ++i){ hidden[i] = Math.random() * 2 - 1; } for (int t = 0; t < 10000; ++t) { //学習を繰り返す boolean fin = false; for (Map.Entry<Integer, double[]> entry : patterns) { double[] pattern = new double[entry.getValue().length + 1]; for (int i = 0; i < entry.getValue().length; ++i) { pattern[i + 1] = entry.getValue()[i]; } pattern[0] = 1; int pcls = entry.getKey();//正解 double[] hiddenvalue = new double[hiddendim];//中間層の出力値 //入力層→中間層 for (int j = 0; j < w.length; ++j) { double in = 0; for (int i = 0; i < pattern.length; ++i) { in += pattern[i] * w[j][i]; } hiddenvalue[j + 1] = sigmoid(in); } hiddenvalue[0] = 1; //中間層→出力層 double out = 0;//出力 for (int i = 0; i < hiddenvalue.length; ++i) { out += hidden[i] * hiddenvalue[i]; } out = sigmoid(out); //出力層→中間層 double p = (pcls - out) * out * (1 - out); double[] e = new double[hiddendim];//中間層の補正値 double[] oldhidden = hidden.clone();//補正前の係数 for(int i = 0; i < hiddendim; ++i){ e[i] = p * hiddenvalue[i]; hidden[i] += e[i] * k; } //中間層→入力層 for(int i = 1; i< hiddendim; ++i){ double ek = e[i] * oldhidden[i] * hiddenvalue[i] * (1 - hiddenvalue[i]); for(int j = 0; j < dim + 1; ++j){ w[i - 1][j] += pattern[j] * ek * k; } } } if (fin) { break; } } } private double sigmoid(double d) { return 1 / (1 + Math.exp(-d)); } public int trial(double[] data) { double[] pattern = new double[data.length + 1]; for(int i = 0; i < data.length; ++i){ pattern[i + 1] = data[i]; } pattern[0] = 1; double[] hiddendata = new double[hiddendim]; //入力層→中間層 for (int j = 0; j < w.length; ++j) { double in = 0; for (int i = 0; i < pattern.length; ++i) { in += pattern[i] * w[j][i]; } hiddendata[j + 1] = sigmoid(in); } hiddendata[0] = 1; //中間層→出力層 double out = 0; for (int i = 0; i < hiddendata.length; ++i) { out += hiddendata[i] * hidden[i]; } return (sigmoid(out) > .5) ? 1 : -1; } }
*1:収束判定してないので、学習に失敗する場合もあります