Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atomic fadd for reverse mode #849

Merged
merged 4 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ if (NOT DEFINED LLVM_EXTERNAL_LIT)
message("found llvm match ${CMAKE_MATCH_1} dir ${LLVM_DIR}")
if (EXISTS ${LLVM_DIR}/../../../bin/llvm-lit)
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/../../../bin/llvm-lit)
else()
set(LLVM_EXTERNAL_LIT lit)
endif()
else()
if (EXISTS ${LLVM_DIR}/bin/llvm-lit)
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/bin/llvm-lit)
else()
set(LLVM_EXTERNAL_LIT lit)
endif()
endif()
endif()
Expand Down
106 changes: 80 additions & 26 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,25 @@ class AdjointGenerator
}

void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
if (Mode == DerivativeMode::ForwardMode) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
switch (I.getOperation()) {
case AtomicRMWInst::FAdd:
case AtomicRMWInst::FSub: {

if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) {
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ForwardModeSplit) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else {
eraseIfUnused(I);
}
return;
}

switch (I.getOperation()) {
case AtomicRMWInst::FAdd:
case AtomicRMWInst::FSub: {

if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeSplit) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
auto rule = [&](Value *ptr, Value *dif) -> Value * {
if (!gutils->isConstantInstruction(&I)) {
assert(ptr);
Expand Down Expand Up @@ -981,32 +994,73 @@ class AdjointGenerator
setDiffe(&I, diff, BuilderZ);
return;
}
default:
break;
if (Mode == DerivativeMode::ReverseModePrimal) {
eraseIfUnused(I);
return;
}
}
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
if (looseTypeAnalysis) {
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto valType = I.getValOperand()->getType();
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
/*errifnotfound*/ false,
/*pointerIntSame*/ true);
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
goto noerror;
if ((Mode == DerivativeMode::ReverseModeCombined ||
Mode == DerivativeMode::ReverseModeGradient) &&
gutils->isConstantValue(&I)) {
if (!gutils->isConstantValue(I.getValOperand())) {
assert(!gutils->isConstantValue(I.getPointerOperand()));
IRBuilder<> Builder2(&I);
getReverseBuilder(Builder2);
Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2);
auto order = I.getOrdering();
if (order == AtomicOrdering::Release)
order = AtomicOrdering::Monotonic;
else if (order == AtomicOrdering::AcquireRelease)
order = AtomicOrdering::Acquire;

auto rule = [&](Value *ip) -> Value * {
#if LLVM_VERSION_MAJOR > 7
LoadInst *dif1 =
Builder2.CreateLoad(I.getType(), ip, I.isVolatile());
#else
LoadInst *dif1 = Builder2.CreateLoad(dif1Ptr, I.isVolatile());
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
#endif
dif1->setAlignment(I.getAlign());
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
dif1->setOrdering(order);
dif1->setSyncScopeID(I.getSyncScopeID());
return dif1;
};
Value *diff = applyChainRule(I.getType(), Builder2, rule, ip);

addToDiffe(I.getValOperand(), diff, Builder2,
I.getValOperand()->getType()->getScalarType());
}
if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else
eraseIfUnused(I);
return;
}
TR.dump();
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
llvm::errs() << "I: " << I << "\n";
assert(0 && "Active atomic inst not handled");
break;
}
default:
break;
}
noerror:;

if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
if (looseTypeAnalysis) {
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto valType = I.getValOperand()->getType();
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
/*errifnotfound*/ false,
/*pointerIntSame*/ true);
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else
eraseIfUnused(I);
return;
}
}
TR.dump();
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
llvm::errs() << "I: " << I << "\n";
llvm_unreachable("Active atomic inst not yet handled");
}

