Skip to content

Commit 3cc9067

Browse files
authored
Merge pull request #1870 from FluxML/bc/nondiff_dropout_mask
Mark dropout_mask as non-differentiable
2 parents 0b7e1b6 + 5cf7d2f commit 3cc9067

File tree

5 files changed

+16
-8
lines changed

5 files changed

+16
-8
lines changed

NEWS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
## v0.13
44
* After a deprecations cycle, the datasets in `Flux.Data` have
55
been removed in favour of MLDatasets.jl.
6-
* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl
6+
* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl
77
* `flatten` is not exported anymore due to clash with Iterators.flatten.
88
* Remove Juno.jl progress bar support as it is now obsolete.
9-
* Improved compatibility of Dropout with Int and Complex types.
9+
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
1010

1111
## v0.12.10
1212
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ArrayInterface = "3.1, 4"
2626
CUDA = "3"
2727
Functors = "0.2.1"
2828
MacroTools = "0.5"
29-
NNlib = "0.8"
29+
NNlib = "0.8.2"
3030
NNlibCUDA = "0.2"
3131
ProgressLogging = "0.1"
3232
Reexport = "0.2, 1.0"

src/layers/normalise.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ function _dropout_mask(rng, x, p; dims=:)
5555
return y
5656
end
5757

58+
# TODO move this to NNlib
59+
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p)
60+
5861
"""
5962
Dropout(p; dims=:, rng = rng_from_array())
6063

test/layers/basic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ import Flux: activations
241241
Parallel(f_cnt, sin)(1)
242242
@test CNT[] == 3
243243
end
244-
244+
245245
# Ref https://github.com/FluxML/Flux.jl/issues/1673
246246
@testset "Input domain" begin
247247
struct Input
@@ -278,7 +278,7 @@ import Flux: activations
278278
vocab_size, embed_size = 10, 4
279279
m = Flux.Embedding(vocab_size, embed_size)
280280
@test size(m.weight) == (embed_size, vocab_size)
281-
281+
282282
x = rand(1:vocab_size, 3)
283283
y = m(x)
284284
@test y isa Matrix{Float32}
@@ -315,7 +315,7 @@ end
315315
# https://github.com/FluxML/NNlib.jl/issues/362
316316
m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2))
317317
x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3)
318-
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
318+
@test Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
319319
end
320320

321321
@testset "gradients of Chain{Vector}" begin

test/layers/normalisation.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ end
273273
x = reshape(collect(1:prod(sizes)), sizes)
274274

275275
@test Flux.hasaffine(m) == true
276-
@test length(Flux.params(m)) == 2
276+
@test length(Flux.params(m)) == 2
277277
x = Float64.(x)
278278
y = m(x)
279279
μ = mean(x, dims=1)
@@ -287,7 +287,7 @@ end
287287
x = reshape(collect(1:prod(sizes)), sizes)
288288
@test Flux.hasaffine(m) == false
289289
@test length(Flux.params(m)) == 0
290-
290+
291291
x = Float64.(x)
292292
y = m(x)
293293
μ = mean(x, dims=1)
@@ -458,3 +458,8 @@ end
458458
@test BN(x) GN(x)
459459
end
460460
end
461+
462+
@testset "second derivatives" begin
463+
m1 = Dropout(0.5)
464+
@test Zygote.hessian_reverse(summ1, [1.0,2.0,3.0]) == zeros(3, 3)
465+
end

0 commit comments

Comments
 (0)