Skip to content

Commit 6d983d5

Browse files
committed
Collapse arrays of CR zeros
1 parent 4fb6daa commit 6d983d5

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/compiler/chainrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain
152152
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
153153
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
154154
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
155+
@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs)
155156
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
156157
xp = map(wrap_chainrules_input, dxs)
157158
# This produces Tangent{Any} since it does not get to see the primal, `x`.

0 commit comments

Comments
 (0)