Skip to content

Commit e389e4f

Browse files
authored
Merge branch 'main' into hmc_proposed_sample_handling_refactor
2 parents 71d6dee + bf65463 commit e389e4f

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
with:
6767
fail_ci_if_error: true
6868
token: ${{ secrets.CODECOV_TOKEN }}
69-
file: lcov.info
69+
files: lcov.info
7070
docs:
7171
name: Documentation
7272
runs-on: ubuntu-latest

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BAT"
22
uuid = "c0cd4b16-88b7-57fa-983b-ab80aecada7e"
3-
version = "3.3.1"
3+
version = "4.0.0"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -120,7 +120,7 @@ Folds = "0.2"
120120
ForwardDiff = "0.10"
121121
ForwardDiffPullbacks = "0.1.1, 0.2"
122122
FunctionChains = "0.1.4"
123-
Functors = "0.2, 0.3, 0.4"
123+
Functors = "0.2, 0.3, 0.4, 0.5"
124124
HDF5 = "0.15, 0.16, 0.17"
125125
HeterogeneousComputing = "0.2"
126126
HypothesisTests = "0.10, 0.11"

ext/BATOptimizationExt.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ function test_bat_optimization_ext()
2424
end
2525

2626
AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
27-
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
27+
Base.convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
2828

2929
BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()
3030

3131

3232
function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
33-
adm = convert_ad(ADTypes.AbstractADType, adsel)
33+
adm = convert(ADTypes.AbstractADType, reverse_ad_selector(adsel))
3434
optimization_function = Optimization.OptimizationFunction(f, adm)
3535
return optimization_function
3636
end
@@ -59,7 +59,9 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg,
5959
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)
6060

6161
algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
62-
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)
62+
# Not all algorithms support abstol, just filter all NaN-valued opts out:
63+
filtered_algopts = NamedTuple(filter(p -> !isnan(p[2]), pairs(algopts)))
64+
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...)
6365

6466
transformed_mode = optimization_result.u
6567
result_mode = inv_trafo(transformed_mode)

src/BAT.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ using DomainSets: UnitInterval, UnitCube, Rectangle, FullSpace, RealNumbers
9999

100100
using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent, AbstractThunk, unthunk
101101

102-
using Functors: fmap, @functor
102+
using Functors: fmap
103103

104104
# For Dual specializations:
105105
import ForwardDiff

test/optimization/test_mode_estimators.jl

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using BAT
22
using Test
33

4+
using AutoDiffOperators
45
using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface
56
using UnPack, InverseFunctions
67
import ForwardDiff
@@ -101,6 +102,10 @@ using Optim, OptimizationOptimJL
101102
context = BATContext(rng = Philox4x((0, 0)))
102103
# result is not type-stable:
103104
test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), pretransform = DoNotTransform()), 0.01, context, inferred = false)
105+
106+
context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff))
107+
# result is not type-stable:
108+
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), pretransform = DoNotTransform()), 0.01, context, inferred = false)
104109
end
105110

106111
@testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl

0 commit comments

Comments
 (0)