Skip to content

Commit 525b645

Browse files
mcabbottToucheSir
andauthored
Replace @adjoint with rrule (#1863)
* replace at-adjoint with rrule * fixup * onecold was missing * rm comment Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent 57ef5c0 commit 525b645

14 files changed

+50
-39
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.13.0-DEV"
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627
Adapt = "3.0"
2728
ArrayInterface = "3.1, 4"
2829
CUDA = "3"
30+
ChainRulesCore = "1.12"
2931
Functors = "0.2.1"
3032
MLUtils = "0.1.4"
3133
MacroTools = "0.5"
@@ -35,7 +37,7 @@ ProgressLogging = "0.1"
3537
Reexport = "0.2, 1.0"
3638
SpecialFunctions = "1.8.2, 2.1.2"
3739
StatsBase = "0.33"
38-
Zygote = "0.6"
40+
Zygote = "0.6.34"
3941
julia = "1.6"
4042

4143
[extras]

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using MLUtils
1212

1313
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1414
export gradient
15+
using ChainRulesCore
1516

1617
export Chain, Dense, Maxout, SkipConnection, Parallel,
1718
RNN, LSTM, GRU, GRUv3,

src/cuda/cuda.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ module CUDAint
33
using ..CUDA
44

55
import ..Flux: Flux
6-
import Zygote
7-
using Zygote: @adjoint
6+
using ChainRulesCore
87
import NNlib, NNlibCUDA
98

109
include("cudnn.jl")

src/cuda/cudnn.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
1111
training=Flux._isactive(BN)))
1212
end
1313

14-
@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
14+
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
1515
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
1616
function batchnorm_pullback(Δ)
17-
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
17+
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
18+
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
1819
end
1920
y, batchnorm_pullback
2021
end

src/deprecations.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,
1717

1818
# v0.13 deprecations
1919

20+
function Broadcast.broadcasted(f::Recur, args...)
21+
# This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12
22+
Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order.
23+
Re-writing this as a comprehension would be better.""", :broadcasted)
24+
map(f, args...) # map isn't really safe either, but
25+
end
26+
2027
@deprecate frequencies(xs) group_counts(xs)
2128

2229
# Channel notation: Changed to match Conv, but very softly deprecated!

src/functor.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
122122
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
123123
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
124124

125-
Zygote.@adjoint function Array(x::CUDA.CuArray)
126-
Array(x), d -> (CUDA.cu(d),)
125+
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
126+
Array(x), d -> (NoTangent(), CUDA.cu(d),)
127127
end
128128

129-
Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
130-
adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),)
129+
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
130+
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
131131
end
132132

133133
# CPU/GPU movement conveniences
@@ -204,7 +204,7 @@ function check_use_cuda()
204204
end
205205
end
206206
end
207-
Zygote.@nograd check_use_cuda
207+
ChainRulesCore.@non_differentiable check_use_cuda()
208208

209209
# Precision
210210

src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
273273
)
274274
end
275275

276-
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
277-
@nograd conv_transpose_dims
276+
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
278277

279278
function (c::ConvTranspose)(x::AbstractArray)
280279
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)

src/layers/normalise.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
istraining() = false
22

3-
@adjoint istraining() = true, _ -> nothing
3+
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

55
_isactive(m) = isnothing(m.active) ? istraining() : m.active
66

@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3838
end
3939
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4040

41-
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
42-
active || return x, Δ -> (Δ, nothing)
43-
y = dropout_mask(rng, x, p, dims=dims)
44-
return x .* y, Δ -> (nothing, Δ .* y, nothing)
45-
end
46-
4741
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
4842
dropout_mask(rng, x::CuArray, p; kwargs...) =
4943
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
@@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:)
5650
end
5751

5852
# TODO move this to NNlib
59-
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p)
53+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
6054

6155
"""
6256
Dropout(p; dims=:, rng = rng_from_array())
@@ -234,7 +228,8 @@ function _track_stats!(
234228
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
235229
return nothing
236230
end
237-
Zygote.@nograd _track_stats!
231+
232+
ChainRulesCore.@non_differentiable _track_stats!(::Any...)
238233

239234
"""
240235
BatchNorm(channels::Integer, λ=identity;

src/layers/recurrent.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
66
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
77
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
88

9-
@adjoint function multigate(x::AbstractArray, h, c)
9+
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
1010
function multigate_pullback(dy)
11-
dx = Zygote._zero(x, eltype(x))
12-
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13-
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
11+
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
12+
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13+
dyᵢ isa AbstractZero && return
14+
@. dxᵢ += dyᵢ
1415
end
15-
return (dx, nothing, nothing)
16+
return (NoTangent(), dx, NoTangent(), NoTangent())
1617
end
1718
return multigate(x, h, c), multigate_pullback
1819
end
@@ -379,8 +380,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
379380
"""
380381
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
381382
Recur(m::GRUv3Cell) = Recur(m, m.state0)
382-
383-
384-
@adjoint function Broadcast.broadcasted(f::Recur, args...)
385-
Zygote.∇map(__context__, f, args...)
386-
end

src/losses/Losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module Losses
33
using Statistics
44
using Zygote
55
using Zygote: @adjoint
6+
using ChainRulesCore
67
using ..Flux: ofeltype, epseltype
78
using CUDA
89
using NNlib: logsoftmax, logσ

src/losses/ctc.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,10 @@ for mathematical details.
133133
"""
134134
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
135135

136-
@adjoint function ctc_loss(ŷ, y)
137-
out = ctc_alpha(ŷ, y)
138-
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
139-
return out.loss, ctc_loss_pullback
136+
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
137+
tmp = ctc_alpha(ŷ, y)
138+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
139+
return tmp.loss, ctc_loss_pullback
140140
end
141141

142142

src/losses/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ end
2323
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
2424
end
2525

26+
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
27+
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)
28+
2629
# This can be made an error in Flux v0.13, for now just a warning
2730
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
2831
for d in 1:max(ndims(ŷ), ndims(y))
@@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
3336
end
3437
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3538

36-
Zygote.@nograd _check_sizes
39+
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)

src/onehot.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ function _fast_argmax(x::OneHotLike)
230230
end
231231
end
232232

233-
@nograd OneHotArray, onecold, onehot, onehotbatch
233+
ChainRulesCore.@non_differentiable onehot(::Any...)
234+
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
235+
ChainRulesCore.@non_differentiable onecold(::Any...)
236+
237+
ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)
234238

235239
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
236240
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)

src/utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ function _restructure(m, xs)
472472
return
473473
end
474474

475-
@adjoint function _restructure(m, xs)
475+
@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
476476
m̄, numel = _restructure(m, xs), length(xs)
477477
function _restructure_pullback(dm)
478478
xs′ = destructure(dm)[1]
@@ -603,7 +603,10 @@ true
603603
"""
604604
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
605605

606-
@nograd modules
606+
@nograd modules # TODO: is this correct? might fail with explicit parameters.
607+
function ChainRulesCore.rrule(::typeof(modules), m)
608+
modules(m), dm -> error("Flux.modules is not at present differentiable, sorry")
609+
end
607610

608611
isleaflike(x) = Functors.isleaf(x)
609612
isleaflike(::Tuple{Vararg{<:Number}}) = true

0 commit comments

Comments
 (0)