Description
Currently, if the rrule
(and/or pullback and/or thunks) fail to infer, then test_rrule
will throw an error. And similarly for frule
and test_frule
.
I've seen a few reports of users being confused by what exactly this error is, and think it'd probably be better if this resulted in a test failure.
Here's the current behaviour:
julia> A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0];
julia> test_rrule(exp, A)
test_rrule: exp on Matrix{Float64}: Error During Test at /Users/npr/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:227
Got exception outside of a @test
return type Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} does not match inferred return type Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:324
[3] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:323
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:235 [inlined]
[5] macro expansion
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[6] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::typeof(exp), args::Matrix{Float64}; output_tangent::ChainRulesTestUtils.Auto, tangent_transforms::Vector{Function}, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:230
[7] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::Function, args::Matrix{Float64})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:222
[8] #test_rrule#44
@ ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:204 [inlined]
[9] test_rrule(::Function, ::Matrix{Float64})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:203
[10] top-level scope
@ REPL[6]:1
[11] eval
@ ./boot.jl:360 [inlined]
[12] eval_user_input(ast::Any, backend::REPL.REPLBackend)
@ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:139
[13] repl_backend_loop(backend::REPL.REPLBackend)
@ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:200
[14] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
@ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:185
[15] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
@ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:317
[16] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:305
[17] (::Base.var"#874#876"{Bool, Bool, Bool})(REPL::Module)
@ Base ./client.jl:387
[18] #invokelatest#2
@ ./essentials.jl:708 [inlined]
[19] invokelatest
@ ./essentials.jl:706 [inlined]
[20] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
@ Base ./client.jl:372
[21] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:302
[22] _start()
@ Base ./client.jl:485
Test Summary: | Error Total
test_rrule: exp on Matrix{Float64} | 1 1
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.
I think better would be an output like:
julia> test_rrule(exp, A)
type stable pullback: Test Failed at /Users/npr/repos/ChainRulesTestUtils.jl/src/testers.jl:333
Expression: false
Problem: The pullback should be type stable. Or use `test_rrule` with `check_inferred=false`. `@inferred` gave:
return type Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} does not match inferred return type Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}
Stacktrace:
[1] macro expansion
@ ~/repos/ChainRulesTestUtils.jl/src/testers.jl:333 [inlined]
[2] macro expansion
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/repos/ChainRulesTestUtils.jl/src/testers.jl:324
Test Summary: | Pass Fail Total
test_rrule: exp on Matrix{Float64} | 8 1 9
type stable rrule | 1 1
type stable pullback | 1 1
type stable thunk | 1 1
ERROR: Some tests did not pass: 8 passed, 1 failed, 0 errored, 0 broken.
(^this is me manually constructing the output... but hopefully it gives an idea of what i think would be helpful)
For implementing this, i think the main issue is that Test.@inferred
throws an ErrorException
, so one option would be to add to the Test stdlib a @test_inferred
that returns a test failure e.g. something like
julia> @test_inferred pullback(ȳ)
Test Failed at REPL[16]:1
Expression: actual === inferred
Evaluated: Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} === Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, Symmetric{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}}}}
another option would be to just fix things here to be how we want (then we have full control and don't need to wait for newer Julia versions). For example, we could wrap our _test_inferred
helper (which calls @maybe_inferred
) in try-catch to catch the @inferred
error and then use @test_msg
to provide useful output and a test failure (this is wht i tried to mock-up above).