Decorator annotationをJava5/Javassist3.1で実装してみるも...

先日の [id:bellbind:20050903:p2] でmemorization annotation/javaagentをつくったとき、汎用化するDecorator annotationを作れないか、と思って作ったのがこれ。

例コード

Fib.java

public class Fib {
    @Decorate(MemorizeDecorator.class)
    public long fib(int n) {
        System.out.println("when n=" + n);
        if (n < 2) return 1;
        long result = fib(n - 1) + fib(n - 2);
        return result;
    }
}

MemorizedDecorator.java

import java.util.*;

public class MemorizeDecorator implements Decorator {
    
    public Call decorate(final Call call) {
        Call wrapper = new Call() {
            private HashMap<Args, Object> map = new HashMap<Args, Object>();
            
            public Object proceed(Object[] arg) {
                Args args = new Args(arg);
                Object result = map.get(args);
                if (result == null) {
                    result = call.proceed(arg);
                    map.put(args, result);
                }
                return result;
            }
        };
        return wrapper;
    }
}

Args.java

import java.util.*;

public class Args {
    private Object[] args;
    public Args(Object[] args) {
        this.args = args;
    }
    
    public int hashCode() {
        return Arrays.deepHashCode(this.args);
    }
    
    public boolean equals(Object o) {
        if (o instanceof Args) {
            return Arrays.deepEquals(this.args, ((Args) o).args);
        }
        return false;
    }
}

Main.java

public class Main {
    public static void main(String[] args) {
        Fib fib = new Fib();
        System.out.println(fib.fib(80));
        //System.out.println(fib.fib(5));
    }
}

MemorizedDecoratorではメモ化のデコレーションしたメソッド(となるオブジェクト)を作って渡すコードを書き、それをクラスのannotationにつけて渡す、という機能になるよう実装してみた。

プラットフォーム側

Decorate.java

import java.lang.annotation.*;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Decorate {
    public Class<? extends Decorator> value();
}

パラメータにクラスオブジェクトをもつことができるアノテーション

Decorator.java

public interface Decorator {
    public Call decorate(Call call);
}

実際にデコレーションをさせるインタフェース。@Decorate(MyDecorator.class)のような感じで使う。アノテーション記述ごとに一つのインスタンスができる。
引数側のCallはもとのメソッド、戻り値はデコレーションしたメソッドになる。

Call.java

public interface Call {
    public Object proceed(Object[] args);
}

これは元のメソッドやデコレーションしたもののインタフェース。

DecoratorAgent.java

import java.lang.instrument.*;
import java.security.*;
import java.io.*;
import javassist.*;

public class DecoratorAgent implements ClassFileTransformer {
    private Instrumentation inst;
    private ClassPool pool;
    private CtClass callClass;
    private CtClass objectClass;
    private CtClass objectArrayClass;
    private static String FIELD_PREFIX = "_decorate_$_field_$_";
    private static String ESCAPE_PREFIX = "_decorate_$_original_$_";
    
    public DecoratorAgent(Instrumentation inst) throws Exception {
        this.inst = inst;
        this.pool = new ClassPool();
        this.pool.appendSystemPath();
        this.callClass = this.pool.get("Call");
        this.objectClass = this.pool.get("java.lang.Object");
        this.objectArrayClass = this.pool.get("java.lang.Object[]");
    }
    
    public static void premain(String agentArgs, Instrumentation inst) throws Exception {
        inst.addTransformer(new DecoratorAgent(inst));
    }
    
    public byte[] transform(ClassLoader loader, String className,
                              Class<?> classBeingRedefined,
                              ProtectionDomain protectionDomain,
                              byte[] classfileBuffer) throws IllegalClassFormatException {
        try {
            return this.getDecoratedClass(classfileBuffer);
        } catch (Exception ex) {
            ex.printStackTrace();
            throw new IllegalClassFormatException(ex.getMessage());
        }
    }
    
