Skip to content

Make failure to infer a test failure rather than an error #187

Open
@nickrobinson251

Description

@nickrobinson251

Currently, if the rrule (and/or pullback and/or thunks) fail to infer, then test_rrule will throw an error. And similarly for frule and test_frule.

I've seen a few reports of users being confused by what exactly this error is, and think it'd probably be better if this resulted in a test failure.

Here's the current behaviour:

julia> A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0];

julia> test_rrule(exp, A)
test_rrule: exp on Matrix{Float64}: Error During Test at /Users/npr/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:227
  Got exception outside of a @test
  return type Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} does not match inferred return type Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:33
    [2] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:324
    [3] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:323
    [4] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:235 [inlined]
    [5] macro expansion
      @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
    [6] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::typeof(exp), args::Matrix{Float64}; output_tangent::ChainRulesTestUtils.Auto, tangent_transforms::Vector{Function}, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:230
    [7] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::Function, args::Matrix{Float64})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:222
    [8] #test_rrule#44
      @ ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:204 [inlined]
    [9] test_rrule(::Function, ::Matrix{Float64})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/AX7fv/src/testers.jl:203
   [10] top-level scope
      @ REPL[6]:1
   [11] eval
      @ ./boot.jl:360 [inlined]
   [12] eval_user_input(ast::Any, backend::REPL.REPLBackend)
      @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:139
   [13] repl_backend_loop(backend::REPL.REPLBackend)
      @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:200
   [14] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
      @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:185
   [15] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
      @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:317
   [16] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:305
   [17] (::Base.var"#874#876"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:387
   [18] #invokelatest#2
      @ ./essentials.jl:708 [inlined]
   [19] invokelatest
      @ ./essentials.jl:706 [inlined]
   [20] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:372
   [21] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:302
   [22] _start()
      @ Base ./client.jl:485
Test Summary:                      | Error  Total
test_rrule: exp on Matrix{Float64} |     1      1
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.

I think better would be an output like:

julia> test_rrule(exp, A)
type stable pullback: Test Failed at /Users/npr/repos/ChainRulesTestUtils.jl/src/testers.jl:333
  Expression: false
  Problem: The pullback should be type stable. Or use `test_rrule` with `check_inferred=false`. `@inferred` gave:
  return type Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} does not match inferred return type Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}
Stacktrace:
 [1] macro expansion
   @ ~/repos/ChainRulesTestUtils.jl/src/testers.jl:333 [inlined]
 [2] macro expansion
   @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] _test_inferred(::Function, ::ChainRulesTestUtils.ADviaRuleConfig, ::Vararg{Any, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ ChainRulesTestUtils ~/repos/ChainRulesTestUtils.jl/src/testers.jl:324
Test Summary:                      | Pass  Fail  Total
test_rrule: exp on Matrix{Float64} |    8     1      9
  type stable rrule                |    1            1
  type stable pullback             |          1      1
  type stable thunk                |    1            1
ERROR: Some tests did not pass: 8 passed, 1 failed, 0 errored, 0 broken.

(^this is me manually constructing the output... but hopefully it gives an idea of what i think would be helpful)

For implementing this, i think the main issue is that Test.@inferred throws an ErrorException, so one option would be to add to the Test stdlib a @test_inferred that returns a test failure e.g. something like

julia> @test_inferred pullback(ȳ)
Test Failed at REPL[16]:1
  Expression: actual === inferred
   Evaluated: Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}} === Union{Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback#1537"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LU{Float64, Matrix{Float64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Matrix{Float64}, ChainRules.var"#exp_pullback_hermitian#1536"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, Symmetric{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}}}}

another option would be to just fix things here to be how we want (then we have full control and don't need to wait for newer Julia versions). For example, we could wrap our _test_inferred helper (which calls @maybe_inferred) in try-catch to catch the @inferred error and then use @test_msg to provide useful output and a test failure (this is wht i tried to mock-up above).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions