@@ -115,12 +115,34 @@ as a named tuple.
115115julia> y, ∇ = withgradient(/, 1, 2)
116116(val = 0.5, grad = (0.5, -0.25))
117117
118- julia> ∇ == gradient(/, 1, 2) # explicit mode
118+ julia> ∇ == gradient(/, 1, 2)
119119true
120+ ```
121+
122+ Allows you to capture auxillary outputs, in addition to the scalar
123+ used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
124+ Then it calculates `grad = gradient(first∘f, args...)
125+ but returns the whole `val = f(args...)`:
126+
127+ ```jldoctest; setup=:(using Zygote)
128+ julia> withgradient([1,2,4]) do x
129+ z = 1 ./ x
130+ sum(z), z # here z is an auxillary output
131+ end
132+ (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
133+
134+ julia> withgradient(3.0, 4.0) do x, y
135+ (div = x/y, mul = x*y)
136+ end
137+ (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
138+ ```
139+
140+ Also supports implicit mode:
120141
142+ ```jldoctest; setup=:(using Zygote)
121143julia> w = [3.0];
122144
123- julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
145+ julia> res = withgradient(() -> sum(abs2, w), Params([w]))
124146(val = 9.0, grad = Grads(...))
125147
126148julia> res.grad[w]
@@ -130,7 +152,15 @@ julia> res.grad[w]
130152"""
131153function withgradient (f, args... )
132154 y, back = pullback (f, args... )
133- grad = back (sensitivity (y))
155+ grad = if y isa Tuple
156+ dy = (sensitivity (first (y)), map (_ -> nothing , Base. tail (y))... )
157+ back (dy)
158+ elseif y isa NamedTuple
159+ dy = (sensitivity (first (y)), map (_ -> nothing , Base. tail (y))... )
160+ back (NamedTuple {propertynames(y), typeof(dy)} (dy))
161+ else
162+ back (sensitivity (y))
163+ end
134164 results = isnothing (grad) ? map (_ -> nothing , args) : map (_project, args, grad)
135165 (val= y, grad= results)
136166end
0 commit comments