Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Parallel Vararg #1698

Closed
wants to merge 8 commits into from
Closed

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Aug 23, 2021

This PR makes it so vararg inputs and layers are treated as zip(layers, inputs) which are then splat into the connection.

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
@ToucheSir
Copy link
Member

This seems out of sync with master? We already have https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl#L447-L449.

@DhairyaLGandhi
Copy link
Member Author

Yeah, there was a small conflict, but I fixed that.

@DhairyaLGandhi DhairyaLGandhi linked an issue Aug 23, 2021 that may be closed by this pull request
@DhairyaLGandhi DhairyaLGandhi removed a link to an issue Aug 23, 2021
@DhairyaLGandhi
Copy link
Member Author

master:

julia> Parallel(+, Dense(3,3))
Error showing value of type Parallel{typeof(+), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}:
ERROR: MethodError: no method matching iterate(::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] trainable(m::Parallel{typeof(+), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}})
    @ Flux ~/Downloads/new_clones/Flux.jl/src/layers/basic.jl:456

@darsnack
Copy link
Member

darsnack commented Aug 23, 2021

This will be breaking since the previous semantics only required that connection was binary and associative. Now connection must handle the vararg case too. This is true for the most common connections like concat and arithmetic operations, so it should only be slightly breaking.

We should still make this change though.

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Aug 23, 2021

connections already need to handle binary ops, so that would still be the exact same way. It doesn't need to handle vararg, but now they can.

Of course, the N>3 case would need all the outputs. Maybe vararg is the wrong thing to call this.

@darsnack
Copy link
Member

darsnack commented Aug 23, 2021

It doesn't need to handle vararg, but now they can.

Not quite, if there are 5 branches, then connection will need to handle 5 arguments now. Previously, connection could only be binary and the 5 branch case would still reduce correctly.

@mcabbott
Copy link
Member

Agree we should do this, and that it's breaking.

While breaking things, it should probably be adjusted to handle the 3rd case of #1685

The contract for Parallel has 3 supported cases:

N layers, 1 input
N layers, N inputs
1 layer, N inputs

Needs tests & doc updates, obviously.

@DhairyaLGandhi
Copy link
Member Author

I think the last case is something better handled on the user side. If someone has multiple inputs, it's easier to have a method that accepts a tuple of arguments and forward that as necessary than it is for us to guess where these arguments go.

(m::Layer)(x::Tuple) = m(x[1], x[2], ...)

For the Parallel layer,

struct MultiInput{T}
  W::T
end
(m::MultiInput)(x, y) = m.W * x, m.W * y, x * y

x = (rand(3,3), rand(3,3))

l = Parallel((x...) -> identity.(x),         # This will be annoying to deal with
              MultiInput(rand(3,3)),
              MultiInput(rand(3,3)))

This would need two things: One is that we remove (m::Parallel)(xs::Tuple) = m(xs...) which is a win since it is plausible to want multiple inputs and they usually (or preferably) come as tuples. And two, is that users have a method that accepts Tuple.

Another case that this handles better is N inputs M outputs. All these cases are subsets of treating inputs and outputs as something generic that is better left to the user.

@darsnack
Copy link
Member

darsnack commented Aug 24, 2021

As discussed here, the tuple method exists to make Chain work. And it seems like the existing behavior is the desired behavior based on feedback from users in that issue. The previous approach handled MIMO cases too. I would revert the changes made to remove this method.

I am not sure what 1 Layer, N Inputs is even supposed to do? Replicate the layer onto each input? That seems out of the contract for Parallel.

@DhairyaLGandhi
Copy link
Member Author

Right and the distinction to make there is considering tuple as a single input always. If users want the elements of the tuple to be inputs, they can splat, else pass the tuple along.

That way we don't have to wrap tuples multiple times so that the automatic splat produces the correct tuple inputs to layers that expect multiple inputs.

One other case would be how MIMO would look when we have Parallel nested within Parallel.

@darsnack
Copy link
Member

