Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to fix #62 #63

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) =

function _trainable_biwalk(f, x, aux)
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
au = _aux_children(aux)
_trainmap(f, ch, _trainable(x), au) |> re
end

_aux_children(off) = functor(off)[1]
_aux_children(off::AbstractArray) = off # leaflike according to Functors, but we need to see each offset

function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
Expand All @@ -103,7 +106,7 @@ end

function _Tangent_biwalk(f, x, aux) # use with prune = NoT
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
au = _aux_children(aux)
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
Expand All @@ -126,7 +129,7 @@ ChainRulesCore.@non_differentiable _zero(x)
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
off′ = _aux_children(off)
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
flat
end
Expand Down
43 changes: 42 additions & 1 deletion test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
@test g8[1].x == [2,4,6]
@test g8[2].b.x == [8]
@test g8[3] == [[10.0]]
@test g8[3] == [[10.0]] # fails

g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
@test g9.c === nothing
Expand Down Expand Up @@ -180,3 +180,44 @@ end
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
end

@testset "issue 62" begin
# Flux.Chain used to have children which aren't its own fields, which Skip immitates.

sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0]))
@test fmap(identity, sk) == sk

gk = gradient(x -> sum(x[2].y), sk)[1]
@test fmap(Zygote.accum, sk, gk) isa Skip # this relies on functor(typeof(x), dx)

st = fmapstructure(identity, sk)
@test st isa Tuple{Vector, NamedTuple}
@test_throws Exception fmap(+, sk, st) # this fails because of functor(typeof(x), dx)

v, re = destructure(sk)
@test v == [1,2,4,5]
@test re(10v) isa Skip
@test re(10v)[1] == [10, 20]

@test gradient(zero(v)) do w
re(w)[2].y[1]
end == ([0,0,1,0],)

gradient(sk) do x
w, _ = destructure(x)
w[1]
end
#=

ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.
Stacktrace:
[1] _backing_error(P::Type, G::Type, E::Type)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:62
[2] ChainRulesCore.Tangent{Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}}(backing::Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:36
[3] _Tangent_biwalk(f::Function, x::Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, aux::Tuple{Int64, NamedTuple{(:x, :y), Tuple{Tuple{}, Int64}}})
@ Optimisers ~/.julia/dev/Optimisers/src/destructure.jl:116

=#

end
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

struct Skip{T} # like Flux 0.12's Chain
layers::T
Skip(ls...) = new{typeof(ls)}(ls)
end
Base.getindex(x::Skip, i::Integer) = x.layers[i]
Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...)

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down Expand Up @@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
end

@testset "issue 62" begin
m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0])
s62 = Optimisers.setup(Descent(), m62)
g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62)
s, m = Optimisers.update(s62, m62, g62...)
@test m.s isa Skip
@test m.s[2].x ≈ [2.9]
@test m.t ≈ [4, 4.7]
end

end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down