Skip to content

Commit

Permalink
Call MulAddMul instead of multiplication in _generic_matmatmul! (#5…
Browse files Browse the repository at this point in the history
…6089)

Fix https://github.com/JuliaLang/julia/issues/56085 by calling a newly
created `MulAddMul` object that only wraps the `alpha` (with `beta` set
to `false`). This avoids the explicit multiplication if `alpha` is known
to be `isone`.
  • Loading branch information
jishnub authored Oct 15, 2024
1 parent d749f0e commit 0af99e6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))

@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul) where {T,S,R}
_add::MulAddMul{ais1}) where {T,S,R,ais1}
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
Expand All @@ -935,11 +935,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
Balpha = B[k,n]*_add.alpha
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
Balpha = _rmul_alpha(B[k,n])
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1130,4 +1130,22 @@ end
@test a * transpose(B) A * transpose(B)
end

@testset "issue #56085" begin
struct Thing
data::Float64
end

Base.zero(::Type{Thing}) = Thing(0.)
Base.zero(::Thing) = Thing(0.)
Base.one(::Type{Thing}) = Thing(1.)
Base.one(::Thing) = Thing(1.)
Base.:+(t::Thing...) = +(getfield.(t, :data)...)
Base.:*(t::Thing...) = *(getfield.(t, :data)...)

M = Float64[1 2; 3 4]
A = Thing.(M)

@test A * A M * M
end

end # module TestMatmul

0 comments on commit 0af99e6

Please sign in to comment.