WebAssemblyにコンパイルしたRustコードにJavaインタフェースをマッピングする

Chicoryを使ってRustをコンパイルしたwasmをJavaから呼び出してみました。
JVMでWebAssemblyにコンパイルしたRustのコードを動かす - きしだのHatena

ただ、結構呼び出しがめんどいので、Javaインタフェースを定義したらなんかメソッド呼び出しで使える、というよく見かけるやつを作ってみます。

Rustのコードはこう。

#[no_mangle]
pub fn add(left: i32, right: i32) -> i32 {
    left + right
}

#[no_mangle]
pub fn sub(left: i32, right: i32) -> i32 {
    left - right
}

#[no_mangle]
pub fn mul(left: i32, right: i32) -> i64 {
    (left as i64) * (right as i64)
}

こんなJavaインタフェースで呼び出せるようにしたい。

interface RustFuncs {
    int add(int left, int right);
    int sub(int left, int right);
    long mul(int left, int right);
}

ということで、こんな感じでProxyを作る。

static <T> T bind(String name, Class<T> type) {
    var wasm = CallWasm.class.getClassLoader().getResourceAsStream(name);
    var module = Module.builder(wasm).build();
    var instance = module.instantiate();

    T obj = (T) Proxy.newProxyInstance(type.getClassLoader(), new Class<?>[]{type}, 
            (p, m, a) -> methodHandler(instance, type, p, m, a));
    return obj;
}

static Object methodHandler(Instance ins, Class type,
        Object proxy, Method method, Object[] args) throws Throwable {
    var m = ins.export(method.getName()); // todo:use annotation name
    List<Value> values = new ArrayList<>();
    for (int i = 0; i < args.length; ++i) {
        values.add(switch (args[i]) {
            case Integer n -> Value.i32(n);
            case Long l -> Value.i64(l);
            default ->throw new RuntimeException("unknown type " + method.getParameterTypes()[i]);
        });
    }
    var result = m.apply(values.toArray(Value[]::new))[0];
    var rt = method.getReturnType();
    if(rt == long.class || rt == Long.class) {
        return result.asLong();
    } else if (rt == int.class || rt == int.class) {
        return result.asInt();
    } else {
        throw new RuntimeException("unknown ret type " + rt);
    }
}

こうやって呼び出せるようになりました!やったね!

public static void main(String[] args) throws IOException {
    var rf = bind("hello_wasm.wasm", RustFuncs.class);
    System.out.println(rf.add(12, 34));
    System.out.println(rf.sub(34, 12));
    System.out.println(rf.mul(100_000_000, 100_000_000));
}

文字列とかはwasmでは定義されてなくて結構めんどいので、そのあたりはもう少し考えてみます。
使い物になりそうだったらasmかJavassistバイトコード生成を。