Skip to content

Commit 0599968

Browse files
committed
fixup
1 parent df4019d commit 0599968

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

src/deprecations.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,
3636

3737

3838
# v0.13 deprecations
39+
function Broadcast.broadcasted(f::Recur, args...)
40+
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
41+
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
42+
Re-writing this as a comprehension would be better.""", :broadcasted)
43+
map(f, args...) # map isn't really safe either, but
44+
end

src/layers/recurrent.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
435435
"""
436436
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
437437
Recur(m::GRUv3Cell) = Recur(m, m.state0)
438-
439-
# TODO move to ChainRulesCore?
440-
@adjoint function Broadcast.broadcasted(f::Recur, args...)
441-
Zygote.∇map(__context__, f, args...)
442-
end

src/losses/ctc.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ for mathematical details.
134134
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
135135

136136
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
137-
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent())
138-
return ctc_loss(ŷ, y), ctc_loss_pullback
137+
tmp = ctc_alpha(ŷ, y)
138+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
139+
return tmp.loss, ctc_loss_pullback
139140
end
140141

141142

src/losses/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
2424
end
2525

26-
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # is this good enough?
26+
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
2727
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)
2828

2929
# This can be made an error in Flux v0.13, for now just a warning

src/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,10 @@ true
793793
"""
794794
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
795795

796-
@nograd modules
797-
ChainRulesCore.@non_differentiable modules(::Any) # is this correct?
796+
@nograd modules # TODO: is this correct? might fail with explicit parameters.
797+
function ChainRulesCore.rrule(::typeof(modules), m)
798+
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
799+
end
798800

799801
isleaflike(x) = Functors.isleaf(x)
800802
isleaflike(::Tuple{Vararg{<:Number}}) = true

0 commit comments

Comments
 (0)