Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested derivatives do not work for functions with arrays as argument #148

Open
MariusDrulea opened this issue Sep 6, 2023 · 19 comments
Open

Comments

@MariusDrulea
Copy link

MariusDrulea commented Sep 6, 2023

See the following code.
Pseudocode + explanations:

f = x1 * x2
grad_f = (x2, x1)
fg = sum(grad_f) = x2+x1
grad_fg = [1, 1]

Actual code:

using Tracker

x = param([1, 2])  # Tracked 3-element Vector{Float64}
f(x) = prod(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)

The value of gg is [-2.0, -0.5] instead of [1, 1]

@MariusDrulea MariusDrulea changed the title Incorrect hessian Nested derivatives do not work Sep 6, 2023
@MariusDrulea
Copy link
Author

Another related example, using sum instead of prod:
Pseudocode + explanations:

f = x1 + x2
grad_f = (1, 1)
fg = sum(grad_f) = 2
grad_fg = [0, 0]

Actual code:

using Tracker

x = param([1, 2])  # Tracked 3-element Vector{Float64}
f(x) = sum(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)

This code gives the following: ERROR: MethodError: no method matching back!(::Float64)

@MariusDrulea
Copy link
Author

As a side info, the hessian of the prod function works in AutoGrad.jl:

using AutoGrad

x = Param([1,2,3])		# user declares parameters
p(x) = prod(x)

hess(f,i=1) = grad((x...)->grad(f)(x...)[i])
hess(p, 1)(x)
hess(p, 2)(x)
hess(p, 3)(x)

Returns the correct result:

3-element Vector{Float64}:
 0.0
 3.0
 2.0

3-element Vector{Float64}:
 3.0
 0.0
 1.0

3-element Vector{Float64}:
 2.0
 1.0
 0.0

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 13, 2023

Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays.

using Tracker

x1 = param(1)
x2 = param(2)

f(x1, x2) = x1*x2
∇f(x1, x2) = gradient(f, x1, x2, nest=true)

H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]
H12 = gradient(x2->∇f(x1, x2)[1], x2)[1]
H21 = gradient(x1->∇f(x1, x2)[2], x1)[1]
H22 = gradient(x2->∇f(x1, x2)[2], x2)[1]

H = [H11 H12; H21 H22]

Result:

2×2 Matrix{Tracker.TrackedReal{Float64}}:
 0.0  1.0
 1.0  0.0

@MariusDrulea
Copy link
Author

Higher order derivatives of functions with a single argument, work correctly:

using Tracker

f(x) = sin(cos(x^2))
df(x) = gradient(f, x, nest=true)[1]
d2f(x) = gradient(u->df(u), x, nest=true)[1]
d3f(x) = gradient(u->d2f(u), x, nest=true)[1]

x0 = param(1)
df(x0) # -1.4432122981268867 (tracked)
d2f(x0) # -4.7534826540186135 (tracked)
d3f(x0) # -5.683220233612525 (tracked)

@MariusDrulea MariusDrulea changed the title Nested derivatives do not work hessian of the prod function gives incorrect results Sep 14, 2023
@MariusDrulea MariusDrulea changed the title hessian of the prod function gives incorrect results Nested derivatives do not work for functions with arrays as argument Sep 14, 2023
@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 14, 2023

Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays.
f(x1, x2) = x1*x2
∇f(x1, x2) = gradient(f, x1, x2, nest=true)
H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]

I tried to replicate this approach with the following code, but yields incorrect results:

∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x)
hess1([1, 2]) # ([-2.0, 0.0] (tracked),)

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 16, 2023

Getting closer. I have manually created the corresponding graph for the gradient and the hessian. The rrules used are correct. There is something wrong with the tracker algo in the "nested" gradients. The extra graph to record the derivatives is perhaps not created correctly.

using Tracker

p=x->prod(x)
q=x->p(x)./x # the hard-coded gradient
jacobian(q, x) # jacobian of the gradient = hessian

Gives the correct result:

Tracked 2×2 Matrix{Float64}:
 0.0  1.0
 1.0  0.0

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 20, 2023

I have created a function to print the graph under a specific node:
https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L140

One can notice the Tracker does not record the methods performed, but the pullback of these methods. This is because we only need the pullbacks when we do back-propagation in the graph.

using Tracker
f(x) = prod(x)
∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x, nest=true)

h1 = hess1([3, 4])
Tracker.print_graph(stdout, h1[1])

# Prints the following
TrackedData
-data=[-1.3333333333333333, 0.0]
-Tracker=
--isleaf=false
--grad=UndefInitializer()
--Call=
---f
----back
-----func.f=+
-----func.args=([0.0, 0.0], [-1.3333333333333333, -0.0] (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=UndefInitializer()
-----Call=
------f
-------back
--------func.f=partial
--------func.args=(Base.RefValue{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}}(Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}(Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}(Base.Broadcast.var"#11#13"()), Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}(Base.Broadcast.var"#17#19"())), Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}(Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}(Base.Broadcast.var"#27#28"())), Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}(Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}(Base.Broadcast.var"#23#24"())), /), *)), [1.0, 0.0] (tracked), 2, 12.0, [3.0, 4.0] (tracked), 1)
------args
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), [1.0, 0.0] (tracked), 2, [0.0, 0.0], [4.0, 3.0] (tracked))
---------args
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------#714
--------------func.f=getindex
------------args
-------------nothing
-------------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------back
--------------func.f=#12
--------------func.args=(12.0, [3.0, 4.0] (tracked), 1)
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=UndefInitializer()
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=true
-----------------grad=[0.0, 0.0]
-----------------Call=
------------------f
------------------args
-------------nothing
-------nothing
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------#718
---------args
----------Tracker=
-----------isleaf=true
-----------grad=[0.0, 0.0]
-----------Call=
------------f
------------args
-------nothing

@mcabbott
Copy link
Member

Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.

I do not know what's going wrong in these cases, I never thought much about how this package handles second derivatives. There is a way to mark rules as only suitable for first derivatives, which the prod rule does not use.

@MariusDrulea
Copy link
Author

I do suspect some broadcasting rule does not work correctly. If you look in the graph above there is a long chain of broadcasts, which looks weird.
Screenshot from 2023-09-20 21-31-34

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 20, 2023

This is the graph of the example which works correctly. Looks clean.

using Tracker
x1 = param(1)
x2 = param(2)

f_(x1, x2) = x1*x2
∇f_(x1, x2) = gradient(f_, x1, x2, nest=true)
H12 = gradient(x2->∇f_(x1, x2)[1], x2, nest=true)[1]

##Prints the following
Tracker.print_graph(stdout, H12)
TrackedData
-data=1.0
-Tracker=
--isleaf=false
--grad=0.0
--Call=
---f
----back
-----func.f=+
-----func.args=(0.0, 1.0 (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=0.0
-----Call=
------f
-------#229
--------func.b=1
------args
-------Tracker=
--------isleaf=false
--------grad=0.0
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), 1, 2, 0.0, 2.0 (tracked))
---------args
----------nothing
----------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=0.0
-----------Call=
------------f
-------------#230
--------------func.a=1
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=0.0
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=false
-----------------grad=0.0
-----------------Call=
------------------f
-------------------#718
------------------args
-------------------Tracker=
--------------------isleaf=true
--------------------grad=1.0
--------------------Call=
---------------------f
---------------------args
-------nothing

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 20, 2023

Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.

I have reviewed the back-propagation algo for both simple and nested gradients and everything seem just correct.
So probably is a good idea now to switch to ChainRules as these rules are more robust.
Another internal design decision is whether to record the original function or the pullback function in the graph. The algo will be mostly the same in both of these options. Currently the pullback is stored, but this makes the printing and visualization of the graph a bit more difficult (so is debugging). I also don't know if storing the pullback has any advantage over storing the original function, so I would be tempted to store the original function instead and invoke the rrule of this original function only during back-propagation in the graph.

To summarize, 3 actions are needed to improve the robustness and ease of development of this package:

  • switch to ChainRules
  • in the graph, store the original function instead of the pullback
  • document the code and the public functions, generate docs

@ToucheSir
Copy link
Member

Calling the rrule only during the backwards pass would probably not work great, because it'd end up with us recomputing the primal. Did you perhaps mean calling the pullback function rrule returns? I know Zygote and Tracker have not been great at giving those names, but ChainRules is pretty good about doing so. If I call rrule(f, args...), I'll probably get a (primal, pullback::typeof(f_pullback)) back.

@MariusDrulea
Copy link
Author

@ToucheSir you are right, we have to store the pullback, not the original function.
This is also clear now to me: https://discourse.julialang.org/t/second-order-derivatives-with-chainrules/103606

Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91
Second derivative still do not work, because not all operations are tracked and sent to ChainRules, I just added +(x, y) for instance: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/lib/array.jl#L455.

