Skip to content

Commit

Permalink
Fix insert choice (rust-lang#988)
Browse files Browse the repository at this point in the history
* fix insert choice

* fix tests
  • Loading branch information
tgymnich authored Feb 12, 2023
1 parent 23d6ecc commit 248dbbb
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 88 deletions.
35 changes: 17 additions & 18 deletions enzyme/Enzyme/TraceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,36 +489,35 @@ class TraceUtils {

CallInst *InsertChoice(IRBuilder<> &Builder, Value *address, Value *score,
Value *choice) {
auto size = choice->getType()->getPrimitiveSizeInBits() / 8;
Type *size_type = interface->getChoiceTy()->getParamType(3);

auto M = interface->getSampleFunction()->getParent();
auto &DL = M->getDataLayout();
auto pointersize = DL.getPointerSizeInBits();
auto choicesize = choice->getType()->getPrimitiveSizeInBits();

Value *retval;
if (choice->getType()->isPointerTy()) {
retval = Builder.CreatePointerCast(choice, Builder.getInt8PtrTy());
} else {
IRBuilder<> AllocaBuilder(
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
auto alloca = AllocaBuilder.CreateAlloca(choice->getType(), nullptr,
choice->getName() + ".ptr");
Builder.CreateStore(choice, alloca);
bool fitsInPointer =
choice->getType()->getPrimitiveSizeInBits() == pointersize;
if (fitsInPointer) {
auto dblptr =
PointerType::get(Builder.getInt8PtrTy(), DL.getAllocaAddrSpace());
retval = Builder.CreateLoad(Builder.getInt8PtrTy(),
Builder.CreatePointerCast(alloca, dblptr));
auto M = interface->getSampleFunction()->getParent();
auto &DL = M->getDataLayout();
auto pointersize = DL.getPointerSizeInBits();
if (choicesize <= pointersize) {
auto cast = Builder.CreateBitCast(
choice, IntegerType::get(M->getContext(), choicesize));
cast = choicesize == pointersize
? cast
: Builder.CreateZExt(cast, Builder.getIntPtrTy(DL));
retval = Builder.CreateIntToPtr(cast, Builder.getInt8PtrTy());
} else {
IRBuilder<> AllocaBuilder(
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
auto alloca = AllocaBuilder.CreateAlloca(choice->getType(), nullptr,
choice->getName() + ".ptr");
Builder.CreateStore(choice, alloca);
retval = alloca;
}
}

Value *args[] = {trace, address, score, retval,
ConstantInt::get(size_type, size)};
ConstantInt::get(size_type, choicesize / 8)};

auto call = Builder.CreateCall(interface->insertChoiceTy(),
interface->insertChoice(), args);
Expand Down
24 changes: 9 additions & 15 deletions enzyme/test/Enzyme/ProbProg/condition-dynamic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@ entry:
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
; CHECK-NEXT: %call1.ptr3 = alloca double
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %7 = load i8*, i8** %6
; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %call.ptr2 = alloca double
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %9 = load i8*, i8** %8
; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
Expand Down Expand Up @@ -119,9 +117,8 @@ entry:
; CHECK: entry.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call)
; CHECK-NEXT: store double %call, double* %call.ptr2
; CHECK-NEXT: %15 = bitcast double* %call.ptr2 to i8**
; CHECK-NEXT: %16 = load i8*, i8** %15
; CHECK-NEXT: %15 = bitcast double %call to i64
; CHECK-NEXT: %16 = inttoptr i64 %15 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %16, i64 8)
; CHECK-NEXT: %has.choice.call1 = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call1, label %condition.call1.with.trace, label %condition.call1.without.trace
Expand All @@ -139,9 +136,8 @@ entry:
; CHECK: entry.cntd.cntd: ; preds = %condition.call1.without.trace, %condition.call1.with.trace
; CHECK-NEXT: %call1 = phi double [ %from.trace.call1, %condition.call1.with.trace ], [ %sample.call1, %condition.call1.without.trace ]
; CHECK-NEXT: %likelihood.call1 = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call1)
; CHECK-NEXT: store double %call1, double* %call1.ptr3
; CHECK-NEXT: %18 = bitcast double* %call1.ptr3 to i8**
; CHECK-NEXT: %19 = load i8*, i8** %18
; CHECK-NEXT: %18 = bitcast double %call1 to i64
; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
Expand All @@ -156,9 +152,9 @@ entry:
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
; CHECK-NEXT: %call24 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call24, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call24, 1
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
Expand All @@ -171,7 +167,6 @@ entry:
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %1 = load i8*, i8** %0
; CHECK-NEXT: %insert_choice = bitcast i8* %1 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %call.ptr2 = alloca double
; CHECK-NEXT: %2 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %3 = load i8*, i8** %2
; CHECK-NEXT: %get_choice = bitcast i8* %3 to i64 (i8*, i8*, i8*, i64)*
Expand Down Expand Up @@ -219,9 +214,8 @@ entry:
; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call)
; CHECK-NEXT: store double %call, double* %call.ptr2
; CHECK-NEXT: %11 = bitcast double* %call.ptr2 to i8**
; CHECK-NEXT: %12 = load i8*, i8** %11
; CHECK-NEXT: %11 = bitcast double %call to i64
; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8)
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
; CHECK-NEXT: %13 = load double, double* %arrayidx3
Expand Down
24 changes: 9 additions & 15 deletions enzyme/test/Enzyme/ProbProg/condition-static.ll
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ entry:

; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8* %observations)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call1.ptr3 = alloca double
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %call.ptr2 = alloca double
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
; CHECK-NEXT: %has.choice.call = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
Expand All @@ -101,9 +99,8 @@ entry:
; CHECK: entry.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call)
; CHECK-NEXT: store double %call, double* %call.ptr2
; CHECK-NEXT: %1 = bitcast double* %call.ptr2 to i8**
; CHECK-NEXT: %2 = load i8*, i8** %1
; CHECK-NEXT: %1 = bitcast double %call to i64
; CHECK-NEXT: %2 = inttoptr i64 %1 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %2, i64 8)
; CHECK-NEXT: %has.choice.call1 = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call1, label %condition.call1.with.trace, label %condition.call1.without.trace
Expand All @@ -121,9 +118,8 @@ entry:
; CHECK: entry.cntd.cntd: ; preds = %condition.call1.without.trace, %condition.call1.with.trace
; CHECK-NEXT: %call1 = phi double [ %from.trace.call1, %condition.call1.with.trace ], [ %sample.call1, %condition.call1.without.trace ]
; CHECK-NEXT: %likelihood.call1 = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call1)
; CHECK-NEXT: store double %call1, double* %call1.ptr3
; CHECK-NEXT: %4 = bitcast double* %call1.ptr3 to i8**
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %4 = bitcast double %call1 to i64
; CHECK-NEXT: %5 = inttoptr i64 %4 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %5, i64 8)
; CHECK-NEXT: %has.call.call2 = call i1 @__enzyme_has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
Expand All @@ -138,9 +134,9 @@ entry:
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
; CHECK-NEXT: %call24 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call24, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call24, 1
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
Expand All @@ -150,7 +146,6 @@ entry:

; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8* %observations)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call.ptr2 = alloca double
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
Expand Down Expand Up @@ -189,9 +184,8 @@ entry:
; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %1, double 1.000000e+00, double %call)
; CHECK-NEXT: store double %call, double* %call.ptr2
; CHECK-NEXT: %3 = bitcast double* %call.ptr2 to i8**
; CHECK-NEXT: %4 = load i8*, i8** %3
; CHECK-NEXT: %3 = bitcast double %call to i64
; CHECK-NEXT: %4 = inttoptr i64 %3 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %4, i64 8)
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
; CHECK-NEXT: %5 = load double, double* %arrayidx3
Expand Down
12 changes: 4 additions & 8 deletions enzyme/test/Enzyme/ProbProg/simple-condition.ll
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ entry:

