1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5  * in compliance with the License. You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software distributed under the License
10  * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11  * or implied. See the License for the specific language governing permissions and limitations under
12  * the License.
13  */
14 package lockedregioncodeinjection;
15 
16 import static com.google.common.base.Preconditions.checkElementIndex;
17 import static com.google.common.base.Preconditions.checkNotNull;
18 import static com.google.common.base.Preconditions.checkState;
19 
20 import org.objectweb.asm.ClassVisitor;
21 import org.objectweb.asm.MethodVisitor;
22 import org.objectweb.asm.Opcodes;
23 import org.objectweb.asm.commons.TryCatchBlockSorter;
24 import org.objectweb.asm.tree.AbstractInsnNode;
25 import org.objectweb.asm.tree.InsnList;
26 import org.objectweb.asm.tree.InsnNode;
27 import org.objectweb.asm.tree.LabelNode;
28 import org.objectweb.asm.tree.LineNumberNode;
29 import org.objectweb.asm.tree.MethodInsnNode;
30 import org.objectweb.asm.tree.MethodNode;
31 import org.objectweb.asm.tree.TypeInsnNode;
32 import org.objectweb.asm.tree.TryCatchBlockNode;
33 import org.objectweb.asm.tree.analysis.Analyzer;
34 import org.objectweb.asm.tree.analysis.AnalyzerException;
35 import org.objectweb.asm.tree.analysis.BasicValue;
36 import org.objectweb.asm.tree.analysis.Frame;
37 
38 import java.util.ArrayList;
39 import java.util.Arrays;
40 import java.util.LinkedList;
41 import java.util.List;
42 
43 /**
44  * This visitor operates on two kinds of targets.  For a legacy target, it does the following:
45  *
46  * 1. Finds all the MONITOR_ENTER / MONITOR_EXIT in the byte code and inserts the corresponding pre
47  * and post methods calls should it matches one of the given target type in the Configuration.
48  *
49  * 2. Find all methods that are synchronized and insert pre method calls in the beginning and post
50  * method calls just before all return instructions.
51  *
52  * For a scoped target, it does the following:
53  *
54  * 1. Finds all the MONITOR_ENTER instructions in the byte code.  If the target of the opcode is
55  *    named in a --scope switch, then the pre method is invoked ON THE TARGET immediately after
56  *    MONITOR_ENTER opcode completes.
57  *
58  * 2. Finds all the MONITOR_EXIT instructions in the byte code.  If the target of the opcode is
59  *    named in a --scope switch, then the post method is invoked ON THE TARGET immediately before
60  *    MONITOR_EXIT opcode completes.
61  */
62 class LockFindingClassVisitor extends ClassVisitor {
63     private String className = null;
64     private final List<LockTarget> targets;
65 
LockFindingClassVisitor(List<LockTarget> targets, ClassVisitor chain)66     public LockFindingClassVisitor(List<LockTarget> targets, ClassVisitor chain) {
67         super(Utils.ASM_VERSION, chain);
68         this.targets = targets;
69     }
70 
71     @Override
visitMethod(int access, String name, String desc, String signature, String[] exceptions)72     public MethodVisitor visitMethod(int access, String name, String desc, String signature,
73             String[] exceptions) {
74         assert this.className != null;
75         MethodNode mn = new TryCatchBlockSorter(null, access, name, desc, signature, exceptions);
76         MethodVisitor chain = super.visitMethod(access, name, desc, signature, exceptions);
77         return new LockFindingMethodVisitor(this.className, mn, chain);
78     }
79 
80     @Override
visit(int version, int access, String name, String signature, String superName, String[] interfaces)81     public void visit(int version, int access, String name, String signature, String superName,
82             String[] interfaces) {
83         this.className = name;
84         super.visit(version, access, name, signature, superName, interfaces);
85     }
86 
87     class LockFindingMethodVisitor extends MethodVisitor {
88         private String owner;
89         private MethodVisitor chain;
90         private final String className;
91         private final String methodName;
92 
LockFindingMethodVisitor(String owner, MethodNode mn, MethodVisitor chain)93         public LockFindingMethodVisitor(String owner, MethodNode mn, MethodVisitor chain) {
94             super(Utils.ASM_VERSION, mn);
95             assert owner != null;
96             this.owner = owner;
97             this.chain = chain;
98             className = owner;
99             methodName = mn.name;
100         }
101 
102         @SuppressWarnings("unchecked")
103         @Override
visitEnd()104         public void visitEnd() {
105             MethodNode mn = (MethodNode) mv;
106 
107             Analyzer a = new Analyzer(new LockTargetStateAnalysis(targets));
108 
109             LockTarget ownerMonitor = null;
110             if ((mn.access & Opcodes.ACC_SYNCHRONIZED) != 0) {
111                 for (LockTarget t : targets) {
112                     if (t.getTargetDesc().equals("L" + owner + ";")) {
113                         ownerMonitor = t;
114                         if (ownerMonitor.getScoped()) {
115                             final String emsg = String.format(
116                                 "scoped targets do not support synchronized methods in %s.%s()",
117                                 className, methodName);
118                             throw new RuntimeException(emsg);
119                         }
120                     }
121                 }
122             }
123 
124             try {
125                 a.analyze(owner, mn);
126             } catch (AnalyzerException e) {
127                 throw new RuntimeException("Locked region code injection: " + e.getMessage(), e);
128             }
129             InsnList instructions = mn.instructions;
130 
131             Frame[] frames = a.getFrames();
132             List<Frame> frameMap = new LinkedList<>();
133             frameMap.addAll(Arrays.asList(frames));
134 
135             List<List<TryCatchBlockNode>> handlersMap = new LinkedList<>();
136 
137             for (int i = 0; i < instructions.size(); i++) {
138                 handlersMap.add(a.getHandlers(i));
139             }
140 
141             if (ownerMonitor != null) {
142                 AbstractInsnNode s = instructions.getFirst();
143                 MethodInsnNode call = new MethodInsnNode(Opcodes.INVOKESTATIC,
144                         ownerMonitor.getPreOwner(), ownerMonitor.getPreMethod(), "()V", false);
145                 insertMethodCallBeforeSync(mn, frameMap, handlersMap, s, 0, call);
146             }
147 
148             boolean anyDup = false;
149 
150             for (int i = 0; i < instructions.size(); i++) {
151                 AbstractInsnNode s = instructions.get(i);
152 
153                 if (s.getOpcode() == Opcodes.MONITORENTER) {
154                     Frame f = frameMap.get(i);
155                     BasicValue operand = (BasicValue) f.getStack(f.getStackSize() - 1);
156                     if (operand instanceof LockTargetState) {
157                         LockTargetState state = (LockTargetState) operand;
158                         for (int j = 0; j < state.getTargets().size(); j++) {
159                             LockTarget target = state.getTargets().get(j);
160                             MethodInsnNode call = methodCall(target, true);
161                             if (target.getScoped()) {
162                                 TypeInsnNode cast = typeCast(target);
163                                 i += insertInvokeAcquire(mn, frameMap, handlersMap, s, i,
164                                         call, cast);
165                                 anyDup = true;
166                             } else {
167                                 i += insertMethodCallBefore(mn, frameMap, handlersMap, s, i, call);
168                             }
169                         }
170                     }
171                 }
172 
173                 if (s.getOpcode() == Opcodes.MONITOREXIT) {
174                     Frame f = frameMap.get(i);
175                     BasicValue operand = (BasicValue) f.getStack(f.getStackSize() - 1);
176                     if (operand instanceof LockTargetState) {
177                         LockTargetState state = (LockTargetState) operand;
178                         for (int j = 0; j < state.getTargets().size(); j++) {
179                             // The instruction after a monitor_exit should be a label for
180                             // the end of the implicit catch block that surrounds the
181                             // synchronized block to call monitor_exit when an exception
182                             // occurs.
183                             checkState(instructions.get(i + 1).getType() == AbstractInsnNode.LABEL,
184                                 "Expected to find label after monitor exit");
185 
186                             int labelIndex = i + 1;
187                             checkElementIndex(labelIndex, instructions.size());
188 
189                             LabelNode label = (LabelNode)instructions.get(labelIndex);
190 
191                             checkNotNull(handlersMap.get(i));
192                             checkElementIndex(0, handlersMap.get(i).size());
193                             checkState(handlersMap.get(i).get(0).end == label,
194                                 "Expected label to be the end of monitor exit's try block");
195 
196                             LockTarget target = state.getTargets().get(j);
197                             MethodInsnNode call = methodCall(target, false);
198                             if (target.getScoped()) {
199                                 TypeInsnNode cast = typeCast(target);
200                                 i += insertInvokeRelease(mn, frameMap, handlersMap, s, i,
201                                         call, cast);
202                                 anyDup = true;
203                             } else {
204                                 insertMethodCallAfter(mn, frameMap, handlersMap, label,
205                                         labelIndex, call);
206                             }
207                         }
208                     }
209                 }
210 
211                 if (ownerMonitor != null && (s.getOpcode() == Opcodes.RETURN
212                         || s.getOpcode() == Opcodes.ARETURN || s.getOpcode() == Opcodes.DRETURN
213                         || s.getOpcode() == Opcodes.FRETURN || s.getOpcode() == Opcodes.IRETURN)) {
214                     MethodInsnNode call =
215                             new MethodInsnNode(Opcodes.INVOKESTATIC, ownerMonitor.getPostOwner(),
216                                     ownerMonitor.getPostMethod(), "()V", false);
217                     insertMethodCallBeforeSync(mn, frameMap, handlersMap, s, i, call);
218                     i++; // Skip ahead. Otherwise, we will revisit this instruction again.
219                 }
220             }
221 
222             if (anyDup) {
223                 mn.maxStack++;
224             }
225 
226             super.visitEnd();
227             mn.accept(chain);
228         }
229 
230         // Insert a call to a monitor pre handler.  The node and the index identify the
231         // monitorenter call itself.  Insert DUP immediately prior to the MONITORENTER.
232         // Insert the typecast and call (in that order) after the MONITORENTER.
insertInvokeAcquire(MethodNode mn, List<Frame> frameMap, List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index, MethodInsnNode call, TypeInsnNode cast)233         public int insertInvokeAcquire(MethodNode mn, List<Frame> frameMap,
234                 List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
235                 MethodInsnNode call, TypeInsnNode cast) {
236             InsnList instructions = mn.instructions;
237 
238             // Insert a DUP right before MONITORENTER, to capture the object being locked.
239             // Note that the object will be typed as java.lang.Object.
240             instructions.insertBefore(node, new InsnNode(Opcodes.DUP));
241             frameMap.add(index, frameMap.get(index));
242             handlersMap.add(index, handlersMap.get(index));
243 
244             // Insert the call right after the MONITORENTER.  These entries are pushed after
245             // MONITORENTER so they are inserted in reverse order.  MONITORENTER should be
246             // the target of a try/catch block, which means it must be immediately
247             // followed by a label (which is part of the try/catch block definition).
248             // Move forward past the label so the invocation in inside the proper block.
249             // Throw an error if the next instruction is not a label.
250             node = node.getNext();
251             if (!(node instanceof LabelNode)) {
252                 throw new RuntimeException(String.format("invalid bytecode sequence in %s.%s()",
253                                 className, methodName));
254             }
255             node = node.getNext();
256             index = instructions.indexOf(node);
257 
258             instructions.insertBefore(node, cast);
259             frameMap.add(index, frameMap.get(index));
260             handlersMap.add(index, handlersMap.get(index));
261 
262             instructions.insertBefore(node, call);
263             frameMap.add(index, frameMap.get(index));
264             handlersMap.add(index, handlersMap.get(index));
265 
266             return 3;
267         }
268 
269         // Insert instructions completely before the current opcode.  This is slightly
270         // different from insertMethodCallBefore(), which inserts the call before MONITOREXIT
271         // but inserts the start and end labels after MONITOREXIT.
insertInvokeRelease(MethodNode mn, List<Frame> frameMap, List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index, MethodInsnNode call, TypeInsnNode cast)272         public int insertInvokeRelease(MethodNode mn, List<Frame> frameMap,
273                 List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
274                 MethodInsnNode call, TypeInsnNode cast) {
275             InsnList instructions = mn.instructions;
276 
277             instructions.insertBefore(node, new InsnNode(Opcodes.DUP));
278             frameMap.add(index, frameMap.get(index));
279             handlersMap.add(index, handlersMap.get(index));
280 
281             instructions.insertBefore(node, cast);
282             frameMap.add(index, frameMap.get(index));
283             handlersMap.add(index, handlersMap.get(index));
284 
285             instructions.insertBefore(node, call);
286             frameMap.add(index, frameMap.get(index));
287             handlersMap.add(index, handlersMap.get(index));
288 
289             return 3;
290         }
291     }
292 
methodCall(LockTarget target, boolean pre)293     public static MethodInsnNode methodCall(LockTarget target, boolean pre) {
294         String spec = "()V";
295         if (!target.getScoped()) {
296             if (pre) {
297                 return new MethodInsnNode(
298                     Opcodes.INVOKESTATIC, target.getPreOwner(), target.getPreMethod(), spec);
299             } else {
300                 return new MethodInsnNode(
301                     Opcodes.INVOKESTATIC, target.getPostOwner(), target.getPostMethod(), spec);
302             }
303         } else {
304             if (pre) {
305                 return new MethodInsnNode(
306                     Opcodes.INVOKEVIRTUAL, target.getPreOwner(), target.getPreMethod(), spec);
307             } else {
308                 return new MethodInsnNode(
309                     Opcodes.INVOKEVIRTUAL, target.getPostOwner(), target.getPostMethod(), spec);
310             }
311         }
312     }
313 
typeCast(LockTarget target)314     public static TypeInsnNode typeCast(LockTarget target) {
315         if (!target.getScoped()) {
316             return null;
317         } else {
318             // preOwner and postOwner return the same string for scoped targets.
319             return new TypeInsnNode(Opcodes.CHECKCAST, target.getPreOwner());
320         }
321     }
322 
323     /**
324      * Insert a method call before the beginning or end of a synchronized method.
325      */
insertMethodCallBeforeSync(MethodNode mn, List<Frame> frameMap, List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index, MethodInsnNode call)326     public static void insertMethodCallBeforeSync(MethodNode mn, List<Frame> frameMap,
327             List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
328             MethodInsnNode call) {
329         List<TryCatchBlockNode> handlers = handlersMap.get(index);
330         InsnList instructions = mn.instructions;
331         LabelNode end = new LabelNode();
332         instructions.insert(node, end);
333         frameMap.add(index, null);
334         handlersMap.add(index, null);
335         instructions.insertBefore(node, call);
336         frameMap.add(index, null);
337         handlersMap.add(index, null);
338 
339         LabelNode start = new LabelNode();
340         instructions.insert(node, start);
341         frameMap.add(index, null);
342         handlersMap.add(index, null);
343         updateCatchHandler(mn, handlers, start, end, handlersMap);
344     }
345 
insertMethodCallAfter(MethodNode mn, List<Frame> frameMap, List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index, MethodInsnNode call)346     public static void insertMethodCallAfter(MethodNode mn, List<Frame> frameMap,
347             List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
348             MethodInsnNode call) {
349         List<TryCatchBlockNode> handlers = handlersMap.get(index + 1);
350         InsnList instructions = mn.instructions;
351 
352         LabelNode end = new LabelNode();
353         instructions.insert(node, end);
354         frameMap.add(index + 1, null);
355         handlersMap.add(index + 1, null);
356 
357         instructions.insert(node, call);
358         frameMap.add(index + 1, null);
359         handlersMap.add(index + 1, null);
360 
361         LabelNode start = new LabelNode();
362         instructions.insert(node, start);
363         frameMap.add(index + 1, null);
364         handlersMap.add(index + 1, null);
365 
366         updateCatchHandler(mn, handlers, start, end, handlersMap);
367     }
368 
369     // Insert instructions completely before the current opcode.  This is slightly different from
370     // insertMethodCallBeforeSync(), which inserts the call before MONITOREXIT but inserts the
371     // start and end labels after MONITOREXIT.
insertMethodCallBefore(MethodNode mn, List<Frame> frameMap, List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index, MethodInsnNode call)372     public int insertMethodCallBefore(MethodNode mn, List<Frame> frameMap,
373             List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
374             MethodInsnNode call) {
375         InsnList instructions = mn.instructions;
376 
377         instructions.insertBefore(node, call);
378         frameMap.add(index, frameMap.get(index));
379         handlersMap.add(index, handlersMap.get(index));
380 
381         return 1;
382     }
383 
384 
385     @SuppressWarnings("unchecked")
updateCatchHandler(MethodNode mn, List<TryCatchBlockNode> handlers, LabelNode start, LabelNode end, List<List<TryCatchBlockNode>> handlersMap)386     public static void updateCatchHandler(MethodNode mn, List<TryCatchBlockNode> handlers,
387             LabelNode start, LabelNode end, List<List<TryCatchBlockNode>> handlersMap) {
388         if (handlers == null || handlers.size() == 0) {
389             return;
390         }
391 
392         InsnList instructions = mn.instructions;
393         List<TryCatchBlockNode> newNodes = new ArrayList<>(handlers.size());
394         for (TryCatchBlockNode handler : handlers) {
395             if (!(instructions.indexOf(handler.start) <= instructions.indexOf(start)
396                     && instructions.indexOf(end) <= instructions.indexOf(handler.end))) {
397                 TryCatchBlockNode newNode =
398                         new TryCatchBlockNode(start, end, handler.handler, handler.type);
399                 newNodes.add(newNode);
400                 for (int i = instructions.indexOf(start); i <= instructions.indexOf(end); i++) {
401                     if (handlersMap.get(i) == null) {
402                         handlersMap.set(i, new ArrayList<>());
403                     }
404                     handlersMap.get(i).add(newNode);
405                 }
406             } else {
407                 for (int i = instructions.indexOf(start); i <= instructions.indexOf(end); i++) {
408                     if (handlersMap.get(i) == null) {
409                         handlersMap.set(i, new ArrayList<>());
410                     }
411                     handlersMap.get(i).add(handler);
412                 }
413             }
414         }
415         mn.tryCatchBlocks.addAll(0, newNodes);
416     }
417 }
418