Skip to content

Commit e0f6c9d

Browse files
committed
mark dropout_mask as non-differentiable
1 parent 0b7e1b6 commit e0f6c9d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/layers/normalise.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ function _dropout_mask(rng, x, p; dims=:)
5555
return y
5656
end
5757

58+
# TODO move this to NNlib
59+
Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p)
60+
5861
"""
5962
Dropout(p; dims=:, rng = rng_from_array())
6063

test/layers/normalisation.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ end
273273
x = reshape(collect(1:prod(sizes)), sizes)
274274

275275
@test Flux.hasaffine(m) == true
276-
@test length(Flux.params(m)) == 2
276+
@test length(Flux.params(m)) == 2
277277
x = Float64.(x)
278278
y = m(x)
279279
μ = mean(x, dims=1)
@@ -287,7 +287,7 @@ end
287287
x = reshape(collect(1:prod(sizes)), sizes)
288288
@test Flux.hasaffine(m) == false
289289
@test length(Flux.params(m)) == 0
290-
290+
291291
x = Float64.(x)
292292
y = m(x)
293293
μ = mean(x, dims=1)
@@ -458,3 +458,8 @@ end
458458
@test BN(x) GN(x)
459459
end
460460
end
461+
462+
@testset "second derivatives" begin
463+
m1 = Dropout(0.5)
464+
@test Zygote.hessian_reverse(summ1, [1.0,2.0,3.0]) == ones(3, 3)
465+
end

0 commit comments

Comments
 (0)