We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4fb6daa commit 6d983d5Copy full SHA for 6d983d5
src/compiler/chainrules.jl
@@ -152,6 +152,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain
152
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
153
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
154
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
155
+@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs)
156
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
157
xp = map(wrap_chainrules_input, dxs)
158
# This produces Tangent{Any} since it does not get to see the primal, `x`.
0 commit comments