Skip to content

Commit 923eca0

Browse files
committed
fixup
1 parent 5e8009c commit 923eca0

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

src/deprecations.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,11 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,
3636

3737

3838
# v0.13 deprecations
39-
@deprecate Maxout(layers::Tuple) Maxout(layers...)
39+
@deprecate Maxout(layers::Tuple) Maxout(layers...)
40+
41+
function Broadcast.broadcasted(f::Recur, args...)
42+
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
43+
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
44+
Re-writing this as a comprehension would be better.""", :broadcasted)
45+
map(f, args...) # map isn't really safe either, but
46+
end

src/layers/recurrent.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
434434
"""
435435
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
436436
Recur(m::GRUv3Cell) = Recur(m, m.state0)
437-
438-
# TODO move to ChainRulesCore?
439-
@adjoint function Broadcast.broadcasted(f::Recur, args...)
440-
Zygote.∇map(__context__, f, args...)
441-
end

src/losses/ctc.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ for mathematical details.
134134
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
135135

136136
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
137-
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent())
138-
return ctc_loss(ŷ, y), ctc_loss_pullback
137+
tmp = ctc_alpha(ŷ, y)
138+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
139+
return tmp.loss, ctc_loss_pullback
139140
end
140141

141142

src/losses/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
2424
end
2525

26-
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # is this good enough?
26+
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
2727
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)
2828

2929
# This can be made an error in Flux v0.13, for now just a warning

src/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,8 +788,10 @@ L2 (generic function with 1 method)
788788
"""
789789
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
790790

791-
@nograd modules
792-
ChainRulesCore.@non_differentiable modules(::Any) # is this correct?
791+
@nograd modules # TODO: is this correct? might fail with explicit parameters.
792+
function ChainRulesCore.rrule(::typeof(modules), m)
793+
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
794+
end
793795

794796
isleaflike(x) = Functors.isleaf(x)
795797
isleaflike(::Tuple{Vararg{<:Number}}) = true

0 commit comments

Comments
 (0)