Skip to content

Commit ae252a4

Browse files
committed
Test type inference in Julia v.16 upwards only
1 parent 8d4457b commit ae252a4

File tree

4 files changed

+48
-36
lines changed

4 files changed

+48
-36
lines changed

test/test_fwd_back.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ include("testfuncs.jl")
1616
ΔΩy = SVector(ΔΩ)
1717

1818

19-
@test @inferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(1), ΔΩ)) == 280
20-
@test @inferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(2), ΔΩ)) == (600, 1040)
21-
@test @inferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(3), ΔΩ)) == SVector(1600, 2280, 3080)
19+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(1), ΔΩ)) == 280
20+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(2), ΔΩ)) == (600, 1040)
21+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_fwd_back(f, xs, Val(3), ΔΩ)) == SVector(1600, 2280, 3080)
2222
end

test/test_rrules.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using ChainRulesCore
88
include("testfuncs.jl")
99

1010
@testset "ChainRulesCore" begin
11-
@test @inferred(ChainRulesCore.rrule(fwddiff(f), xs...)) isa Tuple{Tuple, Function}
12-
@test @inferred((ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ)) isa Tuple{ChainRulesCore.Zero, ForwardDiffPullbacks.FwdDiffPullbackThunk, ForwardDiffPullbacks.FwdDiffPullbackThunk,ForwardDiffPullbacks.FwdDiffPullbackThunk}
13-
@test @inferred(map(unthunk, (ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ))) == (Zero(), 280, (600, 1040), SVector(1600, 2280, 3080))
11+
@test @tinferred(ChainRulesCore.rrule(fwddiff(f), xs...)) isa Tuple{Tuple, Function}
12+
@test @tinferred((ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ)) isa Tuple{ChainRulesCore.Zero, ForwardDiffPullbacks.FwdDiffPullbackThunk, ForwardDiffPullbacks.FwdDiffPullbackThunk,ForwardDiffPullbacks.FwdDiffPullbackThunk}
13+
@test @tinferred(map(unthunk, (ChainRulesCore.rrule(fwddiff(f), xs...)[2])(ΔΩ))) == (Zero(), 280, (600, 1040), SVector(1600, 2280, 3080))
1414
end

test/test_zygote.jl

+30-30
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ include("testfuncs.jl")
4444
end
4545

4646

47-
@test @inferred(cr_fwd_and_back(fwddiff(f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3, Float32}}}
48-
@test @inferred(zg_fwd_and_back(fwddiff(f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3, Float32}}}
47+
@test @tinferred(cr_fwd_and_back(fwddiff(f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3, Float32}}}
48+
@test @tinferred(zg_fwd_and_back(fwddiff(f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3, Float32}}}
4949

5050
@test cr_fwd_and_back(fwddiff(f), xs, ΔΩ) == ((139, 783, 42), (280, (600, 1040), SVector(1600, 2280, 3080)))
5151
@test zg_fwd_and_back(fwddiff(f), xs, ΔΩ) == ((139, 783, 42), (280, (600, 1040), SVector(1600, 2280, 3080))) # == zg_fwd_and_back(f, xs, ΔΩ)
@@ -73,30 +73,30 @@ include("testfuncs.jl")
7373
Xs = map(x -> fill(x, 5), xs)
7474
ΔΩA = fill(ΔΩ, 5)
7575

76-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[1]), Val(3), ΔΩA)) == fill(280, 5)
77-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[2]), Val(3), ΔΩA)) == fill((600, 1040), 5)
78-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[3]), Val(3), ΔΩA)) == fill(SVector(1600, 2280, 3080), 5)
76+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[1]), Val(3), ΔΩA)) == fill(280, 5)
77+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[2]), Val(3), ΔΩA)) == fill((600, 1040), 5)
78+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, (Xs[1], Ref(xs[2]), Xs[3]), Val(3), ΔΩA)) == fill(SVector(1600, 2280, 3080), 5)
7979

8080
for args in (Xs, (Xs[1], Ref(xs[2]), Xs[3]), map(Ref, xs))
81-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), ΔΩA)) == fill(280, 5)
82-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), ΔΩA)) == fill((600, 1040), 5)
83-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), ΔΩA)) == fill(SVector(1600, 2280, 3080), 5)
81+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), ΔΩA)) == fill(280, 5)
82+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), ΔΩA)) == fill((600, 1040), 5)
83+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), ΔΩA)) == fill(SVector(1600, 2280, 3080), 5)
8484
end
8585

8686
for args in (Xs, (Xs[1], Ref(xs[2]), Xs[3]))
87-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), Ref(ΔΩ))) == fill(280, 5)
88-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), Ref(ΔΩ))) == fill((600, 1040), 5)
89-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), Ref(ΔΩ))) == fill(SVector(1600, 2280, 3080), 5)
87+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), Ref(ΔΩ))) == fill(280, 5)
88+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), Ref(ΔΩ))) == fill((600, 1040), 5)
89+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), Ref(ΔΩ))) == fill(SVector(1600, 2280, 3080), 5)
9090
end
9191

9292
let args = map(Ref, xs), ΔY = Ref(ΔΩ)
93-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), Ref(ΔΩ))) == 280
94-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), Ref(ΔΩ))) == (600, 1040)
95-
@test @inferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), Ref(ΔΩ))) == SVector(1600, 2280, 3080)
93+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(1), Ref(ΔΩ))) == 280
94+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(2), Ref(ΔΩ))) == (600, 1040)
95+
@test @tinferred(ForwardDiffPullbacks.forwarddiff_bc_fwd_back(f, args, Val(3), Ref(ΔΩ))) == SVector(1600, 2280, 3080)
9696
end
9797

