Skip to content

Commit d634408

Browse files
committed
Add ForwardDiff-specific methods of second-derivative functions
1 parent 08f0b44 commit d634408

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

ext/AbstractDifferentiationForwardDiffExt.jl

+20
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,40 @@ function AD.hessian(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6161
return (ForwardDiff.hessian(f, x, cfg),)
6262
end
6363

64+
function AD.value_and_derivative(::AD.ForwardDiffBackend, f, x::Real)
65+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
66+
ydual = f(ForwardDiff.Dual{T}(x, one(x)))
67+
return ForwardDiff.value(T, ydual), (ForwardDiff.extract_derivative(T, ydual),)
68+
end
69+
6470
function AD.value_and_gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6571
result = DiffResults.GradientResult(x)
6672
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
6773
ForwardDiff.gradient!(result, f, x, cfg)
6874
return DiffResults.value(result), (DiffResults.derivative(result),)
6975
end
7076

77+
function AD.value_and_second_derivative(ba::AD.ForwardDiffBackend, f, x::Real)
78+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
79+
ydual, ddual = AD.value_and_derivative(ba, f, ForwardDiff.Dual{T}(x, one(x)))
80+
return value(T, ydual), (extract_derivative(T, ddual[1]),)
81+
end
82+
7183
function AD.value_and_hessian(ba::AD.ForwardDiffBackend, f, x)
7284
result = DiffResults.HessianResult(x)
7385
cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x))
7486
ForwardDiff.hessian!(result, f, x, cfg)
7587
return DiffResults.value(result), (DiffResults.hessian(result),)
7688
end
7789

90+
function AD.value_and_derivatives(ba::AD.ForwardDiffBackend, f, x::Real)
91+
T = typeof(ForwardDiff.Tag(f, typeof(x)))
92+
ydual, ddual = AD.value_and_derivative(ba, f, ForwardDiff.Dual{T}(x, one(x)))
93+
return ForwardDiff.value(T, ydual),
94+
(ForwardDiff.value(T, ddual[1]),),
95+
(ForwardDiff.extract_derivative(T, ddual[1]),)
96+
end
97+
7898
@inline step_toward(x::Number, v::Number, h) = x + h * v
7999
# support arrays and tuples
80100
@noinline step_toward(x, v, h) = x .+ h .* v

0 commit comments

Comments
 (0)