Skip to content

Commit 8d4457b

Browse files
committed
Add documentation
1 parent a7b11a0 commit 8d4457b

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
[![Build Status](https://github.com/oschulz/ForwardDiffPullbacks.jl/workflows/CI/badge.svg?branch=master)](https://github.com/oschulz/ForwardDiffPullbacks.jl/actions?query=workflow%3ACI)
77
[![Codecov](https://codecov.io/gh/oschulz/ForwardDiffPullbacks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/oschulz/ForwardDiffPullbacks.jl)
88

9+
ForwardDiffPullbacks implements pullbacks compatible with
10+
[ChainRulesCore](https://github.com/JuliaDiff/ChainRulesCore.jl) that are calculated via
11+
[ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl).
12+
13+
See the documentation for details.
914

1015
## Documentation
1116

docs/src/index.md

+8
Original file line numberDiff line numberDiff line change
@@ -1 +1,9 @@
11
# ForwardDiffPullbacks.jl
2+
3+
ForwardDiffPullbacks implements pullbacks compatible with [ChainRulesCore](https://github.com/JuliaDiff/ChainRulesCore.jl) that are calculated via [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl).
4+
5+
This package provides the function [`fwddiff`](@ref). If wrapped around a function (i.e. `fwddiff(f)`), it will cause ChainRules (and implicitly Zygote) pullbacks to be calculated using ForwardDiff (i.e. by evaluating the original function with `ForwardDiff.Dual` numbers, possibly multiple times). The pullback will return a ChainRule thunk for each argument of the function.
6+
7+
So `Zygote.gradient(fwddiff(f), xs...)` should yield the same result as `Zygote.gradient(f, xs...)`, but will typically be substantially faster a function that has a comparatively small number of arguments, especially if the function runs a deep calculation. Broadcasting (i.e. `g.(fwddiff(f))`) is supported as well.
8+
9+
Currently, ForwardDiffPullbacks supports functions with `Real`, `Tuple` and `StaticArrays.SVector` arguments. Support for `StaticArrays.SArray` and `Array`-valued arguments in general is on the to-do list.

src/with_forwarddiff.jl

+29-1
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,41 @@ end
1010
# Desireable for consistent behavior?
1111
# Base.broadcasted(wrapped_f::WithForwardDiff, xs...) = broadcast(wrapped_f.f, xs...)
1212

13+
1314
"""
1415
fwddiff(f::Base.Callable)::Function
1516
16-
Use `ForwardDiff` in `ChainRulesCore` pullback For
17+
Use `ForwardDiff` dual numbers to implement `ChainRulesCore` pullbacks For
1718
1819
* `fwddiff(f)(args...)
1920
* `fwddiff(f).(args...)
21+
22+
Example:
23+
24+
```
25+
using ForwardDiffPullbacks, StaticArrays
26+
27+
f = (xs...) -> (sum(map(x -> sum(map(x -> x^2, x)), xs)))
28+
xs = (2, (3, 4), SVector(5, 6, 7))
29+
f(xs...) == 139
30+
31+
using ChainRulesCore
32+
33+
y, back = rrule(fwddiff(f), xs...)
34+
y == 139
35+
map(unthunk, back(1)) == (Zero(), 4, (6, 8), [10, 12, 14])
36+
37+
using Zygote
38+
39+
Zygote.gradient(fwddiff(f), xs...) == Zygote.gradient(f, xs...)
40+
41+
Xs = map(x -> fill(x, 100), xs)
42+
Zygote.gradient((Xs...) -> sum(fwddiff(f).(Xs...)), Xs...) ==
43+
Zygote.gradient((Xs...) -> sum(f.(Xs...)), Xs...)
44+
```
45+
46+
The gradient is the same with and without `fwddiff`, but `fwddiff` makes the
47+
gradient calculation a lot faster here.
2048
"""
2149
fwddiff(f::Base.Callable) = WithForwardDiff(f)
2250
export fwddiff

0 commit comments

Comments
 (0)