Skip to content

Commit fb502b2

Browse files
committed
Fix Optimization ext
1 parent 159f24f commit fb502b2

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ext/BATOptimizationExt.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ function test_bat_optimization_ext()
2424
end
2525

2626
AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
27-
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
27+
Base.convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
2828

2929
BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()
3030

3131

3232
function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
33-
adm = convert_ad(ADTypes.AbstractADType, adsel)
33+
adm = convert(ADTypes.AbstractADType, reverse_ad_selector(adsel))
3434
optimization_function = Optimization.OptimizationFunction(f, adm)
3535
return optimization_function
3636
end
@@ -59,7 +59,9 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg,
5959
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)
6060

6161
algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
62-
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)
62+
# Not all algorithms support abstol, just filter all NaN-valued opts out:
63+
filtered_algopts = NamedTuple(filter(p -> !isnan(p[2]), pairs(algopts)))
64+
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...)
6365

6466
transformed_mode = optimization_result.u
6567
result_mode = inv_trafo(transformed_mode)

test/optimization/test_mode_estimators.jl

+4
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ using Optim, OptimizationOptimJL
101101
context = BATContext(rng = Philox4x((0, 0)))
102102
# result is not type-stable:
103103
test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), pretransform = DoNotTransform()), 0.01, context, inferred = false)
104+
105+
context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff))
106+
# result is not type-stable:
107+
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), trafo = DoNotTransform()), 0.01, context, inferred = false)
104108
end
105109

106110
@testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl

0 commit comments

Comments
 (0)