-
Notifications
You must be signed in to change notification settings - Fork 0
/
autograd.go
102 lines (88 loc) · 2.29 KB
/
autograd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package gosor
import "fmt"
type GradFunc func() ([]*Tensor, error)
type GradientTracker struct {
children []*GradientTracker
gradient *Tensor
gradFunc GradFunc
isNotLeaf bool
}
func (g *GradientTracker) Gradient() (*Tensor, error) {
if g.gradient == nil {
return nil, fmt.Errorf("gradient not calculated")
}
return g.gradient, nil
}
func (g *GradientTracker) ResetGradient() {
g.gradient = nil
}
func (g *GradientTracker) Backward(previousGradient *Tensor) (err error) {
if g == nil {
return fmt.Errorf("backwards on tensor without gradient tracker")
}
if previousGradient == nil {
previousGradient = Wrap(New(WithValues(1))).MustValue()
}
if g.gradient == nil {
g.gradient, err = New(WithSize(previousGradient.sizes...))
if err != nil {
return err
}
}
_, err = AddInto(g.gradient, g.gradient, previousGradient)
if err != nil {
return err
}
if g.gradFunc == nil {
return
}
localGradients, err := g.gradFunc()
if err != nil {
return fmt.Errorf("calculating local gradient: %w", err)
}
if len(localGradients) != len(g.children) {
return fmt.Errorf("wrong amount tensors returned from grad func, should be one for each child.")
}
globalGradient := make([]*Tensor, 0, len(localGradients))
for i := 0; i < len(localGradients); i++ {
if localGradients[i] == nil {
continue
}
curGlobalGrad, err := Mul(localGradients[i], previousGradient)
if err != nil {
return fmt.Errorf("multiplication of local and global gradient: %w", err)
}
globalGradient = append(globalGradient, curGlobalGrad)
}
for i := 0; i < len(globalGradient); i++ {
if globalGradient[i] != nil && g.children[i] != nil {
err := g.children[i].Backward(globalGradient[i])
if err != nil {
return err
}
}
}
return
}
func addGradientTracker(res *Tensor, children []*Tensor, gradFunc GradFunc) {
if res == nil || res.GradientTracker == nil {
return
}
if res.isNotLeaf {
childrenTrackers := make([]*GradientTracker, len(children))
shouldHaveTracker := false
for i := 0; i < len(children); i++ {
if children[i].GradientTracker != nil {
childrenTrackers[i] = children[i].GradientTracker
shouldHaveTracker = true
}
}
if !shouldHaveTracker {
return
}
res.GradientTracker = &GradientTracker{
children: childrenTrackers,
gradFunc: gradFunc,
}
}
}