Skip to content

Commit 6734622

Browse files
committed
ChainRulesCore v1.0 compatibility
1 parent cd57efc commit 6734622

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1111

1212
[compat]
13-
ChainRulesCore = "0.9.44, 0.10"
13+
ChainRulesCore = "0.9.44, 0.10, 1"
1414
ForwardDiff = "0.10"
1515
Requires = "0.5, 1"
1616
StaticArrays = "0.12, 1.0"

src/ForwardDiffPullbacks.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import ChainRulesCore
1515
import ForwardDiff
1616
import StaticArrays
1717

18-
using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent
18+
using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent, AbstractThunk, unthunk
1919

2020
# using Requires
2121

src/rrules.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# This file is a part of ForwardDiffPullbacks.jl, licensed under the MIT License (MIT).
22

3+
4+
# ToDo: Use ProjectTo in rrules (requires ChainRulesCore >= v0.10.11).
5+
6+
37
struct FwdDiffPullbackThunk{F<:Base.Callable,T<:Tuple,i,U<:Any} <: ChainRulesCore.AbstractThunk
48
f::F
59
xs::T
@@ -11,9 +15,10 @@ function FwdDiffPullbackThunk(f::F, xs::T, ::Val{i}, ΔΩ::U) where {F<:Base.Cal
1115
end
1216

1317
@inline function ChainRulesCore.unthunk(tnk::FwdDiffPullbackThunk{F,T,i,U}) where {F,T,i,U}
14-
forwarddiff_fwd_back(tnk.f, tnk.xs, Val(i), tnk.ΔΩ)
18+
forwarddiff_fwd_back(tnk.f, tnk.xs, Val(i), unthunk(tnk.ΔΩ))
1519
end
1620

21+
# ToDo: Remove (obsolete with ChainRulesCore >= v0.10.):
1722
(tnk::FwdDiffPullbackThunk)() = ChainRulesCore.unthunk(tnk)
1823

1924

@@ -48,9 +53,10 @@ function FwdDiffBCPullbackThunk(f::F, Xs::T, ::Val{i}, ΔΩA::U) where {F<:Base.
4853
end
4954

5055
@inline function ChainRulesCore.unthunk(tnk::FwdDiffBCPullbackThunk{F,T,i,U}) where {F,T,i,U}
51-
forwarddiff_bc_fwd_back(tnk.f, tnk.Xs, Val(i), tnk.ΔΩA)
56+
forwarddiff_bc_fwd_back(tnk.f, tnk.Xs, Val(i), unthunk(tnk.ΔΩA))
5257
end
5358

59+
# ToDo: Remove (obsolete with ChainRulesCore >= v0.10.):
5460
(tnk::FwdDiffBCPullbackThunk)() = ChainRulesCore.unthunk(tnk)
5561

5662

0 commit comments

Comments
 (0)