|
1 | 1 | """ |
2 | | -Recipe functions to easily build commonly used inversions. |
| 2 | +Recipe functions to easily build commonly used inversions and objective functions. |
3 | 3 | """ |
4 | 4 |
|
5 | 5 | from collections.abc import Callable |
6 | 6 |
|
7 | | -from .base import Minimizer, Objective |
| 7 | +import numpy as np |
| 8 | +import numpy.typing as npt |
| 9 | + |
| 10 | +from .base import Combo, Minimizer, Objective |
8 | 11 | from .conditions import ChiTarget, ObjectiveChanged |
9 | 12 | from .data_misfit import DataMisfit |
10 | 13 | from .directives import Irls, MultiplierCooler |
11 | 14 | from .inversion import Inversion |
12 | 15 | from .inversion_log import Column |
13 | 16 | from .preconditioners import JacobiPreconditioner |
| 17 | +from .regularization import Flatness, Smallness |
14 | 18 | from .typing import Model, Preconditioner |
15 | 19 |
|
16 | 20 |
|
@@ -256,3 +260,82 @@ def create_sparse_inversion( |
256 | 260 | ), |
257 | 261 | ) |
258 | 262 | return inversion |
| 263 | + |
| 264 | + |
| 265 | +def create_tikhonov_regularization( |
| 266 | + mesh, |
| 267 | + *, |
| 268 | + active_cells: npt.NDArray[np.bool] | None = None, |
| 269 | + cell_weights: npt.NDArray | dict[str, npt.NDArray] | None = None, |
| 270 | + reference_model: Model | None = None, |
| 271 | + alpha_s: float | None = None, |
| 272 | + alpha_x: float | None = None, |
| 273 | + alpha_y: float | None = None, |
| 274 | + alpha_z: float | None = None, |
| 275 | + reference_model_in_flatness: bool = False, |
| 276 | +) -> Combo: |
| 277 | + """ |
| 278 | + Create a linear combination of Tikhonov (L2) regularization terms. |
| 279 | +
|
| 280 | + Define a :class:`inversion_ideas.base.Combo` with L2 smallness and flatness |
| 281 | + regularization terms. |
| 282 | +
|
| 283 | + Parameters |
| 284 | + ---------- |
| 285 | + mesh : discretize.base.BaseMesh |
| 286 | + Mesh to use in the regularization. |
| 287 | + active_cells : (n_params) array or None, optional |
| 288 | + Array full of bools that indicate the active cells in the mesh. |
| 289 | + cell_weights : (n_params) array or dict of (n_params) arrays or None, optional |
| 290 | + Array with cell weights. |
| 291 | + For multiple cell weights, pass a dictionary where keys are strings and values |
| 292 | + are the different weights arrays. |
| 293 | + If None, no cell weights are going to be used. |
| 294 | + reference_model : (n_params) array or None, optional |
| 295 | + Array with values for the reference model. |
| 296 | + alpha_s : float or None, optional |
| 297 | + Multiplier for the smallness term. |
| 298 | + alpha_x, alpha_y, alpha_z : float or None, optional |
| 299 | + Multipliers for the flatness terms. |
| 300 | +
|
| 301 | + Returns |
| 302 | + ------- |
| 303 | + inversion_ideas.base.Combo |
| 304 | + Combo of L2 regularization terms. |
| 305 | +
|
| 306 | + Notes |
| 307 | + ----- |
| 308 | + TODO |
| 309 | + """ |
| 310 | + # TODO: raise errors: |
| 311 | + # if dims == 2 and alpha_z is passed |
| 312 | + # if dims == 1 and alpha_y or alpha_z are passed |
| 313 | + smallness = Smallness( |
| 314 | + mesh, |
| 315 | + active_cells=active_cells, |
| 316 | + cell_weights=cell_weights, |
| 317 | + reference_model=reference_model, |
| 318 | + ) |
| 319 | + if alpha_s is not None: |
| 320 | + smallness = alpha_s * smallness |
| 321 | + |
| 322 | + kwargs = { |
| 323 | + "active_cells": active_cells, |
| 324 | + "cell_weights": cell_weights, |
| 325 | + } |
| 326 | + if reference_model_in_flatness: |
| 327 | + kwargs["reference_model"] = reference_model |
| 328 | + |
| 329 | + flatness_x = Flatness(mesh, **kwargs, direction="x") |
| 330 | + if alpha_x is not None: |
| 331 | + flatness_x = alpha_x * flatness_x |
| 332 | + |
| 333 | + flatness_y = Flatness(mesh, **kwargs, direction="y") |
| 334 | + if alpha_y is not None: |
| 335 | + flatness_y = alpha_y * flatness_y |
| 336 | + |
| 337 | + flatness_z = Flatness(mesh, **kwargs, direction="z") |
| 338 | + if alpha_z is not None: |
| 339 | + flatness_z = alpha_z * flatness_z |
| 340 | + |
| 341 | + return (smallness + flatness_x + flatness_y + flatness_z).flatten() |
0 commit comments