HMMの隠れたマルコフを覗いてみる

HMMではマルコフモデルが隠れていたのですが、こういう隠れているものを覗いてみたくなるのが人情というものです。
ということで、出力からマルコフ遷移を推定してみます。これにはビタビ(Viterbi)のアルゴリズムを使います。


ビタビのアルゴリズムは、要するに「それ動的計画法でできるよ!」というもので、図のように、ある状態に複数の遷移でたどりついたときに、一番確率高いものだけを考えればいいというものです。赤い遷移と青い遷移で同じ状態にたどりついたとき、確率の高い赤い遷移だけを残して、確率が低い方の青い遷移は捨てることができます。

確率が低い遷移が、そのあとの遷移によって確率が逆転することはないからです。これは、マルコフ過程が一度状態が決まればそれ以前の状態には影響を受けることがなくなることを利用したものです。
動的計画法は、このように、何かの最適解を求めるとき、それ以前の状態遷移には影響を受けない点があれば、その時点での最適解だけを残してあとの解を捨てるというアルゴリズムです。


それでは、前回の出力から、隠れた状態遷移を推定してみます。
プログラムを実行すると、このようになって、最後が違うだけで、あとはうまく状態遷移を推定できていることがわかります。

出力:serrar rer te rr rrsertah reht
正解:31 32 33 31 32 33 0 31 32 33 0 21 22 0 21 22 0 21 22 31 32 33 31 32 33 0 31 32 33 31
推定:31 32 33 31 32 33 0 31 32 33 0 21 22 0 21 22 0 21 22 31 32 33 31 32 33 0 31 32 33 21 


コードの前半は、出力プログラムと同じHMMを定義しています。
また、状態推定部分を見ると、出力文字数分のループ、状態種別数のループ、状態種別数のループ、という3重ループになっていることから、計算量がO(出力文字数×状態種別数^2)となることがわかります。

import java.util.*;

public class Viterbi {
    public static void main(String[] args){
        String[] Q = {//状態
            "0", "1", "21", "22", "31", "32", "33"};
        double[][] A = {//状態遷移頻度
            {0,  1,   2,    0,    2,    0,    0},
            {0,  0,   1,    0,    1,    0,    0},
            {0,  0,   0,    1,    0,    0,    0},
            {2,  2,   3,    0,    3,    0,    0},
            {0,  0,   0,    0,    0,    1,    0},
            {0,  0,   0,    0,    0,    0,    1},
            {3,  1,   3,    0,    2,    0,    0}};
        String[] Σ = {//出力文字
            " ", "a", "e", "s", "t", "h", "r"};
        double[][] E = {//文字出力頻度
            {1,   0,   0,   0,   0,   0,   0},
            {0,   2,   0,   3,   3,   0,   0},
            {0,   0,   0,   2,   2,   1,   3},
            {0,   1,   2,   0,   0,   1,   1},
            {0,   0,   0,   2,   2,   1,   2},
            {0,   2,   3,   0,   0,   0,   0},
            {0,   0,   0,   0,   0,   1,   2}};
        //確率の正規化
        for(double[] d : A){
            double sum = 0;
            for(double v : d) sum += v;
            for(int i = 0; i  < d.length; ++i){
                d[i] /= sum;
            }
        }
        for(double[] d : E){
            double sum = 0;
            for(double v : d) sum += v;
            for(int i = 0; i  < d.length; ++i){
                d[i] /= sum;
            }            
        }
        //↑ここまでは前のサンプルと同じ
        
        //出力文字列
        String output = "serrar rer te rr rrsertah reht";
        System.out.println("出力:" + output);
        System.out.println("正解:31 32 33 31 32 33 0 31 32 33 0 21 22 0" +
                " 21 22 0 21 22 31 32 33 31 32 33 0 31 32 33 31");
        
        //文字列を分解
        List<Integer> outputindex = new ArrayList<Integer>();
        //出力文字の転置インデックス
        Map<String, Integer> reverse = new HashMap<String, Integer>();
        for(int i = 0; i < Σ.length; ++i){
            reverse.put(Σ[i], i);
        }
        for(int i = 0; i < output.length(); ++i){
            outputindex.add(reverse.get(output.substring(i, i + 1)));
        }

        //状態遷移の推定
        double[] probervility = new double[Q.length];//この状態に遷移する確率
        probervility[0] = 1;//初期状態に遷移する確率を100%にしておく
        //各状態に遷移するもっとも確率の高い状態遷移列
        Map<Integer, List<Integer>> move = new HashMap<Integer, List<Integer>>();
        
        for(int s : outputindex){//各文字について
            double[] nextprob = new double[Q.length];
            Map<Integer, List<Integer>> nextmove = new HashMap<Integer, List<Integer>>();
            for(int i = 0; i < Q.length; ++i){//現在状態について
                //この状態までの遷移確率
                double pi = probervility[i];
                for(int j = 0; j < Q.length; ++j){//次の状態について
                    //状態S_iから状態S_jへの遷移確率
                    double pj = A[i][j] * pi;
                    //状態S_jでの文字sの出力確率
                    double ps = pj * E[j][s];
                    
                    //この状態遷移が状態S_jに遷移する中で一番確率が高ければ記録
                    if(nextprob[j] < ps){
                        nextprob[j] = ps;
                        List<Integer> m = move.get(i);
                        if(m == null) m = new ArrayList<Integer>();
                        List<Integer> nm = new ArrayList<Integer>(m);
                        nm.add(j);
                        nextmove.put(j, nm);
                    }
                }
            }
            probervility = nextprob;
            move = nextmove;
        }
        //最終状態に遷移したもののうち一番確率が高いものを取得
        List<Integer> result = null;
        double max = 0;
        for(int i = 0; i < probervility.length; ++i){
            if(max < probervility[i]){
                max = probervility[i];
                result = move.get(i);
            }
        }
        //結果出力
        System.out.print("推定:");
        for(int i : result){
            System.out.print(Q[i] + " ");
        }
        System.out.println();
    }
}