diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7c7de8655..1d17ea91d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -344,4 +344,8 @@ z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Nu end end +function z2d(dx::AbstractDict, primal::T) where T<:Dict + return Tangent{T}(Dict(kk => z2d(dvv, primal[kk]) for (kk, dvv) in dx)) +end + z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs