-
Notifications
You must be signed in to change notification settings - Fork 4
/
InjectAMDGCNFunction.cpp
58 lines (49 loc) · 1.96 KB
/
InjectAMDGCNFunction.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include "InjectAMDGCNFunction.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
using namespace llvm;
bool InjectAMDGCNFunc::runOnModule(Module &M) {
bool ModifiedCodeGen = false;
auto &CTX = M.getContext();
for (auto &F : M) {
if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL) {
FunctionType *FT = FunctionType::get(Type::getVoidTy(CTX),
{Type::getInt32Ty(CTX)}, false);
FunctionCallee InjectedFunctionCallee =
M.getOrInsertFunction("_Z11PrintKerneli", FT);
Function *InjectedFunction =
cast<Function>(InjectedFunctionCallee.getCallee());
errs() << "Function To Be Injected: " << InjectedFunction->getName()
<< "\n";
// Get an IR builder. Sets the insertion point to the top of the function
IRBuilder<> Builder(&*F.getEntryBlock().getFirstInsertionPt());
Function *WorkItemXIDIntrinsicFunc = Intrinsic::getDeclaration(
F.getParent(), Intrinsic::amdgcn_workitem_id_x);
Value *WorkItemXValue = Builder.CreateCall(WorkItemXIDIntrinsicFunc, {});
Builder.CreateCall(InjectedFunction, {WorkItemXValue});
errs() << "Injecting Device Function Into AMDGPU Kernel: " << F.getName()
<< "\n";
ModifiedCodeGen = true;
}
}
return ModifiedCodeGen;
}
PassPluginLibraryInfo getPassPluginInfo() {
const auto callback = [](PassBuilder &PB) {
PB.registerPipelineEarlySimplificationEPCallback(
[&](ModulePassManager &MPM, auto) {
MPM.addPass(InjectAMDGCNFunc());
return true;
});
};
return {LLVM_PLUGIN_API_VERSION, "inject-amdgcn-func", LLVM_VERSION_STRING,
callback};
};
extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() {
return getPassPluginInfo();
}