void visitStoreInst(llvm::StoreInst &SI) {
Expand Down
107 changes: 82 additions & 25 deletions enzyme/test/Enzyme/ReverseMode/atomicadd.ll
Original file line number Diff line number Diff line change
@@ -1,33 +1,90 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s

; Function Attrs: norecurse nounwind readonly uwtable
define dso_local double @sum(i64* nocapture %n, double %x) #0 {
entry:
%res = atomicrmw add i64* %n, i64 1 monotonic
%fp = uitofp i64 %res to double
%mul = fmul double %fp, %x
ret double %mul
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define dso_local void @foo1(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
ret void
}
define dso_local void @foo2(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v acquire
ret void
}
define dso_local void @foo3(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v release
ret void
}
define dso_local void @foo4(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
ret void
}
define dso_local void @foo5(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
ret void
}
define dso_local void @foo6(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
ret void
}

; Function Attrs: nounwind uwtable
define dso_local void @dsum(i64* %x, i64* %xp, double %n) local_unnamed_addr #1 {
entry:
%0 = tail call double (double (i64*, double)*, ...) @__enzyme_autodiff(double (i64*, double)* nonnull @sum, i64* %x, double %n)
define void @caller(double* %a, double* %b, double %v) {
%r1 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), double* %a, double* %b, double %v)
%r2 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo2 to i8*), double* %a, double* %b, double %v)
%r3 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo3 to i8*), double* %a, double* %b, double %v)
%r4 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo4 to i8*), double* %a, double* %b, double %v)
%r5 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo5 to i8*), double* %a, double* %b, double %v)
%r6 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), double* %a, double* %b, double %v)
ret void
}

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (i64*, double)*, ...) #2
declare double @_Z17__enzyme_autodiffPviRdS0_(i8*, double*, double*, double)


; CHECK: define internal { double } @diffefoo1(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo2(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acquire
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

attributes #0 = { norecurse nounwind readonly uwtable }
attributes #1 = { nounwind uwtable }
attributes #2 = { nounwind }
; CHECK: define internal { double } @diffefoo3(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v release
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo4(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo5(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" seq_cst, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffesum(i64* nocapture %n, double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %res = atomicrmw add i64* %n, i64 1 monotonic
; CHECK-NEXT: %fp = uitofp i64 %res to double
; CHECK-NEXT: %m1diffex = fmul fast double %differeturn, %fp
; CHECK-NEXT: %0 = insertvalue { double } undef, double %m1diffex, 0
; CHECK-NEXT: ret { double } %0
; CHECK: define internal { double } @diffefoo6(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
; CHECK-NEXT: ret { double } zeroinitializer
; CHECK-NEXT: }
42 changes: 42 additions & 0 deletions enzyme/test/Enzyme/ReverseModeVector/atomicadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define dso_local void @foo1(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
ret void
}
define dso_local void @foo6(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
ret void
}

define void @caller(double* %a, double* %b, double %v) {
%r1 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
%r6 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
ret void
}

declare [2 x double] @_Z17__enzyme_autodiffPviRdS0_(...)

; CHECK: define internal { [2 x double] } @diffe2foo1(double* %p, [2 x double*] %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
; CHECK-NEXT: %0 = extractvalue [2 x double*] %"p'", 0
; CHECK-NEXT: %1 = load atomic volatile double, double* %0 monotonic, align 8
; CHECK-NEXT: %2 = extractvalue [2 x double*] %"p'", 1
; CHECK-NEXT: %3 = load atomic volatile double, double* %2 monotonic, align 8
; CHECK-NEXT: %.fca.0.insert5 = insertvalue [2 x double] {{(undef|poison)}}, double %1, 0
; CHECK-NEXT: %.fca.1.insert8 = insertvalue [2 x double] %.fca.0.insert5, double %3, 1
; CHECK-NEXT: %4 = insertvalue { [2 x double] } undef, [2 x double] %.fca.1.insert8, 0
; CHECK-NEXT: ret { [2 x double] } %4
; CHECK-NEXT: }

; CHECK: define internal { [2 x double] } @diffe2foo6(double* %p, [2 x double*] %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
; CHECK-NEXT: ret { [2 x double] } zeroinitializer
; CHECK-NEXT: }