Kengo's blog

Technical articles about original projects, JVM, Static Analysis and TypeScript.

ASMで実行時にメソッド定義を書き換える

今さらながらJavaクラスに対する動的なweavingを試してみました。DIコンテナフレームワークに任せることが一般的になった領域であり技術が直接役に立つことはないでしょうが、JVMやクラスファイルに対する理解を深めるきっかけにもなると考えたためです。
JavaバイトコードをいじるためのライブラリにはAspectJjavassistが有名ですが、今回はPMDが使用しているASMを選択しました。他のライブラリに比べ、軽快さやライブラリの小ささで優れているようです。

ライブラリの入手

ASMウェブサイトから最新のzipを入手して解凍します。今回はバージョン3.3を使用しました。
zipに含まれているjarにはいくつか種類がありますが、特に理由がなければ全部入りのasm-all-3.3.jarを使えばいいでしょう。これをCLASSPATHに追加してやれば準備完了です。依存するライブラリなどはありません。

ドキュメントを読む

ASMはドキュメントの充実っぷりが素晴らしいです。javadocもPDFもきちんと整備されており、スタックやフレームに関する知識があやふやでも学びながら読むことができます。
残念ながらすべて英語ですが、特に難しい言い回しはないように思います。まずはユーザガイド(PDF)から読まれることをおすすめします。特に英語が得意でない私でも4〜5時間程度で目を通すことができましたので、量もそこまで多くないと言えるでしょう。

目標を立てる

今回は「指定したメソッド名を持つメソッドに対して処理を埋め込み、System.errにメソッド実行にかかった時間を出力する」ことを目的としました。実践的な内容であるとともに、簡単すぎない適度な難易度であると考えたためです。

使うクラスを決める

目標を決めたら、ユーザガイドを参考にして必要なクラスを洗い出していきます。
メソッドに対して手を加えるのでMethodVisitorは必須でしょう。ただし今回は処理開始時間を記録するためのローカル変数を増やす必要があるため、newLocalメソッドを持つサブクラスLocalVariablesSorterの方が便利そうです(63ページ参照)。
またMethodVisitorを呼び出すクラスには、ClassAdapterを使用すればよさそうです。ではClassAdapterを呼び出すクラスは……?
ClassLoaderを継承する方法もありますが、今回はjava.lang.instrumentパッケージのClassFileTransformerを実装することでmainメソッド実行前にバイトコード変換用クラスを登録する手法をとりました。
なおこの手法はJava5以降専用であり、1.4以前では使えないようです。詳しくはHisidamaさんのサイトやinstrumentパッケージのjavadocに書かれていますので、そちらを参照のこと。

書いてみた

イメージとしては、メソッドの最初に

long timer = System.currentTimeMillis();

を、メソッドの最後に

System.err.println(System.currentTimeMillis() - timer);

を書き加えれば目的を達成できそうです。どのクラスのどのメソッドでかかった時間なのかを明確にするために、クラス名やメソッド名やメソッドのディスクリプタなどを併記すればより便利になりそうです。


以上の処理を実装したものが以下のクラスです。

import org.objectweb.asm.ClassAdapter;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.LocalVariablesSorter;

import static org.objectweb.asm.Opcodes.*;

public class AddLoggingAdapter extends ClassAdapter {
	private boolean isInterface;
	private String className;
	private String targetMethodName;

	public AddLoggingAdapter(ClassVisitor cv, String targetMethodName) {
		super(cv);
		this.targetMethodName = targetMethodName;
	}

	@Override
	public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
		cv.visit(version, access, name, signature, superName, interfaces);
		isInterface = (access & ACC_INTERFACE) != 0;
		className = name;
	}

	@Override
	public MethodVisitor visitMethod(int access, String name, String desc,
			String signature, String[] exceptions) {
		MethodVisitor mv = cv.visitMethod(access, name, desc, signature, exceptions);
		if (!isInterface && mv != null && name.equals(targetMethodName)) {
			mv = new AddLoggingVisitor(access, className, name, signature, desc, mv);
		}
		return mv;
	}

	private static final class AddLoggingVisitor extends LocalVariablesSorter {
		private int time;
		private final String header;

		public AddLoggingVisitor(int access, String className, String methodName, String signature, String desc, MethodVisitor mv) {
			super(access, desc, mv);
			this.header = className + '.' + methodName + desc + ": ";
		}

		@Override
		public void visitCode() {
			super.visitCode();
			mv.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J");
			time = newLocal(Type.LONG_TYPE);
			mv.visitVarInsn(LSTORE, time);
		}

		@Override
		public void visitInsn(int opcode) {
			if ((opcode >= IRETURN && opcode <= RETURN) || opcode == ATHROW) {
				mv.visitFieldInsn(GETSTATIC, "java/lang/System", "err", "Ljava/io/PrintStream;");
				mv.visitInsn(DUP);
				mv.visitLdcInsn(header);
				mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "print", "(Ljava/lang/String;)V");
				mv.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J");
				mv.visitVarInsn(LLOAD, time);
				mv.visitInsn(LSUB);
				mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(J)V");
			}
			super.visitInsn(opcode);
		}

		@Override
		public void visitMaxs(int maxStack, int maxLocals) {
			super.visitMaxs(maxStack + 8, maxLocals);
		}
	}
}

内部クラスを使っているので若干長く感じますが、バイトコードDUPやLSUBといった基本的なスタック処理やINVOKESTATICやINVOKEVIRTUALといったメソッド呼び出ししか使っていません。スタックのイメージが湧けば、ASMを使ってバイトコード処理を書くのはかなり簡単と言えるでしょう。
JVMのスタックについてはJVM勉強会における櫻庭さんの発表が分かりやすかったので、資料を貼っておきます。

使ってみる

では実際に動作確認をしてみましょう。
今回作ったClassAdapterをClassReader#acceptに渡してやる場合、第2引数にClassReader.EXPAND_FRAMESを指定する必要があることに注意が必要です。ローカル変数を動的に増やしているため、フレームを拡張する必要があるためです。

ここだけ注意すればあとは簡単に使うことができます。例えばClassFileTransformerを以下のように実装するだけです。

import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;

import org.objectweb.asm.ClassAdapter;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;

final class MyTransformer implements ClassFileTransformer {
	@Override
	public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) {
		ClassWriter cw = new ClassWriter(0);
		ClassAdapter ca = new AddLoggingAdapter(cw, "check");
		ClassReader cr = new ClassReader(classfileBuffer);
		cr.accept(ca, ClassReader.EXPAND_FRAMES);
		return cw.toByteArray();
	}
}

で、これを実行するとcheckメソッド実行時に

asm/test/Hoge.check()V: 500
asm/test/Hoge.check()V: 600
asm/test/Hoge.check()V: 700

と実行にかかった時間がミリ秒単位で標準エラー出力に出るようになります。やったね!