Open
Description
It looks like finite difference implementation has hard time going through iterate
(see MRE and full stacktrace below):
juia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
...
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
...
Below I provide rrule()
implementation for iterate
on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of _make_j′vp_call()
. Also, I see the same error when testing with arrays.
MWE
using ChainRulesCore
import ChainRulesCore.rrule
using ChainRulesTestUtils
function ungetfield(dy, s::Tuple, f::Int)
T = typeof(s)
return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end
function rrule(::typeof(iterate), t::Tuple)
y = iterate(t)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, 1)
end
return y, iterate_pullback
end
function rrule(::typeof(iterate), t::Tuple, i::Integer)
y = iterate(t, i)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent()
end
return y, iterate_pullback
end
test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
Complete stacktrace
julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
[2] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
[3] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
[8] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined]
[9] macro expansion
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined]
[10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
[11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
[12] top-level scope
@ REPL[1]:1
[13] eval
@ ./boot.jl:368 [inlined]
[14] eval
@ ./Base.jl:65 [inlined]
[15] repleval(m::Module, code::Expr, #unused#::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222
[16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186
[17] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:511
[18] with_logger
@ ./logging.jl:623 [inlined]
[19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187
[20] #invokelatest#2
@ ./essentials.jl:729 [inlined]
[21] invokelatest(::Any)
@ Base ./essentials.jl:726
[22] macro expansion
@ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
[23] (::VSCodeServer.var"#61#62")()
@ VSCodeServer ./task.jl:484
Test Summary: | Pass Error Total Time
test_rrule: iterate on Float64,Float64 | 3 1 4 0.0s
ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.