    private byte[] getDecoratedClass(byte[] classfileBuffer) throws Exception {
        ByteArrayInputStream istream = new ByteArrayInputStream(classfileBuffer);
        CtClass ctClass = this.pool.makeClass(istream);
        int fieldIndex = 0;
        int methodIndex = 0;
        for (CtMethod method: ctClass.getMethods()) {
            String callFieldName = null;
            for (Object annotation: method.getAnnotations()) {
                if (!(annotation instanceof Decorate)) break;
                Decorate decorate = (Decorate) annotation;
                if (callFieldName == null) {
                    String escapeName = ESCAPE_PREFIX + methodIndex;
                    CtMethod escapeMethod = this.escapeMethod(ctClass, escapeName, method);
                    methodIndex++;
                    
                    callFieldName = FIELD_PREFIX + fieldIndex;
                    CtClass escapeClass = this.createEscapeCallClass(ctClass, escapeName, escapeMethod);
                    
                    this.addEscapeCallField(ctClass, callFieldName, escapeClass);
                    fieldIndex++;
                }
                String calledFieldName = callFieldName;
                callFieldName = FIELD_PREFIX + fieldIndex;
                this.addDecoratedCallField(ctClass, callFieldName, decorate.value(), calledFieldName);
                fieldIndex++;
            }
            if (callFieldName != null) {
                replaceMethod(ctClass, method, callFieldName);
            }
        }
        return ctClass.toBytecode();
    }
    
    private CtMethod escapeMethod(CtClass ctClass, String escapeName, CtMethod method) throws Exception {
        // escape original method;
        CtMethod escapeMethod = new CtMethod(method, ctClass, null);
        escapeMethod.setName(escapeName);
        ctClass.addMethod(escapeMethod);
        return escapeMethod;
    }
    
    private CtClass createEscapeCallClass(CtClass ctClass, String escapeName, CtMethod escapeMethod) throws Exception {
        // create original method Call class
        String escapeCallClassName = ctClass.getName() + "$" + escapeName;
        CtClass escapeCallClass = this.pool.makeClass(escapeCallClassName);
        escapeCallClass.addInterface(this.callClass);
        
        CtField parentField = new CtField(ctClass, "parent", escapeCallClass);
        escapeCallClass.addField(parentField);
        
        CtConstructor constructor = new CtConstructor(new CtClass[] {ctClass}, escapeCallClass);
        constructor.setBody("{this.parent = $1;}");
        escapeCallClass.addConstructor(constructor);
        
        CtMethod proceedMethod = new CtMethod(this.objectClass, "proceed", new CtClass[] {this.objectArrayClass}, escapeCallClass);
        String body = "this.parent." + escapeMethod.getName() + "(" + this.unwrapperArgs(escapeMethod, "$1") + ")";
        String bodyCode = "{" + this.unwrapperReturnCode(escapeMethod, body) + "}";
        //System.out.println(bodyCode);
        proceedMethod.setBody(bodyCode);
        escapeCallClass.addMethod(proceedMethod);
        
        // registering class
        Class<?> escapeCallClassObject = this.pool.toClass(escapeCallClass);
        
        return escapeCallClass;
    }
    
    private void addEscapeCallField(CtClass ctClass, String fieldName, CtClass escapeCallClass) throws Exception {
        // add call field and initializer
        CtField callField = new CtField(this.callClass, fieldName, ctClass);
        String initExpr = "new " + escapeCallClass.getName() + "(this)";
        CtField.Initializer callFieldInitializer = CtField.Initializer.byExpr(initExpr);
        ctClass.addField(callField, callFieldInitializer);
    }
    
    private void addDecoratedCallField(CtClass ctClass, String fieldName,
                                       Class<? extends Decorator> decoratorClass,
                                       String prevCallField) throws Exception {
        // add call field and initializer
        CtField callField = new CtField(this.callClass, fieldName, ctClass);
        String initExpr = "new " + decoratorClass.getName() + "().decorate(" + prevCallField + ")";
        CtField.Initializer callFieldInitializer = CtField.Initializer.byExpr(initExpr);
        ctClass.addField(callField, callFieldInitializer);
    }
    
    private void replaceMethod(CtClass ctClass, CtMethod method, String callFieldName) throws Exception {
        // replace caller call
        String code = "{" +
        "return ($r) this." + callFieldName + ".proceed($args);" +
        "}";
        method.setBody(code);
    }
    
    private String unwrapperReturnCode(CtMethod method, String body) throws Exception {
        CtClass returnType = method.getReturnType();
        if (returnType.equals(CtClass.voidType)) {
            // TBD
            return body + "; return null;";
        } else if (returnType.equals(CtClass.booleanType)) {
            return "return Boolean.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.byteType)) {
            return "return Byte.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.shortType)) {
            return "return Short.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.intType)) {
            return "return Integer.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.longType)) {
            return "return Long.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.floatType)) {
            return "return Float.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.doubleType)) {
            return "return Double.valueOf(" + body + ");";
        } else {
            return "return " + body + ";";
        }
    }
    