darsnack commented Aug 24, 2021

considering tuple as a single input always

That way we don't have to wrap tuples multiple times

Wrapping the output like ((x, y),) is the Julia semantics for a single return value that is a tuple. (x, y) as a return value is always considered multiple outputs in Julia (and consequently should be multiple inputs to the successive layer).

If users want the elements of the tuple to be inputs, they can splat, else pass the tuple along.

Users don't control how Chain calls each layer, and Chain does not splat. It is not possible to insert a splatting layer before the Parallel in a Chain in this design. In contrast, with the previous design, you could insert (x...) -> tuple(x) to condense multiple outputs into a single entity (or do this when you return as well).

@darsnack
Copy link
Member

A way to handle the 1 Layer, N Inputs case is to do Iterators.cycle(m.layers) though this would break the zip-behavior that I think is more desirable.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 24, 2021

I am not sure what 1 Layer, N Inputs is even supposed to do? Replicate the layer onto each input?

If by replicate you mean "apply" and not "make a copy of, and then apply", then yes. What the layer that does this should be called is up for debate, but our discussion over TimeDistributed has shown there is a desire for something like this.

RE splatting/multiarg over tuples, I feel it is such a deep rabbit hole that we ought to avoid it as much as possible. Trying to guess user intent is nothing if not fraught, and it's better to be strict and consistent rather than inconsistently lenient.

@darsnack
Copy link
Member

If by replicate you mean "apply" and not "make a copy of, and then apply", then yes.

Yeah, that's what I meant, and the RNN discussion is what was in the back of my head.

it's better to be strict and consistent rather than inconsistently lenient.

Just to clarify, what's consistent here? As I see it, Julia itself, Chain, and Bilinear all treat a tuple output as multiple values. So the consistent thing is maintain the tuple behavior in Parallel.

@DhairyaLGandhi
Copy link
Member Author

Chain passes along whatever input it gets, and explicitly expects tuples for multiple argument cases. It doesn't have a multiple argument forward pass either (partly for this reason). The first transform is expected to then handle this tuple. Julia itself has seen invariant tuples as a good design goal (see JuliaLang/julia#24614) and NamedTuples already are.

@ToucheSir
Copy link
Member

Just to clarify, what's consistent here? As I see it, Julia itself, Chain, and Bilinear all treat a tuple output as multiple values. So the consistent thing is maintain the tuple behavior in Parallel.

Currently, Chain expects every child layer to take 1 input and return 1 output. Those may be composite types like (named)tuples, but Chain itself will never attempt to unpack them. So I think we're on the same page here.

@darsnack
Copy link
Member

Chain passes along whatever input it gets, and explicitly expects tuples for multiple argument cases. It doesn't have a multiple argument forward pass either (partly for this reason). The first transform is expected to then handle this tuple. Julia itself has seen invariant tuples as a good design goal (see JuliaLang/julia#24614) and NamedTuples already are.

I don't think anyone disagrees with this. But regardless of how we pass multiple arguments, the expected (and requested) behavior of Parallel is to map over multiple inputs like zip(layers, xs). Removing the tuple method (as is currently done in this PR) is basically doing the opposite of what you are arguing. Instead of the transforms "handl[ing] this tuple" according to the Parallel semantics, it is compressing/packing the tuple into a single argument like a single array.

If we want to be really strict about "multiple inputs/outputs are always tuples" then we should eliminate the Vararg methods instead — forcing people to call a single Parallel like (m::Parallel)((x, y, z, ...)).

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Aug 24, 2021

Currently, Chain expects every child layer to take 1 input and return 1 output. Those may be composite types like (named)tuples, but Chain itself will never attempt to unpack them.

Correct. Passing along the composite type is what we are doing here.

par = Par(f, l1, l2)

Said another way, the tuple method doesn't override the case that a function expects multiple arguments, l1 may have a different arity to l2. So the correct thing would be to say if par is called with a single (composite)object, then send it to every layer else send the respective input to the respective layer. (zip(layers, xs))