; CHECK: define internal i8* @condition_test(i8* %observations)
; CHECK-NEXT: entry:
; CHECK-NEXT: %x.ptr2 = alloca double
; CHECK-NEXT: %x.ptr = alloca double
; CHECK-NEXT: %mu.ptr1 = alloca double
; CHECK-NEXT: %mu.ptr = alloca double
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
; CHECK-NEXT: %has.choice.mu = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0))
Expand All @@ -72,9 +70,8 @@ entry:
; CHECK: entry.cntd: ; preds = %condition.mu.without.trace, %condition.mu.with.trace
; CHECK-NEXT: %mu = phi double [ %from.trace.mu, %condition.mu.with.trace ], [ %sample.mu, %condition.mu.without.trace ]
; CHECK-NEXT: %likelihood.mu = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %mu)
; CHECK-NEXT: store double %mu, double* %mu.ptr1
; CHECK-NEXT: %1 = bitcast double* %mu.ptr1 to i8**
; CHECK-NEXT: %2 = load i8*, i8** %1
; CHECK-NEXT: %1 = bitcast double %mu to i64
; CHECK-NEXT: %2 = inttoptr i64 %1 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), double %likelihood.mu, i8* %2, i64 8)
; CHECK-NEXT: %has.choice.x = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.x, label %condition.x.with.trace, label %condition.x.without.trace
Expand All @@ -92,9 +89,8 @@ entry:
; CHECK: entry.cntd.cntd: ; preds = %condition.x.without.trace, %condition.x.with.trace
; CHECK-NEXT: %x = phi double [ %from.trace.x, %condition.x.with.trace ], [ %sample.x, %condition.x.without.trace ]
; CHECK-NEXT: %likelihood.x = call double @normal_logpdf(double %mu, double 1.000000e+00, double %x)
; CHECK-NEXT: store double %x, double* %x.ptr2
; CHECK-NEXT: %4 = bitcast double* %x.ptr2 to i8**
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %4 = bitcast double %x to i64
; CHECK-NEXT: %5 = inttoptr i64 %4 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %5, i64 8)
; CHECK-NEXT: ret i8* %trace
; CHECK-NEXT: }
12 changes: 4 additions & 8 deletions enzyme/test/Enzyme/ProbProg/simple-trace.ll
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,16 @@ entry:

; CHECK: define internal i8* @trace_test()
; CHECK-NEXT: entry:
; CHECK-NEXT: %x.ptr = alloca double
; CHECK-NEXT: %mu.ptr = alloca double
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
; CHECK-NEXT: %mu = call double @normal(double 0.000000e+00, double 1.000000e+00)
; CHECK-NEXT: %likelihood.mu = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %mu)
; CHECK-NEXT: store double %mu, double* %mu.ptr
; CHECK-NEXT: %0 = bitcast double* %mu.ptr to i8**
; CHECK-NEXT: %1 = load i8*, i8** %0
; CHECK-NEXT: %0 = bitcast double %mu to i64
; CHECK-NEXT: %1 = inttoptr i64 %0 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), double %likelihood.mu, i8* %1, i64 8)
; CHECK-NEXT: %x = call double @normal(double %mu, double 1.000000e+00)
; CHECK-NEXT: %likelihood.x = call double @normal_logpdf(double %mu, double 1.000000e+00, double %x)
; CHECK-NEXT: store double %x, double* %x.ptr
; CHECK-NEXT: %2 = bitcast double* %x.ptr to i8**
; CHECK-NEXT: %3 = load i8*, i8** %2
; CHECK-NEXT: %2 = bitcast double %x to i64
; CHECK-NEXT: %3 = inttoptr i64 %2 to i8*
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %3, i64 8)
; CHECK-NEXT: ret i8* %trace
; CHECK-NEXT: }
Loading

0 comments on commit 248dbbb

Please sign in to comment.