Skip to content

Commit

Permalink
Move all platforms to use llvm.minimum/llvm.maximum for fmin/fmax (#5…
Browse files Browse the repository at this point in the history
…6371)

This used to not work but LLVM now has support for this on all platforms
we care about.

Maybe this should be a builtin.
This allows for more vectorization opportunities since llvm understands
the code better

Fix #48487.

---------

Co-authored-by: Mosè Giordano <[email protected]>
Co-authored-by: oscarddssmith <[email protected]>
  • Loading branch information
3 people authored Jan 14, 2025
1 parent 664528a commit a861a55
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 47 deletions.
4 changes: 4 additions & 0 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ add_tfunc(add_float, 2, 2, math_tfunc, 2)
add_tfunc(sub_float, 2, 2, math_tfunc, 2)
add_tfunc(mul_float, 2, 2, math_tfunc, 8)
add_tfunc(div_float, 2, 2, math_tfunc, 10)
add_tfunc(min_float, 2, 2, math_tfunc, 1)
add_tfunc(max_float, 2, 2, math_tfunc, 1)
add_tfunc(fma_float, 3, 3, math_tfunc, 8)
add_tfunc(muladd_float, 3, 3, math_tfunc, 8)

Expand All @@ -198,6 +200,8 @@ add_tfunc(add_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 8)
add_tfunc(div_float_fast, 2, 2, math_tfunc, 10)
add_tfunc(min_float_fast, 2, 2, math_tfunc, 1)
add_tfunc(max_float_fast, 2, 2, math_tfunc, 1)

# bitwise operators
# -----------------
Expand Down
10 changes: 4 additions & 6 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ module FastMath
export @fastmath

import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, min_float_fast, max_float_fast,
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
import Base: afoldl

Expand Down Expand Up @@ -168,6 +168,9 @@ add_fast(x::T, y::T) where {T<:FloatTypes} = add_float_fast(x, y)
sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y)
mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y)
div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y)
max_fast(x::T, y::T) where {T<:FloatTypes} = max_float_fast(x, y)
min_fast(x::T, y::T) where {T<:FloatTypes} = min_float_fast(x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = (min_fast(x, y), max_fast(x, y))

@fastmath begin
cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x<y, -1, +1))
Expand Down Expand Up @@ -236,11 +239,6 @@ ComplexTypes = Union{ComplexF32, ComplexF64}

ne_fast(x::T, y::T) where {T<:ComplexTypes} = !(x==y)

# Note: we use the same comparison for min, max, and minmax, so
# that the compiler can convert between them
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))
end

# fall-back implementations and type promotion
Expand Down
45 changes: 5 additions & 40 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using .Base: sign_mask, exponent_mask, exponent_one,
significand_bits, exponent_bits, exponent_bias,
exponent_max, exponent_raw_max, clamp, clamp!

using Core.Intrinsics: sqrt_llvm
using Core.Intrinsics: sqrt_llvm, min_float, max_float

using .Base: IEEEFloat