Next is to:

  • continue integration of ChainRules: track some missing operations (like x+y=track(+, x, y), where x and y are TrackedArrays etc
  • make sure all first order derivatives work
  • fix second order derivatives
  • remove obsolete rules defined by Tracker
  • document the code & generate documentation

The following first order derivatives works via ChainRules:

using Tracker
f2(x, y) = prod(x+y)
dx = gradient(f2, [3, 4], [7, 9], nest=true)

Prints:

[ Info: Chainrules for +
[ Info: Chainrules for prod
([13.0, 10.0], [13.0, 10.0])

@MariusDrulea
Copy link
Author

Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91

Not that easy, son. It is easy only for the first order derivatives. I do have to further track operations performed by these first order derivatives, such that the second order derivatives can be called.

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 22, 2023

I'm currently stuck as I don't know how to deal with the ChainRules.rrule(s) here.

The logic in Tracker is to define an untracked forward pass and a tracked pullback and the AD engine will do the rest.

  • f(x::TrackedReal) = f(data(x)) # untracked forward pass, data(x) is Float64
  • back(d)=d*∂f(x) # tracked pullback, x is a TrackedReal

What ChainRules offers:

y, back = rrule(sin, x::TrackedReal)

This would yield sin(x) (tracked) and back also tracked.
What we want to get is sin(data(x)) (not tracked) and back tracked. So basically I want to be able to modify the forward pass.
With diffrules this is easily achievable as we can define our forward and pullback separately, e.g.

sin(x) = sin(data(x)) # untracked
DiffRules.diffrule(sin, x) = cos(x) # tracked, implicitly provided by DiffRules

@ToucheSir, @mcabbott Any idea? Or ping somebody who can help?

@ToucheSir
Copy link
Member

This would yield sin(x) (tracked) and back also tracked.
What we want to get is sin(data(x)) (not tracked) and back tracked.

I'm not sure I understand, why shouldn't it be sin(x) (tracked)? If the primal output is not tracked, AD will just stop working for subsequent operations.

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 23, 2023

The tracking of sin(x) is done by the Tracker.jl logic, not by rrule.
For sin this is what happens:
sin(xs::TrackedReal) = track(sin, xs::TrackedReal)

The definition of track for sin performs like this:

function track(sin, xs::TrackedReal)
  # y is not tracked as we only perform sin over data(xs)
  # the pullback, cos(xs) is tracked as we perform cos on a TrackedReal object - essentially a track function will be called for it, just like for sin
  y, back = sin(data(xs)), cos(xs) # cos(xs) is the pullback provided by DiffRules

  # here we create another TrackedReal object whose data is y and we also record the pullback and the previous node in the graph (tracker.(xs))
  track(Call(back, tracker.(xs)), y) # the tracking of primal y happens here
end

@MariusDrulea
Copy link
Author

MariusDrulea commented Sep 25, 2023

I have checked Autograd.jl engine and also HIPS/autograd engine. These engines calls the differentiation (rrules) in the backward pass. So does Tracker.jl.

The right thing to do here is to thunk the primal computation for all rrules in ChainRules, or any other implementation avoiding the primal computation. Yes, the primal is needed most of the time with rrules, but not always!

I think we can currently achieve the following:

  1. If only first order derivatives are needed, we can simply use rrules during the forward pass and register the pullback for the backward pass. This way the primal will be computed only once.
  2. If higher order derivatives are needed, we will have to do a primal on the data, while calling rrules on TrackedReal etc on the backward pass. This mean the primal function will be called twice.

Sample code for item 2:

using ChainRules
using ChainRules: rrule
using ChainRulesCore
import Base: +, *

struct _Tracked <: Real
    data::Float64
    f::Any
    _Tracked(data, f) = new(data, f)
    _Tracked(data) = new(data, nothing)
end

function track(f, a, b)
    data, delayed = _forward(f, a, b)
    return _Tracked(data, delayed)
end

function _forward(f, a, b)
    data = f(a.data, b.data) # primal no tracking
    delayed = ()->rrule(f, a, b) # delayed pullback with tracking
    return data, delayed
end

a::_Tracked + b::_Tracked = track(+, a, b)
a::_Tracked * b::_Tracked = track(*, a, b)

##
a = _Tracked(10)
b = _Tracked(20)
tr = a*b
back2 = tr.f()[2]
da, db = back2(_Tracked(1.0))[2:end] 
# result: da =_Tracked(20.0, ...), db =_Tracked(10.0, ...)

@MariusDrulea
Copy link
Author

Status:

Test Summary:                | Pass  Fail  Error  Broken  Total   Time
Tracker                      |  357     1     15       3    376  22.0s
  gradtests 1                |   16                          16   0.6s
  gradtests 1.1              |   10                          10   0.4s
  indexing & slicing         |    1                           1   0.1s
  concat                     |  181                         181   2.1s
  getindex (Nabla.jl - #139) |    1                           1   0.0s
  gradtests 2                |   22            2             24   0.9s
  mean                       |    5                           5   0.1s
  maximum                    |    5                           5   0.1s
  minimum                    |    5                           5   0.1s
  gradtests 3                |    9            4             13   0.5s
  transpose                  |   48                          48   0.0s
  conv, 1d                   |    2                    1      3   2.3s
  conv, 2d                   |    2                    1      3   1.1s
  conv, 3d                   |    2                    1      3   8.0s
  pooling                    |    4                           4   2.0s
  equality & order           |   16                          16   0.0s
  reshape                    |    8                           8   0.0s
  Intermediates              |    2                           2   0.0s
  Fallbacks                  |    1                           1   0.0s
  collect                    |                 1              1   0.1s
  Hooks                      |                 1              1   0.1s
  Checkpointing              |    2     1      1              4   0.2s
  Updates                    |    2                           2   0.0s
  Params                     |    1                           1   0.0s
  Forward                    |                 2              2   0.2s
  Custom Sensitivities       |                 1              1   0.1s
  PDMats                     |                 1              1   0.1s
  broadcast                  |    2            1              3   0.0s
  logabsgamma                |    2                           2   0.1s
  Jacobian                   |    1                           1   0.8s
  withgradient               |    3                           3   1.3s
  NNlib.within_gradient      |    2                           2   0.0s
ERROR: Some tests did not pass: 357 passed, 1 failed, 15 errored, 3 broken

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants