Skip to content

Commit

Permalink
Merge pull request #430 from JuliaArrays/aos_to_soa_ReverseDiff
Browse files Browse the repository at this point in the history
Specialize Array of Structs to Struct of Array for ReverseDiff
  • Loading branch information
ChrisRackauckas authored Mar 4, 2024
2 parents 7c5d183 + 80e220d commit dfa6732
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
17 changes: 10 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.7.1"
version = "7.8.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,6 +14,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

Expand All @@ -22,16 +23,17 @@ ArrayInterfaceBandedMatricesExt = "BandedMatrices"
ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
ArrayInterfaceCUDAExt = "CUDA"
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
ArrayInterfaceReverseDiffExt = "ReverseDiff"
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
ArrayInterfaceTrackerExt = "Tracker"

[compat]
Adapt = "3, 4"
LinearAlgebra = "1.9"
Adapt = "4"
LinearAlgebra = "1.10"
Requires = "1"
SparseArrays = "1.9"
SuiteSparse = "1.9"
julia = "1.9"
SparseArrays = "1.10"
SuiteSparse = "1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand All @@ -41,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -50,4 +53,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker", "ReverseDiff"]
24 changes: 24 additions & 0 deletions ext/ArrayInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module ArrayInterfaceReverseDiffExt

if isdefined(Base, :get_extension)
using ArrayInterface
import ReverseDiff
else
using ..ArrayInterface
import ..ReverseDiff
end

ArrayInterface.ismutable(::Type{<:ReverseDiff.TrackedArray}) = false
ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false
ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal,N}) where {N}
if length(x) > 1
reduce(vcat,x)
else
@show "here?"
reduce(vcat,[x[1],x[1]])[1:1]
end
end

end # module
20 changes: 20 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using ArrayInterface, ReverseDiff, Tracker, Test
x = ReverseDiff.track([4.0])
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
x = [ReverseDiff.track([4.0])[1]]
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
x = [x[1],x[2]]
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray

x = Tracker.TrackedArray([4.0])
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
x = [Tracker.TrackedArray([4.0])[1]]
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
x = Tracker.TrackedArray([4.0,4.0])
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
x = [x[1],x[2]]
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ end
@time @safetestset "BandedMatrices" begin include("bandedmatrices.jl") end
@time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end
@time @safetestset "Core" begin include("core.jl") end
@time @safetestset "AD Integration" begin include("ad.jl") end
@time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end
end

if GROUP == "GPU"
activate_gpu_env()
@time @safetestset "CUDA" begin include("gpu/cuda.jl") end
end
end
end

0 comments on commit dfa6732

Please sign in to comment.