Java5/Javassist3.1による汎用Decorator実装

[id:bellbind:20050906:p3]の続き

概要

Java5のannotationでメソッドにデコレータをつけられるようにしたjavaagentの実装。

具体的には、[id:bellbind:20050906:p3]の最後の考察部分を実現できるようにした。

@Decorate({LoggingDecorator.class, MemorizeDecorator.class})
public long fib(int n) {
    if (n <= 1) return 1;
    return fib(n - 1) + fib(n - 2);
}

さらに省いていた例外処理も、AspectJのdeclare softのように例外をラップして運ぶようにした。

フレームワークコード

Decorate.java

import java.lang.annotation.*;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Decorate {
    public Class<? extends Decorator>[] value();
}

Decorator.java: [id:bellbind:20050906:p3]のと同じ

Call.java

public interface Call {
    public Object proceed(Object[] args) throws ExceptionCarrier;
}

ExceptionCarrier.java

public class ExceptionCarrier extends Error {
    public ExceptionCarrier(Throwable cause) {
        super(cause);
    }
}

DecoratorAgent.java

import java.lang.instrument.*;
import java.security.*;
import java.io.*;
import javassist.*;

public class DecoratorAgent implements ClassFileTransformer {
    private Instrumentation inst;
    private ClassPool pool;
    private CtClass callClass;
    private CtClass objectClass;
    private CtClass objectArrayClass;
    private static String FIELD_PREFIX = "_decorate_$_field_$_";
    private static String ESCAPE_PREFIX = "_decorate_$_original_$_";
    
    public DecoratorAgent(Instrumentation inst) throws Exception {
        this.inst = inst;
        this.pool = new ClassPool();
        this.pool.appendSystemPath();
        this.callClass = this.pool.get("Call");
        this.objectClass = this.pool.get("java.lang.Object");
        this.objectArrayClass = this.pool.get("java.lang.Object[]");
    }
    
    public static void premain(String agentArgs, Instrumentation inst) throws Exception {
        inst.addTransformer(new DecoratorAgent(inst));
    }
    
    public byte[] transform(ClassLoader loader, String className,
                              Class<?> classBeingRedefined,
                              ProtectionDomain protectionDomain,
                              byte[] classfileBuffer) throws IllegalClassFormatException {
        try {
            return this.getDecoratedClass(classfileBuffer);
        } catch (Exception ex) {
            ex.printStackTrace();
            throw new IllegalClassFormatException(ex.getMessage());
        }
    }
    
    private byte[] getDecoratedClass(byte[] classfileBuffer) throws Exception {
        ByteArrayInputStream istream = new ByteArrayInputStream(classfileBuffer);
        CtClass ctClass = this.pool.makeClass(istream);
        int fieldIndex = 0;
        int methodIndex = 0;
        for (CtMethod method: ctClass.getMethods()) {
            Decorate decorate = this.getDecorateAnnotation(method);
            if (decorate == null) continue;
            
            String escapeName = ESCAPE_PREFIX + methodIndex;
            CtMethod escapeMethod = this.escapeMethod(ctClass, escapeName, method);
            methodIndex++;
            
            String callFieldName = FIELD_PREFIX + fieldIndex;
            CtClass escapeClass = this.createEscapeCallClass(ctClass, escapeName, escapeMethod);
            
            this.addEscapeCallField(ctClass, callFieldName, escapeClass);
            fieldIndex++;
            
            for (Class<? extends Decorator> decoratorClass: decorate.value()) {
                String calledFieldName = callFieldName;
                callFieldName = FIELD_PREFIX + fieldIndex;
                this.addDecoratedCallField(ctClass, callFieldName, decoratorClass, calledFieldName);
                fieldIndex++;
            }
            replaceMethod(ctClass, method, callFieldName);
        }
        return ctClass.toBytecode();
    }
    
    private Decorate getDecorateAnnotation(CtMethod method) throws Exception {
        for (Object annotation: method.getAnnotations()) {
            if (annotation instanceof Decorate) {
                return (Decorate) annotation;
            }
        }
        return null;
    }
    
    private CtMethod escapeMethod(CtClass ctClass, String escapeName, CtMethod method) throws Exception {
        // escape original method;
        CtMethod escapeMethod = new CtMethod(method, ctClass, null);
        escapeMethod.setName(escapeName);
        ctClass.addMethod(escapeMethod);
        return escapeMethod;
    }
    
    private CtClass createEscapeCallClass(CtClass ctClass, String escapeName, CtMethod escapeMethod) throws Exception {
        // create original method Call class
        String escapeCallClassName = ctClass.getName() + "$" + escapeName;
        CtClass escapeCallClass = this.pool.makeClass(escapeCallClassName);
        escapeCallClass.addInterface(this.callClass);
        
        CtField parentField = new CtField(ctClass, "parent", escapeCallClass);
        escapeCallClass.addField(parentField);
        
        CtConstructor constructor = new CtConstructor(new CtClass[] {ctClass}, escapeCallClass);
        constructor.setBody("{this.parent = $1;}");
        escapeCallClass.addConstructor(constructor);
        
        CtMethod proceedMethod = new CtMethod(this.objectClass, "proceed", new CtClass[] {this.objectArrayClass}, escapeCallClass);
        String body = "this.parent." + escapeMethod.getName() + "(" + this.unwrapperArgs(escapeMethod, "$1") + ")";
        String bodyCode = "{ " +
          "try { " +
          this.unwrapperReturnCode(escapeMethod, body) +
          "} catch (Throwable th) { " +
          "  throw new ExceptionCarrier(th); " +
          "}" +
          "}";
        //System.out.println(bodyCode);
        proceedMethod.setBody(bodyCode);
        escapeCallClass.addMethod(proceedMethod);
        
        // registering generated class into agent ClassLoader
        Class<?> escapeCallClassObject = this.pool.toClass(escapeCallClass);
        
        return escapeCallClass;
    }
    