@darsnack
Copy link
Member

Correct. Passing along the composite type is what we are doing here.

No it isn't. Under this PR, the only way for a Chain of layers containing a Parallel to execute the connection(map(..., zip(layers, xs))...) behavior is for Chain to splat the output of every layer into the next one. That's explicitly not "passing along the composite." The Chain changes are not yet in the PR, but it will have to be, or else I don't think there's any other way to make the zip execution happen.

You're thinking about MIMO in the context of branches, but missing MIMO in the context of the complete Parallel layer.

@DhairyaLGandhi
Copy link
Member Author

See the edit, i was replying to Brian, my browser hadn't updated the comments yet.

@darsnack
Copy link
Member

darsnack commented Aug 24, 2021

So the correct thing would be to say if par is called with a single (composite)object, then send it to every layer else send the respective input to the respective layer. (zip(layers, xs))

Okay so if I have Chain(MyThreeOutputLayer, Parallel(+, Dense, Dense, Dense)), how do you propose that the three outputs from MyThreeOutputLayer are encoded and passed to the Parallel so that each Dense operates on one of the three outputs? The only way to "send the respective input to the respective layer" is for Chain to splat between MyThreeOutputLayer and Parallel. Unless I've totally missed something.

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Aug 24, 2021

It is likely that a model (produces multiple outputs/ receives multiple inputs) in the Chain, but the output of the Parallel is defined by the combinator, not the branches themselves. So IMO its more pressing to make the input representation reliable.

the only way for a Chain [...] splat the output of every layer into the next one.

I understand what you mean by this. The question is: how to distinguish a tuple expected to be sent to the layers as is from a tuple that needs to be mapped. The current design will always flatten a tuple input, breaking the contract of N inputs.

For the MIMO case, master would mean (::MIMO)(x,x) works but (::MIMO)(x) can't for tuple x no matter how much it is wrapped in tuples. This needs to be rectified since since then we can't share multiple arguments over several branches or mix functions with multiple and single arguments in the Parallel easily in a Chain.

@ToucheSir
Copy link
Member

I wonder if, instead of having each individual layer handle this, we define a common wrapper layer which solely handles splatting tuples into varargs:

struct Splat{T} # arbitrarily chosen name
  inner::T
end
@functor Splat
(s::Splat)(x) = s.inner(x...) # rough version, probably needs more error checking

Then instead of having (m::Parallel)(xs::Tuple) or (a::Bilinear)(x::NTuple{2, AbstractArray}), one would write just the n-ary function and pass Splat(Parallel(...)) or Splat(Binlinear(...)) to a container layer.

The main concern I have with this is compilation overhead. AIUI splatting large tuples is quite slow, and having a bunch of differing input lengths would also trigger recompilation. If those turn out to be a non-issue, however, then I would advocate for a separate layer.

@DhairyaLGandhi
Copy link
Member Author

My main concern is this makes simple operators like splats closer to DSLs, so I would avoid such an approach.

@ToucheSir
Copy link
Member

It's a fine line, isn't it? We had a similar discussion in #1289, and though I was very much against Parallel at the time, it's hard to argue with the utility of it now.

One thing to note is that this is one area other frameworks do very poorly in that we could do well. Maybe a dedicated splat layer (call it Unpack[Input][Tuple] or something if that helps reduce the DSL-ness) is crossing the line, but having orthogonal, re-usable building blocks is quite Julian.

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Aug 24, 2021

Well, we'll probably need a better reason or a motivating example.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 24, 2021

Is being stuck in PR review limbo a good one? Because I have a feeling the discussion above will be recapitulated every time we talk about adding new container layers...

Edit: removed useless strained analogy, see my next comment about bringing this into a synchronous design discussion.

@darsnack
Copy link
Member

darsnack commented Aug 24, 2021

The current design will always flatten a tuple input, breaking the contract of N inputs.

Which contract? IMO master doesn't break any contracts, but the current version of this PR breaks the contract for Parallel and fixing this breakage means breaking the contract for Chain.