    // convert object array to flatten parameters
    private String unwrapperArgs(CtMethod method, String variable) throws Exception {
        StringBuilder buffer = new StringBuilder();
        int count = 0;
        for (CtClass paramType: method.getParameterTypes()) {
            if (count > 0) buffer.append(", ");
            int argIndex = count;
            if (paramType.equals(CtClass.booleanType)) {
                buffer.append("((Boolean) " + variable + "[" + argIndex + "]).booleanValue()");
            } else if (paramType.equals(CtClass.byteType)) {
                buffer.append("((Byte) " + variable + "[" + argIndex + "]).byteValue()");
            } else if (paramType.equals(CtClass.shortType)) {
                buffer.append("((Short) " + variable + "[" + argIndex + "]).shortValue()");
            } else if (paramType.equals(CtClass.intType)) {
                buffer.append("((Integer) " + variable + "[" + argIndex + "]).intValue()");
            } else if (paramType.equals(CtClass.longType)) {
                buffer.append("((Long) " + variable + "[" + argIndex + "]).longValue()");
            } else if (paramType.equals(CtClass.floatType)) {
                buffer.append("((Float) " + variable + "[" + argIndex + "]).floatValue()");
            } else if (paramType.equals(CtClass.doubleType)) {
                buffer.append("((Double) " + variable + "[" + argIndex + "]).doubleValue()");
            } else {
                buffer.append(variable + "[" + argIndex + "]");
            }
            count++;
        }
        return buffer.toString();
    }
    
}

このjavaagentは、@Decorateがあるメソッドに対して

  • 1. $_decorator_$_original_$_nという名前で、そのメソッドと同じ内容のメソッドにコピー追加
  • 2. DecoratedClass$_decorator_$_original_$_nという名の、上記メソッドを呼ぶCallを実装した(内部クラス的)クラスを生成
  • 3. $_decorator_$_field_$_mという名のフィールドを作り、上記Callのインスタンスで初期化するコードを追加
  • 4. $_decorator_$_field_$_lという名(l=m+1)のフィールドを作り、@Decorateのvalue()のdecorate($_decorator_$_field_$_m)で初期化を行うコードを追加
  • 5. @Decorateのついたメソッドの中身を$_decorator_$_field_$_lを実行するように入れ替え

を行うもの(効率化と例外処理は省略している)

JavaagentのPremainクラスのクラスローダーはメインプログラムのクラスローダーとシステムのクラスローダーの間に位置する。そこで新たなクラスを読み込むと、副作用的にメインプログラムでも使えるクラスになる。

実行

agent.mf

Premain-Class: DecoratorAgent

実行はMemorizeAgentと同じように

$ jar cvfm agent.jar agent.mf
$ javac -classpath javassist.jar;. *.java
$ java -classpath javassist.jar;. -javaagent:agent.jar Main
when n=80
when n=79
when n=78
...
when n=1
when n=0
37889062373143906

これ自体は普通に実行できる。

Java5のannotationは同一要素に一つだけ

このDecoratorでLoggingもつけたいと思った。つまり、LoggingDecorator.java

import java.util.*;

public class LoggingDecorator implements Decorator {
    public Call decorate(final Call call) {
        Call wrapper = new Call() {
            public Object proceed(Object[] args) {
                Object result = call.proceed(args);
                System.out.println("method(" + Arrays.deepToString(args) + ") =" + result);
                return result;
            }
        };
        return wrapper;
    }
}

を作って、

@Decorate(LoggingDecorator.class)  // 実際はコンパイルエラー
@Decorator(MemorizeDecorator.class)
public long fib(int n) {
  if (n <= 1) return 1;
  return fib(n - 1) + fib(n - 2);
}

みたいに書きたい、と。実際にそうかければ実現できるように上記DecoratorAgentは書いた。

ただし、ここで一つ問題がでる。Java5では同じ名前のAnnotationは同じ要素には一つしか付けられない。つまり上記は無理ということだ。

ここまで書いて、

@Decorate({LoggingDecorator.class, MemorizeDecorator.class})
public long fib(int n) {
  if (n <= 1) return 1;
  return fib(n - 1) + fib(n - 2);
}

これができればいいんじゃないかと気が付いた。