    private void addEscapeCallField(CtClass ctClass, String fieldName, CtClass escapeCallClass) throws Exception {
        // add call field and initializer
        CtField callField = new CtField(this.callClass, fieldName, ctClass);
        String initExpr = "new " + escapeCallClass.getName() + "(this)";
        CtField.Initializer callFieldInitializer = CtField.Initializer.byExpr(initExpr);
        ctClass.addField(callField, callFieldInitializer);
    }
    
    private void addDecoratedCallField(CtClass ctClass, String fieldName,
                                       Class<? extends Decorator> decoratorClass,
                                       String prevCallField) throws Exception {
        // add call field and initializer
        CtField callField = new CtField(this.callClass, fieldName, ctClass);
        String initExpr = "new " + decoratorClass.getName() + "().decorate(" + prevCallField + ")";
        CtField.Initializer callFieldInitializer = CtField.Initializer.byExpr(initExpr);
        ctClass.addField(callField, callFieldInitializer);
    }
    
    private void replaceMethod(CtClass ctClass, CtMethod method, String callFieldName) throws Exception {
        // replace caller call
        String code = "{" +
        "try {" + 
        "return ($r) this." + callFieldName + ".proceed($args);" +
        "} catch (ExceptionCarrier ex) {" +
        this.unwrapperThrowCode(method, "ex") +
        "}" +
        "}";
        method.setBody(code);
    }
    
    private String unwrapperThrowCode(CtMethod method, String exceptionName) throws Exception {
        StringBuilder builder = new StringBuilder();
        builder.append("if (" + exceptionName + ".getCause() instanceof RuntimeException) {");
        builder.append("  throw (RuntimeException) " + exceptionName + ".getCause();");
        builder.append("} else if (" + exceptionName + ".getCause() instanceof Error) {");
        builder.append("  throw (Error) " + exceptionName + ".getCause();");
        builder.append("}");   
        for (CtClass exceptionClass: method.getExceptionTypes()) {
            builder.append(" else if (" + exceptionName + ".getCause() instanceof " + exceptionClass.getName() + ") {");
            builder.append("  throw (" + exceptionClass.getName() + ") " + exceptionName + ".getCause();");
            builder.append("}");
        }
        builder.append("throw new Error(" + exceptionName + ");");
        //System.out.println(builder);
        return builder.toString();
    }
    
    
    private String unwrapperReturnCode(CtMethod method, String body) throws Exception {
        CtClass returnType = method.getReturnType();
        if (returnType.equals(CtClass.voidType)) {
            // TBD
            return body + "; return null;";
        } else if (returnType.equals(CtClass.booleanType)) {
            return "return Boolean.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.byteType)) {
            return "return Byte.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.shortType)) {
            return "return Short.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.intType)) {
            return "return Integer.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.longType)) {
            return "return Long.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.floatType)) {
            return "return Float.valueOf(" + body + ");";
        } else if (returnType.equals(CtClass.doubleType)) {
            return "return Double.valueOf(" + body + ");";
        } else {
            return "return " + body + ";";
        }
    }
    
    // convert object array to flatten parameters
    private String unwrapperArgs(CtMethod method, String variable) throws Exception {
        StringBuilder buffer = new StringBuilder();
        int count = 0;
        for (CtClass paramType: method.getParameterTypes()) {
            if (count > 0) buffer.append(", ");
            int argIndex = count;
            if (paramType.equals(CtClass.booleanType)) {
                buffer.append("((Boolean) " + variable + "[" + argIndex + "]).booleanValue()");
            } else if (paramType.equals(CtClass.byteType)) {
                buffer.append("((Byte) " + variable + "[" + argIndex + "]).byteValue()");
            } else if (paramType.equals(CtClass.shortType)) {
                buffer.append("((Short) " + variable + "[" + argIndex + "]).shortValue()");
            } else if (paramType.equals(CtClass.intType)) {
                buffer.append("((Integer) " + variable + "[" + argIndex + "]).intValue()");
            } else if (paramType.equals(CtClass.longType)) {
                buffer.append("((Long) " + variable + "[" + argIndex + "]).longValue()");
            } else if (paramType.equals(CtClass.floatType)) {
                buffer.append("((Float) " + variable + "[" + argIndex + "]).floatValue()");
            } else if (paramType.equals(CtClass.doubleType)) {
                buffer.append("((Double) " + variable + "[" + argIndex + "]).doubleValue()");
            } else {
                buffer.append(variable + "[" + argIndex + "]");
            }
            count++;
        }
        return buffer.toString();
    }
    
}

基本的な流れは変更なし。Callのリストの取り出し方が変わったのと例外対策を追加しただけ。

効率化はなし。この負荷はクラスロード時だけだし、こういうタイプのクラス生成で本気で効率化するならバイトコードレベルでやるし、そういう用途ならJavassistではなくASMを使うほうが分かりやすいだろうし。

考察

実装的にデコレーションの順番は左側に置いたものから包んでいく。最初の例の場合ならcall = Memorize.decorate(Logging.decorate(original));という感じ。

decorateは情報不足な気もするが、安易に情報を増やすと複雑化しそうなのでそこは目的整合性をとるべき。