Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

any function is slow #109

Closed
Zentrik opened this issue Mar 26, 2023 · 5 comments
Closed

any function is slow #109

Zentrik opened this issue Mar 26, 2023 · 5 comments

Comments

@Zentrik
Copy link
Contributor

Zentrik commented Mar 26, 2023

This is the assembly when I was using if any(x > y) as part of a larger function

  vcmpltps        ymm15, ymm2, ymm14
  vextractf128    xmm2, ymm15, 1
  vpackssdw       xmm2, xmm15, xmm2
  vpacksswb       xmm2, xmm2, xmm2
  vpand   xmm2, xmm2, xmm1
  vmovq   rsi, xmm2
  test    rsi, rsi

If we instead do this we can get faster assembly instructions

@eval @generated function fcmp_ogt(x::SIMD.LVec{N, T}, y::SIMD.LVec{N, T}, ::F=nothing) where {N, T <: SIMD.FloatingTypes, F<:SIMD.Intrinsics.FPFlags}
    fpflags = SIMD.Intrinsics.fp_str(F)
    fflag = $(QuoteNode(:ogt))
    s = """
    %res = fcmp $(fpflags) $(fflag) <$(N) x $(SIMD.Intrinsics.d[T])> %0, %1
    %resb = sext <$(N) x i1> %res to <$(N) x i8>
    ret <$(N) x i8> %resb
    """
    return :(
        $(Expr(:meta, :inline));
        Base.llvmcall($s, SIMD.LVec{N, Bool}, Tuple{SIMD.LVec{N, T}, SIMD.LVec{N, T}}, x, y)
    )
end

function horizontal_or(x::SIMD.Vec)
    b = SIMD.Intrinsics.bitcast(SIMD.LVec{8, Float32}, convert(SIMD.Vec{8, Int32}, x).data)
    return ccall("llvm.x86.avx.vtestz.ps.256", llvmcall, Int32, (SIMD.LVec{8, Float32}, SIMD.LVec{8, Float32}), b, b) == 0
end

function test7(x, y) 
    horizontal_or(SIMD.Vec(fcmp_ogt(x.data, y.data)))
end
 
@code_native debuginfo=:none syntax=:intel test7(Vec{8, Float32}(1), Vec{8, Float32}(0))

The relevant instructions here are

vcmpltps ymm0, ymm0, ymmword ptr [rcx]
vtestps ymm0, ymm0

Here we still use vcmpltps, but we use vtestps to implement the any function, which ends up being faster than the instructions used currently.

@Zentrik
Copy link
Contributor Author

Zentrik commented Mar 29, 2023

Looking at #95, we can avoid redefining the comparison operators by instead using

@generated function sext(::Type{T}, x::SIMD.Vec{N, Bool}) where {N,T}
    t = SIMD.Intrinsics.llvm_type(T)
    s = """
    %2 = trunc <$N x i8> %0 to <$N x i1>
    %3 = sext  <$N x i1> %2 to <$N x $t>
    ret <$N x $t> %3
    """
    return :( $(Expr(:meta,:inline)); Vec(Base.llvmcall($s, SIMD.LVec{$N,$T}, Tuple{SIMD.LVec{$N,Bool}}, x.data)) )
end

function SIMD.any(x::SIMD.Vec{8, Bool})
    y = SIMD.Intrinsics.bitcast(SIMD.LVec{8, Float32}, sext(Int32, x).data)
    return ccall("llvm.x86.avx.vtestz.ps.256", llvmcall, Int32, (SIMD.LVec{8, Float32}, SIMD.LVec{8, Float32}), y, y) == 0
end

Even when we have a mask that isn't obtained by a comparison, this implementation should still be faster than the current one. The current one: https://bit.ly/3nwEnf0, new one: https://bit.ly/40F69VH

@KristofferC
Copy link
Collaborator

KristofferC commented Mar 29, 2023

return ccall("llvm.x86.avx.vtestz.ps.256", llvmcall, Int32, (SIMD.LVec{8, Float32}, SIMD.LVec{8, Float32}), b, b) == 0

This is not portable.

I think this is mostly fixed in LLVM 14 which is available on Julia 1.9.

using SIMD
v = Vec(Tuple(rand(Float32, 8))...)
f(x, y) = any(x < y) 
@code_native f(v, v)

Julia 1.8:

        vmovups (%rdi), %ymm0
        vcmpltps        (%rsi), %ymm0, %ymm0
        movabsq $.LCPI0_0, %rax
        vextractf128    $1, %ymm0, %xmm1
        vpackssdw       %xmm1, %xmm0, %xmm0
        vpacksswb       %xmm0, %xmm0, %xmm0
        vpand   (%rax), %xmm0, %xmm0
        vpshufd $85, %xmm0, %xmm1               # xmm1 = xmm0[1,1,1,1]
        vpor    %xmm1, %xmm0, %xmm0
        vpsrld  $16, %xmm0, %xmm1
        vpor    %xmm1, %xmm0, %xmm0
        vpsrlw  $8, %xmm0, %xmm1
        vpor    %xmm1, %xmm0, %xmm0
        vmovd   %xmm0, %eax
                                        # kill: def $al killed $al killed $eax
        vzeroupper
        retq

Julia 1.9:

        pushq   %rbp
        movq    %rsp, %rbp
        vmovups (%rdi), %ymm0
        vcmpltps        (%rsi), %ymm0, %ymm0
        vmovmskps       %ymm0, %eax
        testl   %eax, %eax
        setne   %al
        popq    %rbp
        vzeroupper
        retq

@Zentrik
Copy link
Contributor Author

Zentrik commented Mar 29, 2023

The Julia 1.9 version is nearly as fast as my one so that's good. What do you mean by it's not portable, is it because it only works for 8 wide booleans?

@KristofferC
Copy link
Collaborator

It's only x86 with AVX etc.

@KristofferC
Copy link
Collaborator

I think we can close this based on it being fixed with the new LLVM version that comes with Julia in 1.9.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants