Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqrt_fastmath does not emit fast flag to LLVM, prevents lowering to rsqrt #33220

Closed
smallnamespace opened this issue Sep 11, 2019 · 6 comments
Closed

Comments

@smallnamespace
Copy link
Contributor

LLVM will lower an inverse sqrt to rsqrt instruction in fast mode, e.g.: https://godbolt.org/z/wMLxq1

However, sqrt_fast is missing the fast flag, preventing LLVM's lowering.

On Skylake+ this less important, however for older architectures using rsqrt can be ~2-3x as fast, e.g. on Ivy Bridge, 6 vs. 14 cycles.

@vchuravy
Copy link
Member

This will require adding a new intrinsics sqrt_fast_llvm and do the appropriate codegen in intrinsics.cpp if you want to take a stab at it.

smallnamespace added a commit to smallnamespace/julia that referenced this issue Sep 12, 2019
Note: requires LLVM 7+ to generatic rsqrt intrinsics
@smallnamespace
Copy link
Contributor Author

smallnamespace commented Sep 12, 2019

Thanks, added PR. I verified generation of test flag in IR:

rfsq(x::Float32) = 1.0f0 / sqrt(x)
@fastmath rfsq_fast(x::Float32) = 1.0f0 / sqrt(x)

code_llvm(rfsq, (Float32,))
code_llvm(rfsq_fast, (Float32,))

outputs (no fast-math):

define float @julia_rfsq_16835(float) {
top:
; ┌ @ math.jl:493 within `sqrt'
; │┌ @ float.jl:457 within `<'
    %1 = fcmp uge float %0, 0.000000e+00
; │└
   br i1 %1, label %L5, label %L3

L3:                                               ; preds = %top
   call void @julia_throw_complex_domainerror_16836(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 4618632984 to %jl_value_t*) to %jl_value_t addrspace(10)*), float %0)
   call void @llvm.trap()
   unreachable

L5:                                               ; preds = %top
; │ @ math.jl:494 within `sqrt'
   %2 = call float @llvm.sqrt.f32(float %0)
; └
; ┌ @ float.jl:406 within `/'
   %3 = fdiv float 1.000000e+00, %2
; └
  ret float %3
}

and (fast-math)

define float @julia_rfsq_fast_16837(float) {
top:
; ┌ @ fastmath.jl:280 within `sqrt_fast'
   %1 = call fast float @llvm.sqrt.f32(float %0)
; └
; ┌ @ fastmath.jl:164 within `div_fast'
   %2 = fdiv fast float 1.000000e+00, %1
; └
  ret float %2
}

vchuravy added a commit that referenced this issue Nov 14, 2019
@non-Jedi
Copy link
Contributor

non-Jedi commented Apr 7, 2020

I see the call fast when using sqrt but not when using with julia 1.4. Is this expected? As far a I know, the two forms are supposed to be identical.

 
julia> f(x) = @fastmath 1 / sqrt(x)
f (generic function with 1 method)

julia> g(x) = @fastmath 1 / x
g (generic function with 1 method)

julia> @code_llvm f(2f0)

;  @ REPL[1]:1 within `f'
define float @julia_f_17320(float) {
top:
; ┌ @ fastmath.jl:280 within `sqrt_fast'
   %1 = call fast float @llvm.sqrt.f32(float %0)
; └
; ┌ @ fastmath.jl:263 within `div_fast' @ fastmath.jl:164
   %2 = fdiv fast float 1.000000e+00, %1
; └
  ret float %2
}

julia> @code_llvm g(2f0)

;  @ REPL[2]:1 within `g'
define float @julia_g_17321(float) {
top:
; ┌ @ math.jl:557 within `sqrt'
; │┌ @ float.jl:457 within `<'
    %1 = fcmp uge float %0, 0.000000e+00
; │└
   br i1 %1, label %L5, label %L3

L3:                                               ; preds = %top
   %2 = call nonnull %jl_value_t addrspace(10)* @julia_throw_complex_domainerror_17322(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140619689496976 to %jl_value_t*) to %jl_value_t addrspace(10)*), float %0)
   call void @llvm.trap()
   unreachable

L5:                                               ; preds = %top
; │ @ math.jl:558 within `sqrt'
   %3 = call float @llvm.sqrt.f32(float %0)
; └
; ┌ @ fastmath.jl:263 within `div_fast' @ fastmath.jl:164
   %4 = fdiv fast float 1.000000e+00, %3
; └
  ret float %4
}

@yuyichao
Copy link
Contributor

yuyichao commented Apr 7, 2020

Fast math macro is based on name. It seems that the Unicode symbol isn't handled.

KristofferC pushed a commit that referenced this issue Apr 11, 2020
Note: requires LLVM 7+ to generatic rsqrt intrinsics
@ron-wolf
Copy link
Contributor

@non-Jedi @yuyichao Is this Unicode issue still the case? I'm not sure how to inspect the LLVM output to verify.

@chriselrod
Copy link
Contributor

chriselrod commented Feb 11, 2021

You don't need to inspect LLVM, just checking the macro shows that it doesn't transform the unicode symbol:

julia> @macroexpand @fastmath sqrt(x)
:(Base.FastMath.sqrt_fast(x))

julia> @macroexpand @fastmath (x)
:(x)

But if you want to:

; julia> @code_llvm Base.FastMath.sqrt_fast(1.2)
;  @ fastmath.jl:284 within `sqrt_fast'
define double @julia_sqrt_fast_2302(double %0) {
top:
  %1 = call fast double @llvm.sqrt.f64(double %0)
  ret double %1
}

Note the call fast double @llvm.sqrt.f64 above (particularly the "fast" part), versus:

; julia> @code_llvm √(1.2)
;  @ math.jl:590 within `sqrt'
define double @julia_sqrt_2371(double %0) {
top:
;  @ math.jl:591 within `sqrt'
; ┌ @ float.jl:395 within `<'
   %1 = fcmp uge double %0, 0.000000e+00
; └
  br i1 %1, label %L5, label %L3

L3:                                               ; preds = %top
  %2 = call nonnull {}* @j_throw_complex_domainerror_2373({}* inttoptr (i64 139974361904792 to {}*), double %0)
  call void @llvm.trap()
  unreachable

L5:                                               ; preds = %top
;  @ math.jl:592 within `sqrt'
  %3 = call double @llvm.sqrt.f64(double %0)
  ret double %3
}

The error checking is still there, and it's just call double @llvm.sqrt.f64.
If you want to skip the domain check but don't want the fast flag:

julia> sqrt_llvm(x) = Base.sqrt_llvm(x)
sqrt_llvm (generic function with 1 method)

julia> @code_llvm sqrt_llvm(1.2)

Base.sqrt_llvm isn't a generic function, so I need to wrap it for @code_llvm to work.

;  @ REPL[15]:1 within `sqrt_llvm'
define double @julia_sqrt_llvm_2366(double %0) {
top:
  %1 = call double @llvm.sqrt.f64(double %0)
  ret double %1
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants