Skip to content

Commit 5e8009c

Browse files
committed
replace at-adjoint with rrule
1 parent cce7ad0 commit 5e8009c

File tree

13 files changed

+43
-33
lines changed

13 files changed

+43
-33
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1112
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1213
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -33,6 +34,7 @@ AbstractTrees = "0.3"
3334
Adapt = "3.0"
3435
ArrayInterface = "3.1, 4"
3536
CUDA = "3"
37+
ChainRulesCore = "1.12"
3638
CodecZlib = "0.7"
3739
Colors = "0.12"
3840
Functors = "0.2.1"
@@ -43,7 +45,7 @@ ProgressLogging = "0.1"
4345
Reexport = "0.2, 1.0"
4446
StatsBase = "0.33"
4547
ZipFile = "0.9"
46-
Zygote = "0.6"
48+
Zygote = "0.6.34"
4749
julia = "1.6"
4850

4951
[extras]

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
12+
using ChainRulesCore
1213

1314
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
1415
RNN, LSTM, GRU, GRUv3,

src/cuda/cuda.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module CUDAint
33
using ..CUDA
44

55
import ..Flux: Flux
6-
import Zygote
7-
using Zygote: @adjoint
6+
# import Zygote
7+
# using Zygote: @adjoint
8+
using ChainRulesCore
89
import NNlib, NNlibCUDA
910

1011
include("cudnn.jl")

src/cuda/cudnn.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
1111
training=Flux._isactive(BN)))
1212
end
1313

14-
@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
14+
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
1515
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
1616
function batchnorm_pullback(Δ)
17-
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
17+
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
18+
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
1819
end
1920
y, batchnorm_pullback
2021
end

src/functor.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
120120
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
121121
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
122122

123-
Zygote.@adjoint function Array(x::CUDA.CuArray)
124-
Array(x), d -> (CUDA.cu(d),)
123+
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
124+
Array(x), d -> (NoTangent(), CUDA.cu(d),)
125125
end
126126

127-
Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
128-
adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),)
127+
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
128+
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
129129
end
130130

131131
# CPU/GPU movement conveniences
@@ -202,7 +202,7 @@ function check_use_cuda()
202202
end
203203
end
204204
end
205-
Zygote.@nograd check_use_cuda
205+
ChainRulesCore.@non_differentiable check_use_cuda()
206206

207207
# Precision
208208

src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
274274
)
275275
end
276276

277-
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
278-
@nograd conv_transpose_dims
277+
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
279278

280279
function (c::ConvTranspose)(x::AbstractArray)
281280
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)

src/layers/normalise.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
istraining() = false
22

3-
@adjoint istraining() = true, _ -> nothing
3+
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

55
_isactive(m) = isnothing(m.active) ? istraining() : m.active
66

@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3838
end
3939
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4040

41-
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
42-
active || return x, Δ -> (Δ, nothing)
43-
y = dropout_mask(rng, x, p, dims=dims)
44-
return x .* y, Δ -> (nothing, Δ .* y, nothing)
45-
end
46-
4741
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
4842
dropout_mask(rng, x::CuArray, p; kwargs...) =
4943
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
@@ -54,6 +48,8 @@ function _dropout_mask(rng, x, p; dims=:)
5448
return y
5549
end
5650

51+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
52+
5753
"""
5854
Dropout(p; dims=:, rng = rng_from_array())
5955
@@ -232,7 +228,8 @@ function _track_stats!(
232228
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
233229
return nothing
234230
end
235-
Zygote.@nograd _track_stats!
231+
232+
ChainRulesCore.@non_differentiable _track_stats!(::Any...)
236233

237234
"""
238235
BatchNorm(channels::Integer, λ=identity;

src/layers/recurrent.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
66
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
77
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
88

9-
@adjoint function multigate(x::AbstractArray, h, c)
9+
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
1010
function multigate_pullback(dy)
11-
dx = Zygote._zero(x, eltype(x))
12-
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13-
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
11+
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
12+
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13+
dyᵢ isa AbstractZero && return
14+
@. dxᵢ += dyᵢ
1415
end
15-
return (dx, nothing, nothing)
16+
return (NoTangent(), dx, NoTangent(), NoTangent())
1617
end
1718
return multigate(x, h, c), multigate_pullback
1819
end
@@ -434,7 +435,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
434435
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
435436
Recur(m::GRUv3Cell) = Recur(m, m.state0)
436437

437-
438+
# TODO move to ChainRulesCore?
438439
@adjoint function Broadcast.broadcasted(f::Recur, args...)
439440
Zygote.∇map(__context__, f, args...)
440441
end

src/losses/Losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module Losses
33
using Statistics
44
using Zygote
55
using Zygote: @adjoint
6+
using ChainRulesCore
67
using ..Flux: ofeltype, epseltype
78
using CUDA
89
using NNlib: logsoftmax, logσ

src/losses/ctc.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ for mathematical details.
133133
"""
134134
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
135135

136-
@adjoint function ctc_loss(ŷ, y)
137-
out = ctc_alpha(ŷ, y)
138-
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
139-
return out.loss, ctc_loss_pullback
136+
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
137+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent())
138+
return ctc_loss(ŷ, y), ctc_loss_pullback
140139
end
141140

142141

src/losses/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ end
2323
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
2424
end
2525

26+
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # is this good enough?
27+
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)
28+
2629
# This can be made an error in Flux v0.13, for now just a warning
2730
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
2831
for d in 1:max(ndims(ŷ), ndims(y))
@@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
3336
end
3437
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3538

36-
Zygote.@nograd _check_sizes
39+
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)

src/onehot.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ function _fast_argmax(x::OneHotLike)
230230
end
231231
end
232232

233-
@nograd OneHotArray, onecold, onehot, onehotbatch
233+
ChainRulesCore.@non_differentiable onehot(::Any, ::Any)
234+
ChainRulesCore.@non_differentiable onehot(::Any, ::Any, ::Any)
235+
ChainRulesCore.@non_differentiable onehotbatch(::Any, ::Any)
236+
ChainRulesCore.@non_differentiable onehotbatch(::Any, ::Any, ::Any)
237+
ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)
234238

235239
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
236240
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ function _restructure(m, xs)
662662
return
663663
end
664664

665-
@adjoint function _restructure(m, xs)
665+
@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
666666
m̄, numel = _restructure(m, xs), length(xs)
667667
function _restructure_pullback(dm)
668668
xs′ = destructure(dm)[1]
@@ -789,6 +789,7 @@ L2 (generic function with 1 method)
789789
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
790790

791791
@nograd modules
792+
ChainRulesCore.@non_differentiable modules(::Any) # is this correct?
792793

793794
isleaflike(x) = Functors.isleaf(x)
794795
isleaflike(::Tuple{Vararg{<:Number}}) = true

0 commit comments

Comments
 (0)