Expand Down Expand Up @@ -831,47 +831,12 @@ min(x::T, y::T) where {T<:AbstractFloat} = isnan(x) || ~isnan(y) && _isless(x, y
max(x::T, y::T) where {T<:AbstractFloat} = isnan(x) || ~isnan(y) && _isless(y, x) ? x : y
minmax(x::T, y::T) where {T<:AbstractFloat} = min(x, y), max(x, y)

_isless(x::Float16, y::Float16) = signbit(widen(x) - widen(y))

const has_native_fminmax = Sys.ARCH === :aarch64
@static if has_native_fminmax
@eval begin
Base.@assume_effects :total @inline llvm_min(x::Float64, y::Float64) = ccall("llvm.minimum.f64", llvmcall, Float64, (Float64, Float64), x, y)
Base.@assume_effects :total @inline llvm_min(x::Float32, y::Float32) = ccall("llvm.minimum.f32", llvmcall, Float32, (Float32, Float32), x, y)
Base.@assume_effects :total @inline llvm_max(x::Float64, y::Float64) = ccall("llvm.maximum.f64", llvmcall, Float64, (Float64, Float64), x, y)
Base.@assume_effects :total @inline llvm_max(x::Float32, y::Float32) = ccall("llvm.maximum.f32", llvmcall, Float32, (Float32, Float32), x, y)
end
end

function min(x::T, y::T) where {T<:Union{Float32,Float64}}
@static if has_native_fminmax
return llvm_min(x,y)
end
diff = x - y
argmin = ifelse(signbit(diff), x, y)
anynan = isnan(x)|isnan(y)
return ifelse(anynan, diff, argmin)
function min(x::T, y::T) where {T<:IEEEFloat}
return min_float(x, y)
end

function max(x::T, y::T) where {T<:Union{Float32,Float64}}
@static if has_native_fminmax
return llvm_max(x,y)
end
diff = x - y
argmax = ifelse(signbit(diff), y, x)
anynan = isnan(x)|isnan(y)
return ifelse(anynan, diff, argmax)
end

function minmax(x::T, y::T) where {T<:Union{Float32,Float64}}
@static if has_native_fminmax
return llvm_min(x, y), llvm_max(x, y)
end
diff = x - y
sdiff = signbit(diff)
min, max = ifelse(sdiff, x, y), ifelse(sdiff, y, x)
anynan = isnan(x)|isnan(y)
return ifelse(anynan, diff, min), ifelse(anynan, diff, max)
function max(x::T, y::T) where {T<:IEEEFloat}
return max_float(x, y)
end

"""
Expand Down
32 changes: 32 additions & 0 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ const auto &float_func() {
float_func[sub_float] = true;
float_func[mul_float] = true;
float_func[div_float] = true;
float_func[min_float] = true;
float_func[max_float] = true;
float_func[add_float_fast] = true;
float_func[sub_float_fast] = true;
float_func[mul_float_fast] = true;
float_func[div_float_fast] = true;
float_func[min_float_fast] = true;
float_func[max_float_fast] = true;
float_func[fma_float] = true;
float_func[muladd_float] = true;
float_func[eq_float] = true;
Expand Down Expand Up @@ -1490,6 +1494,34 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, ArrayRef<Va
case sub_float: return math_builder(ctx)().CreateFSub(x, y);
case mul_float: return math_builder(ctx)().CreateFMul(x, y);
case div_float: return math_builder(ctx)().CreateFDiv(x, y);
case min_float: {
assert(x->getType() == y->getType());
FunctionCallee minintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::minimum, ArrayRef<Type*>(t));
return ctx.builder.CreateCall(minintr, {x, y});
}
case max_float: {
assert(x->getType() == y->getType());
FunctionCallee maxintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::maximum, ArrayRef<Type*>(t));
return ctx.builder.CreateCall(maxintr, {x, y});
}
case min_float_fast: {
assert(x->getType() == y->getType());
FunctionCallee minintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::minimum, ArrayRef<Type*>(t));
auto call = ctx.builder.CreateCall(minintr, {x, y});
auto fmf = call->getFastMathFlags();
fmf.setFast();
call->copyFastMathFlags(fmf);
return call;
}
case max_float_fast: {
assert(x->getType() == y->getType());
FunctionCallee maxintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::maximum, ArrayRef<Type*>(t));
auto call = ctx.builder.CreateCall(maxintr, {x, y});
auto fmf = call->getFastMathFlags();
fmf.setFast();
call->copyFastMathFlags(fmf);
return call;
}
case add_float_fast: return math_builder(ctx, true)().CreateFAdd(x, y);
case sub_float_fast: return math_builder(ctx, true)().CreateFSub(x, y);
case mul_float_fast: return math_builder(ctx, true)().CreateFMul(x, y);
Expand Down
4 changes: 4 additions & 0 deletions src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ADD_I(sub_float, 2) \
ADD_I(mul_float, 2) \
ADD_I(div_float, 2) \
ADD_I(min_float, 2) \
ADD_I(max_float, 2) \
ADD_I(fma_float, 3) \
ADD_I(muladd_float, 3) \
/* fast arithmetic */ \
Expand All @@ -25,6 +27,8 @@
ALIAS(sub_float_fast, sub_float) \
ALIAS(mul_float_fast, mul_float) \
ALIAS(div_float_fast, div_float) \
ALIAS(min_float_fast, min_float) \
ALIAS(max_float_fast, max_float) \
/* same-type comparisons */ \
ADD_I(eq_int, 2) \
ADD_I(ne_int, 2) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,8 @@ JL_DLLEXPORT jl_value_t *jl_add_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_sub_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_mul_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_div_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_min_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_max_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_fma_float(jl_value_t *a, jl_value_t *b, jl_value_t *c);
JL_DLLEXPORT jl_value_t *jl_muladd_float(jl_value_t *a, jl_value_t *b, jl_value_t *c);

Expand Down
39 changes: 38 additions & 1 deletion src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1398,13 +1398,50 @@ bi_iintrinsic_fast(LLVMURem, rem, urem_int, u)
bi_iintrinsic_fast(jl_LLVMSMod, smod, smod_int, )
#define frem(a, b) \
fp_select2(a, b, fmod)

un_fintrinsic(neg_float,neg_float)
bi_fintrinsic(add,add_float)
bi_fintrinsic(sub,sub_float)
bi_fintrinsic(mul,mul_float)
bi_fintrinsic(div,div_float)

float min_float(float x, float y) JL_NOTSAFEPOINT
{
float diff = x - y;
float argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

double min_double(double x, double y) JL_NOTSAFEPOINT
{
double diff = x - y;
double argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

#define _min(a, b) sizeof(a) == sizeof(float) ? min_float(a, b) : min_double(a, b)
bi_fintrinsic(_min, min_float)

float max_float(float x, float y) JL_NOTSAFEPOINT
{
float diff = x - y;
float argmin = signbit(diff) ? y : x;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

double max_double(double x, double y) JL_NOTSAFEPOINT
{
double diff = x - y;
double argmin = signbit(diff) ? x : y;
int is_nan = isnan(x) || isnan(y);
return is_nan ? diff : argmin;
}

#define _max(a, b) sizeof(a) == sizeof(float) ? max_float(a, b) : max_double(a, b)
bi_fintrinsic(_max, max_float)

// ternary operators //
// runtime fma is broken on windows, define julia_fma(f) ourself with fma_emulated as reference.
#if defined(_OS_WINDOWS_)
Expand Down

0 comments on commit a861a55

Please sign in to comment.