Skip to content

Commit

Permalink
codegen: explicitly handle Float16 intrinsics (#45249)
Browse files Browse the repository at this point in the history
Fixes #44829, until llvm fixes the support for these intrinsics itself

Also need to handle vectors, since the vectorizer may have introduced them.

Also change our runtime emulation versions to f32 for consistency.
  • Loading branch information
vtjnash authored May 18, 2022
1 parent 2d40898 commit f2c627e
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 92 deletions.
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 12 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1523,8 +1523,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
296 changes: 242 additions & 54 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,194 @@ INST_STATISTIC(FCmp);

namespace {

inline AttributeSet getFnAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getFnAttrs();
#else
return Attrs.getFnAttributes();
#endif
}

inline AttributeSet getRetAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getRetAttrs();
#else
return Attrs.getRetAttributes();
#endif
}

static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
{
Intrinsic::ID ID = call->getIntrinsicID();
assert(ID);
auto oldfType = call->getFunctionType();
auto nargs = oldfType->getNumParams();
assert(args.size() > nargs);
SmallVector<Type*, 8> argTys(nargs);
for (unsigned i = 0; i < nargs; i++)
argTys[i] = args[i]->getType();
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());

// Accumulate an array of overloaded types for the given intrinsic
// and compute the new name mangling schema
SmallVector<Type*, 4> overloadTys;
{
SmallVector<Intrinsic::IITDescriptor, 8> Table;
getIntrinsicInfoTableEntries(ID, Table);
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
(void)res;
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
assert(matchvararg);
(void)matchvararg;
}
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
assert(newF->getFunctionType() == newfType);
newF->setCallingConv(call->getCallingConv());
assert(args.back() == call->getCalledFunction());
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
newCall->setTailCallKind(call->getTailCallKind());
auto old_attrs = call->getAttributes();
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
getRetAttrs(old_attrs), {})); // drop parameter attributes
return newCall;
}


static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
{
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
if (auto *VC = dyn_cast<Constant>(V)) {
// The input IR often has things of the form
// fcmp olt half %0, 0xH7C00
// and we would like to avoid turning that constant into a call here
// if we can simply constant fold it to the new type.
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
if (VC)
return VC;
}
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
if (SrcTy->isVectorTy()) {
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
Value *NewV = UndefValue::get(DestTy);
RetTy = RetTy->getScalarType();
for (unsigned i = 0; i < NumElems; ++i) {
Value *I = builder.getInt32(i);
Value *Vi = builder.CreateExtractElement(V, I);
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
NewV = builder.CreateInsertElement(NewV, Vi, I);
}
return NewV;
}
auto &M = *builder.GetInsertBlock()->getModule();
auto &ctx = M.getContext();
// Pick the Function to call in the Julia runtime
StringRef Name;
switch (opcode) {
case Instruction::FPExt:
// this is exact, so we only need one conversion
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::FPTrunc:
assert(DestTy->isHalfTy());
if (SrcTy->isFloatTy())
Name = "julia__gnu_f2h_ieee";
else if (SrcTy->isDoubleTy())
Name = "julia__truncdfhf2";
break;
// All F16 fit exactly in Int32 (-65504 to 65504)
case Instruction::FPToSI: JL_FALLTHROUGH;
case Instruction::FPToUI:
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::SIToFP: JL_FALLTHROUGH;
case Instruction::UIToFP:
assert(DestTy->isHalfTy());
Name = "julia__gnu_f2h_ieee";
SrcTy = Type::getFloatTy(ctx);
break;
default:
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("invalid cast");
}
if (Name.empty()) {
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("illegal cast");
}
// Coerce the source to the required size and type
auto T_int16 = Type::getInt16Ty(ctx);
if (SrcTy->isHalfTy())
SrcTy = T_int16;
if (opcode == Instruction::SIToFP)
V = builder.CreateSIToFP(V, SrcTy);
else if (opcode == Instruction::UIToFP)
V = builder.CreateUIToFP(V, SrcTy);
else
V = builder.CreateBitCast(V, SrcTy);
// Call our intrinsic
if (RetTy->isHalfTy())
RetTy = T_int16;
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
FunctionCallee F = M.getOrInsertFunction(Name, FT);
Value *I = builder.CreateCall(F, {V});
// Coerce the result to the expected type
if (opcode == Instruction::FPToSI)
I = builder.CreateFPToSI(I, DestTy);
else if (opcode == Instruction::FPToUI)
I = builder.CreateFPToUI(I, DestTy);
else if (opcode == Instruction::FPExt)
I = builder.CreateFPCast(I, DestTy);
else
I = builder.CreateBitCast(I, DestTy);
return I;
}

static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

if (auto CI = dyn_cast<CastInst>(&I)) {
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
++TotalChanged;
IRBuilder<> builder(&I);
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
continue;
}

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -64,6 +243,9 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
if (intrinsic->getIntrinsicID())
break;
continue;
}

Expand All @@ -75,72 +257,78 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
if (Op->getType()->getScalarType()->isHalfTy()) {
++TotalExt;
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
}
Operands[i] = (Op);
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
// XXX: this is not correct in general
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
Type *RetTy = I.getType();
if (RetTy->getScalarType()->isHalfTy())
RetTy = RetTy->getWithNewType(T_float32);
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
Loading

0 comments on commit f2c627e

Please sign in to comment.