9898

99-
# @inferred cr_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
99+
# @tinferred cr_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
100100
# zg_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
101101
# zg_bc_fwd_and_back(f, Xs, ΔΩA)
102102

@@ -146,14 +146,14 @@ include("testfuncs.jl")
146146
x = 0.5
147147

148148
ΔΩ = (y = 1, ladj = nothing)
149-
#=@inferred=# cr_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, missing), ΔΩ)
150-
#=@inferred=# zg_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, missing), ΔΩ)
149+
#=@tinferred=# cr_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, missing), ΔΩ)
150+
#=@tinferred=# zg_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, missing), ΔΩ)
151151
#compare with (ignore missings):
152152
zg_fwd_and_back(disttrafo, (trg_d, src_d, x, missing), ΔΩ)
153153

154154
ΔΩ = (y = 1, ladj = 42)
155-
#=@inferred=# cr_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, 7), ΔΩ)
156-
#=@inferred=# zg_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, 7), ΔΩ)
155+
#=@tinferred=# cr_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, 7), ΔΩ)
156+
#=@tinferred=# zg_fwd_and_back(fwddiff(disttrafo), (trg_d, src_d, x, 7), ΔΩ)
157157
#compare with (use deep approx):
158158
zg_fwd_and_back(disttrafo, (trg_d, src_d, x, 7), ΔΩ)
159159

@@ -164,14 +164,14 @@ include("testfuncs.jl")
164164
X = randn(n)
165165

166166
ΔΩs = fill((y = 1, ladj = nothing), n)
167-
#=@inferred=#(cr_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, missing), ΔΩs))
168-
#=@inferred=#(zg_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, missing), ΔΩs))
167+
#=@tinferred=#(cr_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, missing), ΔΩs))
168+
#=@tinferred=#(zg_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, missing), ΔΩs))
169169
#compare with (ignore missings):
170170
zg_bc_fwd_and_back(disttrafo, (trg_D, src_D, X, missing), ΔΩs)
171171

172172
ΔΩs = fill((y = 1, ladj = 42), n)
173-
#=@inferred=#(cr_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, 7), ΔΩs))
174-
#=@inferred=#(zg_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, 7), ΔΩs))
173+
#=@tinferred=#(cr_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, 7), ΔΩs))
174+
#=@tinferred=#(zg_bc_fwd_and_back(fwddiff(disttrafo), (trg_D, src_D, X, 7), ΔΩs))
175175
#compare with (use deep approx):
176176
zg_bc_fwd_and_back(disttrafo, (trg_D, src_D, X, 7), ΔΩs)
177177

@@ -189,8 +189,8 @@ include("testfuncs.jl")
189189
ismissing(ladj) ? 2 * y : typeof(y)(y^2 + ladj^2)
190190
end
191191

192-
#=@inferred=# Zygote.gradient(dummy_loss, trg_d, src_d, x, 42)
193-
#=@inferred=# Zygote.gradient(dummy_loss, trg_d, src_d, x, missing)
192+
#=@tinferred=# Zygote.gradient(dummy_loss, trg_d, src_d, x, 42)
193+
#=@tinferred=# Zygote.gradient(dummy_loss, trg_d, src_d, x, missing)
194194

195195
function bc_dummy_loss(trg_D::AbstractVector{<:Distribution}, src_D::AbstractVector{<:Distribution}, X::AbstractVector{<:Real}, prev_ladj::Union{Real,Missing})
196196
Y_ladj = fwddiff(disttrafo).(trg_D, src_D, X, prev_ladj)
@@ -200,13 +200,13 @@ include("testfuncs.jl")
200200
any(ismissing, ladj) ? T(sum(Y)) : T(sum(Y)^2 + sum(ladj)^2)
201201
end
202202

203-
#=@inferred=# bc_dummy_loss(trg_D, src_D, X, 42)
204-
#=@inferred=# bc_dummy_loss(trg_D, src_D, X, missing)
203+
#=@tinferred=# bc_dummy_loss(trg_D, src_D, X, 42)
204+
#=@tinferred=# bc_dummy_loss(trg_D, src_D, X, missing)
205205

206-
#=@inferred=# Zygote.gradient(bc_dummy_loss, trg_D, src_D, X, 42)
207-
#=@inferred=# Zygote.gradient(bc_dummy_loss, trg_D, src_D, X, missing)
206+
#=@tinferred=# Zygote.gradient(bc_dummy_loss, trg_D, src_D, X, 42)
207+
#=@tinferred=# Zygote.gradient(bc_dummy_loss, trg_D, src_D, X, missing)
208208

209-
#=@inferred=# Zygote.gradient(X -> bc_dummy_loss(trg_D, src_D, X, 42), X)
209+
#=@tinferred=# Zygote.gradient(X -> bc_dummy_loss(trg_D, src_D, X, 42), X)
210210

211211
# Requires Zygote to utilize thunks (https://github.com/FluxML/Zygote.jl/pull/966):
212212
if isdefined(Zygote, Symbol("@_adjoint_keepthunks"))

test/testfuncs.jl

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# This file is a part of ForwardDiffPullbacks.jl, licensed under the MIT License (MIT).
22

3+
import Test
4+
5+
# Tests type inference on Julia >= v1.6 only:
6+
macro tinferred(expr)
7+
if VERSION >= v"1.6"
8+
esc(:(Test.@inferred($expr)))
9+
else
10+
esc(:($expr))
11+
end
12+
end
13+
14+
315
using LinearAlgebra, StaticArrays
416

517
sum_pow2(x) = sum(map(x -> x^2, x))

0 commit comments

Comments
 (0)