You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/tutorials/gradient_zoo.md
+61-23Lines changed: 61 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -5,6 +5,9 @@ also known as reverse-mode automatic differentiation.
5
5
Given a model, some data, and a loss function, this answers the question
6
6
"what direction, in the space of the model's parameters, reduces the loss fastest?"
7
7
8
+
This page is a brief overview of ways to perform automatic differentiation in Julia,
9
+
and how they relate to Flux.
10
+
8
11
### `gradient(f, x)` interface
9
12
10
13
Julia's ecosystem has many versions of `gradient(f, x)`, which evaluates `y = f(x)` then retuns `∂y_∂x`. The details of how they do this vary, but the interfece is similar. An incomplete list is (alphabetically):
@@ -142,22 +149,40 @@ In this case they are all identical, but there are some caveats, explored below.
142
149
143
150
Both Zygote and Tracker were written for Flux, and at present, Flux loads Zygote and exports `Zygote.gradient`, and calls this within `Flux.train!`. But apart from that, there is very little coupling between Flux and the automatic differentiation package.
144
151
145
-
This page has very brief notes on how all these packages compare, as a guide for anyone wanting to experiment with them. We stress "experiment" since Zygote is (at present) by far the best-tested. All notes are from February 2024,
152
+
This page has very brief notes on how all these packages compare, as a guide for anyone wanting to experiment with them. We stress "experiment" since Zygote is (at present) by far the best-tested. All notes are from February 2024,
Reverse-mode source-to-source automatic differentiation, written by hooking into Julis's compiler.
156
+
Reverse-mode source-to-source automatic differentiation, written by hooking into Julia's compiler.
150
157
151
158
* By far the best-tested option for Flux models.
152
159
153
160
* Long compilation times, on the first call.
154
161
155
162
* Allows mutation of structs, but not of arrays. This leads to the most common error... sometimes this happens because you mutate an array, often because you call some function which, internally, creates the array it wants to return & then fills it in.
156
163
157
-
* Custom rules via `ZygoteRules.@adjpoint` or better, `ChainRulesCore.rrule`.
164
+
```julia
165
+
functionmysum2(x::AbstractMatrix) # implements y = vec(sum(x; dims=2))
166
+
y =similar(x, size(x,1))
167
+
for col ineachcol(x)
168
+
y .+= col # mutates y, Zygote will not allow this
169
+
end
170
+
return y
171
+
end
172
+
173
+
Zygote.jacobian(x ->sum(x; dims=2).^2, Float32[123; 456])[1] # returns a 2×6 Matrix
174
+
Zygote.jacobian(x ->mysum2(x).^2, Float32[123; 456])[1] # ERROR: Mutating arrays is not supported
175
+
```
158
176
159
-
*Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero. Does not track shared arrays, hence may return different contributions
177
+
*Custom rules via `ZygoteRules.@adjpoint` or (equivalently) `ChainRulesCore.rrule`.
160
178
179
+
* Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero.
180
+
181
+
* Does not track shared arrays, hence may return different contributions.
182
+
183
+
```julia
184
+
185
+
```
161
186
162
187
!!! compat "Deprecated: Zygote's implicit mode"
163
188
Flux's default used to be work like this, instead of using deeply nested trees for gradients as above:
@@ -194,7 +219,7 @@ julia> model_tracked = Flux.fmap(x -> x isa Array ? Tracker.param(x) : x, model)
194
219
Chain(
195
220
Embedding(3=>2), # 6 parameters
196
221
NNlib.softmax,
197
-
)
222
+
)
198
223
199
224
julia> val_tracked =loss(model_tracked)
200
225
0.6067761f0 (tracked)
@@ -230,7 +255,23 @@ New package which works on the LLVM code which Julia compiles down to.
230
255
231
256
* Returns another struct of the same type as the model, such as `Chain` above. Non-differentiable objects are left alone, not replaced by a zero.
* Needs a simple array of parameters, i.e. supports only `gradient(f, x::AbstractArray{<:Real})`.
268
309
269
-
* Forward mode is generally not what you want!
310
+
* Forward mode is generally not what you want for nerual networks! It's ideal for ``ℝ → ℝᴺ`` functions, but the wrong algorithm for ``ℝᴺ → ℝ``.
270
311
271
312
*`gradient(f, x)` will call `f(x)` multiple times. Layers like `BatchNorm` with state may get confused.
272
313
@@ -316,7 +357,4 @@ This year's new attempt to build a simpler one?
316
357
317
358
Really `rrule_via_ad` is another mechanism, but only for 3 systems.
318
359
319
-
Sold as an attempt at unification, but its design of extensible `rrule`s turned out to be too closely tied to Zygote/Diffractor style AD, and not a good fit for Enzyme/Tapir which therefore use their own rule systems. Also not a natural fit for Tracker/ReverseDiff/ForwardDiff style of operator overloading AD.
320
-
321
-
322
-
360
+
Sold as an attempt at unification, but its design of extensible `rrule`s turned out to be too closely tied to Zygote/Diffractor style AD, and not a good fit for Enzyme/Mooncake which therefore use their own rule systems. Also not a natural fit for Tracker/ReverseDiff/ForwardDiff style of operator overloading AD.
0 commit comments