-
-
Notifications
You must be signed in to change notification settings - Fork 611
Replace @adjoint
with rrule
#1863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
istraining() = false | ||
|
||
@adjoint istraining() = true, _ -> nothing | ||
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised there isn't an equivalent for this in ChainRules already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhere I was writing a function like |
||
|
||
_isactive(m) = isnothing(m.active) ? istraining() : m.active | ||
|
||
|
@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true) | |
end | ||
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) | ||
|
||
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true) | ||
active || return x, Δ -> (Δ, nothing) | ||
y = dropout_mask(rng, x, p, dims=dims) | ||
return x .* y, Δ -> (nothing, Δ .* y, nothing) | ||
end | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
dropout_mask(rng, x::CuArray, p; kwargs...) = | ||
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) | ||
|
@@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:) | |
end | ||
|
||
# TODO move this to NNlib | ||
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p) | ||
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) | ||
|
||
""" | ||
Dropout(p; dims=:, rng = rng_from_array()) | ||
|
@@ -234,7 +228,8 @@ function _track_stats!( | |
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new | ||
return nothing | ||
end | ||
Zygote.@nograd _track_stats! | ||
|
||
ChainRulesCore.@non_differentiable _track_stats!(::Any...) | ||
|
||
""" | ||
BatchNorm(channels::Integer, λ=identity; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,9 @@ end | |
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) | ||
end | ||
|
||
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting | ||
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't literally translate I hope that Diffractor's broadcasting will work via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these needed if https://github.com/JuliaStats/LogExpFunctions.jl/blob/c8a4c28ffe7b6e4f8d5253e01cef091bb8d2f42c/src/chainrules.jl#L1-L2 is are already loaded through a transitive dep? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flux could switch to those. It has branches not ifelse, and different NaN behaviour, not sure if that matters: And 5 dependencies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But for now perhaps it's evidence that the scalar rules are ok? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you looking to do some testing soon with this and Diffractor/not Zygote? Otherwise I think it would be cleaner to have a separate PR that removes all of the code above in favour of https://github.com/FluxML/Zygote.jl/blob/master/src/lib/logexpfunctions.jl and the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can remove these rules for now if you prefer. The functions ought to be differentiable without special rules, mostly. The PR just wants to translate as many things as possible over for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I said:
This is wrong, because https://github.com/FluxML/Flux.jl/blob/master/src/losses/utils.jl#L27 While I guess these broadcasts aren't so performance-sensitive (since there will only be one, for the whole model) it would be nice if all loss functions were all second-differentiable. Whether that already works, or needs to be done by fiddling with broadcasting, or rules for the loss functions themselves, I don't know. |
||
|
||
# This can be made an error in Flux v0.13, for now just a warning | ||
function _check_sizes(ŷ::AbstractArray, y::AbstractArray) | ||
for d in 1:max(ndims(ŷ), ndims(y)) | ||
|
@@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray) | |
end | ||
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 | ||
|
||
Zygote.@nograd _check_sizes | ||
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) |
Uh oh!
There was an error while loading. Please reload this page.