From 4c16785f43c68cd7894a3a92e87f0fccd5daf2da Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Wed, 11 Dec 2024 10:05:23 +0100 Subject: [PATCH] Implemented method 'create_array' for Base.ReshapedArrays. --- src/code.jl | 4 ++++ test/code.jl | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/src/code.jl b/src/code.jl index 4128a39f..7e64af81 100644 --- a/src/code.jl +++ b/src/code.jl @@ -560,6 +560,10 @@ end create_array(P, S, nd, d, elems...) end +@inline function create_array(::Type{<:Base.ReshapedArray{T, N, P}}, S, nd::Val, d, elems...) where {T, N, P} + create_array(P, S, nd, d, elems...) +end + ## SArray @inline function create_array(::Type{<:SArray}, ::Nothing, nd::Val, ::Val{dims}, elems...) where dims SArray{Tuple{dims...}}(elems...) diff --git a/test/code.jl b/test/code.jl index c0520016..209d82b7 100644 --- a/test/code.jl +++ b/test/code.jl @@ -150,6 +150,9 @@ nanmath_st.rewrites[:nanmath] = true @test eval(toexpr(Let([a ← 1, b ← 2, arr ← [1,2]], MakeArray([a b;a+b a/b], arr)))) == [1 2; 3 1/2] + @test eval(toexpr(Let([a ← 1, b ← 2, arr ← [1,2]], + MakeArray(reshape(view([a,b,a+b,a/b], :), 1, 4), arr)))) == [1 2 3 1/2] + @test eval(toexpr(Let([a ← 1, b ← 2, arr ← @SVector([1,2])], MakeArray([a,b,a+b,a/b], arr)))) === @SVector [1, 2, 3, 1/2] @@ -159,6 +162,9 @@ nanmath_st.rewrites[:nanmath] = true @test eval(toexpr(Let([a ← 1, b ← 2, arr ← @SLVector((:a, :b))(@SVector[1,2])], MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2]) + @test eval(toexpr(Let([a ← 1, b ← 2, arr ← reshape(view([1,2], :), 1, 2)], + MakeArray([a,b,a+b,a/b], arr)))) == [1, 2, 3, 1/2] + trackedarr = eval(toexpr(Let([a ← ReverseDiff.track(1.0), b ← 2, arr ← ReverseDiff.track(ones(2))], MakeArray([a+b,a/b], arr)))) @test trackedarr isa ReverseDiff.TrackedArray