@@ -44,8 +44,8 @@ include("testfuncs.jl")
44
44
end
45
45
46
46
47
- @test @inferred (cr_fwd_and_back (fwddiff (f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3 , Float32}}}
48
- @test @inferred (zg_fwd_and_back (fwddiff (f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3 , Float32}}}
47
+ @test @tinferred (cr_fwd_and_back (fwddiff (f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3 , Float32}}}
48
+ @test @tinferred (zg_fwd_and_back (fwddiff (f), xs, ΔΩ)) isa Tuple{Tuple{Float32, Float32, Int64}, Tuple{Float32, Tuple{Float32, Float32}, SVector{3 , Float32}}}
49
49
50
50
@test cr_fwd_and_back (fwddiff (f), xs, ΔΩ) == ((139 , 783 , 42 ), (280 , (600 , 1040 ), SVector (1600 , 2280 , 3080 )))
51
51
@test zg_fwd_and_back (fwddiff (f), xs, ΔΩ) == ((139 , 783 , 42 ), (280 , (600 , 1040 ), SVector (1600 , 2280 , 3080 ))) # == zg_fwd_and_back(f, xs, ΔΩ)
@@ -73,30 +73,30 @@ include("testfuncs.jl")
73
73
Xs = map (x -> fill (x, 5 ), xs)
74
74
ΔΩA = fill (ΔΩ, 5 )
75
75
76
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[1 ]), Val (3 ), ΔΩA)) == fill (280 , 5 )
77
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[2 ]), Val (3 ), ΔΩA)) == fill ((600 , 1040 ), 5 )
78
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[3 ]), Val (3 ), ΔΩA)) == fill (SVector (1600 , 2280 , 3080 ), 5 )
76
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[1 ]), Val (3 ), ΔΩA)) == fill (280 , 5 )
77
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[2 ]), Val (3 ), ΔΩA)) == fill ((600 , 1040 ), 5 )
78
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, (Xs[1 ], Ref (xs[2 ]), Xs[3 ]), Val (3 ), ΔΩA)) == fill (SVector (1600 , 2280 , 3080 ), 5 )
79
79
80
80
for args in (Xs, (Xs[1 ], Ref (xs[2 ]), Xs[3 ]), map (Ref, xs))
81
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), ΔΩA)) == fill (280 , 5 )
82
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), ΔΩA)) == fill ((600 , 1040 ), 5 )
83
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), ΔΩA)) == fill (SVector (1600 , 2280 , 3080 ), 5 )
81
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), ΔΩA)) == fill (280 , 5 )
82
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), ΔΩA)) == fill ((600 , 1040 ), 5 )
83
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), ΔΩA)) == fill (SVector (1600 , 2280 , 3080 ), 5 )
84
84
end
85
85
86
86
for args in (Xs, (Xs[1 ], Ref (xs[2 ]), Xs[3 ]))
87
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), Ref (ΔΩ))) == fill (280 , 5 )
88
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), Ref (ΔΩ))) == fill ((600 , 1040 ), 5 )
89
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), Ref (ΔΩ))) == fill (SVector (1600 , 2280 , 3080 ), 5 )
87
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), Ref (ΔΩ))) == fill (280 , 5 )
88
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), Ref (ΔΩ))) == fill ((600 , 1040 ), 5 )
89
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), Ref (ΔΩ))) == fill (SVector (1600 , 2280 , 3080 ), 5 )
90
90
end
91
91
92
92
let args = map (Ref, xs), ΔY = Ref (ΔΩ)
93
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), Ref (ΔΩ))) == 280
94
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), Ref (ΔΩ))) == (600 , 1040 )
95
- @test @inferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), Ref (ΔΩ))) == SVector (1600 , 2280 , 3080 )
93
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (1 ), Ref (ΔΩ))) == 280
94
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (2 ), Ref (ΔΩ))) == (600 , 1040 )
95
+ @test @tinferred (ForwardDiffPullbacks. forwarddiff_bc_fwd_back (f, args, Val (3 ), Ref (ΔΩ))) == SVector (1600 , 2280 , 3080 )
96
96
end
97
97
98
98
99
- # @inferred cr_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
99
+ # @tinferred cr_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
100
100
# zg_bc_fwd_and_back(fwddiff(f), Xs, ΔΩA)
101
101
# zg_bc_fwd_and_back(f, Xs, ΔΩA)
102
102
@@ -146,14 +146,14 @@ include("testfuncs.jl")
146
146
x = 0.5
147
147
148
148
ΔΩ = (y = 1 , ladj = nothing )
149
- #= @inferred =# cr_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, missing ), ΔΩ)
150
- #= @inferred =# zg_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, missing ), ΔΩ)
149
+ #= @tinferred =# cr_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, missing ), ΔΩ)
150
+ #= @tinferred =# zg_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, missing ), ΔΩ)
151
151
# compare with (ignore missings):
152
152
zg_fwd_and_back (disttrafo, (trg_d, src_d, x, missing ), ΔΩ)
153
153
154
154
ΔΩ = (y = 1 , ladj = 42 )
155
- #= @inferred =# cr_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, 7 ), ΔΩ)
156
- #= @inferred =# zg_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, 7 ), ΔΩ)
155
+ #= @tinferred =# cr_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, 7 ), ΔΩ)
156
+ #= @tinferred =# zg_fwd_and_back (fwddiff (disttrafo), (trg_d, src_d, x, 7 ), ΔΩ)
157
157
# compare with (use deep approx):
158
158
zg_fwd_and_back (disttrafo, (trg_d, src_d, x, 7 ), ΔΩ)
159
159
@@ -164,14 +164,14 @@ include("testfuncs.jl")
164
164
X = randn (n)
165
165
166
166
ΔΩs = fill ((y = 1 , ladj = nothing ), n)
167
- #= @inferred =# (cr_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, missing ), ΔΩs))
168
- #= @inferred =# (zg_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, missing ), ΔΩs))
167
+ #= @tinferred =# (cr_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, missing ), ΔΩs))
168
+ #= @tinferred =# (zg_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, missing ), ΔΩs))
169
169
# compare with (ignore missings):
170
170
zg_bc_fwd_and_back (disttrafo, (trg_D, src_D, X, missing ), ΔΩs)
171
171
172
172
ΔΩs = fill ((y = 1 , ladj = 42 ), n)
173
- #= @inferred =# (cr_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, 7 ), ΔΩs))
174
- #= @inferred =# (zg_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, 7 ), ΔΩs))
173
+ #= @tinferred =# (cr_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, 7 ), ΔΩs))
174
+ #= @tinferred =# (zg_bc_fwd_and_back (fwddiff (disttrafo), (trg_D, src_D, X, 7 ), ΔΩs))
175
175
# compare with (use deep approx):
176
176
zg_bc_fwd_and_back (disttrafo, (trg_D, src_D, X, 7 ), ΔΩs)
177
177
@@ -189,8 +189,8 @@ include("testfuncs.jl")
189
189
ismissing (ladj) ? 2 * y : typeof (y)(y^ 2 + ladj^ 2 )
190
190
end
191
191
192
- #= @inferred =# Zygote. gradient (dummy_loss, trg_d, src_d, x, 42 )
193
- #= @inferred =# Zygote. gradient (dummy_loss, trg_d, src_d, x, missing )
192
+ #= @tinferred =# Zygote. gradient (dummy_loss, trg_d, src_d, x, 42 )
193
+ #= @tinferred =# Zygote. gradient (dummy_loss, trg_d, src_d, x, missing )
194
194
195
195
function bc_dummy_loss (trg_D:: AbstractVector{<:Distribution} , src_D:: AbstractVector{<:Distribution} , X:: AbstractVector{<:Real} , prev_ladj:: Union{Real,Missing} )
196
196
Y_ladj = fwddiff (disttrafo).(trg_D, src_D, X, prev_ladj)
@@ -200,13 +200,13 @@ include("testfuncs.jl")
200
200
any (ismissing, ladj) ? T (sum (Y)) : T (sum (Y)^ 2 + sum (ladj)^ 2 )
201
201
end
202
202
203
- #= @inferred =# bc_dummy_loss (trg_D, src_D, X, 42 )
204
- #= @inferred =# bc_dummy_loss (trg_D, src_D, X, missing )
203
+ #= @tinferred =# bc_dummy_loss (trg_D, src_D, X, 42 )
204
+ #= @tinferred =# bc_dummy_loss (trg_D, src_D, X, missing )
205
205
206
- #= @inferred =# Zygote. gradient (bc_dummy_loss, trg_D, src_D, X, 42 )
207
- #= @inferred =# Zygote. gradient (bc_dummy_loss, trg_D, src_D, X, missing )
206
+ #= @tinferred =# Zygote. gradient (bc_dummy_loss, trg_D, src_D, X, 42 )
207
+ #= @tinferred =# Zygote. gradient (bc_dummy_loss, trg_D, src_D, X, missing )
208
208
209
- #= @inferred =# Zygote. gradient (X -> bc_dummy_loss (trg_D, src_D, X, 42 ), X)
209
+ #= @tinferred =# Zygote. gradient (X -> bc_dummy_loss (trg_D, src_D, X, 42 ), X)
210
210
211
211
# Requires Zygote to utilize thunks (https://github.com/FluxML/Zygote.jl/pull/966):
212
212
if isdefined (Zygote, Symbol (" @_adjoint_keepthunks" ))
0 commit comments