Skip to content

Commit 7b7ab8a

Browse files
committed
Support new ChainRulesCore.NoTangent
1 parent f456aac commit 7b7ab8a

5 files changed

+13
-7
lines changed

src/chain_rules_aliases.jl

+6
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ else
1212
const Tangent{P,T} = ChainRulesCore.Composite{P,T}
1313
const ZeroTangent = ChainRulesCore.Zero
1414
end
15+
16+
@static if isdefined(ChainRulesCore, :NoTangent)
17+
const NoTangent = ChainRulesCore.NoTangent
18+
else
19+
const NoTangent = ZeroTangent
20+
end

src/rrules.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818

1919

2020
Base.@generated function _forwarddiff_pullback_thunks(f::Base.Callable, xs::NTuple{N,Any}, ΔΩ::Any) where N
21-
Expr(:tuple, ChainRulesCore.NO_FIELDS, (:(ForwardDiffPullbacks.FwdDiffPullbackThunk(f, xs, Val($i), ΔΩ)) for i in 1:N)...)
21+
Expr(:tuple, NoTangent(), (:(ForwardDiffPullbacks.FwdDiffPullbackThunk(f, xs, Val($i), ΔΩ)) for i in 1:N)...)
2222
end
2323

2424
g_ΔΩ = nothing
@@ -56,7 +56,7 @@ end
5656

5757

5858
Base.@generated function _forwarddiff_bc_pullback_thunks(f::Base.Callable, Xs::NTuple{N,Any}, ΔΩA::Any) where N
59-
Expr(:tuple, ChainRulesCore.NO_FIELDS, ChainRulesCore.NO_FIELDS, (:(ForwardDiffPullbacks.FwdDiffBCPullbackThunk(f, Xs, Val($i), ΔΩA)) for i in 1:N)...)
59+
Expr(:tuple, NoTangent(), NoTangent(), (:(ForwardDiffPullbacks.FwdDiffBCPullbackThunk(f, Xs, Val($i), ΔΩA)) for i in 1:N)...)
6060
end
6161

6262
function ChainRulesCore.rrule(::typeof(Base.broadcasted), wrapped_f::WithForwardDiff, Xs::Vararg{Any,N}) where N

src/with_forwarddiff.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using ChainRulesCore
3232
3333
y, back = rrule(fwddiff(f), xs...)
3434
y == 139
35-
map(unthunk, back(1)) == (ZeroTangent(), 4, (6, 8), [10, 12, 14])
35+
map(unthunk, back(1)) == (NoTangent(), 4, (6, 8), [10, 12, 14])
3636
3737
using Zygote
3838

test/test_rrules.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ include("testfuncs.jl")
99

1010
@testset "ChainRulesCore" begin
1111
@test @tinferred(ChainRulesCore.rrule(fwddiff(f), xs...)) isa Tuple{Tuple, Function}
12-
@test @tinferred((ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ)) isa Tuple{ZeroTangent, ForwardDiffPullbacks.FwdDiffPullbackThunk, ForwardDiffPullbacks.FwdDiffPullbackThunk,ForwardDiffPullbacks.FwdDiffPullbackThunk}
13-
@test @tinferred(map(unthunk, (ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ))) == (ZeroTangent(), 280, (600, 1040), SVector(1600, 2280, 3080))
12+
@test @tinferred((ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ)) isa Tuple{NoTangent, ForwardDiffPullbacks.FwdDiffPullbackThunk, ForwardDiffPullbacks.FwdDiffPullbackThunk,ForwardDiffPullbacks.FwdDiffPullbackThunk}
13+
@test @tinferred(map(unthunk, (ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ))) == (NoTangent(), 280, (600, 1040), SVector(1600, 2280, 3080))
1414
end

test/test_zygote.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include("testfuncs.jl")
1313
y, back = ChainRulesCore.rrule(f, xs...)
1414
back_thunks = back(ΔΩ)
1515
Δx = map(unthunk, back_thunks)
16-
@assert first(Δx) == ZeroTangent()
16+
@assert first(Δx) == NoTangent()
1717
y, Base.tail(Δx)
1818
end
1919

@@ -28,7 +28,7 @@ include("testfuncs.jl")
2828
y, back = ChainRulesCore.rrule(Base.broadcasted, f, Xs...)
2929
back_thunks = back(ΔΩA)
3030
Δx = map(unthunk, back_thunks)
31-
@assert first(Δx) == ZeroTangent()
31+
@assert first(Δx) == NoTangent()
3232
y, Base.tail(Δx)
3333
end
3434

0 commit comments

Comments
 (0)