For the MIMO case, master would mean (::MIMO)(x,x) works but (::MIMO)(x) can't for tuple x no matter how much it is wrapped in tuples, which needs to be rectified since since then we can't share multiple arguments over several branches or mix functions with multiple and single arguments in the Parallel easily in a Chain.

The solution here is to be explicit. I do see what you mean though — wrapping multiple outputs (which is a tuple) in another tuple will not help when there are multiple MIMO branches.

The question is: how to distinguish a tuple expected to be sent to the layers as is from a tuple that needs to be mapped

Instead of guessing how the user wants the inputs distributed, we stay strict to the Parallel contract of zip(layers, xs). The user must specify xs accordingly. In other words, Parallel is always a MIMO layer itself. Multiple inputs by definition of zip(layers, xs) semantics, and possibly multiple outputs depending on the combinator (the output of Parallel is not so relevant to our discussion though). The original code makes this opaque, but a better implementation would be:

(m::Parallel)(xs::Tuple) = m.connection(map((f, x) -> f(x), m.layers, xs)...)
(m::Parallel)(x) = m((x,))
(m::Parallel)(xs...) = m(xs)

This makes it clear that the zip(layers, xs) contract implies that Parallel in multi-input always, and that we handle multi-input as a tuple like we've all been agreeing on. The single input and Vararg-type cases are just convenience features.

Let's use concrete examples to avoid confusion. Here are some cases:

# SingleOutput produces something other than a tuple which goes to each Dense
Chain(SingleOutput, Parallel(+, Dense, Dense, Dense)
# MultiOutput produces multiple outputs as a tuple, each of which are passed to each Dense
Chain(MultiOutput, Parallel(+, Dense, Dense, Dense)
# MultiOutput produces multiple outputs
# We explicitly state that this should be kept as one unit to each MultiInput
Chain(MultiOutput, (x...) -> (x, x, x), Parallel(+, MultiInput, MultiInput, MultiInput)
# Each branch takes in differing number of inputs
Chain(ThreeOutput, (x, y, z) -> (x, (y, z)), Parallel(+, Dense, TwoInput))
# Multiple outputs from Parallel is easily handled by the combinator
Chain(Dense, Parallel(MIMO, Dense, Dense, Dense))

Note that stuff like (x, y, z) -> (x, (y, z)) can be explicitly in the Chain if it needs customization based on the preceding and succeeding layers, or it can be sucked into the definition of ThreeOutput, etc. itself if it is always consistent. This is up to the user.

Now, I agree that some of these, like (x...) -> (x, x, x), aren't pretty. But we can always add utilities to make things more natural.

What I am trying to point out is that it is not possible to go in the other direction where we take a tuple and break it into multiple outputs to hit the vararg case. Unless we use something like @ToucheSir's Splat. What I like about the Splat proposal is that multiple inputs are always Vararg. It's very explicit and consistent, so I think it would scale to future nested MIMO behavior. The downside is that it makes since more verbose for the Chain(MultiOutput, Parallel(+, Dense, Dense, Dense) which I think is the more common case.

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Aug 24, 2021

Not sure how to interpret Brian's comment, but to Kyle's point about the dense layers, I had a test that showed using multiple layers each receiving multiple inputs. I would prefer to get the invariant tuples done too, I think it's doable. Either way this PR is an improvement.

Master will always flatten a tuple input, breaking the contract of N inputs, which I was trying to iterate on.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 24, 2021

I think your browser may have missed another comment update. My point was that we've had a number of PRs that are stalled on design disagreements, and even as somebody who didn't author any of them I feel some frustration about not being able to come to some decision. Perhaps we should arrange some time for every ML/AD call to walk through design discussions? Synchronous communication should be far more efficient, and it'd also address weeks when there aren't enough ecosystem updates to fill the alloted time.

@ToucheSir
Copy link
Member

Funnily enough, if JuliaLang/julia#42717 lands then Base may have resolved this debate for us :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants