LockAgent: Refactor transformation code

Extract a helper class.

Bug: 124744938
Test: m
Test: manual
Change-Id: I9fb6b4b5d7bc47f4233cacef8a192989563ce279
diff --git a/tools/lock_agent/agent.cpp b/tools/lock_agent/agent.cpp
index 59bfa2b..8a775f8 100644
--- a/tools/lock_agent/agent.cpp
+++ b/tools/lock_agent/agent.cpp
@@ -77,118 +77,143 @@
 using namespace dex;
 using namespace lir;
 
-bool transform(std::shared_ptr<ir::DexFile> dexIr) {
-    bool modified = false;
+class Transformer {
+public:
+    explicit Transformer(std::shared_ptr<ir::DexFile> dexIr) : dexIr_(dexIr) {}
 
-    std::unique_ptr<ir::Builder> builder;
+    bool transform() {
+        bool classModified = false;
 
-    for (auto& method : dexIr->encoded_methods) {
-        // Do not look into abstract/bridge/native/synthetic methods.
-        if ((method->access_flags & (kAccAbstract | kAccBridge | kAccNative | kAccSynthetic))
-                != 0) {
-            continue;
+        std::unique_ptr<ir::Builder> builder;
+
+        for (auto& method : dexIr_->encoded_methods) {
+            // Do not look into abstract/bridge/native/synthetic methods.
+            if ((method->access_flags & (kAccAbstract | kAccBridge | kAccNative | kAccSynthetic))
+                    != 0) {
+                continue;
+            }
+
+            struct HookVisitor: public Visitor {
+                HookVisitor(Transformer* transformer, CodeIr* c_ir)
+                        : transformer(transformer), cIr(c_ir) {
+                }
+
+                bool Visit(Bytecode* bytecode) override {
+                    if (bytecode->opcode == OP_MONITOR_ENTER) {
+                        insertHook(bytecode, true,
+                                reinterpret_cast<VReg*>(bytecode->operands[0])->reg);
+                        return true;
+                    }
+                    if (bytecode->opcode == OP_MONITOR_EXIT) {
+                        insertHook(bytecode, false,
+                                reinterpret_cast<VReg*>(bytecode->operands[0])->reg);
+                        return true;
+                    }
+                    return false;
+                }
+
+                void insertHook(lir::Instruction* before, bool pre, u4 reg) {
+                    transformer->preparePrePost();
+                    transformer->addCall(cIr, before, OP_INVOKE_STATIC_RANGE,
+                            transformer->hookType_, pre ? "preLock" : "postLock",
+                            transformer->voidType_, transformer->objectType_, reg);
+                    myModified = true;
+                }
+
+                Transformer* transformer;
+                CodeIr* cIr;
+                bool myModified = false;
+            };
+
+            CodeIr c(method.get(), dexIr_);
+            bool methodModified = false;
+
+            HookVisitor visitor(this, &c);
+            for (auto it = c.instructions.begin(); it != c.instructions.end(); ++it) {
+                lir::Instruction* fi = *it;
+                fi->Accept(&visitor);
+            }
+            methodModified |= visitor.myModified;
+
+            if (methodModified) {
+                classModified = true;
+                c.Assemble();
+            }
         }
 
-        struct HookVisitor: public Visitor {
-            HookVisitor(std::unique_ptr<ir::Builder>* b, std::shared_ptr<ir::DexFile> d_ir,
-                    CodeIr* c_ir) :
-                    b(b), dIr(d_ir), cIr(c_ir) {
-            }
+        return classModified;
+    }
 
-            bool Visit(Bytecode* bytecode) override {
-                if (bytecode->opcode == OP_MONITOR_ENTER) {
-                    prepare();
-                    addCall(bytecode, OP_INVOKE_STATIC_RANGE, hookType, "preLock", voidType,
-                            objectType, reinterpret_cast<VReg*>(bytecode->operands[0])->reg);
-                    myModified = true;
-                    return true;
-                }
-                if (bytecode->opcode == OP_MONITOR_EXIT) {
-                    prepare();
-                    addCall(bytecode, OP_INVOKE_STATIC_RANGE, hookType, "postLock", voidType,
-                            objectType, reinterpret_cast<VReg*>(bytecode->operands[0])->reg);
-                    myModified = true;
-                    return true;
-                }
-                return false;
-            }
+private:
+    void preparePrePost() {
+        // Insert "void LockHook.(pre|post)(Object o)."
 
-            void prepare() {
-                if (*b == nullptr) {
-                    *b = std::unique_ptr<ir::Builder>(new ir::Builder(dIr));
-                }
-                if (voidType == nullptr) {
-                    voidType = (*b)->GetType("V");
-                    hookType = (*b)->GetType("Lcom/android/lock_checker/LockHook;");
-                    objectType = (*b)->GetType("Ljava/lang/Object;");
-                }
-            }
+        prepareBuilder();
 
-            void addInst(lir::Instruction* instructionAfter, Opcode opcode,
-                    const std::list<Operand*>& operands) {
-                auto instruction = cIr->Alloc<Bytecode>();
-
-                instruction->opcode = opcode;
-
-                for (auto it = operands.begin(); it != operands.end(); it++) {
-                    instruction->operands.push_back(*it);
-                }
-
-                cIr->instructions.InsertBefore(instructionAfter, instruction);
-            }
-
-            void addCall(lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type,
-                    const char* methodName, ir::Type* returnType,
-                    const std::vector<ir::Type*>& types, const std::list<int>& regs) {
-                auto proto = (*b)->GetProto(returnType, (*b)->GetTypeList(types));
-                auto method = (*b)->GetMethodDecl((*b)->GetAsciiString(methodName), proto, type);
-
-                VRegList* paramRegs = cIr->Alloc<VRegList>();
-                for (auto it = regs.begin(); it != regs.end(); it++) {
-                    paramRegs->registers.push_back(*it);
-                }
-
-                addInst(instructionAfter, opcode,
-                        { paramRegs, cIr->Alloc<Method>(method, method->orig_index) });
-            }
-
-            void addCall(lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type,
-                    const char* methodName, ir::Type* returnType, ir::Type* paramType,
-                    u4 paramVReg) {
-                auto proto = (*b)->GetProto(returnType, (*b)->GetTypeList( { paramType }));
-                auto method = (*b)->GetMethodDecl((*b)->GetAsciiString(methodName), proto, type);
-
-                VRegRange* args = cIr->Alloc<VRegRange>(paramVReg, 1);
-
-                addInst(instructionAfter, opcode,
-                        { args, cIr->Alloc<Method>(method, method->orig_index) });
-            }
-
-            std::unique_ptr<ir::Builder>* b;
-            std::shared_ptr<ir::DexFile> dIr;
-            CodeIr* cIr;
-            ir::Type* voidType = nullptr;
-            ir::Type* hookType = nullptr;
-            ir::Type* objectType = nullptr;
-            bool myModified = false;
-        };
-
-        CodeIr c(method.get(), dexIr);
-        HookVisitor visitor(&builder, dexIr, &c);
-
-        for (auto it = c.instructions.begin(); it != c.instructions.end(); ++it) {
-            lir::Instruction* fi = *it;
-            fi->Accept(&visitor);
+        if (voidType_ == nullptr) {
+            voidType_ = builder_->GetType("V");
         }
-
-        if (visitor.myModified) {
-            modified = true;
-            c.Assemble();
+        if (hookType_ == nullptr) {
+            hookType_ = builder_->GetType("Lcom/android/lock_checker/LockHook;");
+        }
+        if (objectType_ == nullptr) {
+            objectType_ = builder_->GetType("Ljava/lang/Object;");
         }
     }
 
-    return modified;
-}
+    void prepareBuilder() {
+        if (builder_ == nullptr) {
+            builder_ = std::unique_ptr<ir::Builder>(new ir::Builder(dexIr_));
+        }
+    }
+
+    static void addInst(CodeIr* cIr, lir::Instruction* instructionAfter, Opcode opcode,
+            const std::list<Operand*>& operands) {
+        auto instruction = cIr->Alloc<Bytecode>();
+
+        instruction->opcode = opcode;
+
+        for (auto it = operands.begin(); it != operands.end(); it++) {
+            instruction->operands.push_back(*it);
+        }
+
+        cIr->instructions.InsertBefore(instructionAfter, instruction);
+    }
+
+    void addCall(CodeIr* cIr, lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type,
+            const char* methodName, ir::Type* returnType,
+            const std::vector<ir::Type*>& types, const std::list<int>& regs) {
+        auto proto = builder_->GetProto(returnType, builder_->GetTypeList(types));
+        auto method = builder_->GetMethodDecl(builder_->GetAsciiString(methodName), proto, type);
+
+        VRegList* paramRegs = cIr->Alloc<VRegList>();
+        for (auto it = regs.begin(); it != regs.end(); it++) {
+            paramRegs->registers.push_back(*it);
+        }
+
+        addInst(cIr, instructionAfter, opcode,
+                { paramRegs, cIr->Alloc<Method>(method, method->orig_index) });
+    }
+
+    void addCall(CodeIr* cIr, lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type,
+            const char* methodName, ir::Type* returnType, ir::Type* paramType,
+            u4 paramVReg) {
+        auto proto = builder_->GetProto(returnType, builder_->GetTypeList( { paramType }));
+        auto method = builder_->GetMethodDecl(builder_->GetAsciiString(methodName), proto, type);
+
+        VRegRange* args = cIr->Alloc<VRegRange>(paramVReg, 1);
+
+        addInst(cIr, instructionAfter, opcode,
+                { args, cIr->Alloc<Method>(method, method->orig_index) });
+    }
+
+    std::shared_ptr<ir::DexFile> dexIr_;
+    std::unique_ptr<ir::Builder> builder_;
+
+    ir::Type* voidType_ = nullptr;
+    ir::Type* hookType_ = nullptr;
+    ir::Type* objectType_ = nullptr;
+};
 
 std::pair<dex::u1*, size_t> maybeTransform(const char* name, size_t classDataLen,
         const unsigned char* classData, dex::Writer::Allocator* allocator) {
@@ -201,8 +226,11 @@
     reader.CreateClassIr(index);
     std::shared_ptr<ir::DexFile> ir = reader.GetIr();
 
-    if (!transform(ir)) {
-        return std::make_pair(nullptr, 0);
+    {
+        Transformer transformer(ir);
+        if (!transformer.transform()) {
+            return std::make_pair(nullptr, 0);
+        }
     }
 
     size_t new_size;