diff --git a/stdlib/LinearAlgebra/src/transpose.jl b/stdlib/LinearAlgebra/src/transpose.jl index 8aa04f7d34b48..a36919b2e557a 100644 --- a/stdlib/LinearAlgebra/src/transpose.jl +++ b/stdlib/LinearAlgebra/src/transpose.jl @@ -74,27 +74,32 @@ julia> A ``` """ adjoint!(B::AbstractMatrix, A::AbstractMatrix) = transpose_f!(adjoint, B, A) + +@noinline function check_transpose_axes(axesA, axesB) + axesB == reverse(axesA) || throw(DimensionMismatch("axes of the destination are incompatible with that of the source")) +end + function transpose!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) copyto!(B, A) end function transpose!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) copyto!(B, A) end function adjoint!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) ccopy!(B, A) end function adjoint!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) ccopy!(B, A) end const transposebaselength=64 function transpose_f!(f, B::AbstractMatrix, A::AbstractMatrix) inds = axes(A) - axes(B,1) == inds[2] && axes(B,2) == inds[1] || throw(DimensionMismatch(string(f))) + check_transpose_axes(axes(B), inds) m, n = length(inds[1]), length(inds[2]) if m*n<=4*transposebaselength diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 2c533af37f912..1a66c7430723e 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -703,4 +703,14 @@ end @test B == At end +@testset "error message in transpose" begin + v = zeros(2) + A = zeros(1,1) + B = zeros(2,3) + for (t1, t2) in Any[(A, v), (v, A), (A, B)] + @test_throws "axes of the destination are incompatible with that of the source" transpose!(t1, t2) + @test_throws "axes of the destination are incompatible with that of the source" adjoint!(t1, t2) + end +end + end # module TestAdjointTranspose