diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 80f0dc95481a..af3e6c7c9157 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -58,10 +58,14 @@ if (NOT DEFINED LLVM_EXTERNAL_LIT) message("found llvm match ${CMAKE_MATCH_1} dir ${LLVM_DIR}") if (EXISTS ${LLVM_DIR}/../../../bin/llvm-lit) set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/../../../bin/llvm-lit) + else() + set(LLVM_EXTERNAL_LIT lit) endif() else() if (EXISTS ${LLVM_DIR}/bin/llvm-lit) set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/bin/llvm-lit) + else() + set(LLVM_EXTERNAL_LIT lit) endif() endif() endif() diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 7660e012b84f..5cd84f0c250c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -935,12 +935,25 @@ class AdjointGenerator } void visitAtomicRMWInst(llvm::AtomicRMWInst &I) { - if (Mode == DerivativeMode::ForwardMode) { - IRBuilder<> BuilderZ(&I); - getForwardBuilder(BuilderZ); - switch (I.getOperation()) { - case AtomicRMWInst::FAdd: - case AtomicRMWInst::FSub: { + + if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit) { + eraseIfUnused(I, /*erase*/ true, /*check*/ false); + } else { + eraseIfUnused(I); + } + return; + } + + switch (I.getOperation()) { + case AtomicRMWInst::FAdd: + case AtomicRMWInst::FSub: { + + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeSplit) { + IRBuilder<> BuilderZ(&I); + getForwardBuilder(BuilderZ); auto rule = [&](Value *ptr, Value *dif) -> Value * { if (!gutils->isConstantInstruction(&I)) { assert(ptr); @@ -981,32 +994,84 @@ class AdjointGenerator setDiffe(&I, diff, BuilderZ); return; } - default: - break; + if (Mode == DerivativeMode::ReverseModePrimal) { + eraseIfUnused(I); + return; } - } - if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) { - if (looseTypeAnalysis) { - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto valType = I.getValOperand()->getType(); - auto storeSize = DL.getTypeSizeInBits(valType) / 8; - auto fp = TR.firstPointer(storeSize, I.getPointerOperand(), - /*errifnotfound*/ false, - /*pointerIntSame*/ true); - if (!fp.isKnown() && valType->isIntOrIntVectorTy()) { - goto noerror; + if ((Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModeGradient) && + gutils->isConstantValue(&I)) { + if (!gutils->isConstantValue(I.getValOperand())) { + assert(!gutils->isConstantValue(I.getPointerOperand())); + IRBuilder<> Builder2(&I); + getReverseBuilder(Builder2); + Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2); + auto order = I.getOrdering(); + if (order == AtomicOrdering::Release) + order = AtomicOrdering::Monotonic; + else if (order == AtomicOrdering::AcquireRelease) + order = AtomicOrdering::Acquire; + + auto rule = [&](Value *ip) -> Value * { +#if LLVM_VERSION_MAJOR > 7 + LoadInst *dif1 = + Builder2.CreateLoad(I.getType(), ip, I.isVolatile()); +#else + LoadInst *dif1 = Builder2.CreateLoad(ip, I.isVolatile()); +#endif + +#if LLVM_VERSION_MAJOR >= 11 + dif1->setAlignment(I.getAlign()); +#else + const DataLayout &DL = I.getModule()->getDataLayout(); + auto tmpAlign = DL.getTypeStoreSize(I.getValOperand()->getType()); +#if LLVM_VERSION_MAJOR >= 10 + dif1->setAlignment(MaybeAlign(tmpAlign.getFixedSize())); +#else + dif1->setAlignment(tmpAlign); +#endif +#endif + dif1->setOrdering(order); + dif1->setSyncScopeID(I.getSyncScopeID()); + return dif1; + }; + Value *diff = applyChainRule(I.getType(), Builder2, rule, ip); + + addToDiffe(I.getValOperand(), diff, Builder2, + I.getValOperand()->getType()->getScalarType()); } + if (Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(I, /*erase*/ true, /*check*/ false); + } else + eraseIfUnused(I); + return; } - TR.dump(); - llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n"; - llvm::errs() << "I: " << I << "\n"; - assert(0 && "Active atomic inst not handled"); + break; + } + default: + break; } - noerror:; - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(I, /*erase*/ true, /*check*/ false); + if (looseTypeAnalysis) { + auto &DL = gutils->newFunc->getParent()->getDataLayout(); + auto valType = I.getValOperand()->getType(); + auto storeSize = DL.getTypeSizeInBits(valType) / 8; + auto fp = TR.firstPointer(storeSize, I.getPointerOperand(), + /*errifnotfound*/ false, + /*pointerIntSame*/ true); + if (!fp.isKnown() && valType->isIntOrIntVectorTy()) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(I, /*erase*/ true, /*check*/ false); + } else + eraseIfUnused(I); + return; + } } + TR.dump(); + llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n"; + llvm::errs() << "I: " << I << "\n"; + llvm_unreachable("Active atomic inst not yet handled"); } void visitStoreInst(llvm::StoreInst &SI) { diff --git a/enzyme/test/Enzyme/ReverseMode/atomicfadd.ll b/enzyme/test/Enzyme/ReverseMode/atomicfadd.ll new file mode 100644 index 000000000000..d254608f6130 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/atomicfadd.ll @@ -0,0 +1,90 @@ +; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi + +; ModuleID = '' +source_filename = "" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define dso_local void @foo1(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v monotonic + ret void +} +define dso_local void @foo2(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v acquire + ret void +} +define dso_local void @foo3(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v release + ret void +} +define dso_local void @foo4(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v acq_rel + ret void +} +define dso_local void @foo5(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v seq_cst + ret void +} +define dso_local void @foo6(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst + ret void +} + +define void @caller(double* %a, double* %b, double %v) { + %r1 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), double* %a, double* %b, double %v) + %r2 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo2 to i8*), double* %a, double* %b, double %v) + %r3 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo3 to i8*), double* %a, double* %b, double %v) + %r4 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo4 to i8*), double* %a, double* %b, double %v) + %r5 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo5 to i8*), double* %a, double* %b, double %v) + %r6 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), double* %a, double* %b, double %v) + ret void +} + +declare double @_Z17__enzyme_autodiffPviRdS0_(i8*, double*, double*, double) + + +; CHECK: define internal { double } @diffefoo1(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic +; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8 +; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffefoo2(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acquire +; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8 +; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffefoo3(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v release +; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8 +; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffefoo4(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acq_rel +; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8 +; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffefoo5(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v seq_cst +; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" seq_cst, align 8 +; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffefoo6(double* %p, double* %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst +; CHECK-NEXT: ret { double } zeroinitializer +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseModeVector/atomicadd.ll b/enzyme/test/Enzyme/ReverseModeVector/atomicadd.ll new file mode 100644 index 000000000000..122dbe6770e5 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseModeVector/atomicadd.ll @@ -0,0 +1,42 @@ +; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi + +; ModuleID = '' +source_filename = "" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define dso_local void @foo1(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double %v monotonic + ret void +} +define dso_local void @foo6(double* %p, double %v) { + %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst + ret void +} + +define void @caller(double* %a, double* %b, double %v) { + %r1 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v) + %r6 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v) + ret void +} + +declare [2 x double] @_Z17__enzyme_autodiffPviRdS0_(...) + +; CHECK: define internal { [2 x double] } @diffe2foo1(double* %p, [2 x double*] %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic +; CHECK-NEXT: %0 = extractvalue [2 x double*] %"p'", 0 +; CHECK-NEXT: %1 = load atomic volatile double, double* %0 monotonic, align 8 +; CHECK-NEXT: %2 = extractvalue [2 x double*] %"p'", 1 +; CHECK-NEXT: %3 = load atomic volatile double, double* %2 monotonic, align 8 +; CHECK-NEXT: %.fca.0.insert5 = insertvalue [2 x double] {{(undef|poison)}}, double %1, 0 +; CHECK-NEXT: %.fca.1.insert8 = insertvalue [2 x double] %.fca.0.insert5, double %3, 1 +; CHECK-NEXT: %4 = insertvalue { [2 x double] } undef, [2 x double] %.fca.1.insert8, 0 +; CHECK-NEXT: ret { [2 x double] } %4 +; CHECK-NEXT: } + +; CHECK: define internal { [2 x double] } @diffe2foo6(double* %p, [2 x double*] %"p'", double %v) +; CHECK-NEXT: invert: +; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst +; CHECK-NEXT: ret { [2 x double] } zeroinitializer +; CHECK-NEXT: } \ No newline at end of file