Skip to content

test_rrule() fails in iterate() #263

Open
@dfdx

Description

@dfdx

It looks like finite difference implementation has hard time going through iterate (see MRE and full stacktrace below):

juia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  DimensionMismatch: second dimension of A, 2, does not match length of x, 1
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
...
    [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
...

Below I provide rrule() implementation for iterate on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of _make_j′vp_call(). Also, I see the same error when testing with arrays.

MWE
using ChainRulesCore
import ChainRulesCore.rrule
using ChainRulesTestUtils

function ungetfield(dy, s::Tuple, f::Int)
    T = typeof(s)
    return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end

function rrule(::typeof(iterate), t::Tuple)
    y = iterate(t)
    function iterate_pullback(dy)
        dy = unthunk(dy)
        return NoTangent(), ungetfield(dy[1], t, 1)
    end
    return y, iterate_pullback
end

function rrule(::typeof(iterate), t::Tuple, i::Integer)
    y = iterate(t, i)
    function iterate_pullback(dy)
        dy = unthunk(dy)
        return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent()
    end
    return y, iterate_pullback
end

test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
Complete stacktrace
julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  DimensionMismatch: second dimension of A, 2, does not match length of x, 1
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
    [2] mul!
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
    [3] mul!
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
    [4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
    [5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80
    [6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73
    [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
    [8] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined]
    [9] macro expansion
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined]
   [10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
   [11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
   [12] top-level scope
      @ REPL[1]:1
   [13] eval
      @ ./boot.jl:368 [inlined]
   [14] eval
      @ ./Base.jl:65 [inlined]
   [15] repleval(m::Module, code::Expr, #unused#::String)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222
   [16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186
   [17] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging ./logging.jl:511
   [18] with_logger
      @ ./logging.jl:623 [inlined]
   [19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187
   [20] #invokelatest#2
      @ ./essentials.jl:729 [inlined]
   [21] invokelatest(::Any)
      @ Base ./essentials.jl:726
   [22] macro expansion
      @ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
   [23] (::VSCodeServer.var"#61#62")()
      @ VSCodeServer ./task.jl:484
Test Summary:                          | Pass  Error  Total  Time
test_rrule: iterate on Float64,Float64 |    3      1      4  0.0s
ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions