Skip to content

Commit 29fa32a

Browse files
authored
Merge pull request #1419 from mcabbott/withgrad3
Allow `f` to return a Tuple in `withgradient(f, args...)`
2 parents 2f49370 + e0d3d8b commit 29fa32a

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/compiler/interface.jl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,34 @@ as a named tuple.
115115
julia> y, ∇ = withgradient(/, 1, 2)
116116
(val = 0.5, grad = (0.5, -0.25))
117117
118-
julia> ∇ == gradient(/, 1, 2) # explicit mode
118+
julia> ∇ == gradient(/, 1, 2)
119119
true
120+
```
121+
122+
Allows you to capture auxillary outputs, in addition to the scalar
123+
used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
124+
Then it calculates `grad = gradient(first∘f, args...)
125+
but returns the whole `val = f(args...)`:
126+
127+
```jldoctest; setup=:(using Zygote)
128+
julia> withgradient([1,2,4]) do x
129+
z = 1 ./ x
130+
sum(z), z # here z is an auxillary output
131+
end
132+
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
133+
134+
julia> withgradient(3.0, 4.0) do x, y
135+
(div = x/y, mul = x*y)
136+
end
137+
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
138+
```
139+
140+
Also supports implicit mode:
120141
142+
```jldoctest; setup=:(using Zygote)
121143
julia> w = [3.0];
122144
123-
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
145+
julia> res = withgradient(() -> sum(abs2, w), Params([w]))
124146
(val = 9.0, grad = Grads(...))
125147
126148
julia> res.grad[w]
@@ -130,7 +152,15 @@ julia> res.grad[w]
130152
"""
131153
function withgradient(f, args...)
132154
y, back = pullback(f, args...)
133-
grad = back(sensitivity(y))
155+
grad = if y isa Tuple
156+
dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...)
157+
back(dy)
158+
elseif y isa NamedTuple
159+
dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...)
160+
back(NamedTuple{propertynames(y), typeof(dy)}(dy))
161+
else
162+
back(sensitivity(y))
163+
end
134164
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
135165
(val=y, grad=results)
136166
end

test/features.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,21 @@ end
866866
end
867867
@test gradient(f760, 3)[1] 123.93054835019153
868868
end
869+
870+
@testset "withgradient" begin
871+
@test withgradient([1,2,4]) do x
872+
z = 1 ./ x
873+
sum(z), z
874+
end == (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
875+
876+
@test withgradient(3.0, 4.0) do x, y
877+
(div = x/y, mul = x*y)
878+
end == (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
879+
880+
f3(x) = sum(sin, x), sum(cos, x), sum(tan, x)
881+
g1 = gradient(firstf3, [1,2,3.0])
882+
y2, g2 = withgradient(firstf3, [1,2,3.0])
883+
y3, g3 = withgradient(f3, [1,2,3.0])
884+
@test g1[1] g2[1] g3[1]
885+
end
886+

0 commit comments

Comments
 (0)