Skip to content

Commit 96a905d

Browse files
committed
tweaks, but more is needed
1 parent bd54862 commit 96a905d

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

docs/src/tutorials/gradient_zoo.md

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ Reverse-mode source-to-source automatic differentiation, written by hooking into
157157

158158
* By far the best-tested option for Flux models.
159159

160-
* Long compilation times, on the first call.
160+
* Medium compilation times, on the first call.
161161

162162
* Allows mutation of structs, but not of arrays. This leads to the most common error... sometimes this happens because you mutate an array, often because you call some function which, internally, creates the array it wants to return & then fills it in.
163163

@@ -175,13 +175,16 @@ Zygote.jacobian(x -> mysum2(x).^2, Float32[1 2 3; 4 5 6])[1] # ERROR: Mutating
175175
```
176176

177177
* Custom rules via `ZygoteRules.@adjpoint` or (equivalently) `ChainRulesCore.rrule`.
178+
Among other things, this lets you wrap functions which internally mutate an array, so that Zygote need not look inside.
178179

179180
* Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero.
180181

181182
* Does not track shared arrays, hence may return different contributions.
182183

183184
```julia
184-
185+
shared = [2.0]
186+
nt = (a = shared, b = shared, c = [2.0])
187+
Zygote.gradient(x -> sum(abs2, x.a + 2*x.b + 3*x.c), nt)[1] # (a = [24.0], b = [48.0], c = [72.0])
185188
```
186189

187190
!!! compat "Deprecated: Zygote's implicit mode"
@@ -255,8 +258,16 @@ New package which works on the LLVM code which Julia compiles down to.
255258

256259
* Returns another struct of the same type as the model, such as `Chain` above. Non-differentiable objects are left alone, not replaced by a zero.
257260

261+
* Shared arrays are shared in the gradient:
262+
263+
```julia
264+
shared = [2.0]
265+
nt = (a = shared, b = shared, c = [2.0])
266+
Enzyme.gradient(Reverse, x -> sum(abs2, x.a + 2*x.b + 3*x.c), nt)[1] # (a = [72.0], b = [72.0], c = [72.0])
267+
```
268+
258269
Enzyme likes to work in-place, with objects and their gradients stored togeter in a `Duplicated(x, dx)`.
259-
Flux has an interface which uses this:
270+
Flux now has an interface which uses this:
260271
```julia
261272
julia> Flux.train!((m,x) -> sum(abs2, m(1)), model, 1:1, opt_state) # train! with Zygote
262273

@@ -274,7 +285,15 @@ julia> Flux.withgradient(loss, Duplicated(model))
274285
### [Mooncake.jl](https://github.com/compintell/Mooncake.jl)
275286

276287
Another new AD to watch. Many similariries in its approach to Enzyme.jl, but operates all in Julia.
288+
[Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) has an interface to try this out:
277289

290+
```julia
291+
julia> grads_m2 = Flux.gradient(loss, Moonduo(model))
292+
((layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),),)
293+
294+
julia> Flux.withgradient(loss, Moonduo(model))
295+
(val = 0.5665111155481435, grad = ((layers = ((weight = [-0.15810298866515066 0.0 0.0; 0.1581029886651505 0.0 0.0],), nothing),),))
296+
```
278297

279298
### [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)
280299

@@ -303,7 +322,8 @@ Another Julia source-to-source reverse-mode AD.
303322

304323
### [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
305324

306-
Forward mode is a different algorithm...
325+
Forward mode AD is a different algorithm, which is easier to implement. This is a reliable old package,
326+
but is of limited interest for use with Flux:
307327

308328
* Needs a simple array of parameters, i.e. supports only `gradient(f, x::AbstractArray{<:Real})`.
309329

@@ -316,9 +336,9 @@ Forward mode is a different algorithm...
316336

317337
* Like Tracker this passes a special TrackedArray type through your function. Allows you to record & compile the tape, and pre-allocate things.
318338

319-
* Needs a flat vector
339+
* Like ForwardDiff it needs a flat vector, only `gradient(f, x::AbstractArray{<:Real})`.
320340

321-
* No support for GPU
341+
* No support for GPU operations.
322342

323343

324344
<hr/>
@@ -343,15 +363,15 @@ I haven't tried really, but I think it ought to work.
343363

344364
## Meta-packages
345365

346-
Besides AD packages, several packages have been written aiming to provide a unified interface to many options. These may offer useful ways to quickly switch between things you are trying.
366+
Besides AD packages, several packages have been written aiming to provide a unified interface to many options. These may offer useful ways to quickly switch between things you are trying. However, Flux does not directly interface with any of them.
347367

348368
### [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl)
349369

350370
The original meta-package for calling any of several engines.
351371

352372
### [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)
353373

354-
This year's new attempt to build a simpler one?
374+
This year's new attempt to build a simpler such meta-package. However, from Flux's point of view
355375

356376
### [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)
357377

0 commit comments

Comments
 (0)