Skip to content

Commit

Permalink
pass trace as argument (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Feb 14, 2023
1 parent 1843339 commit 5894561
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 171 deletions.
9 changes: 5 additions & 4 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1894,11 +1894,12 @@ class EnzymeBase {
Logic.CreateTrace(F, generativeFunctions, mode, has_dynamic_interface);

Value *trace =
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
if (!F->getReturnType()->isVoidTy())
trace = Builder.CreateExtractValue(trace, {1});
Builder.CreateCall(interface->newTraceTy(), interface->newTrace(), {});

args.push_back(trace);

Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);

// try to cast i8* returned from trace to CI->getRetType....
if (CI->getType() != trace->getType())
trace = Builder.CreatePointerCast(trace, CI->getType());

Expand Down
31 changes: 18 additions & 13 deletions enzyme/Enzyme/TraceGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
{
ElseTerm->getParent()->setName("condition." + call.getName() +
".without.trace");
ElseChoice =

auto choice =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
sample_args, "sample." + call.getName());
ElseChoice = choice;
}

Builder.SetInsertPoint(new_call);
Expand Down Expand Up @@ -132,11 +134,16 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
Logic.CreateTrace(called, tutils->generativeFunctions, tutils->mode,
tutils->hasDynamicTraceInterface());

auto trace = tutils->CreateTrace(Builder);

Instruction *tracecall;
switch (mode) {
case ProbProgMode::Trace: {
tracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn,
args, "trace." + called->getName());
SmallVector<Value *, 2> args_and_trace = SmallVector(args);
args_and_trace.push_back(trace);
tracecall =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
args_and_trace, "trace." + called->getName());
break;
}
case ProbProgMode::Condition: {
Expand All @@ -158,8 +165,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
ThenTerm->getParent()->setName("condition." + call.getName() +
".with.trace");
SmallVector<Value *, 2> args_and_cond = SmallVector(args);
auto trace = tutils->GetTrace(Builder, address,
called->getName() + ".subtrace");
auto observations = tutils->GetTrace(Builder, address,
called->getName() + ".subtrace");
args_and_cond.push_back(observations);
args_and_cond.push_back(trace);
ThenTracecall = Builder.CreateCall(samplefn->getFunctionType(),
samplefn, args_and_cond,
Expand All @@ -171,8 +179,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
ElseTerm->getParent()->setName("condition." + call.getName() +
".without.trace");
SmallVector<Value *, 2> args_and_null = SmallVector(args);
auto trace = ConstantPointerNull::get(cast<PointerType>(
auto observations = ConstantPointerNull::get(cast<PointerType>(
tutils->getTraceInterface()->newTraceTy()->getReturnType()));
args_and_null.push_back(observations);
args_and_null.push_back(trace);
ElseTracecall =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
Expand All @@ -188,14 +197,10 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
}
}

Value *ret = Builder.CreateExtractValue(tracecall, {0});
Value *subtrace = Builder.CreateExtractValue(
tracecall, {1}, "newtrace." + called->getName());

tutils->InsertCall(Builder, address, subtrace);
tutils->InsertCall(Builder, address, trace);

ret->takeName(new_call);
new_call->replaceAllUsesWith(ret);
tracecall->takeName(new_call);
new_call->replaceAllUsesWith(tracecall);
new_call->eraseFromParent();
}
}
Expand Down
49 changes: 11 additions & 38 deletions enzyme/Enzyme/TraceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TraceUtils {
private:
TraceInterface *interface;
Value *dynamic_interface = nullptr;
Instruction *trace;
Value *trace;
Value *observations = nullptr;

public:
Expand Down Expand Up @@ -70,10 +70,9 @@ class TraceUtils {
if (mode == ProbProgMode::Condition)
params.push_back(traceType);

Type *RetTy = traceType;
if (!oldFunc->getReturnType()->isVoidTy())
RetTy = StructType::get(Context, {oldFunc->getReturnType(), traceType});
params.push_back(traceType);

Type *RetTy = oldFunc->getReturnType();
FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg());

Twine Name = (mode == ProbProgMode::Condition ? "condition_" : "trace_") +
Expand All @@ -94,21 +93,27 @@ class TraceUtils {
}

if (has_dynamic_interface) {
auto arg = newFunc->arg_end() - (1 + (mode == ProbProgMode::Condition));
auto arg = newFunc->arg_end() - (2 + (mode == ProbProgMode::Condition));
dynamic_interface = arg;
arg->setName("interface");
arg->addAttr(Attribute::ReadOnly);
arg->addAttr(Attribute::NoCapture);
}

if (mode == ProbProgMode::Condition) {
auto arg = newFunc->arg_end() - 1;
auto arg = newFunc->arg_end() - 2;
observations = arg;
arg->setName("observations");
if (oldFunc->getReturnType()->isVoidTy())
arg->addAttr(Attribute::Returned);
}

auto arg = newFunc->arg_end() - 1;
trace = arg;
arg->setName("trace");
if (oldFunc->getReturnType()->isVoidTy())
arg->addAttr(Attribute::Returned);

SmallVector<ReturnInst *, 4> Returns;
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
Expand All @@ -126,38 +131,6 @@ class TraceUtils {
} else {
interface = new StaticTraceInterface(F->getParent());
}

// Create trace for current function

IRBuilder<> Builder(
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
Builder.SetCurrentDebugLocation(oldFunc->getEntryBlock()
.getFirstNonPHIOrDbgOrLifetime()
->getDebugLoc());

trace = CreateTrace(Builder);

// Replace returns with ret trace

SmallVector<ReturnInst *, 3> toReplace;
for (auto &&BB : *newFunc) {
for (auto &&Inst : BB) {
if (auto Ret = dyn_cast<ReturnInst>(&Inst)) {
toReplace.push_back(Ret);
}
}
}

for (auto Ret : toReplace) {
IRBuilder<> Builder(Ret);
if (Ret->getReturnValue()) {
Value *retvals[2] = {Ret->getReturnValue(), trace};
Builder.CreateAggregateRet(retvals, 2);
} else {
Builder.CreateRet(trace);
}
Ret->eraseFromParent();
}
};

~TraceUtils() { delete interface; }
Expand Down
91 changes: 42 additions & 49 deletions enzyme/test/Enzyme/ProbProg/condition-dynamic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ entry:

; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace, i8** %interface)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition
; CHECK-NEXT: %1 = load i32, i32* @enzyme_interface
; CHECK-NEXT: %2 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace)
; CHECK-NEXT: %3 = extractvalue { double, i8* } %2, 1
; CHECK-NEXT: ret i8* %3
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %1 = load i8*, i8** %0
; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()*
; CHECK-NEXT: %2 = load i32, i32* @enzyme_condition
; CHECK-NEXT: %3 = load i32, i32* @enzyme_interface
; CHECK-NEXT: %4 = call i8* %new_trace()
; CHECK-NEXT: %5 = call double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace, i8* %4)
; CHECK-NEXT: ret i8* %4
; CHECK-NEXT: }


; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations)
; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 2
; CHECK-NEXT: %1 = load i8*, i8** %0
Expand All @@ -86,21 +89,20 @@ entry:
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %7 = load i8*, i8** %6
; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %9 = load i8*, i8** %8
; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %insert_choice = bitcast i8* %9 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %11 = load i8*, i8** %10
; CHECK-NEXT: %has_choice = bitcast i8* %11 to i1 (i8*, i8*)*
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %get_choice = bitcast i8* %11 to i64 (i8*, i8*, i8*, i64)*
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %13 = load i8*, i8** %12
; CHECK-NEXT: %new_trace = bitcast i8* %13 to i8* ()*
; CHECK-NEXT: %trace = call i8* %new_trace()
; CHECK-NEXT: %has_choice = bitcast i8* %13 to i1 (i8*, i8*)*
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace

Expand Down Expand Up @@ -139,30 +141,27 @@ entry:
; CHECK-NEXT: %18 = bitcast double %call1 to i64
; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
; CHECK-NEXT: %trace1 = call i8* %new_trace()
; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace

; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd
; CHECK-NEXT: %calculate_loss.subtrace = call i8* %get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace)
; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace, i8* %trace1)
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd
; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null)
; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null, i8* %trace1)
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
; CHECK-NEXT: ret { double, i8* } %mrv1
; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1)
; CHECK-NEXT: ret double %call2
; CHECK-NEXT: }


; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations)
; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %1 = load i8*, i8** %0
Expand All @@ -174,10 +173,6 @@ entry:
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %has_choice = bitcast i8* %5 to i1 (i8*, i8*)*
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %7 = load i8*, i8** %6
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
; CHECK-NEXT: %trace = call i8* %new_trace()
; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup

Expand All @@ -186,42 +181,40 @@ entry:
; CHECK-NEXT: br label %for.body

; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %14, %for.body.cntd ]
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
; CHECK-NEXT: ret { double, i8* } %mrv1
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %12, %for.body.cntd ]
; CHECK-NEXT: ret double %loss.0.lcssa

; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader
; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ]
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %14, %for.body.cntd ]
; CHECK-NEXT: %8 = trunc i64 %indvars.iv to i32
; CHECK-NEXT: %conv2 = sitofp i32 %8 to double
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %12, %for.body.cntd ]
; CHECK-NEXT: %6 = trunc i64 %indvars.iv to i32
; CHECK-NEXT: %conv2 = sitofp i32 %6 to double
; CHECK-NEXT: %mul1 = fmul double %conv2, %m
; CHECK-NEXT: %9 = fadd double %mul1, %b
; CHECK-NEXT: %7 = fadd double %mul1, %b
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace

; CHECK: condition.call.with.trace: ; preds = %for.body
; CHECK-NEXT: %10 = bitcast double* %call.ptr to i8*
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %10, i64 8)
; CHECK-NEXT: %8 = bitcast double* %call.ptr to i8*
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %8, i64 8)
; CHECK-NEXT: %from.trace.call = load double, double* %call.ptr
; CHECK-NEXT: br label %for.body.cntd

; CHECK: condition.call.without.trace: ; preds = %for.body
; CHECK-NEXT: %sample.call = call double @normal(double %9, double 1.000000e+00)
; CHECK-NEXT: %sample.call = call double @normal(double %7, double 1.000000e+00)
; CHECK-NEXT: br label %for.body.cntd

; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call)
; CHECK-NEXT: %11 = bitcast double %call to i64
; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8)
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %7, double 1.000000e+00, double %call)
; CHECK-NEXT: %9 = bitcast double %call to i64
; CHECK-NEXT: %10 = inttoptr i64 %9 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %10, i64 8)
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
; CHECK-NEXT: %13 = load double, double* %arrayidx3
; CHECK-NEXT: %sub = fsub double %call, %13
; CHECK-NEXT: %11 = load double, double* %arrayidx3
; CHECK-NEXT: %sub = fsub double %call, %11
; CHECK-NEXT: %mul2 = fmul double %sub, %sub
; CHECK-NEXT: %14 = fadd double %mul2, %loss.021
; CHECK-NEXT: %12 = fadd double %mul2, %loss.021
; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
; CHECK-NEXT: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
Expand Down
Loading

0 comments on commit 5894561

Please sign in to comment.