-
-
Notifications
You must be signed in to change notification settings - Fork 99
Description
This issue is about the machinery for choosing backends and throwing warnings:
I think that this could be both optimized and simplified due to recent changes in DI.
Nowadays, DI.inner
and DI.outer
can also be called on backends which are not SecondOrder
, they just act as the identity. Thus, you don't need to explicitly create a SecondOrder(adtype, adtype)
. Passing adtype
alone will be equivalent in most cases, and faster in some because it can leverage custom Hessian implementations within a single backend (e.g. SecondOrder(AutoForwardDiff(), AutoForwardDiff())
cannot call ForwardDiff.hessian
whereas AutoForwardDiff()
can).
Furthermore, DI's hvp
and hessian
for AutoZygote()
already use ForwardDiff-over-Zygote.
Here are my suggestions:
- Simplify the
generate_adtype
logic and its variants to avoid creatingSecondOrder
objects altogether. - Throw a warning based on the modes
DI.inner
andDI.outer
, e.g. when the inner backend is not a reverse mode backend. This can be checked withADTypes.mode(DI.inner(adtype)) isa Union{ADTypes.ReverseMode,ADTypes.ForwardOrReverseMode}
. Of course you also want to allow ForwardDiff so feel free to refine. - Document this behavior so that users are less confused by the warnings (see this Discourse thread).
What do you think @Vaibhavdixit02?