ということで、バイトコードコンパイラを作ってみる。

おととい昨日のプログラムを組み合わせて、数式をバイトコードコンパイルして実行させてみました。

import java.io.*;
import java.text.ParseException;
import java.util.Stack;

public class Compiler {
    public static void main(String[] a) throws Exception{
        System.out.printf("%d : %d%n", eval("4-3"), 4-3);
        System.out.printf("%d : %d%n", eval("4--3"), 4- -3);
        System.out.printf("%d : %d%n", eval("((-12 + 3) * 2)"), ((-12 + 3) * 2));
        System.out.printf("%d : %d%n", eval("-12 + 3 * 2"), -12 + 3 * 2);
        System.out.printf("%d : %d%n", eval("3 * 2 + 4 * -5"), 3 * 2 + 4 * -5);
        System.out.printf("%d : %d%n", eval("3 * 2 + -4 * 5"), 3 * 2 + -4 * 5);
        System.out.printf("%d : %d%n", eval("3 * (2 + -4) * 5"), 3 * (2 + -4) * 5);
        System.out.printf("%d : %d%n", eval("3 * (-2 + 4) * 5"), 3 * (-2 + 4) * 5);
        System.out.printf("%d : %d%n", eval("0+1+2+3+4+5+6"), 0+1+2+3+4+5+6);
        System.out.println(eval("(4 * 3) 2"));//おかしい
    }
    
    public static int eval(String str) throws Exception{
        String className = "Compiler";
        String methodName = "eval";
       
        ByteArrayOutputStream buf = new ByteArrayOutputStream();
        DataOutputStream dos = new DataOutputStream(buf);
        dos.writeInt(0xcafebabe);//クラスファイルを識別するマジックナンバー
        dos.writeShort(0);//マイナーバージョン
        dos.writeShort(49);//メジャーバージョン
        
        //定数の出力
        ByteArrayOutputStream constbuf = new ByteArrayOutputStream();
        DataOutputStream constos = new DataOutputStream(constbuf);
        int id = 1;
        int idClassName = id++;
        constos.writeByte(1);//識別子
        constos.writeUTF(className);
        int idMethod = id++;
        constos.writeByte(1);//識別子
        constos.writeUTF(methodName);
        int idType = id++;
        constos.writeByte(1);//識別子
        constos.writeUTF("()I");
        int idObj = id++;
        constos.writeByte(1);//識別子
        constos.writeUTF("java/lang/Object");
        int idClassObj = id++;
        constos.writeByte(7);//クラス
        constos.writeShort(idObj);
        int idClass = id++;
        constos.writeByte(7);//クラス
        constos.writeShort(idClassName);
        int idCode = id++;
        constos.writeByte(1);//識別子
        constos.writeUTF("Code");//コード用の属性名
        constos.close();
        constbuf.close();
        
        dos.writeShort(id);//定数の個数(実際の個数より1多くする)
        dos.write(constbuf.toByteArray());

        //クラス
        dos.writeShort(0x21);// アクセス指定 ACC_SUPER(0x20) | ACC_PUBLIC(0x01)
        dos.writeShort(idClass);//クラス
        dos.writeShort(idClassObj);//スーパークラス
        dos.writeShort(0);//インタフェースの個数
        dos.writeShort(0);//フィールドの数
        dos.writeShort(1);//メソッドの数
        
        //メソッド
        dos.writeShort(9);// アクセス指定 ACC_PUBLIC(0x01) | ACC_STATIC(0x08)
        dos.writeShort(idMethod);//名前
        dos.writeShort(idType);//シグネチャ
        dos.writeShort(1);//属性の数(ここではメソッド本体)
        dos.writeShort(idCode);//属性名(Code)
        byte[] maincode = compile(str);
        
        dos.writeInt(maincode.length + 12);//属性の長さ
        dos.writeShort(maxdepth);//最大スタック
        dos.writeShort(0);//最大変数
        dos.writeInt(maincode.length);//コードのバイト数
        //メソッド
        dos.write(maincode);
        //メソッドおわり
        dos.writeShort(0);//例外の数
        dos.writeShort(0);//属性の数
        
        dos.writeShort(0);//属性の数
        
        dos.close();
        buf.close();
        
        //クラスファイルを読み込んで実行
        final byte[] code = buf.toByteArray();
        ClassLoader cl = new ClassLoader(){
            @Override
            protected Class<?> findClass(String name) throws ClassNotFoundException {
                return defineClass(null, code, 0, code.length);
            }
        };
        Class expClass = cl.loadClass(className);
        int ret = (Integer)expClass.getMethod(methodName).invoke(null, new Object[0]);
        return ret;
    }
    
    static int BC_ICONST_0 = 0x03;
    static int BC_BIPUSH = 0x10;
    static int BC_SIPUSH = 0x11;
    static int BC_IADD = 0x60;
    static int BC_ISUB = 0x64;
    static int BC_IMUL = 0x68;
    static int BC_IDIV = 0x6c;
    static int BC_INEG = 0x74;
    static int BC_IRETURN = 0xac;
    
    static int depth;//スタックの深さ
    static int maxdepth;//スタックの深さの最大値
    private static byte[] compile(String str) throws IOException{
        depth = 0;
        maxdepth = 0;
        ByteArrayOutputStream evalbuf = new ByteArrayOutputStream();
        DataOutputStream evaldos = new DataOutputStream(evalbuf);
        
        int value = 0;
        int numphase = 0;//0:数字がくる 1:数字の途中 2:演算子が来る
        boolean minus = false;
        Stack<State> statestack = new Stack<State>();
        State state = new State(evaldos);
        statestack.push(state);
        try{
        for(int i = 0; i < str.length(); ++i){
            char ch = str.charAt(i);
            if(ch == ' '){
                continue;
            }else if(ch >= '0' && ch <= '9'){
                if(numphase == 2){
                    throw new ParseException("数字がきちゃだめなところに" + ch, i);
                }
                value = value * 10 + (ch - '0');
                numphase = 1;
            }else{
                if(numphase == 0){
                    if(ch == '-'){
                        minus = !minus;
                        continue;
                    }else if(ch == '('){
                        state = new State(evaldos);
                        state.minus = minus;
                        statestack.push(state);
                        minus = false;
                        continue;
                    }else{
                        throw new ParseException(ch + "が来たのでパースエラーです", i);
                    }
                }
                if(numphase == 1){
                    state.push(value, minus);
                }
                value = 0;
                minus = false;
                
                state.mul();
                
                if("*/".indexOf(ch) >= 0){
                    state.opstack.push(ch);
                    numphase = 0;
                }else if("+-".indexOf(ch) >= 0){
                    state.calc();
                    state.opstack.push(ch);
                    numphase = 0;
                }else if(ch == ')'){
                    state.calc();
                    if(state.minus){
                        evaldos.writeByte(BC_INEG);
                    }
                    statestack.pop();
                    if(statestack.size() == 0){
                        //対応する括弧がない
                        throw new ParseException("対応するカッコがない", i);
                    }
                    state = statestack.peek();
                    numphase = 2;
                }else{
                    //エラー
                    throw new ParseException(ch + "はわかんない", i);
                }
                if(maxdepth < depth) maxdepth = depth;
            }
        }
        }catch(ParseException e){
            System.out.println(e.getMessage());
            return new byte[]{
                (byte)BC_ICONST_0,
                (byte)BC_IRETURN
            };
        }
        if(numphase == 1) state.push(value, minus);
        state.mul();
        state.calc();
        
        evaldos.writeByte(BC_IRETURN);
        evaldos.close();
        evalbuf.close();
        return evalbuf.toByteArray();
    }

    static class State{
        boolean minus = false;
        private Stack<Character> opstack = new Stack<Character>();
        private DataOutputStream dos;

        public State(DataOutputStream dos) {
            this.dos = dos;
        }
        
        void push(int value, boolean sign)throws IOException{
            int v = (sign ? -1 : 1) * value;
            if(value <= 5 && !sign){
                dos.writeByte(BC_ICONST_0 + value);
            }else if(value < 128){
                dos.writeByte(BC_BIPUSH);
                dos.writeByte(v);
            }else{
                dos.writeByte(BC_SIPUSH);
                dos.writeShort(v);//今回はshortの範囲までしか対応しない
            }
            ++depth;
            if(maxdepth < depth) maxdepth = depth;
        }
        void mul()throws IOException{
            if(opstack.empty()) return;
            char op = opstack.peek();
            if(op == '*'){
                dos.writeByte(BC_IMUL);
                opstack.pop();
                --depth;
            }else if(op == '/'){
                dos.writeByte(BC_IDIV);
                opstack.pop();
                --depth;
            }
        }
        
        void calc()throws IOException{
            if(opstack.empty()) return;
            char op = opstack.peek();
            if(op == '+'){
                dos.writeByte(BC_IADD);
                opstack.pop();
                --depth;
            }else if(op == '-'){
                dos.writeByte(BC_ISUB);
                opstack.pop();
                --depth;
            }else{
                //なんかおかしい
                System.out.println("なんかおかしいね");
            }
        }
    }
}