diff --git a/lab6/llvm-pass.so.cc b/lab6/llvm-pass.so.cc index 6c6e17e..d6ec5d6 100644 --- a/lab6/llvm-pass.so.cc +++ b/lab6/llvm-pass.so.cc @@ -1,34 +1,79 @@ #include "llvm/Passes/PassPlugin.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" using namespace llvm; +namespace { -struct LLVMPass : public PassInfoMixin { - PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); -}; + struct LLVMPass : public PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + //to get the LLVM context + LLVMContext &Context = M.getContext(); -PreservedAnalyses LLVMPass::run(Module &M, ModuleAnalysisManager &MAM) { - LLVMContext &Ctx = M.getContext(); - IntegerType *Int32Ty = IntegerType::getInt32Ty(Ctx); - FunctionCallee debug_func = M.getOrInsertFunction("debug", Int32Ty); - ConstantInt *debug_arg = ConstantInt::get(Int32Ty, 48763); + Function *MainFunction = M.getFunction("main"); + if (!MainFunction || MainFunction->arg_size() < 2) + return PreservedAnalyses::all(); - for (auto &F : M) { - errs() << "func: " << F.getName() << "\n"; + Function::arg_iterator ArgsIterator = MainFunction->arg_begin(); + Argument *ArgcArgument = &*ArgsIterator++; + Argument *ArgvArgument = &*ArgsIterator; - } - return PreservedAnalyses::none(); + BasicBlock &EntryBlock = MainFunction->getEntryBlock(); + Instruction *FirstInstruction = &*EntryBlock.getFirstInsertionPt(); + IRBuilder<> Builder(FirstInstruction); + + //prep debug(48763) + Type *Int32Type = Type::getInt32Ty(Context); + ConstantInt *DebugValue = ConstantInt::get(Int32Type, 48763); + + FunctionType *DebugFunctionType = + FunctionType::get(Type::getVoidTy(Context), {Int32Type}, false); + FunctionCallee DebugFunction = M.getOrInsertFunction("debug", DebugFunctionType); + + //insert call debug(48763) + Builder.CreateCall(DebugFunction, {DebugValue}); + + // replace all uses of argc with 48763 + if (!ArgcArgument->use_empty()) { + ArgcArgument->replaceAllUsesWith(DebugValue); + } + + Value *hayakuString = + Builder.CreateGlobalStringPtr("hayaku... motohayaku!"); + + + Value *Index1 = ConstantInt::get(Int32Type, 1); + Value *Argv1Pointer = Builder.CreateGEP( + ArgvArgument->getType()->getPointerElementType(), + ArgvArgument, + Index1 + ); + + //stores the string in argv[1] + Builder.CreateStore(hayakuString, Argv1Pointer); + + return PreservedAnalyses::none(); + + } + }; } -extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() { return {LLVM_PLUGIN_API_VERSION, "LLVMPass", "1.0", [](PassBuilder &PB) { - PB.registerOptimizerLastEPCallback( - [](ModulePassManager &MPM, OptimizationLevel OL) { - MPM.addPass(LLVMPass()); - }); - }}; + PB.registerPipelineParsingCallback( + [](StringRef Name, ModulePassManager &MPM, + ArrayRef) { + if (Name == "llvm-pass") { + MPM.addPass(LLVMPass()); + return true; + } + return false; + }); + } + }; }