diff --git a/.copier-answers.yml b/.copier-answers.yml
index b2bdcfa2..eeff466a 100644
--- a/.copier-answers.yml
+++ b/.copier-answers.yml
@@ -1,6 +1,6 @@
# WARNING: Do not edit this file manually.
# Any changes will be overwritten by Copier.
-_commit: v0.0.5
+_commit: v0.4.2
_src_path: gh:easyscience/templates
app_docs_url: https://easyscience.github.io/dynamics-app
app_doi: 10.5281/zenodo.18163581
diff --git a/.github/actions/publish-to-pypi/action.yml b/.github/actions/publish-to-pypi/action.yml
index 522e3a02..719928d9 100644
--- a/.github/actions/publish-to-pypi/action.yml
+++ b/.github/actions/publish-to-pypi/action.yml
@@ -1,13 +1,14 @@
name: 'Publish to PyPI'
-description: 'Publish a built distribution to PyPI using pypa/gh-action-pypi-publish'
+description: 'Publish dist/ to PyPI via Trusted Publishing (OIDC)'
inputs:
- password:
- description: 'PyPI API token (or password) for authentication'
- required: true
+ packages_dir:
+ description: 'Directory containing the built packages to upload'
+ required: false
+ default: 'dist'
runs:
using: 'composite'
steps:
- uses: pypa/gh-action-pypi-publish@release/v1
with:
- password: ${{ inputs.password }}
+ packages-dir: ${{ inputs.packages_dir }}
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 05fef7db..4056c1c0 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -107,11 +107,8 @@ jobs:
- name: Pre-build site step
run: pixi run python -c "import easydynamics"
- # Convert Python scripts in the docs/docs/tutorials/ directory to Jupyter
- # notebooks.
- # This step also strips any existing output from the notebooks and
- # prepares them for documentation.
- - name: Convert tutorial scripts to notebooks
+ # Prepare the Jupyter notebooks for documentation (strip output, etc.).
+ - name: Prepare notebooks
run: pixi run notebook-prepare
# Execute all Jupyter notebooks to generate output cells (plots, tables, etc.).
diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml
index 15a2c6ed..6e48e610 100644
--- a/.github/workflows/pypi-publish.yml
+++ b/.github/workflows/pypi-publish.yml
@@ -14,6 +14,10 @@ jobs:
pypi-publish:
runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
+
steps:
- name: Check-out repository
uses: actions/checkout@v5
@@ -23,10 +27,18 @@ jobs:
- name: Set up pixi
uses: ./.github/actions/setup-pixi
+ # Build the Python package (to dist/ folder)
- name: Create Python package
run: pixi run default-build
+ # Publish the package to PyPI (from dist/ folder)
+ # Instead of publishing with personal access token, we use
+ # GitHub Actions OIDC to get a short-lived token from PyPI.
+ # New publisher must be previously configured in PyPI at
+ # https://pypi.org/manage/project/easydynamics/settings/publishing/
+ # Use the following data:
+ # Owner: easyscience
+ # Repository name: dynamics-lib
+ # Workflow name: pypi-publish.yml
- name: Publish to PyPI
uses: ./.github/actions/publish-to-pypi
- with:
- password: ${{ secrets.PYPI_PASSWORD }}
diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index 1397f485..201dace4 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -79,6 +79,14 @@ jobs:
continue-on-error: true
shell: bash
run: pixi run nonpy-format-check
+ # Check formatting of Jupyter Notebooks in the tutorials folder
+ - name: Prepare notebooks and check formatting
+ id: check_notebooks_formatting
+ continue-on-error: true
+ shell: bash
+ run: |
+ pixi run notebook-prepare
+ pixi run notebook-format-check
# Add summary
- name: Add quality checks summary
diff --git a/.github/workflows/tutorial-tests.yml b/.github/workflows/tutorial-tests.yml
index a3454fe7..55998847 100644
--- a/.github/workflows/tutorial-tests.yml
+++ b/.github/workflows/tutorial-tests.yml
@@ -46,7 +46,7 @@ jobs:
shell: bash
run: pixi run script-tests
- - name: Convert tutorial scripts to notebooks
+ - name: Prepare notebooks
shell: bash
run: pixi run notebook-prepare
diff --git a/.gitignore b/.gitignore
index 7e0f2da3..f7ce4ac2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,11 @@ __pycache__/
.venv/
.coverage
+# PyInstaller
+dist/
+build/
+*.spec
+
# MkDocs
docs/site/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c3d471cd..007d2389 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,57 +1,54 @@
repos:
- repo: local
hooks:
- # -----------------
- # Pre-commit checks
- # -----------------
+ # -------------
+ # Manual checks
+ # -------------
- id: pixi-pyproject-check
name: pixi run pyproject-check
entry: pixi run pyproject-check
language: system
pass_filenames: false
- stages: [pre-commit]
+ stages: [manual]
- - id: pixi-py-lint-check-staged
- name: pixi run py-lint-check-staged
- entry: pixi run py-lint-check-pre
+ - id: pixi-py-lint-check
+ name: pixi run py-lint-check
+ entry: pixi run py-lint-check
language: system
pass_filenames: false
- stages: [pre-commit]
+ stages: [manual]
- - id: pixi-py-format-check-staged
- name: pixi run py-format-check-staged
- entry: pixi run py-format-check-pre
+ - id: pixi-py-format-check
+ name: pixi run py-format-check
+ entry: pixi run py-format-check
language: system
pass_filenames: false
- stages: [pre-commit]
+ stages: [manual]
- - id: pixi-nonpy-format-check-modified
- name: pixi run nonpy-format-check-modified
- entry: pixi run nonpy-format-check-modified
+ - id: pixi-nonpy-format-check
+ name: pixi run nonpy-format-check
+ entry: pixi run nonpy-format-check
language: system
pass_filenames: false
- stages: [pre-commit]
+ stages: [manual]
- id: pixi-docs-format-check
name: pixi run docs-format-check
entry: pixi run docs-format-check
language: system
pass_filenames: false
- stages: [pre-commit]
+ stages: [manual]
- # ----------------
- # Pre-push checks
- # ----------------
- - id: pixi-nonpy-format-check
- name: pixi run nonpy-format-check
- entry: pixi run nonpy-format-check
+ - id: pixi-notebook-format-check
+ name: pixi run notebook-format-check
+ entry: pixi run notebook-format-check
language: system
pass_filenames: false
- stages: [pre-push]
+ stages: [manual]
- id: pixi-unit-tests
name: pixi run unit-tests
entry: pixi run unit-tests
language: system
pass_filenames: false
- stages: [pre-push]
+ stages: [manual]
diff --git a/README.md b/README.md
index 2f55c56f..373d3828 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,18 @@
-
+
-
+
-
+
-**EasyDynamics** is a scientific software for plotting and fitting qens
-and ins powder data.
+**EasyDynamics** is a scientific software for plotting and fitting QENS
+and INS powder data.
+
+
**EasyDynamics** is available both as a Python library and as a
cross-platform desktop application.
diff --git a/docs/docs/assets/stylesheets/extra.css b/docs/docs/assets/stylesheets/extra.css
index 1c199950..a625be80 100644
--- a/docs/docs/assets/stylesheets/extra.css
+++ b/docs/docs/assets/stylesheets/extra.css
@@ -222,9 +222,27 @@ Adjust the margins and paddings to fit the defaults in MkDocs Material and do no
width: 100% !important;
display: flex !important;
}
+
.jp-Notebook {
padding: 0 !important;
margin-top: -3em !important;
+
+ /* Ensure notebook content stretches across the page */
+ width: 100% !important;
+ max-width: 100% !important;
+
+ /* mkdocs-material + some notebook HTML end up as flex */
+ align-items: stretch !important;
+}
+
+.jp-Notebook .jp-Cell {
+ /* Key: flex children often need min-width: 0 to prevent weird shrink */
+ width: 100% !important;
+ max-width: 100% !important;
+ min-width: 0 !important;
+
+ /* Removes jupyter cell paddings */
+ padding-left: 0 !important;
}
/* Removes jupyter cell prefixes, like In[123]: */
@@ -234,11 +252,6 @@ Adjust the margins and paddings to fit the defaults in MkDocs Material and do no
display: none !important;
}
-/* Removes jupyter cell paddings */
-.jp-Cell {
- padding-left: 0 !important;
-}
-
/* Removes jupyter output cell padding to align with input cell text */
.jp-RenderedText {
padding-left: 0.85em !important;
diff --git a/docs/docs/installation-and-setup/index.md b/docs/docs/installation-and-setup/index.md
index 3513f6e9..420ef07e 100644
--- a/docs/docs/installation-and-setup/index.md
+++ b/docs/docs/installation-and-setup/index.md
@@ -8,8 +8,8 @@ icon: material/cog-box
**Python 3.11** through **3.12**.
To install and set up EasyDynamics, we recommend using
-[**Pixi**](https://prefix.dev), a modern package manager for Windows,
-macOS, and Linux.
+[**Pixi**](https://pixi.prefix.dev), a modern package manager for
+Windows, macOS, and Linux.
!!! note "Main benefits of using Pixi"
@@ -46,16 +46,9 @@ This section describes the simplest way to set up EasyDynamics using
```txt
pixi add python=3.12
```
-- Add the GNU Scientific Library (GSL) dependency:
+- Add EasyDynamics to the Pixi environment from PyPI:
```txt
- pixi add gsl
- ```
-- Add EasyDynamics with the `visualization` extras, which include
- optional dependencies used for simplified visualization of charts and
- tables. This can be especially useful for running the Jupyter Notebook
- examples:
- ```txt
- pixi add --pypi "easydynamics[visualization]"
+ pixi add --pypi easydynamics
```
- Add a Pixi task to run EasyDynamics commands easily:
```txt
@@ -160,20 +153,7 @@ simply delete and recreate the environment.
### Installing from PyPI { #from-pypi }
EasyDynamics is available on **PyPI (Python Package Index)** and can be
-installed using `pip`.
-
-We recommend installing the latest release of EasyDynamics with the
-`visualization` extras, which include optional dependencies used for
-simplified visualization of charts and tables. This can be especially
-useful for running the Jupyter Notebook examples. To do so, use the
-following command:
-
-```txt
-pip install 'easydynamics[visualization]'
-```
-
-If only the core functionality is needed, the library can be installed
-simply with:
+installed using `pip`. To do so, use the following command:
```txt
pip install easydynamics
@@ -216,10 +196,10 @@ example:
pip install git+https://github.com/easyscience/dynamics-lib@develop
```
-To include extra dependencies (e.g., visualization):
+To include extra dependencies (e.g., dev):
```txt
-pip install 'easydynamics[visualization] @ git+https://github.com/easyscience/dynamics-lib@develop'
+pip install 'easydynamics[dev] @ git+https://github.com/easyscience/dynamics-lib@develop'
```
## How to Run Tutorials
diff --git a/docs/docs/introduction/index.md b/docs/docs/introduction/index.md
index 58240514..740d4b0d 100644
--- a/docs/docs/introduction/index.md
+++ b/docs/docs/introduction/index.md
@@ -6,8 +6,8 @@ icon: material/information-slab-circle
## Description
-**EasyDynamics** is a scientific software for plotting and fitting qens
-and ins powder data.
+**EasyDynamics** is a scientific software for plotting and fitting QENS
+and INS powder data.
**EasyDynamics** is available both as a Python library and as a
cross-platform desktop application.
diff --git a/docs/docs/tutorials/analysis old parameter bug.ipynb b/docs/docs/tutorials/analysis old parameter bug.ipynb
new file mode 100644
index 00000000..85bddaaa
--- /dev/null
+++ b/docs/docs/tutorials/analysis old parameter bug.ipynb
@@ -0,0 +1,384 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "8643b10c",
+ "metadata": {},
+ "source": [
+ "asd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bca91d3c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from easydynamics.analysis.analysis1d import Analysis1d\n",
+ "from easydynamics.experiment import Experiment\n",
+ "from easydynamics.sample_model import ComponentCollection\n",
+ "from easydynamics.sample_model import DeltaFunction\n",
+ "from easydynamics.sample_model import Gaussian\n",
+ "from easydynamics.sample_model import Polynomial\n",
+ "from easydynamics.sample_model.background_model import BackgroundModel\n",
+ "from easydynamics.sample_model.resolution_model import ResolutionModel\n",
+ "from easydynamics.sample_model.sample_model import SampleModel\n",
+ "from easydynamics.sample_model.instrument_model import InstrumentModel\n",
+ "from easydynamics.analysis.analysis import Analysis\n",
+ "%matplotlib widget"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8deca9b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vanadium_experiment = Experiment('Vanadium')\n",
+ "vanadium_experiment.load_hdf5(filename='vanadium_data_example.h5')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "41f842f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # Create a diffusion_model and components for the SampleModel\n",
+ "\n",
+ "# # Creating components\n",
+ "# component_collection = ComponentCollection()\n",
+ "# delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "# gaussian = Gaussian(display_name='Gaussian', width=0.1, center=0.5, area=0.5)\n",
+ "\n",
+ "# # Adding components to the component collection\n",
+ "# component_collection.append_component(delta_function)\n",
+ "\n",
+ "\n",
+ "# sample_model = SampleModel(\n",
+ "# components=component_collection,\n",
+ "# unit='meV',\n",
+ "# display_name='MySampleModel',\n",
+ "# )\n",
+ "\n",
+ "# res_gauss = Gaussian(width=0.1)\n",
+ "# res_gauss.area.fixed = True\n",
+ "# resolution_model = ResolutionModel(components=res_gauss)\n",
+ "\n",
+ "\n",
+ "# background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n",
+ "\n",
+ "# instrument_model = InstrumentModel(\n",
+ "# resolution_model=resolution_model,\n",
+ "# background_model=background_model,\n",
+ "# )\n",
+ "\n",
+ "# my_analysis = Analysis1d(\n",
+ "# experiment=vanadium_experiment,\n",
+ "# sample_model=sample_model,\n",
+ "# instrument_model=instrument_model,\n",
+ "# Q_index=5,\n",
+ "# )\n",
+ "\n",
+ "\n",
+ "# values = my_analysis.calculate()\n",
+ "# sample_values, background_values = my_analysis.calculate_individual_components()\n",
+ "\n",
+ "# plt.figure()\n",
+ "# plt.plot(my_analysis.energy.values, values, label='Total Model')\n",
+ "# for component_index in range(len(sample_values)):\n",
+ "# plt.plot(\n",
+ "# my_analysis.energy.values,\n",
+ "# sample_values[component_index],\n",
+ "# label=f'Sample Component {component_index}',\n",
+ "# linestyle='--',\n",
+ "# )\n",
+ "\n",
+ "# for component_index in range(len(background_values)):\n",
+ "# plt.plot(\n",
+ "# my_analysis.energy.values,\n",
+ "# background_values[component_index],\n",
+ "# label=f'Background Component {component_index}',\n",
+ "# linestyle=':',\n",
+ "# )\n",
+ "# plt.xlabel('Energy (meV)')\n",
+ "# plt.ylabel('Intensity')\n",
+ "# plt.title(f'Q index: {5}')\n",
+ "# plt.legend()\n",
+ "# plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6762faba",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02702f95",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# my_analysis.plot_data_and_model()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "70091539",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# my_analysis.fit()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2ad6384e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# my_analysis.plot_data_and_model()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2dfb1f90",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# my_analysis.get_all_variables()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5afefbab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# my_analysis.get_fit_parameters()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "465c0e1e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# for Q_index in range(len(my_analysis.Q)):\n",
+ "# my_analysis.Q_index = Q_index\n",
+ "# my_analysis.fit()\n",
+ "# my_analysis.plot_data_and_model()\n",
+ "# print(my_analysis.get_fit_parameters())\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9bdeed2b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a diffusion_model and components for the SampleModel\n",
+ "\n",
+ "# Creating components\n",
+ "component_collection = ComponentCollection()\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "gaussian = Gaussian(display_name='Gaussian', width=0.1, center=0.5, area=0.5)\n",
+ "\n",
+ "# Adding components to the component collection\n",
+ "component_collection.append_component(delta_function)\n",
+ "\n",
+ "\n",
+ "sample_model = SampleModel(\n",
+ " components=component_collection,\n",
+ " unit='meV',\n",
+ " display_name='MySampleModel',\n",
+ ")\n",
+ "\n",
+ "res_gauss = Gaussian(width=0.1)\n",
+ "res_gauss.area.fixed = True\n",
+ "resolution_model = ResolutionModel(components=res_gauss)\n",
+ "\n",
+ "\n",
+ "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n",
+ "\n",
+ "instrument_model = InstrumentModel(\n",
+ " resolution_model=resolution_model,\n",
+ " background_model=background_model,\n",
+ ")\n",
+ "\n",
+ "my_full_analysis = Analysis(\n",
+ " experiment=vanadium_experiment,\n",
+ " sample_model=sample_model,\n",
+ " instrument_model=instrument_model,\n",
+ ")\n",
+ "\n",
+ "# my_full_analysis._fit_all_Q_independently()\n",
+ "my_full_analysis._fit_all_Q_simultaneously()\n",
+ "for analysis_object in my_full_analysis._analysis_list:\n",
+ " analysis_object.plot_data_and_model()\n",
+ " print(analysis_object.get_fit_parameters())\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0a727fc3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for analysis_object in my_full_analysis._analysis_list:\n",
+ " print(analysis_object.get_fit_parameters())\n",
+ "\n",
+ "for analysis_object in my_full_analysis._analysis_list:\n",
+ " print(analysis_object.get_fit_parameters()[0].unique_name)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d0ceec1d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "p1=my_full_analysis._analysis_list[1].get_fit_parameters()[0]\n",
+ "print(p1)\n",
+ "print(p1.unique_name)\n",
+ "p2 = my_full_analysis._analysis_list[9].get_fit_parameters()[0]\n",
+ "print(p2)\n",
+ "print(p2.unique_name)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d792eee3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "my_full_analysis.Q"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4217d56d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from easydynamics.sample_model import ComponentCollection\n",
+ "from easydynamics.sample_model import DeltaFunction\n",
+ "from easydynamics.sample_model.model_base import ModelBase\n",
+ "%matplotlib widget\n",
+ "import numpy as np\n",
+ "Q=np.linspace(0.1,15,31)\n",
+ "component_collection = ComponentCollection()\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "\n",
+ "component_collection.append_component(delta_function)\n",
+ "\n",
+ "\n",
+ "# sample_model = SampleModel(\n",
+ "sample_model = ModelBase(\n",
+ " components=component_collection,\n",
+ " unit='meV',\n",
+ " display_name='MySampleModel',\n",
+ " Q=Q,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "for Q_index in range(len(sample_model.Q)):\n",
+ " pars = sample_model.get_all_variables(Q_index=Q_index) \n",
+ " pars[0].value=pars[0].value+Q_index\n",
+ "\n",
+ "for Q_index in range(len(sample_model.Q)):\n",
+ " pars = sample_model.get_all_variables(Q_index=Q_index)\n",
+ " print(pars[0].unique_name)\n",
+ " print(pars[0])\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "35c89ce3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vars2=sample_model._component_collections[1].get_all_variables()\n",
+ "for var in vars2:\n",
+ " print(var)\n",
+ " print(var.unique_name)\n",
+ "\n",
+ "var3=sample_model._component_collections[10].get_all_variables()\n",
+ "for var in var3:\n",
+ " print(var)\n",
+ " print(var.unique_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02320e75",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a=vanadium_experiment.binned_data.coords['energy']\n",
+ "a"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5451bbf3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import scipp as sc\n",
+ "x_pixel_range = [-10, -5, 0, 5, 10]\n",
+ "a,b=sc.array(values=x_pixel_range, dims='x')\n",
+ "print(a)\n",
+ "print(b)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "easydynamics_newbase",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/docs/tutorials/analysis.ipynb b/docs/docs/tutorials/analysis.ipynb
new file mode 100644
index 00000000..27b0cdb6
--- /dev/null
+++ b/docs/docs/tutorials/analysis.ipynb
@@ -0,0 +1,367 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "8643b10c",
+ "metadata": {},
+ "source": [
+<<<<<<< HEAD
+ "# Analysis\n",
+ "It is time to analyse some data. We here show how to set up an Analysis object and use it to first fit an artificial vanadium measurements, and next an artificial measurement of a model with diffusion and some elastic scattering.\n",
+ "\n",
+ "In the near future, it will be possible to fit the width and area of the Lorentzian to the diffusion model, as well as fitting the diffusion model directly to the data."
+=======
+ "asd"
+>>>>>>> 7b7cf5e (initial analysis class)
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bca91d3c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from easydynamics.analysis.analysis1d import Analysis1d\n",
+ "from easydynamics.experiment import Experiment\n",
+ "from easydynamics.sample_model import ComponentCollection\n",
+ "from easydynamics.sample_model import DeltaFunction\n",
+<<<<<<< HEAD
+ "from easydynamics.sample_model import Lorentzian\n",
+=======
+>>>>>>> 7b7cf5e (initial analysis class)
+ "from easydynamics.sample_model import Gaussian\n",
+ "from easydynamics.sample_model import Polynomial\n",
+ "from easydynamics.sample_model.background_model import BackgroundModel\n",
+ "from easydynamics.sample_model.resolution_model import ResolutionModel\n",
+ "from easydynamics.sample_model.sample_model import SampleModel\n",
+<<<<<<< HEAD
+ "from easydynamics.sample_model.instrument_model import InstrumentModel\n",
+ "from easydynamics.analysis.analysis import Analysis\n",
+ "from copy import copy\n",
+=======
+ "\n",
+>>>>>>> 7b7cf5e (initial analysis class)
+ "%matplotlib widget"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8deca9b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vanadium_experiment = Experiment('Vanadium')\n",
+ "vanadium_experiment.load_hdf5(filename='vanadium_data_example.h5')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+<<<<<<< HEAD
+ "id": "6762faba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example of Analysis with a simple sample model and instrument model\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "sample_model = SampleModel(\n",
+ " components=delta_function,\n",
+ ")\n",
+ "\n",
+ "res_gauss = Gaussian(width=0.1)\n",
+ "res_gauss.area.fixed=True\n",
+=======
+ "id": "41f842f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a diffusion_model and components for the SampleModel\n",
+ "\n",
+ "# Creating components\n",
+ "component_collection = ComponentCollection()\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "gaussian = Gaussian(display_name='Gaussian', width=0.1, center=0.5, area=0.5)\n",
+ "\n",
+ "# Adding components to the component collection\n",
+ "component_collection.append_component(delta_function)\n",
+ "\n",
+ "\n",
+ "sample_model = SampleModel(\n",
+ " components=component_collection,\n",
+ " unit='meV',\n",
+ " display_name='MySampleModel',\n",
+ ")\n",
+ "\n",
+ "res_gauss = Gaussian(width=0.1)\n",
+ "res_gauss.area.fixed = True\n",
+>>>>>>> 7b7cf5e (initial analysis class)
+ "resolution_model = ResolutionModel(components=res_gauss)\n",
+ "\n",
+ "\n",
+ "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n",
+ "\n",
+<<<<<<< HEAD
+ "instrument_model = InstrumentModel(\n",
+ " resolution_model=resolution_model,\n",
+ " background_model=background_model,\n",
+ ")\n",
+ "\n",
+ "vanadium_analysis = Analysis(\n",
+ " display_name='Vanadium Full Analysis',\n",
+ " experiment=vanadium_experiment,\n",
+ " sample_model=sample_model,\n",
+ " instrument_model=instrument_model,\n",
+ ")\n",
+ "\n",
+ "fit_result_independent_single_Q = vanadium_analysis.fit(fit_method=\"independent\", Q_index=5)\n",
+ "vanadium_analysis.plot_data_and_model(Q_index=5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e98e3d65",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fit_result_independent_all_Q = vanadium_analysis.fit(fit_method=\"independent\")\n",
+ "vanadium_analysis.plot_data_and_model()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "af13afce",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fit_result_simultaneous = vanadium_analysis.fit(fit_method=\"simultaneous\")\n",
+ "fit_result_simultaneous\n",
+ "vanadium_analysis.plot_data_and_model()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "133e682e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Inspect the Parameters as a scipp Dataset\n",
+ "vanadium_analysis.parameters_to_dataset()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dfacdf24",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot some of fitted parameters as a function of Q\n",
+ "vanadium_analysis.plot_parameters(names=[\"DeltaFunction area\"])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b6f9f316",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vanadium_analysis.plot_parameters(names=[\"Gaussian width\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3609e6c1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set up the diffusion analysis with the same resolution model as the\n",
+ "# vanadium analysis\n",
+ "diffusion_experiment = Experiment('Diffusion')\n",
+ "diffusion_experiment.load_hdf5(filename='diffusion_data_example.h5')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e685909a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We set up the model first.\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=0.2)\n",
+ "lorentzian = Lorentzian(display_name='Lorentzian', area=0.5, width=0.3)\n",
+ "component_collection=ComponentCollection(\n",
+ " components=[delta_function, lorentzian],\n",
+ ")\n",
+ "sample_model = SampleModel(\n",
+ " components=component_collection,\n",
+ ")\n",
+ "\n",
+ "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n",
+ "\n",
+ "instrument_model = InstrumentModel(\n",
+ " background_model=background_model,\n",
+ ")\n",
+ "\n",
+ "diffusion_analysis = Analysis(\n",
+ " display_name='Diffusion Full Analysis',\n",
+ " experiment=diffusion_experiment,\n",
+ " sample_model=sample_model,\n",
+ " instrument_model=instrument_model,\n",
+ ")\n",
+ "\n",
+ "# We need to hack in the resolution model from the vanadium analysis,\n",
+ "# since the setters and getters overwrite the model. This will be fixed\n",
+ "# asap.\n",
+ "diffusion_analysis.instrument_model._resolution_model = vanadium_analysis.instrument_model.resolution_model\n",
+ "diffusion_analysis.instrument_model.resolution_model.fix_all_parameters()\n",
+ "diffusion_analysis.plot_parameters(names=[\"Gaussian width\"])\n"
+=======
+ "my_analysis = Analysis1d(\n",
+ " experiment=vanadium_experiment,\n",
+ " sample_model=sample_model,\n",
+ " resolution_model=resolution_model,\n",
+ " background_model=background_model,\n",
+ " Q_index=5,\n",
+ ")\n",
+ "\n",
+ "my_analysis._update_models()\n",
+ "\n",
+ "\n",
+ "values = my_analysis.calculate()\n",
+ "sample_values, background_values = my_analysis.calculate_individual_components()\n",
+ "\n",
+ "plt.figure()\n",
+ "plt.plot(my_analysis.energy.values, values, label='Total Model')\n",
+ "for component_index in range(len(sample_values)):\n",
+ " plt.plot(\n",
+ " my_analysis.energy.values,\n",
+ " sample_values[component_index],\n",
+ " label=f'Sample Component {component_index}',\n",
+ " linestyle='--',\n",
+ " )\n",
+ "\n",
+ "for component_index in range(len(background_values)):\n",
+ " plt.plot(\n",
+ " my_analysis.energy.values,\n",
+ " background_values[component_index],\n",
+ " label=f'Background Component {component_index}',\n",
+ " linestyle=':',\n",
+ " )\n",
+ "plt.xlabel('Energy (meV)')\n",
+ "plt.ylabel('Intensity')\n",
+ "plt.title(f'Q index: {5}')\n",
+ "plt.legend()\n",
+ "plt.show()"
+>>>>>>> 7b7cf5e (initial analysis class)
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+<<<<<<< HEAD
+ "id": "c66828eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let us see how good the starting parameters are\n",
+ "diffusion_analysis.plot_data_and_model()"
+=======
+ "id": "6762faba",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02702f95",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_analysis.plot_data_and_model()"
+>>>>>>> 7b7cf5e (initial analysis class)
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+<<<<<<< HEAD
+ "id": "197b44c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Now we fit the data and plot the result. Looks good!\n",
+ "diffusion_analysis.fit(fit_method=\"independent\")\n",
+ "diffusion_analysis.plot_data_and_model()"
+=======
+ "id": "70091539",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_analysis.fit()"
+>>>>>>> 7b7cf5e (initial analysis class)
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+<<<<<<< HEAD
+ "id": "df14b5c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let us look at the most interesting fit parameters\n",
+ "diffusion_analysis.plot_parameters(names=[\"Lorentzian width\", \"Lorentzian area\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eb226c8f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# It will be possible to fit this to a DiffusionModel, but that will\n",
+ "# come later."
+=======
+ "id": "2ad6384e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_analysis.plot_data_and_model()"
+>>>>>>> 7b7cf5e (initial analysis class)
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "easydynamics_newbase",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/docs/tutorials/analysis1d.ipynb b/docs/docs/tutorials/analysis1d.ipynb
new file mode 100644
index 00000000..8a695913
--- /dev/null
+++ b/docs/docs/tutorials/analysis1d.ipynb
@@ -0,0 +1,104 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "8643b10c",
+ "metadata": {},
+ "source": [
+ "# Analysis1d\n",
+ "Sometimes, you will only be interested in a particular Q, not the full dataset. For this, use the Analysis1d object. We here show how to set it up to fit an artificial vanadium measurement."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bca91d3c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from easydynamics.analysis.analysis1d import Analysis1d\n",
+ "from easydynamics.experiment import Experiment\n",
+ "from easydynamics.sample_model import ComponentCollection\n",
+ "from easydynamics.sample_model import DeltaFunction\n",
+ "from easydynamics.sample_model import Gaussian\n",
+ "from easydynamics.sample_model import Polynomial\n",
+ "from easydynamics.sample_model.background_model import BackgroundModel\n",
+ "from easydynamics.sample_model.resolution_model import ResolutionModel\n",
+ "from easydynamics.sample_model.sample_model import SampleModel\n",
+ "from easydynamics.sample_model.instrument_model import InstrumentModel\n",
+ "from easydynamics.analysis.analysis import Analysis\n",
+ "%matplotlib widget"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8deca9b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vanadium_experiment = Experiment('Vanadium')\n",
+ "vanadium_experiment.load_hdf5(filename='vanadium_data_example.h5')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "41f842f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example of Analysis1d with a simple sample model and instrument model\n",
+ "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n",
+ "sample_model = SampleModel(\n",
+ " components=delta_function,\n",
+ ")\n",
+ "\n",
+ "res_gauss = Gaussian(width=0.1)\n",
+ "resolution_model = ResolutionModel(components=res_gauss)\n",
+ "\n",
+ "\n",
+ "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n",
+ "\n",
+ "instrument_model = InstrumentModel(\n",
+ " resolution_model=resolution_model,\n",
+ " background_model=background_model,\n",
+ ")\n",
+ "\n",
+ "my_analysis = Analysis1d(\n",
+ " display_name='Vanadium Analysis',\n",
+ " experiment=vanadium_experiment,\n",
+ " sample_model=sample_model,\n",
+ " instrument_model=instrument_model,\n",
+ " Q_index=5,\n",
+ ")\n",
+ "\n",
+ "fit_result = my_analysis.fit()\n",
+ "my_analysis.plot_data_and_model()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "easydynamics_newbase",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/docs/tutorials/components.ipynb b/docs/docs/tutorials/components.ipynb
index 7815bcd4..83278fc4 100644
--- a/docs/docs/tutorials/components.ipynb
+++ b/docs/docs/tutorials/components.ipynb
@@ -21,6 +21,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
+ "import scipp as sc\n",
"\n",
"from easydynamics.sample_model import DampedHarmonicOscillator\n",
"from easydynamics.sample_model import DeltaFunction\n",
@@ -105,8 +106,6 @@
"metadata": {},
"outputs": [],
"source": [
- "import scipp as sc\n",
- "\n",
"x1 = sc.linspace(dim='x', start=-2.0, stop=2.0, num=100, unit='meV')\n",
"x2 = sc.linspace(dim='x', start=-2.0 * 1e3, stop=2.0 * 1e3, num=101, unit='microeV')\n",
"\n",
diff --git a/docs/docs/tutorials/convolution.ipynb b/docs/docs/tutorials/convolution.ipynb
index 922970f9..b13d7973 100644
--- a/docs/docs/tutorials/convolution.ipynb
+++ b/docs/docs/tutorials/convolution.ipynb
@@ -109,7 +109,7 @@
"\n",
"\n",
"temperature = 10.0 # Temperature in Kelvin\n",
- "offset = 0.5\n",
+ "energy_offset = 0.5\n",
"upsample_factor = 5\n",
"extension_factor = 0.5\n",
"plt.figure()\n",
@@ -119,7 +119,7 @@
"convolver = Convolution(\n",
" sample_components=sample_components,\n",
" resolution_components=resolution_components,\n",
- " energy=energy - offset,\n",
+ " energy=energy - energy_offset,\n",
" upsample_factor=upsample_factor,\n",
" extension_factor=extension_factor,\n",
" temperature=temperature,\n",
@@ -132,8 +132,8 @@
"\n",
"plt.plot(\n",
" energy,\n",
- " sample_components.evaluate(energy - offset)\n",
- " * detailed_balance_factor(energy - offset, temperature),\n",
+ " sample_components.evaluate(energy - energy_offset)\n",
+ " * detailed_balance_factor(energy - energy_offset, temperature),\n",
" label='Sample Model with DB',\n",
" linestyle='--',\n",
")\n",
@@ -145,6 +145,70 @@
"plt.ylim(0, 2.5)\n",
"plt.show()"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c318f9b8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use some of the extra settings for the numerical convolution\n",
+ "sample_components = ComponentCollection()\n",
+ "gaussian = Gaussian(display_name='Gaussian', width=0.3, area=1)\n",
+ "dho = DampedHarmonicOscillator(display_name='DHO', center=1.0, width=0.3, area=2.0)\n",
+ "lorentzian = Lorentzian(display_name='Lorentzian', center=-1.0, width=0.2, area=1.0)\n",
+ "delta = DeltaFunction(display_name='Delta', center=0.4, area=0.5)\n",
+ "sample_components.append_component(gaussian)\n",
+ "# sample_components.append_component(dho)\n",
+ "sample_components.append_component(lorentzian)\n",
+ "# sample_components.append_component(delta)\n",
+ "\n",
+ "resolution_components = ComponentCollection()\n",
+ "resolution_gaussian = Gaussian(display_name='Resolution Gaussian', width=0.15, area=0.8)\n",
+ "resolution_lorentzian = Lorentzian(display_name='Resolution Lorentzian', width=0.25, area=0.2)\n",
+ "resolution_components.append_component(resolution_gaussian)\n",
+ "# resolution_components.append_component(resolution_lorentzian)\n",
+ "\n",
+ "energy = np.linspace(-2, 2, 100)\n",
+ "\n",
+ "\n",
+ "temperature = 10.0 # Temperature in Kelvin\n",
+ "energy_offset = 0.2\n",
+ "upsample_factor = 5\n",
+ "extension_factor = 0.5\n",
+ "plt.figure()\n",
+ "plt.xlabel('Energy (meV)')\n",
+ "plt.ylabel('Intensity (arb. units)')\n",
+ "\n",
+ "convolver = Convolution(\n",
+ " sample_components=sample_components,\n",
+ " resolution_components=resolution_components,\n",
+ " energy=energy,\n",
+ " upsample_factor=upsample_factor,\n",
+ " extension_factor=extension_factor,\n",
+ " energy_offset=energy_offset,\n",
+ " temperature=temperature,\n",
+ ")\n",
+ "y = convolver.convolution()\n",
+ "\n",
+ "\n",
+ "plt.plot(energy, y, label='Convoluted Model')\n",
+ "\n",
+ "plt.plot(\n",
+ " energy,\n",
+ " sample_components.evaluate(energy - energy_offset),\n",
+ " label='Sample Model',\n",
+ " linestyle='--',\n",
+ ")\n",
+ "\n",
+ "plt.plot(energy, resolution_components.evaluate(energy), label='Resolution Model', linestyle=':')\n",
+ "plt.title('Convolution of Sample Model with Resolution Model')\n",
+ "\n",
+ "plt.legend()\n",
+ "plt.ylim(0, 2.5)\n",
+ "plt.show()"
+ ]
}
],
"metadata": {
diff --git a/docs/docs/tutorials/detailed_balance.ipynb b/docs/docs/tutorials/detailed_balance.ipynb
index d09a2546..135894d3 100644
--- a/docs/docs/tutorials/detailed_balance.ipynb
+++ b/docs/docs/tutorials/detailed_balance.ipynb
@@ -23,11 +23,11 @@
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
- "\n",
- "%matplotlib widget\n",
"import numpy as np\n",
"\n",
- "from easydynamics.utils import _detailed_balance_factor as detailed_balance_factor"
+ "from easydynamics.utils import _detailed_balance_factor as detailed_balance_factor\n",
+ "\n",
+ "%matplotlib widget"
]
},
{
diff --git a/docs/docs/tutorials/diffusion_model.ipynb b/docs/docs/tutorials/diffusion_model.ipynb
index 9277486e..f3d1571b 100644
--- a/docs/docs/tutorials/diffusion_model.ipynb
+++ b/docs/docs/tutorials/diffusion_model.ipynb
@@ -39,13 +39,11 @@
"energy = np.linspace(-2, 2, 501)\n",
"scale = 1.0\n",
"diffusion_coefficient = 2.4e-9 # m^2/s\n",
- "diffusion_unit = 'm**2/s'\n",
"\n",
"diffusion_model = BrownianTranslationalDiffusion(\n",
" display_name='DiffusionModel',\n",
" scale=scale,\n",
" diffusion_coefficient=diffusion_coefficient,\n",
- " diffusion_unit=diffusion_unit,\n",
")\n",
"\n",
"component_collections = diffusion_model.create_component_collections(Q)\n",
@@ -69,7 +67,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a50c67ec",
+ "id": "3",
"metadata": {},
"outputs": [],
"source": [
diff --git a/docs/docs/tutorials/experiment.ipynb b/docs/docs/tutorials/experiment.ipynb
index 6319c61f..f6e185df 100644
--- a/docs/docs/tutorials/experiment.ipynb
+++ b/docs/docs/tutorials/experiment.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "906b959a",
+ "id": "0",
"metadata": {},
"source": [
"# Experiment\n",
@@ -12,7 +12,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c7d23add",
+ "id": "1",
"metadata": {},
"outputs": [],
"source": [
@@ -24,7 +24,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2b7c5ca8",
+ "id": "2",
"metadata": {},
"outputs": [],
"source": [
@@ -38,7 +38,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "238ba6ee",
+ "id": "3",
"metadata": {},
"outputs": [],
"source": [
@@ -50,7 +50,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bc32ab1f",
+ "id": "4",
"metadata": {},
"outputs": [],
"source": [
diff --git a/docs/docs/tutorials/instrument_model.ipynb b/docs/docs/tutorials/instrument_model.ipynb
new file mode 100644
index 00000000..a56b300f
--- /dev/null
+++ b/docs/docs/tutorials/instrument_model.ipynb
@@ -0,0 +1,101 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Instrument Model\n",
+ "We here introduce the InstrumentModel, which contains all information related to the instrument: the BackgroundModel, ResolutionModel and also a fittable offset in the energy transfer due to slight instrument misalignment.\n",
+ "\n",
+ "The InstrumentModel does not itself do any calculations; it is merely a container for all information about the instrument.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "from easydynamics.sample_model import Gaussian\n",
+ "from easydynamics.sample_model import Polynomial\n",
+ "from easydynamics.sample_model.background_model import BackgroundModel\n",
+ "from easydynamics.sample_model.instrument_model import InstrumentModel\n",
+ "from easydynamics.sample_model.resolution_model import ResolutionModel\n",
+ "\n",
+ "%matplotlib widget"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a BackgroundModel and a ResolutionModel and add them to an\n",
+ "# InstrumentModel\n",
+ "\n",
+ "Q = np.linspace(0.1, 2.0, 5)\n",
+ "\n",
+ "background_model = BackgroundModel()\n",
+ "background_model.components = Polynomial(coefficients=[1, 0.1, 0.01])\n",
+ "\n",
+ "resolution_model = ResolutionModel()\n",
+ "resolution_model.append_component(Gaussian(width=0.05))\n",
+ "\n",
+ "instrument_model = InstrumentModel(\n",
+ " Q=Q,\n",
+ " resolution_model=resolution_model,\n",
+ " background_model=background_model,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "instrument_model.get_all_variables(Q_index=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3eca4688",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "instrument_model.fix_resolution_parameters()\n",
+ "instrument_model.get_all_variables(Q_index=1)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "easydynamics_newbase",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/docs/tutorials/sample_model.ipynb b/docs/docs/tutorials/sample_model.ipynb
index 802aff0b..5371f7df 100644
--- a/docs/docs/tutorials/sample_model.ipynb
+++ b/docs/docs/tutorials/sample_model.ipynb
@@ -48,12 +48,10 @@
"\n",
"scale = 1.0\n",
"diffusion_coefficient = 2.4e-9 # m^2/s\n",
- "diffusion_unit = 'm**2/s'\n",
"diffusion_model = BrownianTranslationalDiffusion(\n",
" display_name='DiffusionModel',\n",
" scale=scale,\n",
" diffusion_coefficient=diffusion_coefficient,\n",
- " diffusion_unit=diffusion_unit,\n",
")\n",
"\n",
"\n",
diff --git a/pixi.lock b/pixi.lock
index d9461637..da8aee45 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -4091,7 +4091,7 @@ packages:
requires_python: '>=3.5'
- pypi: ./
name: easydynamics
- version: 0.1.0+devdirty7
+ version: 0.1.1+devdirty2
sha256: de299c914d4a865b9e2fdefa5e3947f37b1f26f73ff9087f7918ee417f3dd288
requires_dist:
- darkdetect
diff --git a/pixi.toml b/pixi.toml
index baa0ea35..d280c259 100644
--- a/pixi.toml
+++ b/pixi.toml
@@ -76,41 +76,41 @@ default = { features = ['default', 'py-max'] }
[tasks]
+##################
# ๐งช Testing Tasks
-unit-tests = 'python -m pytest tests/unit/ --color=yes --cov= --cov-report='
+##################
+
+unit-tests = 'python -m pytest tests/unit/ --color=yes -v'
integration-tests = 'python -m pytest tests/integration/ --color=yes -n auto -v'
notebook-tests = 'python -m pytest --nbmake docs/docs/tutorials/ --nbmake-timeout=600 --color=yes -n auto -v'
script-tests = 'python -m pytest tools/test_scripts.py --color=yes -n auto -v'
test = { depends-on = ['unit-tests'] }
+###########
# โ๏ธ Checks
+###########
+
pyproject-check = 'python -m validate_pyproject pyproject.toml'
-py-lint-check-pre = "python -m ruff check"
-py-lint-check = 'pixi run py-lint-check-pre .'
-py-format-check-pre = "python -m ruff format --check"
-py-format-check = "pixi run py-format-check-pre ."
-nonpy-format-check-pre = "npx prettier --list-different --config=prettierrc.toml"
-nonpy-format-check-modified = "pixi run nonpy-format-check-pre $(git diff --diff-filter=d --name-only HEAD | grep -E '\\.(json|ya?ml|toml|md|css|html)$' || echo .)"
-nonpy-format-check = "pixi run nonpy-format-check-pre ."
+docs-format-check = 'docformatter --check src/ docs/docs/tutorials/'
notebook-format-check = 'nbqa ruff docs/docs/tutorials/'
-docs-format-check = 'docformatter src/ docs/docs/tutorials/ --check'
+py-lint-check = 'ruff check src/ tests/ docs/docs/tutorials/'
+py-format-check = "ruff format --check src/ tests/ docs/docs/tutorials/"
+nonpy-format-check = "npx prettier --list-different --config=prettierrc.toml --ignore-unknown ."
+nonpy-format-check-modified = "python tools/nonpy_prettier_modified.py"
-check = { depends-on = [
- 'docs-format-check',
- 'py-format-check',
- 'py-lint-check',
- 'nonpy-format-check-modified',
-] }
+check = 'pre-commit run --hook-stage manual --all-files'
+##########
# ๐ ๏ธ Fixes
-py-lint-fix = 'pixi run py-lint-check --fix'
-#py-format-fix = "python -m ruff format $(git diff --cached --name-only -- '*.py')"
-py-format-fix = "python -m ruff format"
-nonpy-format-fix = 'pixi run nonpy-format-check --write'
-nonpy-format-fix-modified = "pixi run nonpy-format-check-modified --write"
-notebook-format-fix = 'pixi run notebook-format-check --fix'
-docs-format-fix = 'docformatter src/ docs/docs/tutorials/ --in-place'
+##########
+
+docs-format-fix = 'docformatter --in-place src/ docs/docs/tutorials/'
+notebook-format-fix = 'nbqa ruff --fix docs/docs/tutorials/'
+py-lint-fix = 'ruff check --fix src/ tests/ docs/docs/tutorials/'
+py-format-fix = "ruff format src/ tests/ docs/docs/tutorials/"
+nonpy-format-fix = 'npx prettier --write --list-different --config=prettierrc.toml --ignore-unknown .'
+nonpy-format-fix-modified = "python tools/nonpy_prettier_modified.py --write"
success-message-fix = 'echo "โ
All code auto-formatting steps have been applied."'
fix = { depends-on = [
@@ -118,10 +118,14 @@ fix = { depends-on = [
'docs-format-fix',
'py-lint-fix',
'nonpy-format-fix',
+ 'notebook-format-fix',
'success-message-fix',
] }
+####################
# ๐งฎ Code Complexity
+####################
+
complexity-check = 'radon cc -s src/'
complexity-check-json = 'radon cc -s -j src/'
maintainability-check = 'radon mi src/'
@@ -129,8 +133,11 @@ maintainability-check-json = 'radon mi -j src/'
raw-metrics = 'radon raw -s src/'
raw-metrics-json = 'radon raw -s -j src/'
+#############
# ๐ Coverage
-unit-tests-coverage = 'python -m pytest tests/unit/ --color=yes --cov=src/easydynamics --cov-report=term-missing'
+#############
+
+unit-tests-coverage = 'pixi run unit-tests --cov=src/easydynamics --cov-report=term-missing'
integration-tests-coverage = 'pixi run integration-tests --cov=src/easydynamics --cov-report=term-missing'
docstring-coverage = 'interrogate -c pyproject.toml src/easydynamics'
@@ -140,19 +147,25 @@ cov = { depends-on = [
'integration-tests-coverage',
] }
+########################
# ๐ Notebook Management
+########################
+
notebook-convert = 'jupytext docs/docs/tutorials/*.py --from py:percent --to ipynb'
notebook-strip = 'nbstripout docs/docs/tutorials/*.ipynb'
notebook-tweak = 'python tools/tweak_notebooks.py tutorials/'
notebook-exec = 'python -m pytest --nbmake docs/docs/tutorials/ --nbmake-timeout=600 --overwrite --color=yes -n auto -v'
notebook-prepare = { depends-on = [
- ###'notebook-convert',
+ #'notebook-convert',
'notebook-strip',
- ###'notebook-tweak',
+ #'notebook-tweak',
] }
+########################
# ๐ Documentation Tasks
+########################
+
docs-vars = "JUPYTER_PLATFORM_DIRS=1 PYTHONWARNINGS='ignore::RuntimeWarning'"
docs-pre = "pixi run docs-vars python -m mkdocs"
docs-serve = "pixi run docs-pre serve -f docs/mkdocs.yml"
@@ -163,29 +176,19 @@ docs-build-local = "pixi run docs-build --no-directory-urls"
docs-deploy-pre = 'mike deploy -F docs/mkdocs.yml --push --branch gh-pages --update-aliases --alias-type redirect'
docs-set-default-pre = 'mike set-default -F docs/mkdocs.yml --push --branch gh-pages'
-docs-update-assets = 'pixi run python tools/update_docs_assets.py'
+docs-update-assets = 'python tools/update_docs_assets.py'
+##############################
# ๐ฆ Template Management Tasks
-copier-copy = "pixi run copier copy gh:easyscience/templates . --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
-copier-recopy = "pixi run copier recopy --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
-copier-update = "pixi run copier update --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
+##############################
-# ๐ Development & Build Tasks
-default-build = 'python -m build'
-dist-build = 'python -m build --wheel --outdir dist'
+copier-copy = "copier copy gh:easyscience/templates . --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
+copier-recopy = "copier recopy --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
+copier-update = "copier update --data-file ../dynamics/.copier-answers.yml --data template_type=lib"
-npm-config = 'npm config set registry https://registry.npmjs.org/'
-prettier-install = 'npm install --no-save --no-audit --no-fund prettier prettier-plugin-toml'
-
-clean-pycache = "find . -type d -name '__pycache__' -prune -exec rm -rf '{}' +"
-spdx-update = 'python tools/update_spdx.py'
-
-# Run like a real commit: staged files only (almost)
-pre-commit-check = 'pre-commit run --hook-stage pre-commit'
-# CI check: lint/format everything
-pre-commit-check-all = 'pre-commit run --all-files --hook-stage pre-commit'
-# Pre-push check: lint/format everything
-pre-push-check = 'pre-commit run --all-files --hook-stage pre-push'
+#####################
+# ๐ช Pre-commit Hooks
+#####################
pre-commit-clean = 'pre-commit clean'
pre-commit-install = 'pre-commit install --hook-type pre-commit --hook-type pre-push --overwrite'
@@ -196,11 +199,28 @@ pre-commit-setup = { depends-on = [
'pre-commit-install',
] }
+####################################
+# ๐ Other Development & Build Tasks
+####################################
+
+github-labels = 'python tools/update_github_labels.py'
+
+default-build = 'python -m build'
+dist-build = 'python -m build --wheel --outdir dist'
+
+npm-config = 'npm config set registry https://registry.npmjs.org/'
+prettier-install = 'npm install --no-save --no-audit --no-fund prettier prettier-plugin-toml'
+
+clean-pycache = "find . -type d -name '__pycache__' -prune -exec rm -rf '{}' +"
+spdx-update = 'python tools/update_spdx.py'
+
post-install = { depends-on = [
'npm-config',
'prettier-install',
- 'pre-commit-setup',
+ #'pre-commit-setup',
] }
+##########################
# ๐ Main Package Shortcut
+##########################
easydynamics = 'python -m easydynamics'
diff --git a/src/easydynamics/analysis/__init__.py b/src/easydynamics/analysis/__init__.py
new file mode 100644
index 00000000..4cb511b4
--- /dev/null
+++ b/src/easydynamics/analysis/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+from .analysis import Analysis
+
+__all__ = [
+ 'Analysis',
+]
diff --git a/src/easydynamics/analysis/analysis old.py b/src/easydynamics/analysis/analysis old.py
new file mode 100644
index 00000000..9d9039ea
--- /dev/null
+++ b/src/easydynamics/analysis/analysis old.py
@@ -0,0 +1,497 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+
+import numpy as np
+import plopp as pp
+import scipp as sc
+from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase
+from easyscience.fitting.fitter import Fitter as EasyScienceFitter
+from easyscience.variable import Parameter
+
+from easydynamics.convolution import Convolution
+from easydynamics.experiment import Experiment
+from easydynamics.sample_model import BackgroundModel
+from easydynamics.sample_model import ResolutionModel
+from easydynamics.sample_model import SampleModel
+
+
+class Analysis(EasyScienceModelBase):
+ """For analysing data."""
+
+ def __init__(
+ self,
+ display_name: str = "MyAnalysis",
+ unique_name: str | None = None,
+ experiment: Experiment | None = None,
+ sample_model: SampleModel | None = None,
+ resolution_model: ResolutionModel | None = None,
+ background_model: BackgroundModel | None = None,
+ energy_offset: None = None,
+ ):
+
+ super().__init__(display_name=display_name, unique_name=unique_name)
+
+ if experiment is not None and not isinstance(experiment, Experiment):
+ raise TypeError("experiment must be an instance of Experiment or None.")
+
+ self._experiment = experiment
+
+ if sample_model is not None and not isinstance(sample_model, SampleModel):
+ raise TypeError("sample_model must be an instance of SampleModel or None.")
+ sample_model.Q = self.Q
+ self._sample_model = sample_model
+
+ if resolution_model is not None and not isinstance(
+ resolution_model, ResolutionModel
+ ):
+ raise TypeError(
+ "resolution_model must be an instance of ResolutionModel or None."
+ )
+ resolution_model.Q = self.Q
+ self._resolution_model = resolution_model
+
+ if background_model is not None and not isinstance(
+ background_model, BackgroundModel
+ ):
+ raise TypeError(
+ "background_model must be an instance of BackgroundModel or None."
+ )
+ background_model.Q = self.Q
+ self._background_model = background_model
+
+ self._convolvers = [None] * (len(self.Q) if self.Q is not None else 0)
+ self._update_models()
+
+ #############
+ # Properties
+ #############
+
+ @property
+ def experiment(self) -> Experiment | None:
+ """The Experiment associated with this Analysis."""
+ return self._experiment
+
+ @experiment.setter
+ def experiment(self, value: Experiment | None) -> None:
+ if value is not None and not isinstance(value, Experiment):
+ raise TypeError("experiment must be an instance of Experiment or None.")
+ self._experiment = value
+ self._update_models()
+
+ @property
+ def sample_model(self) -> SampleModel | None:
+ """The SampleModel associated with this Analysis."""
+ return self._sample_model
+
+ @sample_model.setter
+ def sample_model(self, value: SampleModel | None) -> None:
+ if value is not None and not isinstance(value, SampleModel):
+ raise TypeError("sample_model must be an instance of SampleModel or None.")
+ self._sample_model = value
+ self._update_models()
+
+ @property
+ def resolution_model(self) -> ResolutionModel | None:
+ """The ResolutionModel associated with this Analysis."""
+ return self._resolution_model
+
+ @resolution_model.setter
+ def resolution_model(self, value: ResolutionModel | None) -> None:
+ if value is not None and not isinstance(value, ResolutionModel):
+ raise TypeError(
+ "resolution_model must be an instance of ResolutionModel or None."
+ )
+ self._resolution_model = value
+ self._update_models()
+
+ @property
+ def background_model(self) -> BackgroundModel | None:
+ """The BackgroundModel associated with this Analysis."""
+ return self._background_model
+
+ @background_model.setter
+ def background_model(self, value: BackgroundModel | None) -> None:
+ if value is not None and not isinstance(value, BackgroundModel):
+ raise TypeError(
+ "background_model must be an instance of BackgroundModel or None."
+ )
+ self._background_model = value
+ self._update_models()
+
+ @property
+ def Q(self) -> sc.Variable | None:
+ """The Q values from the associated Experiment, if available."""
+ if self.experiment is not None:
+ return self.experiment.Q
+ return None
+
+ @Q.setter
+ def Q(self, value) -> None:
+ """Q is a read-only property derived from the Experiment."""
+ raise AttributeError("Q is a read-only property derived from the Experiment.")
+
+ @property
+ def energy(self) -> sc.Variable | None:
+ """The energy values from the associated Experiment, if
+ available.
+ """
+ if self.experiment is not None:
+ return self.experiment.energy
+ return None
+
+ @energy.setter
+ def energy(self, value) -> None:
+ """Energy is a read-only property derived from the
+ Experiment.
+ """
+ raise AttributeError(
+ "energy is a read-only property derived from the Experiment."
+ )
+
+ # TODO: make it use experiment temperature
+ @property
+ def temperature(self) -> Parameter | None:
+ """The temperature from the associated Experiment, if
+ available.
+ """
+ return None
+
+ @temperature.setter
+ def temperature(self, value) -> None:
+ """Temperature is a read-only property derived from the
+ Experiment.
+ """
+ raise AttributeError(
+ "temperature is a read-only property derived from the Experiment."
+ )
+
+ # # TODO: make it use experiment temperature
+ # @property def temperature(self) -> Parameter | None: """The
+ # temperature from the associated Experiment, if available.""" if
+ # self.experiment is not None: return
+ # self.experiment.temperature return None
+
+ # @temperature.setter def temperature(self, value) -> None:
+ # """temperature is a read-only property derived from the
+ # Experiment.""" raise AttributeError( "temperature is a
+ # read-only property derived from the Experiment." )
+
+ #############
+ # Other methods
+ #############
+
+ def calculate(self, energy: float | None, Q_index: int) -> np.ndarray:
+ """Calculate the model prediction for a given Q index.
+
+ Args:
+ energy (float): The energy value to calculate the model for.
+ Q_index (int): The index of the Q value to calculate the
+ model for.
+ Returns:
+ sc.DataArray: The calculated model prediction.
+ """
+ if energy is None:
+ energy = self.energy
+
+ if self.sample_model is None:
+ sample_intensity = np.zeros_like(energy)
+ else:
+ if self.resolution_model is None:
+ sample_intensity = self.sample_model._component_collections[
+ Q_index
+ ].evaluate(energy)
+ else:
+ convolver = self._create_convolver(Q_index)
+ sample_intensity = convolver.convolution()
+
+ if self.background_model is None:
+ background_intensity = np.zeros_like(energy)
+ else:
+ background_intensity = self.background_model._component_collections[
+ Q_index
+ ].evaluate(energy)
+
+ sample_plus_background = sample_intensity + background_intensity
+
+ return sample_plus_background
+
+ def calculate_individual_components(
+ self, Q_index: int
+ ) -> tuple[list[np.ndarray], list[np.ndarray]]:
+ """Calculate the model prediction for a given Q index for each
+ individual component.
+
+ Args:
+ Q_index (int): The index of the Q value to calculate the
+ model for.
+ Returns:
+ list[np.ndarray]: The calculated model predictions for each
+ individual component.
+ """
+ sample_results = []
+ background_results = []
+
+ if self.sample_model is not None:
+ # Calculate sample components
+ for component in self.sample_model._component_collections[
+ Q_index
+ ]._components:
+ if self.resolution_model is None:
+ component_intensity = component.evaluate(self.energy)
+ else:
+ convolver = Convolution(
+ sample_components=component,
+ resolution_components=self.resolution_model._component_collections[
+ Q_index
+ ],
+ energy=self.energy,
+ temperature=self.temperature,
+ )
+ component_intensity = convolver.convolution()
+ sample_results.append(component_intensity)
+
+ if self.background_model is not None:
+ # Calculate background components
+ for component in self.background_model._component_collections[
+ Q_index
+ ]._components:
+ component_intensity = component.evaluate(self.energy)
+ background_results.append(component_intensity)
+
+ return sample_results, background_results
+
+ def calculate_all_Q(self) -> list[np.ndarray]:
+ """Calculate the model prediction for all Q indices.
+
+ Returns:
+ list[np.ndarray]: The calculated model predictions for all Q
+ indices.
+ """
+ results = []
+ for Q_index in range(len(self.Q)):
+ result = self.calculate(Q_index)
+ results.append(result)
+ return results
+
+ # def calculate_individual_components_all_Q(
+ # self,
+ # add_background: bool = True,
+ # ) -> list[tuple[list[np.ndarray], list[np.ndarray]]]:
+ # """Calculate the model prediction for all Q indices for each
+ # individual component.
+
+ # Returns: list[tuple[list[np.ndarray], list[np.ndarray]]]: The
+ # calculated model predictions for each individual component
+ # at all Q indices. """ all_results = [] for Q_index in
+ # range(len(self.Q)): sample_results, background_results =
+ # self.calculate_individual_components( Q_index ) if
+ # add_background: sample_results = sample_results +
+ # background_results all_results.append((sample_results,
+ # background_results)) return all_results
+
+ def calculate_single_component_all_Q(
+ self,
+ component_index: int,
+ ) -> list[np.ndarray]:
+ """Calculate the model prediction for all Q indices for a single
+ component.
+
+ Args:
+ component_index (int): The index of the component
+ Returns:
+ list[np.ndarray]: The calculated model predictions for the
+ specified component at all Q indices.
+ """
+
+ results = []
+ for Q_index in range(len(self.Q)):
+ if self.sample_model is not None:
+ component = self.sample_model._component_collections[
+ Q_index
+ ]._components[component_index]
+ if self.resolution_model is None:
+ component_intensity = component.evaluate(self.energy)
+ else:
+ convolver = Convolution(
+ sample_components=component,
+ resolution_components=self.resolution_model._component_collections[
+ Q_index
+ ],
+ energy=self.energy,
+ temperature=self.temperature,
+ )
+ component_intensity = convolver.convolution()
+ results.append(component_intensity)
+ else:
+ results.append(np.zeros_like(self.energy))
+
+ model_data_array = sc.DataArray(
+ data=sc.array(dims=["Q", "energy"], values=results),
+ coords={
+ "Q": self.Q,
+ "energy": self.energy,
+ },
+ )
+ return model_data_array
+
+ def fit(self, Q_index: int):
+ """Fit the model to the experimental data for a given Q index.
+
+ Args:
+ Q_index (int): The index of the Q value to fit the model
+ to.
+ Returns:
+ FitResult: The result of the fit.
+ """
+ if self._experiment is None:
+ raise ValueError("No experiment is associated with this Analysis.")
+
+ if not isinstance(Q_index, int) or Q_index < 0 or Q_index >= len(self.Q):
+ raise ValueError("Q_index must be a valid index for the Q values.")
+
+ data = self.experiment.data["Q", Q_index]
+ x = data.coords["energy"].values
+ y = data.values
+ e = data.variances**0.5
+
+ def fit_func(x_vals):
+ return self.calculate_theory(energy=x_vals, Q_index=Q_index)
+
+ fitter = EasyScienceFitter(
+ fit_object=self,
+ fit_function=fit_func,
+ )
+
+ # Perform the fit
+ fit_result = fitter.fit(x=x, y=y, weights=1.0 / e)
+
+ # Store result
+ self.fit_result = fit_result
+
+ return fit_result
+
+ def plot_data_and_model(
+ self,
+ plot_individual_components: bool = True,
+ ) -> None:
+ """Plot the experimental data and the model prediction.
+
+ Args:
+ plot_individual_components (bool): Whether to plot
+ individual components. Default is True.
+ """
+ if not isinstance(plot_individual_components, bool):
+ raise TypeError("plot_individual_components must be True or False.")
+
+ model_data_array = self._create_model_data_group(
+ individual_components=plot_individual_components
+ )
+ if self.experiment is None or self.experiment.data is None:
+ raise ValueError("Experiment data is not available for plotting.")
+
+ from IPython.display import display
+
+ fig = pp.slicer(
+ {"Data": self.experiment.data, "Model": model_data_array},
+ color={"Data": "black", "Model": "red"},
+ linestyle={"Data": "none", "Model": "solid"},
+ marker={"Data": "o", "Model": "None"},
+ )
+ display(fig)
+
+ #############
+ # Private methods
+ #############
+
+ def _update_models(self):
+ """Update models based on the current experiment."""
+ if self.experiment is None:
+ return
+
+ for Q_index in range(len(self.Q)):
+ self._convolvers[Q_index] = self._create_convolver(Q_index)
+
+ def _create_convolver(self, Q_index: int):
+ """Initialize and return a Convolution object for the given Q
+ index.
+ """
+ # Add checks of empty sample models etc
+
+ sample_components = self.sample_model._component_collections[Q_index]
+ resolution_components = self.resolution_model._component_collections[Q_index]
+ energy = self.energy
+ convolver = Convolution(
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ energy=energy,
+ temperature=self.temperature,
+ )
+ return convolver
+
+ def _create_model_data_group(self, individual_components=True) -> sc.DataArray:
+ """Create a Scipp DataArray representing the model over all Q
+ and energy values.
+ """
+ if self.Q is None or self.energy is None:
+ raise ValueError("Q and energy must be defined in the experiment.")
+
+ model_data = []
+ for Q_index in range(len(self.Q)):
+ model_at_Q = self.calculate(Q_index)
+ model_data.append(model_at_Q)
+
+ model_data_array = sc.DataArray(
+ data=sc.array(dims=["Q", "energy"], values=model_data),
+ coords={
+ "Q": self.Q,
+ "energy": self.energy,
+ },
+ )
+ model_group = sc.DataGroup({"Model": model_data_array})
+
+ # if plot_individual_components: comps =
+ # ana.calculate_individual_components(E) for name,
+ # vals in comps.items(): if name not in
+ # component_arrays: component_arrays[name] =
+ # sc.zeros_like(data) csel =
+ # component_arrays[name] for d, i in
+ # zip(loop_dims, combo): csel = csel[d, i]
+ # csel.values = vals fsel.values =
+ # ana.calculate_theory(E)
+
+ # # Build plot group
+ # data_and_model = {"Data": self._experiment._data.data,
+ # "Model": fit_total} if plot_individual_components and
+ # component_arrays: data_and_model.update(component_arrays)
+ # data_and_model = sc.DataGroup(data_and_model)
+
+ if individual_components:
+ components = self.calculate_individual_components_all_Q()
+ for Q_index, (sample_comps, background_comps) in enumerate(components):
+ for samp_index, samp_comp in enumerate(sample_comps):
+ model_data_array[samp_comp.display_name] = sc.zeros_like(
+ model_data_array.data
+ )
+ model_data_array[samp_comp.display_name].data[
+ Q_index, :
+ ] = samp_comp
+ for back_index, back_comp in enumerate(background_comps):
+ model_data_array[back_comp.display_name] = sc.zeros_like(
+ model_data_array.data
+ )
+ model_data_array[back_comp.display_name].data[
+ Q_index, :
+ ] = back_comp
+
+ model_data_array = model_data_array + model_group # WRONG BUT LINT
+ return model_data_array
+
+ # def _create_convolvers(
+ # self, energy: np.ndarray | sc.Variable | None = None
+ # ) -> None:
+ # """Create Convolution objects for each Q value."""
+ # num_Q = len(self.Q) if self.Q is not None else 0
+ # self._convolvers = [
+ # self._create_convolver(i, energy=energy) for i in range(num_Q)
+ # ]
diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py
new file mode 100644
index 00000000..81921b20
--- /dev/null
+++ b/src/easydynamics/analysis/analysis.py
@@ -0,0 +1,410 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+
+import numpy as np
+import plopp as pp
+import scipp as sc
+from easyscience.fitting.minimizers.utils import FitResults
+from easyscience.fitting.multi_fitter import MultiFitter
+from easyscience.variable import Parameter
+from scipp import UnitError
+
+from easydynamics.analysis.analysis1d import Analysis1d
+from easydynamics.analysis.analysis_base import AnalysisBase
+from easydynamics.experiment import Experiment
+from easydynamics.sample_model import SampleModel
+from easydynamics.sample_model.instrument_model import InstrumentModel
+from easydynamics.utils.utils import _in_notebook
+
+
+class Analysis(AnalysisBase):
+ """For analysing data."""
+
+ def __init__(
+ self,
+ display_name: str = "MyAnalysis",
+ unique_name: str | None = None,
+ experiment: Experiment | None = None,
+ sample_model: SampleModel | None = None,
+ instrument_model: InstrumentModel | None = None,
+ extra_parameters: Parameter | list[Parameter] | None = None,
+ ):
+
+ super().__init__(
+ display_name=display_name,
+ unique_name=unique_name,
+ experiment=experiment,
+ sample_model=sample_model,
+ instrument_model=instrument_model,
+ extra_parameters=extra_parameters,
+ )
+
+ if experiment is not None and not isinstance(experiment, Experiment):
+ raise TypeError("experiment must be an instance of Experiment or None.")
+
+ self._analysis_list = []
+ if self.Q is not None:
+ for Q_index in range(len(self.Q)):
+ analysis = Analysis1d(
+ display_name=f"{self.display_name}_Q{Q_index}",
+ unique_name=(f"{self.unique_name}_Q{Q_index}"),
+ experiment=self.experiment,
+ sample_model=self.sample_model,
+ instrument_model=self.instrument_model,
+ extra_parameters=self._extra_parameters,
+ Q_index=Q_index,
+ )
+ self._analysis_list.append(analysis)
+
+ #############
+ # Properties
+ #############
+
+ @property
+ def analysis_list(self) -> list[Analysis1d]:
+ """List of Analysis1d objects, one for each Q index."""
+ return self._analysis_list
+
+ @analysis_list.setter
+ def analysis_list(self, value: list[Analysis1d]) -> None:
+ """analysis_list is read-only. To change the analysis list,
+ modify the experiment, sample model, or instrument model."""
+
+ raise AttributeError(
+ "analysis_list is read-only. "
+ "To change the analysis list, modify the experiment, sample model, "
+ "or instrument model."
+ )
+
+ #############
+ # Other methods
+ #############
+ def calculate(
+ self,
+ Q_index: int | None = None,
+ ) -> list[np.ndarray] | np.ndarray:
+ """Calculate model data for a specific Q index.
+ If Q_index is None, calculate for all Q indices and return a
+ list of arrays.
+
+ Parameters: Q_index: Index of the Q value to calculate for. If
+ None, calculate for all Q values.
+
+ Returns: If Q_index is None, returns a list of numpy arrays, one
+ for each Q index. If Q_index is an integer, returns a single
+ numpy array for that Q index.
+ """
+
+ if Q_index is None:
+ return [analysis.calculate() for analysis in self.analysis_list]
+
+ Q_index = self._verify_Q_index(Q_index)
+ return self.analysis_list[Q_index].calculate()
+
+ def fit(
+ self,
+ fit_method: str = "independent",
+ Q_index: int | None = None,
+ ) -> FitResults | list[FitResults]:
+ """Fit the model to the experimental data.
+
+ Parameters:
+ ---------------
+ fit_method: string, optional
+ Method to use for fitting. Options are "independent" (fit
+ each Q index independently, one after the other) or
+ "simultaneous" (fit all Q indices simultaneously).
+ Q_index: int or None, optional
+ If fit_method is "independent", specify which Q index to
+ fit. If None, fit all Q indices independently.
+
+ Returns: Fit results, which may be a list of FitResults if
+ fitting independently, or a single FitResults object if
+ fitting simultaneously.
+ """
+
+ if self.Q is None:
+ raise ValueError(
+ "No Q values available for fitting. Please check the experiment data."
+ )
+
+ Q_index = self._verify_Q_index(Q_index)
+
+ if fit_method == "independent":
+ if Q_index is not None:
+ return self._fit_single_Q(Q_index)
+ else:
+ return self._fit_all_Q_independently()
+ elif fit_method == "simultaneous":
+ return self._fit_all_Q_simultaneously()
+ else:
+ raise ValueError(
+ "Invalid fit method. Choose 'independent' or 'simultaneous'."
+ )
+
+ def plot_data_and_model(
+ self,
+ Q_index: int | None = None,
+ plot_components: bool = True,
+ add_background: bool = True,
+ **kwargs,
+ ) -> None:
+ """Plot the data and model using plopp."""
+
+ if Q_index is not None:
+ Q_index = self._verify_Q_index(Q_index)
+ return self.analysis_list[Q_index].plot_data_and_model(
+ plot_components=plot_components,
+ add_background=add_background,
+ **kwargs,
+ )
+
+ if self.experiment.binned_data is None:
+ raise ValueError("No data to plot. Please load data first.")
+
+ if not _in_notebook():
+ raise RuntimeError(
+ "plot_data() can only be used in a Jupyter notebook environment."
+ )
+
+ if self.Q is None:
+ raise ValueError(
+ "No Q values available for plotting. Please check the experiment data."
+ )
+
+ if not isinstance(plot_components, bool):
+ raise TypeError("plot_components must be True or False.")
+
+ if not isinstance(add_background, bool):
+ raise TypeError("add_background must be True or False.")
+
+ from IPython.display import display
+
+ plot_kwargs_defaults = {
+ "title": self.display_name,
+ "linestyle": {"Data": "none", "Model": "-"},
+ "marker": {"Data": "o", "Model": None},
+ "color": {"Data": "black", "Model": "red"},
+ "markerfacecolor": {"Data": "none", "Model": "none"},
+ }
+ data_and_model = {
+ "Data": self.experiment.binned_data,
+ "Model": self._create_model_array(),
+ }
+
+ if plot_components:
+ components = self._create_components_dataset(add_background=add_background)
+ for key in components.keys():
+ data_and_model[key] = components[key]
+ plot_kwargs_defaults["linestyle"][key] = "--"
+ plot_kwargs_defaults["marker"][key] = None
+
+ # Overwrite defaults with any user-provided kwargs
+ plot_kwargs_defaults.update(kwargs)
+
+ fig = pp.slicer(
+ data_and_model,
+ **plot_kwargs_defaults,
+ )
+ display(fig)
+
+ def parameters_to_dataset(self) -> sc.Dataset:
+ """
+ Creates a scipp dataset with copies of the Parameters in the
+ model. Ensures unit consistency across Q.
+ """
+
+ ds = sc.Dataset(coords={"Q": self.Q})
+
+ # Collect all parameter names
+ all_names = {
+ param.name
+ for analysis in self.analysis_list
+ for param in analysis.get_all_parameters()
+ }
+
+ # Storage
+ values = {name: [] for name in all_names}
+ variances = {name: [] for name in all_names}
+ units = {}
+
+ for analysis in self.analysis_list:
+ pars = {p.name: p for p in analysis.get_all_parameters()}
+
+ for name in all_names:
+ if name in pars:
+ p = pars[name]
+
+ # Unit consistency check
+ if name not in units:
+ units[name] = p.unit
+ elif units[name] != p.unit:
+ try:
+ p.unit.convert(units[name])
+ except Exception as e:
+ raise UnitError(
+ f"Inconsistent units for parameter '{name}': "
+ f"{units[name]} vs {p.unit}"
+ ) from e
+
+ values[name].append(p.value)
+ variances[name].append(p.variance)
+ else:
+ values[name].append(np.nan)
+ variances[name].append(np.nan)
+
+ # Build dataset variables
+ for name in all_names:
+ ds[name] = sc.Variable(
+ dims=["Q"],
+ values=np.asarray(values[name], dtype=float),
+ variances=np.asarray(variances[name], dtype=float),
+ unit=units.get(name, None),
+ )
+
+ return ds
+
+ def plot_parameters(
+ self,
+ names: str | list[str] | None = None,
+ **kwargs,
+ ) -> None:
+ """
+ Plot fitted parameters as a function of Q.
+
+ Parameters:
+ ---------------
+ names: str or list of str
+ Name(s) of the parameter(s) to plot. If None, plots all
+ parameters.
+ kwargs: Additional keyword arguments passed to plopp.slicer for
+ customizing the plot (e.g., title, linestyle, marker,
+ color).
+
+ Returns: A plopp figure.
+ """
+
+ ds = self.parameters_to_dataset()
+
+ if not names:
+ names = list(ds.keys())
+
+ if isinstance(names, str):
+ names = [names]
+
+ if not isinstance(names, list) or not all(
+ isinstance(name, str) for name in names
+ ):
+ raise TypeError("names must be a string or a list of strings.")
+
+ for name in names:
+ if name not in ds:
+ raise ValueError(f"Parameter '{name}' not found in dataset.")
+
+ data_to_plot = {name: ds[name] for name in names}
+ plot_kwargs_defaults = {
+ "linestyle": {name: "none" for name in names},
+ "marker": {name: "o" for name in names},
+ "markerfacecolor": {name: "none" for name in names},
+ }
+
+ plot_kwargs_defaults.update(kwargs)
+ fig = pp.plot(
+ data_to_plot,
+ **plot_kwargs_defaults,
+ )
+ return fig
+
+ #############
+ # Private methods
+ #############
+
+ def _fit_single_Q(self, Q_index: int) -> FitResults:
+ """Fit data for a single Q index."""
+
+ Q_index = self._verify_Q_index(Q_index)
+
+ return self.analysis_list[Q_index].fit()
+
+ def _fit_all_Q_independently(self) -> list[FitResults]:
+ """Fit data for all Q indices independently."""
+ return [analysis.fit() for analysis in self.analysis_list]
+
+ def _fit_all_Q_simultaneously(self) -> FitResults:
+ """Fit data for all Q indices simultaneously."""
+
+ xs = []
+ ys = []
+ ws = []
+
+ for analysis in self.analysis_list:
+ data = analysis.experiment.data["Q", analysis.Q_index]
+
+ x = data.coords["energy"].values
+ y = data.values
+ e = np.sqrt(data.variances)
+
+ # Make sure the convolver is up to date for this Q index
+ analysis._convolver = analysis._create_convolver()
+
+ xs.append(x)
+ ys.append(y)
+ ws.append(1.0 / e)
+
+ mf = MultiFitter(
+ fit_objects=self.analysis_list,
+ fit_functions=self.get_fit_functions(),
+ )
+
+ results = mf.fit(
+ x=xs,
+ y=ys,
+ weights=ws,
+ )
+ return results
+
+ def get_fit_functions(self) -> list[callable]:
+ """
+ Get fit functions for all Q indices, which can be used for
+ simultaneous fitting.
+ """
+ return [analysis.as_fit_function() for analysis in self.analysis_list]
+
+ def _create_model_array(self) -> sc.DataArray:
+ """Create a scipp array for the model"""
+
+ model = sc.array(dims=["Q", "energy"], values=self.calculate())
+ model_data_array = sc.DataArray(
+ data=model,
+ coords={"Q": self.Q, "energy": self.experiment.energy},
+ )
+ return model_data_array
+
+ def _create_components_dataset(self, add_background: bool = True) -> sc.Dataset:
+ """
+ Create a scipp dataset containing the individual components of
+ the model for plotting.
+
+ Parameters:
+ ---------------
+ add_background: bool, optional
+ Whether to add background components to the sample model
+ components. Default is True.
+
+ Returns: A scipp Dataset where each variable is a component of
+ the model, with dimensions "Q" and "energy".
+ """
+ if not isinstance(add_background, bool):
+ raise TypeError("add_background must be True or False.")
+
+ datasets = [
+ analysis._create_components_dataset_single_Q(add_background=add_background)
+ for analysis in self.analysis_list
+ ]
+
+ return sc.concat(datasets, dim="Q")
+
+ #############
+ # Dunder methods
+ #############
diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py
new file mode 100644
index 00000000..c4127960
--- /dev/null
+++ b/src/easydynamics/analysis/analysis1d.py
@@ -0,0 +1,495 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+
+from inspect import Parameter
+
+import numpy as np
+import scipp as sc
+from easyscience.fitting.fitter import Fitter as EasyScienceFitter
+from easyscience.fitting.minimizers.utils import FitResults
+from easyscience.variable import DescriptorNumber
+
+from easydynamics.analysis.analysis_base import AnalysisBase
+from easydynamics.convolution.convolution import Convolution
+from easydynamics.experiment import Experiment
+from easydynamics.sample_model import InstrumentModel
+from easydynamics.sample_model import SampleModel
+from easydynamics.sample_model.component_collection import ComponentCollection
+from easydynamics.sample_model.components.model_component import ModelComponent
+
+
+class Analysis1d(AnalysisBase):
+ """For analysing data."""
+
+ def __init__(
+ self,
+ display_name: str = "MyAnalysis",
+ unique_name: str | None = None,
+ experiment: Experiment | None = None,
+ sample_model: SampleModel | None = None,
+ instrument_model: InstrumentModel | None = None,
+ Q_index: int | None = None,
+ extra_parameters: Parameter | list[Parameter] | None = None,
+ ):
+ super().__init__(
+ display_name=display_name,
+ unique_name=unique_name,
+ experiment=experiment,
+ sample_model=sample_model,
+ instrument_model=instrument_model,
+ )
+
+ self._Q_index = self._verify_Q_index(Q_index)
+
+ self._fit_result = None
+
+ self._convolver = self._create_convolver()
+
+ #############
+ # Properties
+ #############
+
+ @property
+ def Q_index(self) -> int | None:
+ """Get the Q index for single Q analysis."""
+ return self._Q_index
+
+ @Q_index.setter
+ def Q_index(self, value: int | None) -> None:
+ """Set the Q index for single Q analysis.
+
+ Args:
+ index (int | None): The Q index.
+ """
+ self._Q_index = self._verify_Q_index(value)
+ self._on_Q_index_changed()
+
+ #############
+ # Other methods
+ #############
+
+ def calculate(self) -> np.ndarray:
+ """Calculate the model prediction for a given Q index.
+ Makes sure the convolver is up to date before calculating.
+
+ Returns:
+ np.ndarray: The calculated model prediction.
+ """
+
+ self._convolver = self._create_convolver()
+
+ return self._calculate()
+
+ def _calculate(self) -> np.ndarray:
+ """Calculate the model prediction for a given Q index.
+
+ Args:
+ energy (float): The energy value to calculate the model for.
+ Returns:
+ np.ndarray: The calculated model prediction.
+ """
+
+ sample_intensity = self._evaluate_sample()
+
+ background_intensity = self._evaluate_background()
+
+ sample_plus_background = sample_intensity + background_intensity
+
+ return sample_plus_background
+
+ def fit(self) -> FitResults:
+ """Fit the model to the experimental data for a given Q index.
+
+ Returns:
+ FitResult: The result of the fit.
+
+ Notes
+ -----
+ The energy grid is fixed for the duration of the fit.
+ Convolution objects are created once and reused during
+ parameter optimization for performance reasons.
+ """
+ if self._experiment is None:
+ raise ValueError("No experiment is associated with this Analysis.")
+
+ Q_index = self._require_Q_index()
+
+ data = self.experiment.data["Q", Q_index]
+ x = data.coords["energy"].values
+ y = data.values
+ e = data.variances**0.5
+
+ # Create convolver once to reuse during fitting
+ self._convolver = self._create_convolver()
+
+ fitter = EasyScienceFitter(
+ fit_object=self,
+ fit_function=self.as_fit_function(),
+ )
+
+ fit_result = fitter.fit(x=x, y=y, weights=1.0 / e)
+
+ self._fit_result = fit_result
+
+ return fit_result
+
+ def as_fit_function(self, x=None, **kwargs):
+ """
+ Return self._calculate as a fit function.
+
+ The EasyScience fitter requires x as input, but
+ self._calculate() already uses the correct energy from the
+ experiment. So we ignore the x input and just return the
+ calculated model.
+ """
+
+ def fit_function(x, **kwargs):
+ return self._calculate()
+
+ return fit_function
+
+ def get_all_variables(self) -> list[DescriptorNumber]:
+ """Get all variables used in the analysis.
+
+ Returns:
+ List[Descriptor]: A list of all variables.
+ """
+ variables = self.sample_model.get_all_variables(Q_index=self.Q_index)
+
+ variables.extend(self.instrument_model.get_all_variables(Q_index=self.Q_index))
+
+ if self._extra_parameters:
+ variables.extend(self._extra_parameters)
+
+ return variables
+
+ def plot_data_and_model(
+ self,
+ plot_components: bool = True,
+ add_background=True,
+ **kwargs,
+ ):
+ """Plot the experimental data and the model prediction for a
+ given Q index.
+
+ Uses Plopp for plotting.
+
+ Args:
+ add_background (bool): Whether to add the background to the
+ model prediction when plotting individual components.
+
+ kwargs: Keyword arguments to pass to the plotting
+ function.
+ Returns:
+ A plot of the data and model.
+ """
+ import plopp as pp
+
+ if self.experiment.data is None:
+ raise ValueError("No data to plot. Please load data first.")
+
+ data = self.experiment.data["Q", self.Q_index]
+ model_array = self._create_sample_scipp_array()
+
+ component_dataset = self._create_components_dataset_single_Q(
+ add_background=add_background
+ )
+
+ # Create a dataset containing the data, model, and individual
+ # components for plotting.
+ data_and_model = sc.Dataset(
+ {
+ "Data": data,
+ "Model": model_array,
+ }
+ )
+
+ data_and_model = sc.merge(data_and_model, component_dataset)
+ plot_kwargs_defaults = {
+ "title": self.display_name,
+ "linestyle": {"Data": "none", "Model": "-"},
+ "marker": {"Data": "o", "Model": "none"},
+ "color": {"Data": "black", "Model": "red"},
+ "markerfacecolor": {"Data": "none", "Model": "none"},
+ }
+
+ if plot_components:
+ for comp_name in component_dataset.keys():
+ plot_kwargs_defaults["linestyle"][comp_name] = "--"
+ plot_kwargs_defaults["marker"][comp_name] = None
+
+ # Overwrite defaults with any user-provided kwargs
+ plot_kwargs_defaults.update(kwargs)
+
+ fig = pp.plot(
+ data_and_model,
+ **plot_kwargs_defaults,
+ )
+ return fig
+
+ #############
+ # Private methods: small utilities
+ #############
+
+ def _require_Q_index(self) -> int:
+ """
+ Get the Q index, ensuring it is set.
+ Raises a ValueError if the Q index is not set.
+ Returns:
+ int: The Q index.
+ """
+ if self._Q_index is None:
+ raise ValueError("Q_index must be set.")
+ return self._Q_index
+
+ def _on_Q_index_changed(self) -> None:
+ """
+ Handle changes to the Q index.
+
+ This method is called whenever the Q index is changed. It
+ updates the Convolution object for the new Q index.
+ """
+ self._convolver = self._create_convolver()
+
+ #############
+ # Private methods: evaluation
+ #############
+
+ def _evaluate_components(
+ self,
+ components: ComponentCollection | ModelComponent,
+ convolver: Convolution | None = None,
+ convolve: bool = True,
+ ) -> np.ndarray:
+ """
+ Calculate the contribution of a set of components, optionally
+ convolving with the resolution.
+ If convolve is True and a
+ Convolution object is provided (for full model evaluation), we
+ use it to perform the convolution of the components with the
+ resolution.
+ If convolve is True but no Convolution object is
+ provided, create a new Convolution object for the given
+ components (for individual components).
+ If convolve is False, evaluate the components directly without
+ convolution (for background).
+ Args:
+ components (ComponentCollection | ModelComponent):
+ The components to evaluate.
+ convolver (Convolution | None): An optional Convolution
+ object to use for convolution. If None, a new
+ Convolution object will be created if convolve is True.
+ convolve (bool):
+ Whether to perform convolution with the resolution.
+ Default is True.
+ """
+ Q_index = self._require_Q_index()
+ energy = self.energy.values
+ energy_offset = self.instrument_model.get_energy_offset_at_Q(Q_index).value
+
+ # If there are no components, return zero
+ if isinstance(components, ComponentCollection) and components.is_empty:
+ return np.zeros_like(energy)
+
+ # No convolution
+ if not convolve:
+ return components.evaluate(energy - energy_offset)
+
+ resolution = self.instrument_model.resolution_model.get_component_collection(
+ Q_index
+ )
+ if resolution.is_empty:
+ return components.evaluate(energy - energy_offset)
+
+ # If a convolver is provided, use it. This allows reusing the
+ # same convolver for multiple evaluations during fitting for
+ # performance reasons.
+ if convolver is not None:
+ return convolver.convolution()
+
+ # If no convolver is provided, create a new one. This is for
+ # evaluating individual components for plotting, where
+ # performance is not important.
+ conv = Convolution(
+ sample_components=components,
+ resolution_components=resolution,
+ energy=energy,
+ temperature=self.temperature,
+ energy_offset=energy_offset,
+ )
+ return conv.convolution()
+
+ def _evaluate_sample(self) -> np.ndarray:
+ """
+ Evaluate the sample contribution for a given Q index.
+
+ Assumes that self._convolver is up to date.
+
+ Returns:
+ np.ndarray: The evaluated sample contribution.
+ """
+ Q_index = self._require_Q_index()
+ components = self.sample_model.get_component_collection(Q_index=Q_index)
+ return self._evaluate_components(
+ components=components,
+ convolver=self._convolver,
+ convolve=True,
+ )
+
+ def _evaluate_sample_component(
+ self,
+ component: ModelComponent,
+ ) -> np.ndarray:
+ """
+ Evaluate a single sample component for a given Q index.
+
+ Args:
+ component: The sample component to evaluate.
+ Returns:
+ np.ndarray: The evaluated sample component contribution.
+ """
+ return self._evaluate_components(
+ components=component,
+ convolver=None,
+ convolve=True,
+ )
+
+ def _evaluate_background(self) -> np.ndarray:
+ """
+ Evaluate the background contribution for a given Q index.
+
+ Returns:
+ np.ndarray: The evaluated background contribution.
+ """
+ Q_index = self._require_Q_index()
+ background_components = (
+ self.instrument_model.background_model.get_component_collection(
+ Q_index=Q_index
+ )
+ )
+ return self._evaluate_components(
+ components=background_components,
+ convolver=None,
+ convolve=False,
+ )
+
+ def _evaluate_background_component(
+ self,
+ component: ModelComponent,
+ ) -> np.ndarray:
+ """
+ Evaluate a single background component for a given Q index.
+
+ Args:
+ component: The background component to evaluate.
+ Returns:
+ np.ndarray: The evaluated background component contribution.
+ """
+
+ return self._evaluate_components(
+ components=component,
+ convolver=None,
+ convolve=False,
+ )
+
+ def _create_convolver(self) -> Convolution | None:
+ """
+ Initialize and return a Convolution object for the given Q
+ index. If the necessary components for convolution are not
+ available, return None.
+
+ Returns:
+ Convolution | None: The initialized Convolution object or
+ None if not available.
+ """
+ Q_index = self._require_Q_index()
+
+ sample_components = self.sample_model.get_component_collection(Q_index)
+ if sample_components.is_empty:
+ return None
+
+ resolution_components = (
+ self.instrument_model.resolution_model.get_component_collection(Q_index)
+ )
+ if resolution_components.is_empty:
+ return None
+ energy = self.energy
+ # TODO: allow convolution options to be set.
+ convolver = Convolution(
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ energy=energy,
+ temperature=self.temperature,
+ energy_offset=self.instrument_model.get_energy_offset_at_Q(Q_index),
+ )
+ return convolver
+
+ #############
+ # Private methods: create scipp arrays for plotting
+ #############
+
+ def _create_component_scipp_array(
+ self,
+ component: ModelComponent,
+ background: np.ndarray | None = None,
+ ) -> sc.DataArray:
+ values = self._evaluate_sample_component(component)
+ if background is not None:
+ values += background
+ return self._to_scipp_array(values)
+
+ def _create_background_component_scipp_array(
+ self,
+ component: ModelComponent,
+ ) -> sc.DataArray:
+ values = self._evaluate_background_component(component)
+ return self._to_scipp_array(values)
+
+ def _create_sample_scipp_array(self) -> sc.DataArray:
+ values = self._calculate()
+ return self._to_scipp_array(values)
+
+ def _create_components_dataset_single_Q(
+ self, add_background: bool = True
+ ) -> dict[str, sc.DataArray]:
+ """Create sc.DataArrays for all sample and background
+ components."""
+ scipp_arrays = {}
+ sample_components = self.sample_model.get_component_collection(
+ Q_index=self.Q_index
+ ).components
+
+ background_components = (
+ self.instrument_model.background_model.get_component_collection(
+ Q_index=self.Q_index
+ ).components
+ )
+ background = self._evaluate_background() if add_background else None
+ for component in sample_components:
+ scipp_arrays[component.display_name] = self._create_component_scipp_array(
+ component, background=background
+ )
+ for component in background_components:
+ scipp_arrays[component.display_name] = (
+ self._create_background_component_scipp_array(component)
+ )
+ return sc.Dataset(scipp_arrays)
+
+ def _to_scipp_array(self, values: np.ndarray) -> sc.DataArray:
+ """
+ Convert a numpy array of values to a sc.DataArray with the
+ correct coordinates for energy and Q.
+
+ Args:
+ values (np.ndarray): The values to convert.
+ Returns:
+ sc.DataArray: The converted sc.DataArray.
+ """
+ return sc.DataArray(
+ data=sc.array(dims=["energy"], values=values),
+ coords={
+ "energy": self.energy,
+ "Q": self.Q[self.Q_index],
+ },
+ )
diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py
new file mode 100644
index 00000000..e6f63939
--- /dev/null
+++ b/src/easydynamics/analysis/analysis_base.py
@@ -0,0 +1,205 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+
+import scipp as sc
+from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase
+from easyscience.variable import Parameter
+
+from easydynamics.experiment import Experiment
+from easydynamics.sample_model import InstrumentModel
+from easydynamics.sample_model import SampleModel
+
+
+class AnalysisBase(EasyScienceModelBase):
+ """For analysing data."""
+
+ def __init__(
+ self,
+ display_name: str = "MyAnalysis",
+ unique_name: str | None = None,
+ experiment: Experiment | None = None,
+ sample_model: SampleModel | None = None,
+ instrument_model: InstrumentModel | None = None,
+ extra_parameters: Parameter | list[Parameter] | None = None,
+ ):
+ super().__init__(display_name=display_name, unique_name=unique_name)
+
+ if experiment is None:
+ self._experiment = Experiment()
+ elif isinstance(experiment, Experiment):
+ self._experiment = experiment
+ else:
+ raise TypeError("experiment must be an instance of Experiment or None.")
+
+ if sample_model is None:
+ self._sample_model = SampleModel()
+ elif isinstance(sample_model, SampleModel):
+ self._sample_model = sample_model
+ else:
+ raise TypeError("sample_model must be an instance of SampleModel or None.")
+
+ if instrument_model is None:
+ self._instrument_model = InstrumentModel()
+ elif isinstance(instrument_model, InstrumentModel):
+ self._instrument_model = instrument_model
+ else:
+ raise TypeError(
+ "instrument_model must be an instance of InstrumentModel or None."
+ )
+
+ if extra_parameters is not None:
+ if isinstance(extra_parameters, Parameter):
+ self._extra_parameters = [extra_parameters]
+ elif isinstance(extra_parameters, list) and all(
+ isinstance(p, Parameter) for p in extra_parameters
+ ):
+ self._extra_parameters = extra_parameters
+ else:
+ raise TypeError(
+ "extra_parameters must be a Parameter or a list of Parameters."
+ )
+ else:
+ self._extra_parameters = []
+
+ self._on_experiment_changed()
+
+ #############
+ # Properties
+ #############
+
+ @property
+ def experiment(self) -> Experiment | None:
+ """The Experiment associated with this Analysis."""
+ return self._experiment
+
+ @experiment.setter
+ def experiment(self, value: Experiment) -> None:
+ if not isinstance(value, Experiment):
+ raise TypeError("experiment must be an instance of Experiment")
+ self._experiment = value
+ self._on_experiment_changed()
+
+ @property
+ def sample_model(self) -> SampleModel:
+ """The SampleModel associated with this Analysis."""
+ return self._sample_model
+
+ @sample_model.setter
+ def sample_model(self, value: SampleModel) -> None:
+ if not isinstance(value, SampleModel):
+ raise TypeError("sample_model must be an instance of SampleModel")
+ self._sample_model = value
+ self._on_sample_model_changed()
+
+ @property
+ def instrument_model(self) -> InstrumentModel:
+ """The InstrumentModel associated with this Analysis."""
+ return self._instrument_model
+
+ @instrument_model.setter
+ def instrument_model(self, value: InstrumentModel) -> None:
+ if not isinstance(value, InstrumentModel):
+ raise TypeError("instrument_model must be an instance of InstrumentModel")
+ self._instrument_model = value
+ self._on_instrument_model_changed()
+
+ @property
+ def Q(self) -> sc.Variable | None:
+ """The Q values from the associated Experiment, if available."""
+ return self.experiment.Q
+
+ @Q.setter
+ def Q(self, value) -> None:
+ """Q is a read-only property derived from the Experiment."""
+ raise AttributeError("Q is a read-only property derived from the Experiment.")
+
+ @property
+ def energy(self) -> sc.Variable | None:
+ """The energy values from the associated Experiment, if
+ available.
+ """
+ return self.experiment.energy
+
+ @energy.setter
+ def energy(self, value) -> None:
+ """Energy is a read-only property derived from the
+ Experiment.
+ """
+ raise AttributeError(
+ "energy is a read-only property derived from the Experiment."
+ )
+
+ @property
+ def temperature(self) -> Parameter | None:
+ """
+ The temperature from the associated SampleModel, if available.
+ """
+ return self.sample_model.temperature if self.sample_model is not None else None
+
+ @temperature.setter
+ def temperature(self, value) -> None:
+ """
+ Temperature is a read-only property derived from the
+ SampleModel.
+ """
+ raise AttributeError(
+ "temperature is a read-only property derived from the sample model."
+ )
+
+ #############
+ # Other methods
+ #############
+
+ #############
+ # Private methods
+ #############
+
+ def _on_experiment_changed(self) -> None:
+ """
+ Update the Q values in the sample and instrument models when the
+ experiment changes.
+ """
+ self._sample_model.Q = self.Q
+ self._instrument_model.Q = self.Q
+
+ def _on_sample_model_changed(self) -> None:
+ """
+ Update the Q values in the sample model when the sample model
+ changes.
+ """
+ self._sample_model.Q = self.Q
+
+ def _on_instrument_model_changed(self) -> None:
+ """
+ Update the Q values in the instrument model when the instrument
+ model changes.
+ """
+ self._instrument_model.Q = self.Q
+
+ def _verify_Q_index(self, Q_index: int | None) -> int | None:
+ """
+ Verify that the Q index is valid.
+
+ Params:
+ Q_index (int | None): The Q index to verify.
+ Returns:
+ int | None: The verified Q index.
+ Raises:
+ ValueError: If the Q index is not valid.
+ """
+ if Q_index is not None:
+ if (
+ not isinstance(Q_index, int)
+ or Q_index < 0
+ or (self.Q is not None and Q_index >= len(self.Q))
+ ):
+ raise ValueError("Q_index must be a valid index for the Q values.")
+ return Q_index
+
+ #############
+ # Dunder methods
+ #############
+
+ def __repr__(self) -> str:
+ return f"AnalysisBase(display_name={self.display_name}, unique_name={self.unique_name})"
diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py
index cfa56c9f..c24a50f2 100644
--- a/src/easydynamics/convolution/analytical_convolution.py
+++ b/src/easydynamics/convolution/analytical_convolution.py
@@ -3,6 +3,7 @@
import numpy as np
import scipp as sc
+from easyscience.variable import Parameter
from scipy.special import voigt_profile
from easydynamics.convolution.convolution_base import ConvolutionBase
@@ -12,8 +13,7 @@
from easydynamics.sample_model import Voigt
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
-
-Numerical = float | int
+from easydynamics.utils.utils import Numeric
class AnalyticalConvolution(ConvolutionBase):
@@ -35,26 +35,28 @@ class AnalyticalConvolution(ConvolutionBase):
# Mapping of supported component type pairs to convolution methods.
# Delta functions are handled separately.
_CONVOLUTIONS = {
- ('Gaussian', 'Gaussian'): '_convolute_gaussian_gaussian',
- ('Gaussian', 'Lorentzian'): '_convolute_gaussian_lorentzian',
- ('Gaussian', 'Voigt'): '_convolute_gaussian_voigt',
- ('Lorentzian', 'Lorentzian'): '_convolute_lorentzian_lorentzian',
- ('Lorentzian', 'Voigt'): '_convolute_lorentzian_voigt',
- ('Voigt', 'Voigt'): '_convolute_voigt_voigt',
+ ("Gaussian", "Gaussian"): "_convolute_gaussian_gaussian",
+ ("Gaussian", "Lorentzian"): "_convolute_gaussian_lorentzian",
+ ("Gaussian", "Voigt"): "_convolute_gaussian_voigt",
+ ("Lorentzian", "Lorentzian"): "_convolute_lorentzian_lorentzian",
+ ("Lorentzian", "Voigt"): "_convolute_lorentzian_voigt",
+ ("Voigt", "Voigt"): "_convolute_voigt_voigt",
}
def __init__(
self,
energy: np.ndarray | sc.Variable,
- energy_unit: str | sc.Unit = 'meV',
+ energy_unit: str | sc.Unit = "meV",
sample_components: ComponentCollection | ModelComponent | None = None,
resolution_components: ComponentCollection | ModelComponent | None = None,
+ energy_offset: Numeric | Parameter = 0.0,
):
super().__init__(
energy=energy,
energy_unit=energy_unit,
sample_components=sample_components,
resolution_components=resolution_components,
+ energy_offset=energy_offset,
)
def convolution(
@@ -142,8 +144,8 @@ def _convolute_analytic_pair(
if isinstance(resolution_component, DeltaFunction):
raise ValueError(
- 'Analytical convolution with a delta function \
- in the resolution model is not supported.'
+ "Analytical convolution with a delta function \
+ in the resolution model is not supported."
)
# Delta function + anything -->
@@ -169,8 +171,8 @@ def _convolute_analytic_pair(
if func_name is None:
raise ValueError(
- f'Analytical convolution not supported for component pair: '
- f'{type(sample_component).__name__}, {type(resolution_component).__name__}'
+ f"Analytical convolution not supported for component pair: "
+ f"{type(sample_component).__name__}, {type(resolution_component).__name__}"
)
# Call the corresponding method
@@ -199,7 +201,7 @@ def _convolute_delta_any(
The evaluated convolution values at self.energy.
"""
return sample_component.area.value * resolution_components.evaluate(
- self.energy.values - sample_component.center.value
+ self.energy_with_offset.values - sample_component.center.value
)
def _convolute_gaussian_gaussian(
@@ -223,7 +225,9 @@ def _convolute_gaussian_gaussian(
The evaluated convolution values at self.energy.
"""
- width = np.sqrt(sample_component.width.value**2 + resolution_component.width.value**2)
+ width = np.sqrt(
+ sample_component.width.value**2 + resolution_component.width.value**2
+ )
area = sample_component.area.value * resolution_component.area.value
@@ -284,7 +288,8 @@ def _convolute_gaussian_voigt(
center = sample_component.center.value + resolution_component.center.value
gaussian_width = np.sqrt(
- sample_component.width.value**2 + resolution_component.gaussian_width.value**2
+ sample_component.width.value**2
+ + resolution_component.gaussian_width.value**2
)
lorentzian_width = resolution_component.lorentzian_width.value
@@ -384,11 +389,13 @@ def _convolute_voigt_voigt(
center = sample_component.center.value + resolution_component.center.value
gaussian_width = np.sqrt(
- sample_component.gaussian_width.value**2 + resolution_component.gaussian_width.value**2
+ sample_component.gaussian_width.value**2
+ + resolution_component.gaussian_width.value**2
)
lorentzian_width = (
- sample_component.lorentzian_width.value + resolution_component.lorentzian_width.value
+ sample_component.lorentzian_width.value
+ + resolution_component.lorentzian_width.value
)
return self._voigt_eval(
area=area,
@@ -420,7 +427,7 @@ def _gaussian_eval(
"""
normalization = 1 / (np.sqrt(2 * np.pi) * width)
- exponent = -0.5 * ((self.energy.values - center) / width) ** 2
+ exponent = -0.5 * ((self.energy_with_offset.values - center) / width) ** 2
return area * normalization * np.exp(exponent)
@@ -443,7 +450,7 @@ def _lorentzian_eval(self, area: float, center: float, width: float) -> np.ndarr
"""
normalization = width / np.pi
- denominator = (self.energy.values - center) ** 2 + width**2
+ denominator = (self.energy_with_offset.values - center) ** 2 + width**2
return area * normalization / denominator
@@ -471,4 +478,6 @@ def _voigt_eval(
The evaluated Voigt profile values at self.energy.
"""
- return area * voigt_profile(self.energy.values - center, gaussian_width, lorentzian_width)
+ return area * voigt_profile(
+ self.energy_with_offset.values - center, gaussian_width, lorentzian_width
+ )
diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py
index 542515e9..827087c1 100644
--- a/src/easydynamics/convolution/convolution.py
+++ b/src/easydynamics/convolution/convolution.py
@@ -14,8 +14,7 @@
from easydynamics.sample_model import Lorentzian
from easydynamics.sample_model import Voigt
from easydynamics.sample_model.components.model_component import ModelComponent
-
-Numerical = float | int
+from easydynamics.utils.utils import Numeric
class Convolution(NumericalConvolutionBase):
@@ -60,16 +59,16 @@ class Convolution(NumericalConvolutionBase):
# When these attributes are changed, the convolution plan
# needs to be rebuilt
_invalidate_plan_on_change = {
- 'energy',
- '_energy',
- '_energy_grid',
- '_sample_components',
- '_resolution_components',
- '_temperature',
- '_upsample_factor',
- '_extension_factor',
- '_energy_unit',
- '_normalize_detailed_balance',
+ "energy",
+ "_energy",
+ "_energy_grid",
+ "_sample_components",
+ "_resolution_components",
+ "_temperature",
+ "_upsample_factor",
+ "_extension_factor",
+ "_energy_unit",
+ "_normalize_detailed_balance",
}
def __init__(
@@ -77,11 +76,12 @@ def __init__(
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent,
resolution_components: ComponentCollection | ModelComponent,
- upsample_factor: Numerical = 5,
- extension_factor: Numerical = 0.2,
- temperature: Parameter | Numerical | None = None,
- temperature_unit: str | sc.Unit = 'K',
- energy_unit: str | sc.Unit = 'meV',
+ energy_offset: Numeric | Parameter = 0.0,
+ upsample_factor: Numeric = 5,
+ extension_factor: Numeric = 0.2,
+ temperature: Parameter | Numeric | None = None,
+ temperature_unit: str | sc.Unit = "K",
+ energy_unit: str | sc.Unit = "meV",
normalize_detailed_balance: bool = True,
):
self._convolution_plan_is_valid = False
@@ -90,6 +90,7 @@ def __init__(
energy=energy,
sample_components=sample_components,
resolution_components=resolution_components,
+ energy_offset=energy_offset,
upsample_factor=upsample_factor,
extension_factor=extension_factor,
temperature=temperature,
@@ -136,11 +137,13 @@ def convolution(
def _convolve_delta_functions(self) -> np.ndarray:
"Convolve delta function components of the sample model with"
- 'the resolution components.'
- 'No detailed balance correction is applied to delta functions.'
+ "the resolution components."
+ "No detailed balance correction is applied to delta functions."
return sum(
delta.area.value
- * self._resolution_components.evaluate(self.energy.values - delta.center.value)
+ * self._resolution_components.evaluate(
+ self.energy_with_offset.values - delta.center.value
+ )
for delta in self._delta_sample_components.components
)
@@ -165,19 +168,19 @@ def _check_if_pair_is_analytic(
if not isinstance(sample_component, ModelComponent):
raise TypeError(
- f'`sample_component` is an instance of {type(sample_component).__name__}, \
- but must be a ModelComponent.'
+ f"`sample_component` is an instance of {type(sample_component).__name__}, \
+ but must be a ModelComponent."
)
if not isinstance(resolution_component, ModelComponent):
raise TypeError(
- f'`resolution_component` is an instance of {type(resolution_component).__name__}, \
- but must be a ModelComponent.'
+ f"`resolution_component` is an instance of {type(resolution_component).__name__}, \
+ but must be a ModelComponent."
)
if isinstance(resolution_component, DeltaFunction):
raise TypeError(
- 'resolution components contains delta functions. This is not supported.'
+ "resolution components contains delta functions. This is not supported."
)
analytical_types = (Gaussian, Lorentzian, Voigt)
@@ -216,7 +219,9 @@ def _build_convolution_plan(self) -> None:
pair_is_analytic = []
for resolution_component in self._resolution_components.components:
pair_is_analytic.append(
- self._check_if_pair_is_analytic(sample_component, resolution_component)
+ self._check_if_pair_is_analytic(
+ sample_component, resolution_component
+ )
)
# If all resolution components can be convolved analytically
# with this sample component, add it to analytical
@@ -245,6 +250,7 @@ def _set_convolvers(self) -> None:
if self._analytical_sample_components.components:
self._analytical_convolver = AnalyticalConvolution(
energy=self.energy,
+ energy_offset=self.energy_offset,
sample_components=self._analytical_sample_components,
resolution_components=self._resolution_components,
)
@@ -254,6 +260,7 @@ def _set_convolvers(self) -> None:
if self._numerical_sample_components.components:
self._numerical_convolver = NumericalConvolution(
energy=self.energy,
+ energy_offset=self.energy_offset,
sample_components=self._numerical_sample_components,
resolution_components=self._resolution_components,
upsample_factor=self.upsample_factor,
@@ -278,5 +285,8 @@ def __setattr__(self, name, value):
if name in self._invalidate_plan_on_change:
self._convolution_plan_is_valid = False
- if getattr(self, '_reactions_enabled', False) and name in self._invalidate_plan_on_change:
+ if (
+ getattr(self, "_reactions_enabled", False)
+ and name in self._invalidate_plan_on_change
+ ):
self._build_convolution_plan()
diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py
index 34eab3f4..be5cff06 100644
--- a/src/easydynamics/convolution/convolution_base.py
+++ b/src/easydynamics/convolution/convolution_base.py
@@ -3,11 +3,11 @@
import numpy as np
import scipp as sc
+from easyscience.variable import Parameter
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
-
-Numerical = float | int
+from easydynamics.utils.utils import Numeric
class ConvolutionBase:
@@ -30,30 +30,42 @@ def __init__(
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent = None,
resolution_components: ComponentCollection | ModelComponent = None,
- energy_unit: str | sc.Unit = 'meV',
+ energy_unit: str | sc.Unit = "meV",
+ energy_offset: Numeric | Parameter = 0.0,
):
- if isinstance(energy, Numerical):
+ if isinstance(energy, Numeric):
energy = np.array([float(energy)])
if not isinstance(energy, (np.ndarray, sc.Variable)):
- raise TypeError('Energy must be a numpy ndarray or a scipp Variable.')
+ raise TypeError("Energy must be a numpy ndarray or a scipp Variable.")
if not isinstance(energy_unit, (str, sc.Unit)):
- raise TypeError('Energy_unit must be a string or sc.Unit.')
+ raise TypeError("Energy_unit must be a string or sc.Unit.")
if isinstance(energy, np.ndarray):
- energy = sc.array(dims=['energy'], values=energy, unit=energy_unit)
+ energy = sc.array(dims=["energy"], values=energy, unit=energy_unit)
+
+ if isinstance(energy_offset, Numeric):
+ energy_offset = Parameter(
+ name="energy_offset", value=float(energy_offset), unit=energy_unit
+ )
+
+ if not isinstance(energy_offset, Parameter):
+ raise TypeError("Energy_offset must be a number or a Parameter.")
self._energy = energy
self._energy_unit = energy_unit
+ self._energy_offset = energy_offset
if sample_components is not None and not (
isinstance(sample_components, ComponentCollection)
or isinstance(sample_components, ModelComponent)
):
raise TypeError(
- f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501
+ f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501
)
+ if isinstance(sample_components, ModelComponent):
+ sample_components = ComponentCollection(components=[sample_components])
self._sample_components = sample_components
if resolution_components is not None and not (
@@ -61,10 +73,55 @@ def __init__(
or isinstance(resolution_components, ModelComponent)
):
raise TypeError(
- f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501
+ f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501
+ )
+ if isinstance(resolution_components, ModelComponent):
+ resolution_components = ComponentCollection(
+ components=[resolution_components]
)
+ if isinstance(resolution_components, ModelComponent):
+ resolution_components = ComponentCollection(components=[resolution_components])
self._resolution_components = resolution_components
+ @property
+ def energy_offset(self) -> Parameter:
+ """Get the energy offset."""
+ return self._energy_offset
+
+ @energy_offset.setter
+ def energy_offset(self, energy_offset: Numeric | Parameter) -> None:
+ """Set the energy offset.
+ Args:
+ energy_offset : Number or Parameter
+ The energy offset to apply to the convolution.
+
+ Raises:
+ TypeError: If energy_offset is not a number or a Parameter.
+ """
+ if not isinstance(energy_offset, Parameter | Numeric):
+ raise TypeError("Energy_offset must be a number or a Parameter.")
+
+ if isinstance(energy_offset, Numeric):
+ self._energy_offset.value = float(energy_offset)
+
+ if isinstance(energy_offset, Parameter):
+ self._energy_offset = energy_offset
+
+ @property
+ def energy_with_offset(self) -> sc.Variable:
+ """Get the energy with the offset applied."""
+ energy_with_offset = self.energy.copy()
+ energy_with_offset.values = self.energy.values - self.energy_offset.value
+ return energy_with_offset
+
+ @energy_with_offset.setter
+ def energy_with_offset(self, value) -> None:
+ """Energy with offset is a read-only property derived from
+ energy and energy_offset."""
+ raise AttributeError(
+ "Energy with offset is a read-only property derived from energy and energy_offset."
+ )
+
@property
def energy(self) -> sc.Variable:
"""Get the energy."""
@@ -84,14 +141,18 @@ def energy(self, energy: np.ndarray) -> None:
scipp Variable.
"""
- if isinstance(energy, Numerical):
+ if isinstance(energy, Numeric):
energy = np.array([float(energy)])
if not isinstance(energy, (np.ndarray, sc.Variable)):
- raise TypeError('Energy must be a Number, a numpy ndarray or a scipp Variable.')
+ raise TypeError(
+ "Energy must be a Number, a numpy ndarray or a scipp Variable."
+ )
if isinstance(energy, np.ndarray):
- self._energy = sc.array(dims=['energy'], values=energy, unit=self._energy.unit)
+ self._energy = sc.array(
+ dims=["energy"], values=energy, unit=self._energy.unit
+ )
if isinstance(energy, sc.Variable):
self._energy = energy
@@ -106,8 +167,8 @@ def energy_unit(self) -> str:
def energy_unit(self, unit_str: str) -> None:
raise AttributeError(
(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
- f'or create a new {self.__class__.__name__} with the desired unit.'
+ f"Unit is read-only. Use convert_unit to change the unit between allowed types "
+ f"or create a new {self.__class__.__name__} with the desired unit."
)
) # noqa: E501
@@ -121,7 +182,7 @@ def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None:
TypeError: If energy_unit is not a string or scipp unit.
"""
if not isinstance(energy_unit, (str, sc.Unit)):
- raise TypeError('Energy unit must be a string or scipp unit.')
+ raise TypeError("Energy unit must be a string or scipp unit.")
self.energy = sc.to_unit(self.energy, energy_unit)
self._energy_unit = energy_unit
@@ -132,7 +193,9 @@ def sample_components(self) -> ComponentCollection | ModelComponent:
return self._sample_components
@sample_components.setter
- def sample_components(self, sample_components: ComponentCollection | ModelComponent) -> None:
+ def sample_components(
+ self, sample_components: ComponentCollection | ModelComponent
+ ) -> None:
"""Set the sample model.
Args:
sample_components : ComponentCollection or ModelComponent
@@ -144,7 +207,7 @@ def sample_components(self, sample_components: ComponentCollection | ModelCompon
"""
if not isinstance(sample_components, (ComponentCollection, ModelComponent)):
raise TypeError(
- f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501
+ f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501
)
self._sample_components = sample_components
@@ -169,6 +232,6 @@ def resolution_components(
"""
if not isinstance(resolution_components, (ComponentCollection, ModelComponent)):
raise TypeError(
- f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501
+ f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501
)
self._resolution_components = resolution_components
diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py
index 125c4451..95d75917 100644
--- a/src/easydynamics/convolution/numerical_convolution.py
+++ b/src/easydynamics/convolution/numerical_convolution.py
@@ -9,9 +9,10 @@
from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
-from easydynamics.utils.detailed_balance import _detailed_balance_factor as detailed_balance_factor
-
-Numerical = float | int
+from easydynamics.utils.detailed_balance import (
+ _detailed_balance_factor as detailed_balance_factor,
+)
+from easydynamics.utils.utils import Numeric
class NumericalConvolution(NumericalConvolutionBase):
@@ -53,17 +54,19 @@ def __init__(
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent,
resolution_components: ComponentCollection | ModelComponent,
- upsample_factor: Numerical = 5,
- extension_factor: float = 0.2,
- temperature: Parameter | float | None = None,
- temperature_unit: str | sc.Unit = 'K',
- energy_unit: str | sc.Unit = 'meV',
+ energy_offset: Numeric | Parameter = 0.0,
+ upsample_factor: Numeric = 5,
+ extension_factor: Numeric = 0.2,
+ temperature: Parameter | Numeric | None = None,
+ temperature_unit: str | sc.Unit = "K",
+ energy_unit: str | sc.Unit = "meV",
normalize_detailed_balance: bool = True,
):
super().__init__(
energy=energy,
sample_components=sample_components,
resolution_components=resolution_components,
+ energy_offset=energy_offset,
upsample_factor=upsample_factor,
extension_factor=extension_factor,
temperature=temperature,
@@ -87,23 +90,25 @@ def convolution(
# Give warnings if peaks are very wide or very narrow
self._check_width_thresholds(
model=self.sample_components,
- model_name='sample model',
+ model_name="sample model",
)
self._check_width_thresholds(
model=self.resolution_components,
- model_name='resolution model',
+ model_name="resolution model",
)
# Evaluate sample model. If called via the Convolution class,
# delta functions are already filtered out.
sample_vals = self.sample_components.evaluate(
- self._energy_grid.energy_dense - self._energy_grid.energy_even_length_offset
+ self._energy_grid.energy_dense
+ - self._energy_grid.energy_even_length_offset
+ - self.energy_offset.value
)
# Detailed balance correction
if self.temperature is not None:
detailed_balance_factor_correction = detailed_balance_factor(
- energy=self._energy_grid.energy_dense,
+ energy=self._energy_grid.energy_dense - self.energy_offset.value,
temperature=self.temperature,
energy_unit=self.energy.unit,
divide_by_temperature=self.normalize_detailed_balance,
@@ -116,7 +121,7 @@ def convolution(
)
# Convolution
- convolved = fftconvolve(sample_vals, resolution_vals, mode='same')
+ convolved = fftconvolve(sample_vals, resolution_vals, mode="same")
convolved *= self._energy_grid.energy_dense_step # normalize
if self.upsample_factor is not None:
diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py
index ffcf0058..ba40f456 100644
--- a/src/easydynamics/convolution/numerical_convolution_base.py
+++ b/src/easydynamics/convolution/numerical_convolution_base.py
@@ -63,11 +63,12 @@ def __init__(
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent,
resolution_components: ComponentCollection | ModelComponent,
+ energy_offset: Numerical | Parameter = 0.0,
upsample_factor: Numerical = 5,
extension_factor: float = 0.2,
temperature: Parameter | float | None = None,
- temperature_unit: str | sc.Unit = 'K',
- energy_unit: str | sc.Unit = 'meV',
+ temperature_unit: str | sc.Unit = "K",
+ energy_unit: str | sc.Unit = "meV",
normalize_detailed_balance: bool = True,
):
super().__init__(
@@ -75,13 +76,16 @@ def __init__(
sample_components=sample_components,
resolution_components=resolution_components,
energy_unit=energy_unit,
+ energy_offset=energy_offset,
)
- if temperature is not None and not isinstance(temperature, (Numerical, Parameter)):
- raise TypeError('Temperature must be None, a number or a Parameter.')
+ if temperature is not None and not isinstance(
+ temperature, (Numerical, Parameter)
+ ):
+ raise TypeError("Temperature must be None, a number or a Parameter.")
if not isinstance(temperature_unit, (str, sc.Unit)):
- raise TypeError('Temperature_unit must be a string or sc.Unit.')
+ raise TypeError("Temperature_unit must be a string or sc.Unit.")
self._temperature_unit = temperature_unit
self._temperature = None
self.temperature = temperature
@@ -117,10 +121,10 @@ def upsample_factor(self, factor: Numerical) -> None:
return
if not isinstance(factor, Numerical):
- raise TypeError('Upsample factor must be a numerical value or None.')
+ raise TypeError("Upsample factor must be a numerical value or None.")
factor = float(factor)
if factor <= 1.0:
- raise ValueError('Upsample factor must be greater than 1.')
+ raise ValueError("Upsample factor must be greater than 1.")
self._upsample_factor = factor
@@ -156,9 +160,9 @@ def extension_factor(self, factor: Numerical) -> None:
TypeError: If factor is not a number.
"""
if not isinstance(factor, Numerical):
- raise TypeError('Extension factor must be a number.')
+ raise TypeError("Extension factor must be a number.")
if factor < 0.0:
- raise ValueError('Extension factor must be non-negative.')
+ raise ValueError("Extension factor must be non-negative.")
self._extension_factor = factor
# Recreate dense grid when extension factor is updated
@@ -192,7 +196,7 @@ def temperature(self, temp: Parameter | float | None) -> None:
self._temperature.value = float(temp)
else:
self._temperature = Parameter(
- name='temperature',
+ name="temperature",
value=float(temp),
unit=self._temperature_unit,
fixed=True,
@@ -200,7 +204,7 @@ def temperature(self, temp: Parameter | float | None) -> None:
elif isinstance(temp, Parameter):
self._temperature = temp
else:
- raise TypeError('Temperature must be None, a float or a Parameter.')
+ raise TypeError("Temperature must be None, a float or a Parameter.")
@property
def normalize_detailed_balance(self) -> bool:
@@ -221,7 +225,7 @@ def normalize_detailed_balance(self, normalize: bool) -> None:
"""
if not isinstance(normalize, bool):
- raise TypeError('normalize_detailed_balance must be True or False.')
+ raise TypeError("normalize_detailed_balance must be True or False.")
self._normalize_detailed_balance = normalize
@@ -239,9 +243,9 @@ def _create_energy_grid(
The dense grid created by upsampling and extending
energy.
The EnergyGrid has the following attributes:
- energy_dense : np.ndarray
+ energy_dense : np.ndarray
The upsampled and extended energy array.
- energy_dense_centered : np.ndarray
+ energy_dense_centered : np.ndarray
The centered version of energy_dense
(used for resolution evaluation).
energy_dense_step : float
@@ -259,7 +263,7 @@ def _create_energy_grid(
is_uniform = np.allclose(energy_diff, energy_diff[0])
if not is_uniform:
raise ValueError(
- 'Input array `energy` must be uniformly spaced if upsample_factor is not given.' # noqa: E501
+ "Input array `energy` must be uniformly spaced if upsample_factor is not given." # noqa: E501
)
energy_dense = self.energy.values
@@ -276,7 +280,7 @@ def _create_energy_grid(
energy_span_dense = extended_max - extended_min
if len(energy_dense) < 2:
- raise ValueError('Energy array must have at least two points.')
+ raise ValueError("Energy array must have at least two points.")
energy_dense_step = energy_dense[1] - energy_dense[0]
# Handle offset for even length of energy_dense in convolution.
@@ -346,35 +350,41 @@ def _check_width_thresholds(
components = [model] # Treat single ModelComponent as a list
for comp in components:
- if hasattr(comp, 'width'):
- if comp.width.value > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense:
+ if hasattr(comp, "width"):
+ if (
+ comp.width.value
+ > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense
+ ):
warnings.warn(
f"The width of the {model_name} component '{comp.unique_name}' \
({comp.width.value}) is large compared to the span of the input "
- f'array ({self._energy_grid.energy_span_dense}). \
+ f"array ({self._energy_grid.energy_span_dense}). \
This may lead to inaccuracies in the convolution. \
- Increase extension_factor to improve accuracy.',
+ Increase extension_factor to improve accuracy.",
UserWarning,
)
- if comp.width.value < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step:
+ if (
+ comp.width.value
+ < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step
+ ):
warnings.warn(
f"The width of the {model_name} component '{comp.unique_name}' \
({comp.width.value}) is small compared to the spacing of the input "
- f'array ({self._energy_grid.energy_dense_step}). \
+ f"array ({self._energy_grid.energy_dense_step}). \
This may lead to inaccuracies in the convolution. \
- Increase upsample_factor to improve accuracy.',
+ Increase upsample_factor to improve accuracy.",
UserWarning,
)
def __repr__(self) -> str:
return (
- f'{self.__class__.__name__}('
- f'energy=array of shape {self.energy.values.shape},\n '
- f'sample_components={repr(self.sample_components)}, \n'
- f'resolution_components={repr(self.resolution_components)},\n '
- f'energy_unit={self._energy_unit}, '
- f'upsample_factor={self.upsample_factor}, '
- f'extension_factor={self.extension_factor}, '
- f'temperature={self.temperature}, '
- f'normalize_detailed_balance={self.normalize_detailed_balance})'
+ f"{self.__class__.__name__}("
+ f"energy=array of shape {self.energy.values.shape},\n "
+ f"sample_components={repr(self.sample_components)}, \n"
+ f"resolution_components={repr(self.resolution_components)},\n "
+ f"energy_unit={self._energy_unit}, "
+ f"upsample_factor={self.upsample_factor}, "
+ f"extension_factor={self.extension_factor}, "
+ f"temperature={self.temperature}, "
+ f"normalize_detailed_balance={self.normalize_detailed_balance})"
)
diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py
index b3df2a11..771656b0 100644
--- a/src/easydynamics/experiment/experiment.py
+++ b/src/easydynamics/experiment/experiment.py
@@ -1,6 +1,4 @@
import os
-import warnings
-from typing import Optional
import plopp as pp
import scipp as sc
@@ -8,6 +6,8 @@
from scipp.io import load_hdf5 as sc_load_hdf5
from scipp.io import save_hdf5 as sc_save_hdf5
+from easydynamics.utils.utils import _in_notebook
+
class Experiment(NewBase):
"""Holds data from an experiment as a sc.DataArray along with
@@ -19,7 +19,7 @@ class Experiment(NewBase):
def __init__(
self,
- display_name: str = 'MyExperiment',
+ display_name: str = "MyExperiment",
unique_name: str | None = None,
data: sc.DataArray | str | None = None,
):
@@ -29,7 +29,7 @@ def __init__(
)
if data is None:
- self._data: Optional[sc.DataArray] = None
+ self._data = None
elif isinstance(data, str):
self.load_hdf5(filename=data)
elif isinstance(data, sc.DataArray):
@@ -37,7 +37,7 @@ def __init__(
self._data = data
else:
raise TypeError(
- f'Data must be a sc.DataArray or a filename string, not {type(data).__name__}'
+ f"Data must be a sc.DataArray or a filename string, not {type(data).__name__}"
)
self._binned_data = (
@@ -57,7 +57,7 @@ def data(self) -> sc.DataArray | None:
def data(self, value: sc.DataArray):
"""Set the dataset associated with this experiment."""
if not isinstance(value, sc.DataArray):
- raise TypeError(f'Data must be a sc.DataArray, not {type(value).__name__}')
+ raise TypeError(f"Data must be a sc.DataArray, not {type(value).__name__}")
self._validate_coordinates(value)
self._data = value
self._binned_data = (
@@ -72,33 +72,35 @@ def binned_data(self) -> sc.DataArray | None:
@binned_data.setter
def binned_data(self, value: sc.DataArray):
"""Set the binned dataset associated with this experiment."""
- raise AttributeError('binned_data is a read-only property. Use rebin() to rebin the data')
+ raise AttributeError(
+ "binned_data is a read-only property. Use rebin() to rebin the data"
+ )
@property
def Q(self) -> sc.Variable | None:
"""Get the Q values from the dataset."""
if self._data is None:
- warnings.warn('No data loaded.', UserWarning)
+ # warnings.warn("No data loaded.", UserWarning)
return None
- return self._binned_data.coords['Q']
+ return self._binned_data.coords["Q"]
@Q.setter
def Q(self, value: sc.Variable):
"""Set the Q values for the dataset."""
- raise AttributeError('Q is a read-only property derived from the data.')
+ raise AttributeError("Q is a read-only property derived from the data.")
@property
def energy(self) -> sc.Variable:
"""Get the energy values from the dataset."""
if self._data is None:
- warnings.warn('No data loaded.', UserWarning)
+ # warnings.warn("No data loaded.", UserWarning)
return None
- return self._binned_data.coords['energy']
+ return self._binned_data.coords["energy"]
@energy.setter
def energy(self, value: sc.Variable):
"""Set the energy values for the dataset."""
- raise AttributeError('energy is a read-only property derived from the data.')
+ raise AttributeError("energy is a read-only property derived from the data.")
###########
# Handle data
@@ -113,19 +115,19 @@ def load_hdf5(self, filename: str, display_name: str | None = None):
experiment.
"""
if not isinstance(filename, str):
- raise TypeError(f'Filename must be a string, not {type(filename).__name__}')
+ raise TypeError(f"Filename must be a string, not {type(filename).__name__}")
if display_name is not None:
if not isinstance(display_name, str):
raise TypeError(
- f'Display name must be a string, not {type(display_name).__name__}'
+ f"Display name must be a string, not {type(display_name).__name__}"
)
self.display_name = display_name
loaded_data = sc_load_hdf5(filename)
if not isinstance(loaded_data, sc.DataArray):
raise TypeError(
- f'Loaded data must be a sc.DataArray, not {type(loaded_data).__name__}'
+ f"Loaded data must be a sc.DataArray, not {type(loaded_data).__name__}"
)
self._validate_coordinates(loaded_data)
self.data = loaded_data
@@ -138,13 +140,13 @@ def save_hdf5(self, filename: str | None = None):
"""
if filename is None:
- filename = f'{self.unique_name}.h5'
+ filename = f"{self.unique_name}.h5"
if not isinstance(filename, str):
- raise TypeError(f'Filename must be a string, not {type(filename).__name__}')
+ raise TypeError(f"Filename must be a string, not {type(filename).__name__}")
if self._data is None:
- raise ValueError('No data to save.')
+ raise ValueError("No data to save.")
dir_name = os.path.dirname(filename)
if dir_name:
@@ -172,31 +174,33 @@ def rebin(self, dimensions: dict[str, int | sc.Variable]) -> None:
if not isinstance(dimensions, dict):
raise TypeError(
- 'dimensions must be a dictionary mapping dimension names '
- 'to number of bins or bin values as sc.Variable.'
+ "dimensions must be a dictionary mapping dimension names "
+ "to number of bins or bin values as sc.Variable."
)
if self._data is None:
- raise ValueError('No data to rebin. Please load data first.')
+ raise ValueError("No data to rebin. Please load data first.")
binned_data = self._data.copy()
dim_copy = dimensions.copy()
for dim, value in dim_copy.items():
if not isinstance(dim, str):
raise TypeError(
- f'Dimension keys must be strings. Got {type(dim)} for {dim} instead.'
+ f"Dimension keys must be strings. Got {type(dim)} for {dim} instead."
)
if dim not in self._data.dims:
raise KeyError(
f"Dimension '{dim}' not a valid dimension for rebinning. "
- f'Should be one of {self._data.dims}.'
+ f"Should be one of {self._data.dims}."
)
- if isinstance(value, float) and value.is_integer(): # I allow eg. 2.0 as well as 2
+ if (
+ isinstance(value, float) and value.is_integer()
+ ): # I allow eg. 2.0 as well as 2
value = int(value)
# This line can be removed when scipp resize support
# resizing with coordinates
dimensions[dim] = value
if not (isinstance(value, int) or isinstance(value, sc.Variable)):
raise TypeError(
- f'Dimension values must be integers or sc.Variable. '
+ f"Dimension values must be integers or sc.Variable. "
f"Got {type(value)} for dimension '{dim}' instead."
)
binned_data = binned_data.bin({dim: value})
@@ -213,15 +217,17 @@ def plot_data(self, slicer=False, **kwargs) -> None:
"""Plot the dataset using plopp."""
if self._binned_data is None:
- raise ValueError('No data to plot. Please load data first.')
+ raise ValueError("No data to plot. Please load data first.")
- if not self._in_notebook():
- raise RuntimeError('plot_data() can only be used in a Jupyter notebook environment.')
+ if not _in_notebook():
+ raise RuntimeError(
+ "plot_data() can only be used in a Jupyter notebook environment."
+ )
from IPython.display import display
plot_kwargs_defaults = {
- 'title': self.display_name,
+ "title": self.display_name,
}
# Overwrite defaults with any user-provided kwargs
plot_kwargs_defaults.update(kwargs)
@@ -232,7 +238,7 @@ def plot_data(self, slicer=False, **kwargs) -> None:
)
else:
fig = pp.plot(
- self._binned_data.transpose(dims=['energy', 'Q']),
+ self._binned_data.transpose(dims=["energy", "Q"]),
**plot_kwargs_defaults,
)
display(fig)
@@ -241,26 +247,6 @@ def plot_data(self, slicer=False, **kwargs) -> None:
# private methods
###########
- @staticmethod
- def _in_notebook() -> bool:
- """Check if the code is running in a Jupyter notebook.
-
- Returns:
- bool: True if in a Jupyter notebook, False otherwise.
- """
- try:
- from IPython import get_ipython
-
- shell = get_ipython().__class__.__name__
- if shell == 'ZMQInteractiveShell':
- return True # Jupyter notebook or JupyterLab
- elif shell == 'TerminalInteractiveShell':
- return False # Terminal IPython
- else:
- return False
- except (NameError, ImportError):
- return False # Standard Python (no IPython)
-
@staticmethod
def _validate_coordinates(data: sc.DataArray) -> None:
"""Validate that required coordinates are present in the data.
@@ -269,9 +255,9 @@ def _validate_coordinates(data: sc.DataArray) -> None:
ValueError: If required coordinates are missing.
"""
if not isinstance(data, sc.DataArray):
- raise TypeError('Data must be a sc.DataArray.')
+ raise TypeError("Data must be a sc.DataArray.")
- required_coords = ['Q', 'energy']
+ required_coords = ["Q", "energy"]
for coord in required_coords:
if coord not in data.coords:
raise ValueError(f"Data is missing required coordinate: '{coord}'")
@@ -297,11 +283,11 @@ def _convert_to_bin_centers(self, data: sc.DataArray) -> sc.DataArray:
###########
def __repr__(self) -> str:
- return f'Experiment `{self.unique_name}` with data: {self._data}'
+ return f"Experiment `{self.unique_name}` with data: {self._data}"
- def __copy__(self) -> 'Experiment':
+ def __copy__(self) -> "Experiment":
"""Return a copy of the object."""
- temp = self.to_dict(skip=['unique_name'])
+ temp = self.to_dict(skip=["unique_name"])
new_obj = self.__class__.from_dict(temp)
new_obj.data = self.data.copy() if self.data is not None else None
return new_obj
diff --git a/src/easydynamics/sample_model/__init__.py b/src/easydynamics/sample_model/__init__.py
index 5929fc50..443c1982 100644
--- a/src/easydynamics/sample_model/__init__.py
+++ b/src/easydynamics/sample_model/__init__.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
# SPDX-License-Identifier: BSD-3-Clause
+from .background_model import BackgroundModel
from .component_collection import ComponentCollection
from .components import DampedHarmonicOscillator
from .components import DeltaFunction
@@ -8,15 +9,24 @@
from .components import Lorentzian
from .components import Polynomial
from .components import Voigt
-from .diffusion_model.brownian_translational_diffusion import BrownianTranslationalDiffusion
+from .diffusion_model.brownian_translational_diffusion import (
+ BrownianTranslationalDiffusion,
+)
+from .instrument_model import InstrumentModel
+from .resolution_model import ResolutionModel
+from .sample_model import SampleModel
__all__ = [
- 'ComponentCollection',
- 'Gaussian',
- 'Lorentzian',
- 'Voigt',
- 'DeltaFunction',
- 'DampedHarmonicOscillator',
- 'Polynomial',
- 'BrownianTranslationalDiffusion',
+ "ComponentCollection",
+ "Gaussian",
+ "Lorentzian",
+ "Voigt",
+ "DeltaFunction",
+ "DampedHarmonicOscillator",
+ "Polynomial",
+ "BrownianTranslationalDiffusion",
+ "SampleModel",
+ "ResolutionModel",
+ "BackgroundModel",
+ "InstrumentModel",
]
diff --git a/src/easydynamics/sample_model/component_collection.py b/src/easydynamics/sample_model/component_collection.py
index 5978539d..a0b1e668 100644
--- a/src/easydynamics/sample_model/component_collection.py
+++ b/src/easydynamics/sample_model/component_collection.py
@@ -31,8 +31,8 @@ class ComponentCollection(ModelBase):
def __init__(
self,
- unit: str | sc.Unit = 'meV',
- display_name: str = 'MyComponentCollection',
+ unit: str | sc.Unit = "meV",
+ display_name: str = "MyComponentCollection",
unique_name: str | None = None,
components: List[ModelComponent] | None = None,
):
@@ -54,7 +54,7 @@ def __init__(
if unit is not None and not isinstance(unit, (str, sc.Unit)):
raise TypeError(
- f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}'
+ f"unit must be None, a string, or a scipp Unit, got {type(unit).__name__}"
)
self._unit = unit
self._components = []
@@ -62,31 +62,37 @@ def __init__(
# Add initial components if provided. Used for serialization.
if components is not None:
if not isinstance(components, list):
- raise TypeError('components must be a list of ModelComponent instances.')
+ raise TypeError(
+ "components must be a list of ModelComponent instances."
+ )
for comp in components:
self.append_component(comp)
- def append_component(self, component: ModelComponent | 'ComponentCollection') -> None:
+ def append_component(
+ self, component: ModelComponent | "ComponentCollection"
+ ) -> None:
match component:
case ModelComponent():
components = (component,)
case ComponentCollection(components=components):
pass
case _:
- raise TypeError('Component must be a ModelComponent or ComponentCollection.')
+ raise TypeError(
+ "Component must be a ModelComponent or ComponentCollection."
+ )
for comp in components:
if comp in self._components:
raise ValueError(
f"Component '{comp.unique_name}' is already in the collection. "
- f'Existing components: {self.list_component_names()}'
+ f"Existing components: {self.list_component_names()}"
)
self._components.append(comp)
def remove_component(self, unique_name: str) -> None:
if not isinstance(unique_name, str):
- raise TypeError('Component name must be a string.')
+ raise TypeError("Component name must be a string.")
for comp in self._components:
if comp.unique_name == unique_name:
@@ -95,8 +101,8 @@ def remove_component(self, unique_name: str) -> None:
raise KeyError(
f"No component named '{unique_name}' exists. "
- f'Did you accidentally use the display_name? '
- f'Here is a list of the components in the collection: {self.list_component_names()}'
+ f"Did you accidentally use the display_name? "
+ f"Here is a list of the components in the collection: {self.list_component_names()}"
)
@property
@@ -106,16 +112,27 @@ def components(self) -> list[ModelComponent]:
@components.setter
def components(self, components: List[ModelComponent]) -> None:
if not isinstance(components, list):
- raise TypeError('components must be a list of ModelComponent instances.')
+ raise TypeError("components must be a list of ModelComponent instances.")
for comp in components:
if not isinstance(comp, ModelComponent):
raise TypeError(
- 'All items in components must be instances of ModelComponent. '
- f'Got {type(comp).__name__} instead.'
+ "All items in components must be instances of ModelComponent. "
+ f"Got {type(comp).__name__} instead."
)
self._components = components
+ @property
+ def is_empty(self) -> bool:
+ return not self._components
+
+ @is_empty.setter
+ def is_empty(self, value: bool) -> None:
+ raise AttributeError(
+ "is_empty is a read-only property that indicates "
+ "whether the collection has components."
+ )
+
def list_component_names(self) -> List[str]:
"""List the names of all components in the model.
@@ -135,27 +152,27 @@ def normalize_area(self) -> None:
# Useful for convolutions.
"""Normalize the areas of all components so they sum to 1."""
if not self.components:
- raise ValueError('No components in the model to normalize.')
+ raise ValueError("No components in the model to normalize.")
area_params = []
- total_area = Parameter(name='total_area', value=0.0, unit=self._unit)
+ total_area = Parameter(name="total_area", value=0.0, unit=self._unit)
for component in self.components:
- if hasattr(component, 'area'):
+ if hasattr(component, "area"):
area_params.append(component.area)
total_area += component.area
else:
warnings.warn(
f"Component '{component.unique_name}' does not have an 'area' attribute "
- f'and will be skipped in normalization.',
+ f"and will be skipped in normalization.",
UserWarning,
)
if total_area.value == 0:
- raise ValueError('Total area is zero; cannot normalize.')
+ raise ValueError("Total area is zero; cannot normalize.")
if not np.isfinite(total_area.value):
- raise ValueError('Total area is not finite; cannot normalize.')
+ raise ValueError("Total area is not finite; cannot normalize.")
for param in area_params:
param.value /= total_area.value
@@ -167,7 +184,11 @@ def get_all_variables(self) -> list[DescriptorBase]:
List[Parameter]: List of parameters in the component.
"""
- return [var for component in self.components for var in component.get_all_variables()]
+ return [
+ var
+ for component in self.components
+ for var in component.get_all_variables()
+ ]
@property
def unit(self) -> str | sc.Unit:
@@ -183,8 +204,8 @@ def unit(self) -> str | sc.Unit:
def unit(self, unit_str: str) -> None:
raise AttributeError(
(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
- f'or create a new {self.__class__.__name__} with the desired unit.'
+ f"Unit is read-only. Use convert_unit to change the unit between allowed types "
+ f"or create a new {self.__class__.__name__} with the desired unit."
)
) # noqa: E501
@@ -208,7 +229,9 @@ def convert_unit(self, unit: str | sc.Unit) -> None:
pass # Best effort rollback
raise e
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def evaluate(
+ self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray
+ ) -> np.ndarray:
"""Evaluate the sum of all components.
Parameters
@@ -223,7 +246,7 @@ def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray)
"""
if not self.components:
- raise ValueError('No components in the model to evaluate.')
+ return np.zeros_like(x)
return sum(component.evaluate(x) for component in self.components)
def evaluate_component(
@@ -246,11 +269,13 @@ def evaluate_component(
Evaluated values for the specified component.
"""
if not self.components:
- raise ValueError('No components in the model to evaluate.')
+ raise ValueError("No components in the model to evaluate.")
if not isinstance(unique_name, str):
raise TypeError(
- (f'Component unique name must be a string, got {type(unique_name)} instead.')
+ (
+ f"Component unique name must be a string, got {type(unique_name)} instead."
+ )
)
matches = [comp for comp in self.components if comp.unique_name == unique_name]
@@ -303,6 +328,8 @@ def __repr__(self) -> str:
-------
str
"""
- comp_names = ', '.join(c.unique_name for c in self.components) or 'No components'
+ comp_names = (
+ ", ".join(c.unique_name for c in self.components) or "No components"
+ )
return f""
diff --git a/src/easydynamics/sample_model/diffusion_model/__init__.py b/src/easydynamics/sample_model/diffusion_model/__init__.py
index 6fd920dc..dc0a469c 100644
--- a/src/easydynamics/sample_model/diffusion_model/__init__.py
+++ b/src/easydynamics/sample_model/diffusion_model/__init__.py
@@ -2,9 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause
from .brownian_translational_diffusion import BrownianTranslationalDiffusion
-from .diffusion_model_base import DiffusionModelBase
+from .jump_translational_diffusion import JumpTranslationalDiffusion
__all__ = [
- 'DiffusionModelBase',
'BrownianTranslationalDiffusion',
+ 'JumpTranslationalDiffusion',
]
diff --git a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
index 749f8de4..d277d227 100644
--- a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
+++ b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
@@ -3,24 +3,20 @@
from typing import Dict
from typing import List
-from typing import Union
import numpy as np
import scipp as sc
from easyscience.variable import DescriptorNumber
from easyscience.variable import Parameter
-from numpy.typing import ArrayLike
from scipp.constants import hbar as scipp_hbar
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components import Lorentzian
from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase
+from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import Q_type
from easydynamics.utils.utils import _validate_and_convert_Q
-Numeric = Union[float, int]
-
-Q_type = np.ndarray | Numeric | list | ArrayLike
-
class BrownianTranslationalDiffusion(DiffusionModelBase):
"""Model of Brownian translational diffusion, consisting of a
@@ -46,7 +42,6 @@ def __init__(
unit: str | sc.Unit = 'meV',
scale: Numeric = 1.0,
diffusion_coefficient: Numeric = 1.0,
- diffusion_unit: str = 'm**2/s',
):
"""Initialize a new BrownianTranslationalDiffusion model.
@@ -62,65 +57,35 @@ def __init__(
Defaults to "meV".
scale : float or Parameter, optional
Scale factor for the diffusion model.
- diffusion_coefficient : float or Parameter, optional
- Diffusion coefficient D. If a number is provided,
- it is assumed to be in the unit given by diffusion_unit.
+ diffusion_coefficient : Number, optional
+ Diffusion coefficient D in m^2/s.
Defaults to 1.0.
- diffusion_unit : str, optional
- Unit for the diffusion coefficient D. Default is m**2/s.
- Options are 'meV*ร
**2' or 'm**2/s'
"""
- if not isinstance(scale, (Parameter, Numeric)):
+ if not isinstance(scale, Numeric):
raise TypeError('scale must be a number.')
- if not isinstance(diffusion_coefficient, (Parameter, Numeric)):
+ if not isinstance(diffusion_coefficient, Numeric):
raise TypeError('diffusion_coefficient must be a number.')
- if not isinstance(diffusion_unit, str):
- raise TypeError("diffusion_unit must be 'meV*ร
**2' or 'm**2/s'.")
-
- if diffusion_unit == 'meV*ร
**2' or diffusion_unit == 'meV*angstrom**2':
- # In this case, hbar is absorbed in the unit of D
- self._hbar = DescriptorNumber('hbar', 1.0)
- elif diffusion_unit == 'm**2/s' or diffusion_unit == 'm^2/s':
- self._hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar)
- else:
- raise ValueError("diffusion_unit must be 'meV*ร
**2' or 'm**2/s'.")
-
- scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0)
-
diffusion_coefficient = Parameter(
name='diffusion_coefficient',
value=float(diffusion_coefficient),
fixed=False,
- unit=diffusion_unit,
+ unit='m**2/s',
)
super().__init__(
display_name=display_name,
unique_name=unique_name,
unit=unit,
+ scale=scale,
)
+ self._hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar)
self._angstrom = DescriptorNumber('angstrom', 1e-10, unit='m')
- self._scale = scale
self._diffusion_coefficient = diffusion_coefficient
- @property
- def scale(self) -> Parameter:
- """Get the scale parameter of the diffusion model.
-
- Returns
- -------
- Parameter
- Scale parameter.
- """
- return self._scale
-
- @scale.setter
- def scale(self, scale: Numeric) -> None:
- """Set the scale parameter of the diffusion model."""
- if not isinstance(scale, (Numeric)):
- raise TypeError('scale must be a number.')
- self._scale.value = scale
+ # ------------------------------------------------------------------
+ # Properties
+ # ------------------------------------------------------------------
@property
def diffusion_coefficient(self) -> Parameter:
@@ -136,10 +101,14 @@ def diffusion_coefficient(self) -> Parameter:
@diffusion_coefficient.setter
def diffusion_coefficient(self, diffusion_coefficient: Numeric) -> None:
"""Set the diffusion coefficient parameter D."""
- if not isinstance(diffusion_coefficient, (Numeric)):
+ if not isinstance(diffusion_coefficient, Numeric):
raise TypeError('diffusion_coefficient must be a number.')
self._diffusion_coefficient.value = diffusion_coefficient
+ # ------------------------------------------------------------------
+ # Other methods
+ # ------------------------------------------------------------------
+
def calculate_width(self, Q: Q_type) -> np.ndarray:
"""Calculate the half-width at half-maximum (HWHM) for the
diffusion model.
@@ -265,6 +234,10 @@ def create_component_collections(
return component_collection_list
+ # ------------------------------------------------------------------
+ # Private methods
+ # ------------------------------------------------------------------
+
def _write_width_dependency_expression(self, Q: float) -> str:
"""Write the dependency expression for the width as a function
of Q to make dependent Parameters.
@@ -316,6 +289,10 @@ def _write_area_dependency_map_expression(self) -> Dict[str, DescriptorNumber]:
'scale': self.scale,
}
+ # ------------------------------------------------------------------
+ # dunder methods
+ # ------------------------------------------------------------------
+
def __repr__(self):
"""String representation of the BrownianTranslationalDiffusion
model.
diff --git a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
index 18b5bce8..a6711334 100644
--- a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
+++ b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
@@ -1,17 +1,14 @@
# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
# SPDX-License-Identifier: BSD-3-Clause
-import numpy as np
import scipp as sc
from easyscience.base_classes.model_base import ModelBase
from easyscience.variable import DescriptorNumber
-from numpy.typing import ArrayLike
+from easyscience.variable import Parameter
from scipp import UnitError
from easydynamics.utils.utils import Numeric
-Q_type = np.ndarray | Numeric | list | ArrayLike
-
class DiffusionModelBase(ModelBase):
"""Base class for constructing diffusion models."""
@@ -20,6 +17,7 @@ def __init__(
self,
display_name='MyDiffusionModel',
unique_name: str | None = None,
+ scale: Numeric = 1.0,
unit: str | sc.Unit = 'meV',
):
"""Initialize a new DiffusionModel.
@@ -31,6 +29,10 @@ def __init__(
unit : str or sc.Unit, optional
Unit of the diffusion model. Defaults to "meV".
"""
+ if not isinstance(scale, Numeric):
+ raise TypeError('scale must be a number.')
+
+ scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0)
try:
test = DescriptorNumber(name='test', value=1, unit=unit)
@@ -42,6 +44,11 @@ def __init__(
super().__init__(display_name=display_name, unique_name=unique_name)
self._unit = unit
+ self._scale = scale
+
+ # ------------------------------------------------------------------
+ # Properties
+ # ------------------------------------------------------------------
@property
def unit(self) -> str:
@@ -62,6 +69,28 @@ def unit(self, unit_str: str) -> None:
)
) # noqa: E501
+ @property
+ def scale(self) -> Parameter:
+ """Get the scale parameter of the diffusion model.
+
+ Returns
+ -------
+ Parameter
+ Scale parameter.
+ """
+ return self._scale
+
+ @scale.setter
+ def scale(self, scale: Numeric) -> None:
+ """Set the scale parameter of the diffusion model."""
+ if not isinstance(scale, Numeric):
+ raise TypeError('scale must be a number.')
+ self._scale.value = scale
+
+ # ------------------------------------------------------------------
+ # dunder methods
+ # ------------------------------------------------------------------
+
def __repr__(self):
"""String representation of the Diffusion model."""
return f'{self.__class__.__name__}(display_name={self.display_name}, unit={self.unit})'
diff --git a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py
new file mode 100644
index 00000000..286ea486
--- /dev/null
+++ b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py
@@ -0,0 +1,342 @@
+from typing import Dict
+from typing import List
+
+import numpy as np
+import scipp as sc
+from easyscience.variable import DescriptorNumber
+from easyscience.variable import Parameter
+from scipp.constants import hbar as scipp_hbar
+
+from easydynamics.sample_model.component_collection import ComponentCollection
+from easydynamics.sample_model.components import Lorentzian
+from easydynamics.sample_model.diffusion_model.diffusion_model_base import (
+ DiffusionModelBase,
+)
+from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import Q_type
+from easydynamics.utils.utils import _validate_and_convert_Q
+
+
+class JumpTranslationalDiffusion(DiffusionModelBase):
+ """Model of Jump translational diffusion, consisting of a Lorentzian
+ function for each Q-value, where the width is given by :math:`D
+ Q^2/(1+D t Q^2)`. Q is assumed to have units of 1/angstrom. Creates
+ ComponentCollections with Lorentzian components for given Q-values.
+
+ Example usage: Q=np.linspace(0.5,2,7) energy=np.linspace(-2, 2, 501)
+ scale=1.0 diffusion_coefficient = 2.4e-9 # m^2/s
+ diffusion_model=JumpTranslationalDiffusion(display_name="DiffusionModel",
+ scale=scale, diffusion_coefficient= diffusion_coefficient)
+ component_collections=diffusion_model.create_component_collections(Q)
+ See also the examples.
+ """
+
+ def __init__(
+ self,
+ display_name: str | None = "JumpTranslationalDiffusion",
+ unique_name: str | None = None,
+ unit: str | sc.Unit = "meV",
+ scale: Numeric = 1.0,
+ diffusion_coefficient: Numeric = 1.0,
+ relaxation_time: Numeric = 1.0,
+ ):
+ """Initialize a new JumpTranslationalDiffusion model.
+
+ Parameters
+ ----------
+ display_name : str
+ Display name of the diffusion model.
+ unique_name : str or None
+ Unique name of the diffusion model. If None, a unique name
+ is automatically generated.
+ unit : str or sc.Unit, optional
+ Energy unit for the underlying Lorentzian components.
+ Defaults to "meV".
+ scale : float, optional
+ Scale factor for the diffusion model.
+ diffusion_coefficient : float, optional
+ Diffusion coefficient D in m^2/s. Defaults to 1.0.
+ relaxation_time : float, optional
+ Relaxation time t in ps. Defaults to 1.0.
+ """
+ super().__init__(
+ display_name=display_name,
+ unique_name=unique_name,
+ unit=unit,
+ scale=scale,
+ )
+
+ if not isinstance(diffusion_coefficient, Numeric):
+ raise TypeError("diffusion_coefficient must be a number.")
+
+ if not isinstance(relaxation_time, Numeric):
+ raise TypeError("relaxation_time must be a number.")
+
+ diffusion_coefficient = Parameter(
+ name="diffusion_coefficient",
+ value=float(diffusion_coefficient),
+ fixed=False,
+ unit="m**2/s",
+ )
+
+ relaxation_time = Parameter(
+ name="relaxation_time",
+ value=float(relaxation_time),
+ fixed=False,
+ unit="ps",
+ )
+
+ self._hbar = DescriptorNumber.from_scipp("hbar", scipp_hbar)
+ self._angstrom = DescriptorNumber("angstrom", 1e-10, unit="m")
+ self._diffusion_coefficient = diffusion_coefficient
+ self._relaxation_time = relaxation_time
+
+ ################################
+ # Properties
+ ################################
+
+ @property
+ def diffusion_coefficient(self) -> Parameter:
+ """Get the diffusion coefficient parameter D.
+
+ Returns
+ -------
+ Parameter
+ Diffusion coefficient D.
+ """
+ return self._diffusion_coefficient
+
+ @diffusion_coefficient.setter
+ def diffusion_coefficient(self, diffusion_coefficient: Numeric) -> None:
+ """Set the diffusion coefficient parameter D."""
+ if not isinstance(diffusion_coefficient, Numeric):
+ raise TypeError("diffusion_coefficient must be a number.")
+ self._diffusion_coefficient.value = diffusion_coefficient
+
+ @property
+ def relaxation_time(self) -> Parameter:
+ """Get the relaxation time parameter t.
+
+ Returns
+ -------
+ Parameter
+ Relaxation time t.
+ """
+ return self._relaxation_time
+
+ @relaxation_time.setter
+ def relaxation_time(self, relaxation_time: Numeric) -> None:
+ """Set the relaxation time parameter t."""
+ if not isinstance(relaxation_time, Numeric):
+ raise TypeError("relaxation_time must be a number.")
+ self._relaxation_time.value = relaxation_time
+
+ ################################
+ # Other methods
+ ################################
+
+ def calculate_width(self, Q: Q_type) -> np.ndarray:
+ """Calculate the half-width at half-maximum (HWHM) for the
+ diffusion model. Equation: :math:`\\Gamma(Q) = \\hbar D Q^2/(1+D
+ t Q^2)`
+
+ Parameters
+ ----------
+ Q : np.ndarray | Numeric | list | ArrayLike
+ Scattering vector in 1/angstrom
+
+ Returns
+ -------
+ np.ndarray
+ HWHM values in the unit of the model (e.g., meV).
+ """
+
+ Q = _validate_and_convert_Q(Q)
+
+ unit_conversion_factor_numerator = (
+ self._hbar * self.diffusion_coefficient / (self._angstrom**2)
+ )
+ unit_conversion_factor_numerator.convert_unit(self.unit)
+
+ numerator = unit_conversion_factor_numerator.value * Q**2
+
+ unit_conversion_factor_denominator = (
+ self.diffusion_coefficient / self._angstrom**2 * self.relaxation_time
+ )
+ unit_conversion_factor_denominator.convert_unit("dimensionless")
+
+ denominator = 1 + unit_conversion_factor_denominator.value * Q**2
+
+ width = numerator / denominator
+ return width
+
+ def calculate_EISF(self, Q: Q_type) -> np.ndarray:
+ """Calculate the Elastic Incoherent Structure Factor (EISF).
+
+ Parameters
+ ----------
+ Q : np.ndarray | Numeric | list | ArrayLike
+ Scattering vector in 1/angstrom
+
+ Returns
+ -------
+ np.ndarray
+ EISF values (dimensionless).
+ """
+ Q = _validate_and_convert_Q(Q)
+ EISF = np.zeros_like(Q)
+ return EISF
+
+ def calculate_QISF(self, Q: Q_type) -> np.ndarray:
+ """Calculate the Quasi-Elastic Incoherent Structure Factor
+ (QISF).
+
+ Parameters
+ ----------
+ Q : np.ndarray | Numeric | list | ArrayLike
+ Scattering vector in 1/angstrom
+
+ Returns
+ -------
+ np.ndarray
+ QISF values (dimensionless).
+ """
+
+ Q = _validate_and_convert_Q(Q)
+ QISF = np.ones_like(Q)
+ return QISF
+
+ def create_component_collections(
+ self,
+ Q: Q_type,
+ component_display_name: str = "Jump translational diffusion",
+ ) -> List[ComponentCollection]:
+ """Create ComponentCollection components for the diffusion model
+ at given Q values.
+
+ Args:
+ ----------
+ Q : Number, list, or np.ndarray
+ Scattering vector values.
+ component_display_name : str
+ Name of the Jump Diffusion Lorentzian component.
+ Returns
+ -------
+ List[ComponentCollection]
+ List of ComponentCollections with Jump Diffusion
+ Lorentzian components.
+ """
+ Q = _validate_and_convert_Q(Q)
+
+ if not isinstance(component_display_name, str):
+ raise TypeError("component_name must be a string.")
+
+ component_collection_list = [None] * len(Q)
+ # In more complex models, this is used to scale the area of the
+ # Lorentzians and the delta function.
+ QISF = self.calculate_QISF(Q)
+
+ # Create a Lorentzian component for each Q-value, with width
+ # D*Q^2 and area equal to scale. No delta function, as the EISF
+ # is 0.
+ for i, Q_value in enumerate(Q):
+ component_collection_list[i] = ComponentCollection(
+ display_name=f"{self.display_name}_Q{Q_value:.2f}", unit=self.unit
+ )
+
+ lorentzian_component = Lorentzian(
+ display_name=component_display_name,
+ unit=self.unit,
+ )
+
+ # Make the width dependent on Q
+ dependency_expression = self._write_width_dependency_expression(Q[i])
+ dependency_map = self._write_width_dependency_map_expression()
+
+ lorentzian_component.width.make_dependent_on(
+ dependency_expression=dependency_expression,
+ dependency_map=dependency_map,
+ )
+
+ # Make the area dependent on Q
+ area_dependency_map = self._write_area_dependency_map_expression()
+ lorentzian_component.area.make_dependent_on(
+ dependency_expression=self._write_area_dependency_expression(QISF[i]),
+ dependency_map=area_dependency_map,
+ )
+
+ # Resolving the dependency can do weird things to the units,
+ # so we make sure it's correct.
+ lorentzian_component.width.convert_unit(self.unit)
+ component_collection_list[i].append_component(lorentzian_component)
+
+ return component_collection_list
+
+ ################################
+ # Private methods
+ ################################
+
+ def _write_width_dependency_expression(self, Q: float) -> str:
+ """Write the dependency expression for the width as a function
+ of Q to make dependent Parameters.
+
+ Parameters
+ ----------
+ Q : float
+ Scattering vector in 1/angstrom
+ Returns
+ -------
+ str
+ Dependency expression for the width.
+ """
+ if not isinstance(Q, (float)):
+ raise TypeError("Q must be a float.")
+
+ # Q is given as a float, so we need to add the units
+ return f"hbar * D* {Q} **2/(angstrom**2)/(1 + (D * t* {Q} **2/(angstrom**2)))"
+
+ def _write_width_dependency_map_expression(self) -> Dict[str, DescriptorNumber]:
+ """Write the dependency map expression to make dependent
+ Parameters.
+ """
+ return {
+ "D": self._diffusion_coefficient,
+ "t": self._relaxation_time,
+ "hbar": self._hbar,
+ "angstrom": self._angstrom,
+ }
+
+ def _write_area_dependency_expression(self, QISF: float) -> str:
+ """Write the dependency expression for the area to make
+ dependent Parameters.
+
+ Returns
+ -------
+ str
+ Dependency expression for the area.
+ """
+ if not isinstance(QISF, (float)):
+ raise TypeError("QISF must be a float.")
+
+ return f"{QISF} * scale"
+
+ def _write_area_dependency_map_expression(self) -> Dict[str, DescriptorNumber]:
+ """Write the dependency map expression to make dependent
+ Parameters.
+ """
+ return {
+ "scale": self._scale,
+ }
+
+ ################################
+ # dunder methods
+ ################################
+
+ def __repr__(self):
+ """String representation of the JumpTranslationalDiffusion
+ model.
+ """
+ return (
+ f"JumpTranslationalDiffusion(display_name={self.display_name}, "
+ f"diffusion_coefficient={self._diffusion_coefficient}, scale={self._scale})"
+ )
diff --git a/src/easydynamics/sample_model/instrument_model.py b/src/easydynamics/sample_model/instrument_model.py
index 9f3eb1d2..4c767331 100644
--- a/src/easydynamics/sample_model/instrument_model.py
+++ b/src/easydynamics/sample_model/instrument_model.py
@@ -1,5 +1,337 @@
# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
# SPDX-License-Identifier: BSD-3-Clause
-# instrument_model will contain resolution_model and background_model
-# as well as offset
+from copy import copy
+
+import numpy as np
+import scipp as sc
+from easyscience.base_classes.new_base import NewBase
+from easyscience.variable import Parameter
+
+from easydynamics.sample_model.background_model import BackgroundModel
+from easydynamics.sample_model.resolution_model import ResolutionModel
+from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import Q_type
+from easydynamics.utils.utils import _validate_and_convert_Q
+from easydynamics.utils.utils import _validate_unit
+
+
+class InstrumentModel(NewBase):
+ """InstrumentModel represents a model of the instrument in an
+ experiment at various Q. It can contain a model of the resolution
+ function for convolutions, of the background and an offset in the
+ energy axis.
+
+ Parameters
+ ----------
+ display_name : str, optional
+ The display name of the InstrumentModel. Default is
+ "MyInstrumentModel".
+ unique_name : str or None, optional
+ The unique name of the InstrumentModel. Default is None.
+ Q : np.ndarray, list, scipp Variable or None, optional
+ The Q values where the instrument is modelled.
+ resolution_model : ResolutionModel or None, optional
+ The resolution model of the instrument. If None, an empty
+ resolution model is created and no resolution convolution is
+ carried out. Default is None.
+ background_model : BackgroundModel or None, optional
+ The background model of the instrument. If None, an empty
+ background model is created, and the background evaluates to 0.
+ Default is None.
+ energy_offset : float, int or None, optional
+ Template energy offset of the instrument. Will be copied to each
+ Q value. If None, the energy offset will be 0. Default is None.
+ unit : str or sc.Unit, optional
+ The unit of the energy axis. Default is 'meV'.
+ """
+
+ def __init__(
+ self,
+ display_name: str = "MyInstrumentModel",
+ unique_name: str | None = None,
+ Q: Q_type | None = None,
+ resolution_model: ResolutionModel | None = None,
+ background_model: BackgroundModel | None = None,
+ energy_offset: Numeric | None = None,
+ unit: str | sc.Unit = "meV",
+ ):
+ super().__init__(
+ display_name=display_name,
+ unique_name=unique_name,
+ )
+
+ self._unit = _validate_unit(unit)
+
+ if resolution_model is None:
+ self._resolution_model = ResolutionModel()
+ else:
+ if not isinstance(resolution_model, ResolutionModel):
+ raise TypeError(
+ f"resolution_model must be a ResolutionModel or None, "
+ f"got {type(resolution_model).__name__}"
+ )
+ self._resolution_model = resolution_model
+
+ if background_model is None:
+ self._background_model = BackgroundModel()
+ else:
+ if not isinstance(background_model, BackgroundModel):
+ raise TypeError(
+ f"background_model must be a BackgroundModel or None, "
+ f"got {type(background_model).__name__}"
+ )
+ self._background_model = background_model
+
+ if energy_offset is None:
+ energy_offset = 0.0
+
+ if not isinstance(energy_offset, Numeric):
+ raise TypeError("energy_offset must be a number or None")
+
+ self._energy_offset = Parameter(
+ name="energy_offset",
+ value=float(energy_offset),
+ unit=self.unit,
+ fixed=False,
+ )
+ self._Q = _validate_and_convert_Q(Q)
+ self._on_Q_change()
+
+ # -------------------------------------------------------------
+ # Properties
+ # -------------------------------------------------------------
+
+ @property
+ def resolution_model(self) -> ResolutionModel:
+ """Get the resolution model of the instrument."""
+ return self._resolution_model
+
+ @resolution_model.setter
+ def resolution_model(self, value: ResolutionModel):
+ """Set the resolution model of the instrument."""
+ if not isinstance(value, ResolutionModel):
+ raise TypeError(
+ f"resolution_model must be a ResolutionModel, got {type(value).__name__}"
+ )
+ self._resolution_model = value
+ self._on_resolution_model_change()
+
+ @property
+ def background_model(self) -> BackgroundModel:
+ """The background model of the instrument."""
+ return self._background_model
+
+ @background_model.setter
+ def background_model(self, value: BackgroundModel):
+ """Set the background model of the instrument."""
+ if not isinstance(value, BackgroundModel):
+ raise TypeError(
+ f"background_model must be a BackgroundModel, got {type(value).__name__}"
+ )
+ self._background_model = value
+ self._on_background_model_change()
+
+ @property
+ def Q(self) -> np.ndarray | None:
+ """Get the Q values of the InstrumentModel."""
+ return self._Q
+
+ @Q.setter
+ def Q(self, value: Q_type | None) -> None:
+ """Set the Q values of the InstrumentModel."""
+ self._Q = _validate_and_convert_Q(value)
+ self._on_Q_change()
+
+ @property
+ def unit(self) -> sc.Unit:
+ """Get the unit of the InstrumentModel.
+
+ Returns
+ -------
+ str or sc.Unit or None
+ """
+ return self._unit
+
+ @unit.setter
+ def unit(self, unit_str: str) -> None:
+ raise AttributeError(
+ (
+ f"Unit is read-only. Use convert_unit to change the unit between allowed types "
+ f"or create a new {self.__class__.__name__} with the desired unit."
+ )
+ ) # noqa: E501
+
+ @property
+ def energy_offset(self) -> Parameter:
+ """The energy offset template parameter of the instrument
+ model.
+ """
+ return self._energy_offset
+
+ @energy_offset.setter
+ def energy_offset(self, value: Numeric):
+ """Set the offset parameter of the instrument model.".
+
+ Parameters
+ ----------
+ value : float or int
+ The new value for the energy offset parameter. Will be
+ copied to all Q values.
+ Raises
+ ------
+ TypeError
+ If value is not a number.
+ """
+ if not isinstance(value, Numeric):
+ raise TypeError(
+ f"energy_offset must be a number, got {type(value).__name__}"
+ )
+ self._energy_offset.value = value
+
+ self._on_energy_offset_change()
+
+ # --------------------------------------------------------------
+ # Other methods
+ # --------------------------------------------------------------
+
+ def convert_unit(self, unit_str: str | sc.Unit) -> None:
+ """Convert the unit of the InstrumentModel.
+
+ Parameters
+ ----------
+ unit_str : str or sc.Unit
+ The unit to convert to.
+
+ Raises
+ ------
+ TypeError
+ If unit_str is not a string or scipp Unit.
+ """
+ unit = _validate_unit(unit_str)
+ if unit is None:
+ raise ValueError("unit_str must be a valid unit string or scipp Unit")
+
+ self._background_model.convert_unit(unit)
+ self._resolution_model.convert_unit(unit)
+ self._energy_offset.convert_unit(unit)
+ for offset in self._energy_offsets:
+ offset.convert_unit(unit)
+
+ self._unit = unit
+
+ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
+ """Get all variables in the InstrumentModel.
+
+ Parameters
+ ----------
+ Q_index : int | None
+ The index of the Q value to get variables for. If None, get
+ variables for all Q values.
+ Returns
+ -------
+ list of Parameter
+ All variables in the InstrumentModel.
+ """
+ if self._Q is None:
+ return []
+
+ if Q_index is None:
+ variables = [self._energy_offsets[i] for i in range(len(self._Q))]
+ else:
+ if not isinstance(Q_index, int):
+ raise TypeError(
+ f"Q_index must be an int or None, got {type(Q_index).__name__}"
+ )
+ if Q_index < 0 or Q_index >= len(self._Q):
+ raise IndexError(
+ f"Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}"
+ )
+ variables = [self._energy_offsets[Q_index]]
+
+ variables.extend(self._background_model.get_all_variables(Q_index=Q_index))
+ variables.extend(self._resolution_model.get_all_variables(Q_index=Q_index))
+
+ return variables
+
+ def fix_resolution_parameters(self) -> None:
+ """Fix all parameters in the resolution model."""
+ self.resolution_model.fix_all_parameters()
+
+ def free_resolution_parameters(self) -> None:
+ """Free all parameters in the resolution model."""
+ self.resolution_model.free_all_parameters()
+
+ def get_energy_offset_at_Q(self, Q_index: int) -> Parameter:
+ """Get the energy offset Parameter at a specific Q index.
+
+ Parameters
+ ----------
+ Q_index : int
+ The index of the Q value to get the energy offset for.
+
+ Returns
+ -------
+ Parameter
+ The energy offset Parameter at the specified Q index.
+
+ Raises
+ ------
+ IndexError
+ If Q_index is out of bounds.
+ """
+ if self._Q is None:
+ raise ValueError("No Q values are set in the InstrumentModel.")
+
+ if Q_index < 0 or Q_index >= len(self._Q):
+ raise IndexError(
+ f"Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}"
+ )
+
+ return self._energy_offsets[Q_index]
+
+ # --------------------------------------------------------------
+ # Private methods
+ # --------------------------------------------------------------
+
+ def _generate_energy_offsets(self) -> None:
+ """Generate energy offset Parameters for each Q value."""
+ if self._Q is None:
+ self._energy_offsets = []
+ return
+
+ self._energy_offsets = [copy(self._energy_offset) for _ in self._Q]
+
+ def _on_Q_change(self) -> None:
+ """Handle changes to the Q values."""
+ self._generate_energy_offsets()
+ self._resolution_model.Q = self._Q
+ self._background_model.Q = self._Q
+
+ def _on_energy_offset_change(self) -> None:
+ """Handle changes to the energy offset."""
+ for offset in self._energy_offsets:
+ offset.value = self._energy_offset.value
+
+ def _on_resolution_model_change(self) -> None:
+ """Handle changes to the resolution model."""
+ self._resolution_model.Q = self._Q
+
+ def _on_background_model_change(self) -> None:
+ """Handle changes to the background model."""
+ self._background_model.Q = self._Q
+
+ # -------------------------------------------------------------
+ # Dunder methods
+ # -------------------------------------------------------------
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}("
+ f"unique_name={self.unique_name!r}, "
+ f"unit={self.unit}, "
+ f"Q_len={None if self._Q is None else len(self._Q)}, "
+ f"resolution_model={self._resolution_model!r}, "
+ f"background_model={self._background_model!r}"
+ f")"
+ )
diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py
index c99e3f62..039da8bd 100644
--- a/src/easydynamics/sample_model/model_base.py
+++ b/src/easydynamics/sample_model/model_base.py
@@ -1,12 +1,12 @@
# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
# SPDX-License-Identifier: BSD-3-Clause
-import warnings
from copy import copy
import numpy as np
import scipp as sc
from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase
+from easyscience.variable import Parameter
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
@@ -42,9 +42,9 @@ class ModelBase(EasyScienceModelBase):
def __init__(
self,
- display_name: str = 'MyModelBase',
+ display_name: str = "MyModelBase",
unique_name: str | None = None,
- unit: str | sc.Unit | None = 'meV',
+ unit: str | sc.Unit | None = "meV",
components: ModelComponent | ComponentCollection | None = None,
Q: Q_type | None = None,
):
@@ -59,8 +59,8 @@ def __init__(
components, (ModelComponent, ComponentCollection)
):
raise TypeError(
- f'Components must be a ModelComponent, a ComponentCollection or None, '
- f'got {type(components).__name__}'
+ f"Components must be a ModelComponent, a ComponentCollection or None, "
+ f"got {type(components).__name__}"
)
self._components = ComponentCollection()
@@ -87,8 +87,8 @@ def evaluate(
if not self._component_collections:
raise ValueError(
- 'No components in the model to evaluate. '
- 'Run generate_component_collections() first'
+ "No components in the model to evaluate. "
+ "Run generate_component_collections() first"
)
y = [collection.evaluate(x) for collection in self._component_collections]
@@ -106,7 +106,7 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N
The ModelComponent or ComponentCollection to append.
"""
self._components.append_component(component)
- self._generate_component_collections()
+ self._on_components_change()
def remove_component(self, unique_name: str) -> None:
"""Remove a ModelComponent from the SampleModel by its unique
@@ -117,12 +117,12 @@ def remove_component(self, unique_name: str) -> None:
to remove.
"""
self._components.remove_component(unique_name)
- self._generate_component_collections()
+ self._on_components_change()
def clear_components(self) -> None:
"""Clear all ModelComponents from the SampleModel."""
self._components.clear_components()
- self._generate_component_collections()
+ self._on_components_change()
# ------------------------------------------------------------------
# Properties
@@ -142,8 +142,8 @@ def unit(self) -> str | sc.Unit:
def unit(self, unit_str: str) -> None:
raise AttributeError(
(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
- f'or create a new {self.__class__.__name__} with the desired unit.'
+ f"Unit is read-only. Use convert_unit to change the unit between allowed types "
+ f"or create a new {self.__class__.__name__} with the desired unit."
)
) # noqa: E501
@@ -166,7 +166,7 @@ def convert_unit(self, unit: str | sc.Unit) -> None:
except Exception: # noqa: S110
pass # Best effort rollback
raise e
- self._generate_component_collections()
+ self._on_components_change()
@property
def components(self) -> list[ModelComponent]:
@@ -177,7 +177,9 @@ def components(self) -> list[ModelComponent]:
def components(self, value: ModelComponent | ComponentCollection | None) -> None:
"""Set the components of the SampleModel."""
if not isinstance(value, (ModelComponent, ComponentCollection, type(None))):
- raise TypeError('Components must be a ModelComponent or a ComponentCollection')
+ raise TypeError(
+ "Components must be a ModelComponent or a ComponentCollection"
+ )
self.clear_components()
if value is not None:
@@ -191,8 +193,88 @@ def Q(self) -> np.ndarray | None:
@Q.setter
def Q(self, value: Q_type | None) -> None:
"""Set the Q values of the SampleModel."""
- self._Q = _validate_and_convert_Q(value)
- self._generate_component_collections()
+ old_Q = self._Q
+ new_Q = _validate_and_convert_Q(value)
+
+ if (
+ old_Q is not None
+ and new_Q is not None
+ and len(old_Q) == len(new_Q)
+ and all(np.isclose(old_Q, new_Q))
+ ):
+ return # No change in Q, so do nothing
+ self._Q = new_Q
+ self._on_Q_change()
+
+ # ------------------------------------------------------------------
+ # Other methods
+ # ------------------------------------------------------------------
+ def fix_all_parameters(self) -> None:
+ """Fix all Parameters in all ComponentCollections."""
+ for par in self.get_all_variables():
+ par.fixed = True
+
+ def free_all_parameters(self) -> None:
+ """Free all Parameters in all ComponentCollections."""
+ for par in self.get_all_variables():
+ par.fixed = False
+
+ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
+ """Get all Parameters and Descriptors from all
+ ComponentCollections in the ModelBase. Parameters Ignores the
+ Parameters and Descriptors in self._components as these are just
+ templates.
+
+ Parameters
+ ----------
+ Q_index : int | None
+ If int, get variables for the ComponentCollection at
+ this index. If None, get variables for all
+ ComponentCollections.
+ Returns
+ -------
+ list[Parameter]
+ """
+ if Q_index is None:
+ all_vars = [
+ var
+ for collection in self._component_collections
+ for var in collection.get_all_variables()
+ ]
+ else:
+ if not isinstance(Q_index, int):
+ raise TypeError(
+ f"Q_index must be an int or None, got {type(Q_index).__name__}"
+ )
+ if Q_index < 0 or Q_index >= len(self._component_collections):
+ raise IndexError(
+ f"Q_index {Q_index} is out of bounds for component collections "
+ f"of length {len(self._component_collections)}"
+ )
+ all_vars = self._component_collections[Q_index].get_all_variables()
+ return all_vars
+
+ def get_component_collection(self, Q_index: int) -> ComponentCollection:
+ """Get the ComponentCollection at the given Q index.
+
+ Parameters
+ ----------
+ Q_index : int
+ The index of the desired ComponentCollection.
+
+ Returns
+ -------
+ ComponentCollection
+ The ComponentCollection at the specified Q index.
+ """
+ if not isinstance(Q_index, int):
+ raise TypeError(f"Q_index must be an int, got {type(Q_index).__name__}")
+ if Q_index < 0 or Q_index >= len(self._component_collections):
+ raise IndexError(
+ f"Q_index {Q_index} is out of bounds for component collections "
+ f"of length {len(self._component_collections)}"
+ )
+ return self._component_collections[Q_index]
# ------------------------------------------------------------------
# Private methods
@@ -203,32 +285,24 @@ def _generate_component_collections(self) -> None:
# TODO regenerate automatically if Q or components have changed
if self._Q is None:
- warnings.warn('Q is not set. No component collections generated', UserWarning)
+ # warnings.warn(
+ # "Q is not set. No component collections generated", UserWarning
+ # )
self._component_collections = []
return
- self._component_collections = [ComponentCollection() for _ in self._Q]
-
- # Add copies of components from self._components to each
- # component collection
- for collection in self._component_collections:
- for component in self._components.components:
- collection.append_component(copy(component))
-
- def get_all_variables(self):
- """Get all Parameters and Descriptors from all
- ComponentCollections in the ModelBase.
+ # Will fix it for my code I think
+ self._component_collections = []
+ for _ in self._Q:
+ self._component_collections.append(copy(self._components))
- Ignores the Parameters and Descriptors in self._components as
- these are just templates.
- """
+ def _on_Q_change(self) -> None:
+ """Handle changes to the Q values."""
+ self._generate_component_collections()
- all_vars = [
- var
- for collection in self._component_collections
- for var in collection.get_all_variables()
- ]
- return all_vars
+ def _on_components_change(self) -> None:
+ """Handle changes to the components."""
+ self._generate_component_collections()
# ------------------------------------------------------------------
# dunder methods
@@ -236,6 +310,6 @@ def get_all_variables(self):
def __repr__(self):
return (
- f'{self.__class__.__name__}(unique_name={self.unique_name}, '
- f'unit={self.unit}), Q = {self.Q}, components = {self.components}'
+ f"{self.__class__.__name__}(unique_name={self.unique_name}, "
+ f"unit={self.unit}), Q = {self.Q}, components = {self.components}"
)
diff --git a/src/easydynamics/sample_model/sample_model.py b/src/easydynamics/sample_model/sample_model.py
index af5550a9..346bd7a4 100644
--- a/src/easydynamics/sample_model/sample_model.py
+++ b/src/easydynamics/sample_model/sample_model.py
@@ -5,7 +5,7 @@
import scipp as sc
from easyscience.variable import Parameter
-from easydynamics.sample_model.diffusion_model import DiffusionModelBase
+from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase
from easydynamics.sample_model.model_base import ModelBase
from easydynamics.utils import _detailed_balance_factor
from easydynamics.utils.utils import Numeric
@@ -175,7 +175,7 @@ def diffusion_models(
'or None'
)
self._diffusion_models = value
- self._generate_component_collections()
+ self._on_diffusion_models_change()
@property
def temperature(self) -> Parameter | None:
@@ -286,7 +286,7 @@ def evaluate(
return y
- def get_all_variables(self):
+ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
"""Get all Parameters and Descriptors from all
ComponentCollections in the SampleModel.
@@ -294,7 +294,8 @@ def get_all_variables(self):
diffusion models. Ignores the Parameters and Descriptors in
self._components as these are just templates.
"""
- all_vars = super().get_all_variables()
+
+ all_vars = super().get_all_variables(Q_index=Q_index)
if self._temperature is not None:
all_vars.append(self._temperature)
@@ -325,6 +326,10 @@ def _generate_component_collections(self) -> None:
for component in source.components:
target.append_component(component)
+ def _on_diffusion_models_change(self) -> None:
+ """Handle changes to the diffusion models."""
+ self._generate_component_collections()
+
# ------------------------------------------------------------------
# dunder methods
# ------------------------------------------------------------------
diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py
index 576b451d..bcd44c14 100644
--- a/src/easydynamics/utils/utils.py
+++ b/src/easydynamics/utils/utils.py
@@ -26,7 +26,7 @@ def _validate_and_convert_Q(Q: Q_type | None) -> np.ndarray | None:
if Q is None:
return None
if not isinstance(Q, (Numeric, list, np.ndarray, sc.Variable)):
- raise TypeError('Q must be a number, list, numpy array, or scipp array.')
+ raise TypeError("Q must be a number, list, numpy array, or scipp array.")
if isinstance(Q, Numeric):
Q = np.array([Q])
@@ -34,14 +34,14 @@ def _validate_and_convert_Q(Q: Q_type | None) -> np.ndarray | None:
Q = np.array(Q)
if isinstance(Q, np.ndarray):
if Q.ndim > 1:
- raise ValueError('Q must be a 1-dimensional array.')
+ raise ValueError("Q must be a 1-dimensional array.")
- Q = sc.array(dims=['Q'], values=Q, unit='1/angstrom')
+ Q = sc.array(dims=["Q"], values=Q, unit="1/angstrom")
if isinstance(Q, sc.Variable):
- if Q.dims != ('Q',):
+ if Q.dims != ("Q",):
raise ValueError("Q must have a single dimension named 'Q'.")
- Q = Q.to(unit='1/angstrom')
+ Q = Q.to(unit="1/angstrom")
return Q.values
@@ -64,7 +64,29 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None:
"""
if unit is not None and not isinstance(unit, (str, sc.Unit)):
- raise TypeError(f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}')
+ raise TypeError(
+ f"unit must be None, a string, or a scipp Unit, got {type(unit).__name__}"
+ )
if isinstance(unit, str):
unit = sc.Unit(unit)
return unit
+
+
+def _in_notebook() -> bool:
+ """Check if the code is running in a Jupyter notebook.
+
+ Returns:
+ bool: True if in a Jupyter notebook, False otherwise.
+ """
+ try:
+ from IPython import get_ipython
+
+ shell = get_ipython().__class__.__name__
+ if shell == "ZMQInteractiveShell":
+ return True # Jupyter notebook or JupyterLab
+ elif shell == "TerminalInteractiveShell":
+ return False # Terminal IPython
+ else:
+ return False
+ except (NameError, ImportError):
+ return False # Standard Python (no IPython)
diff --git a/tests/conftest.py b/tests/conftest.py
index aefc6c0b..d11735d3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,21 +5,73 @@
# TODO: remove once weakref bug is fixed
-import easyscience.global_object
+# import easyscience.global_object
+# import pytest
+
+
+# @pytest.fixture(autouse=True)
+# def reset_global_object():
+# easyscience.global_object.map._clear()
+
+from unittest.mock import patch
+
import pytest
-# from easyscience.global_object.map import Map
+@pytest.fixture(autouse=True)
+def patch_easyscience_map():
+ """Patch the problematic Map methods."""
+ from easyscience.global_object.map import Map
-# @pytest.fixture(autouse=True)
-# def reset_global_object(monkeypatch):
-# # Before each test
-# monkeypatch.setattr(easyscience.global_object, 'map', Map())
-# yield
-# # After each test (cleanup)
-# monkeypatch.setattr(easyscience.global_object, 'map', Map())
+ # Store the original methods
+ original_add_vertex = Map.add_vertex
+ # original_vertices = Map.vertices
+
+ def safe_add_vertex(self, obj: object, obj_type: str = None):
+ try:
+ return original_add_vertex(self, obj, obj_type)
+ except KeyError:
+ # Object was garbage collected during setup
+ name = obj.unique_name
+ # Clean up any partial state
+ if hasattr(self, '_Map__type_dict') and name in self._Map__type_dict:
+ del self._Map__type_dict[name]
+ if name in self._store:
+ del self._store[name]
+
+ def safe_vertices(self):
+ """Safe version of vertices() that handles dictionary changes
+ during iteration."""
+ max_retries = 3
+ for attempt in range(max_retries):
+ try:
+ return list(self._store.keys())
+ except RuntimeError as e:
+ if 'dictionary changed size during iteration' in str(e):
+ if attempt < max_retries - 1:
+ # Force cleanup and try again
+ import gc
+ gc.collect()
+ continue
+ else:
+ # Last attempt - return what we can get
+ try:
+ # Try to get keys in a different way
+ keys = []
+ for k in list(self._store.data.keys()):
+ if k in self._store:
+ keys.append(k)
+ return keys
+ except: # noqa: E722
+ return []
+ else:
+ raise
+ return []
-@pytest.fixture(autouse=False)
-def reset_global_object():
- easyscience.global_object.map._clear()
+ # Apply the patches
+ with (
+ patch.object(Map, 'add_vertex', safe_add_vertex),
+ patch.object(Map, 'vertices', safe_vertices),
+ ):
+ yield
diff --git a/tests/unit/easydynamics/experiment/test_experiment.py b/tests/unit/easydynamics/experiment/test_experiment.py
index 067a2017..05aa2470 100644
--- a/tests/unit/easydynamics/experiment/test_experiment.py
+++ b/tests/unit/easydynamics/experiment/test_experiment.py
@@ -12,12 +12,12 @@
class TestExperiment:
@pytest.fixture
def experiment(self):
- Q = sc.linspace('Q', 0.5, 1.5, num=10, unit='1/Angstrom')
- energy = sc.linspace('energy', -5, 5, num=11, unit='meV')
- values = sc.array(dims=['Q', 'energy'], values=np.ones((10, 11)))
- data = sc.DataArray(data=values, coords={'Q': Q, 'energy': energy})
+ Q = sc.linspace("Q", 0.5, 1.5, num=10, unit="1/Angstrom")
+ energy = sc.linspace("energy", -5, 5, num=11, unit="meV")
+ values = sc.array(dims=["Q", "energy"], values=np.ones((10, 11)))
+ data = sc.DataArray(data=values, coords={"Q": Q, "energy": energy})
- experiment = Experiment(display_name='test_experiment', data=data)
+ experiment = Experiment(display_name="test_experiment", data=data)
return experiment
##############
@@ -27,51 +27,51 @@ def experiment(self):
def test_init_array(self, experiment):
"Test initialization with a Scipp DataArray"
# WHEN THEN EXPECT
- assert experiment.display_name == 'test_experiment'
+ assert experiment.display_name == "test_experiment"
assert isinstance(experiment._data, sc.DataArray)
- assert 'Q' in experiment._data.dims
- assert 'energy' in experiment._data.dims
- assert experiment._data.sizes['Q'] == 10
- assert experiment._data.sizes['energy'] == 11
+ assert "Q" in experiment._data.dims
+ assert "energy" in experiment._data.dims
+ assert experiment._data.sizes["Q"] == 10
+ assert experiment._data.sizes["energy"] == 11
assert sc.identical(
experiment._data.data,
- sc.array(dims=['Q', 'energy'], values=np.ones((10, 11))),
+ sc.array(dims=["Q", "energy"], values=np.ones((10, 11))),
)
def test_init_string(self, tmp_path):
"Test initialization with a filename string,"
- 'should load the file'
+ "should load the file"
# WHEN
- Q = sc.linspace('Q', 0.5, 1.5, num=10, unit='1/Angstrom')
- energy = sc.linspace('energy', -5, 5, num=11, unit='meV')
- values = sc.array(dims=['Q', 'energy'], values=np.ones((10, 11)))
- data = sc.DataArray(data=values, coords={'Q': Q, 'energy': energy})
+ Q = sc.linspace("Q", 0.5, 1.5, num=10, unit="1/Angstrom")
+ energy = sc.linspace("energy", -5, 5, num=11, unit="meV")
+ values = sc.array(dims=["Q", "energy"], values=np.ones((10, 11)))
+ data = sc.DataArray(data=values, coords={"Q": Q, "energy": energy})
- filename = tmp_path / 'test_experiment.h5'
+ filename = tmp_path / "test_experiment.h5"
sc.io.save_hdf5(data, filename)
# THEN
- experiment = Experiment(display_name='loaded_experiment', data=str(filename))
+ experiment = Experiment(display_name="loaded_experiment", data=str(filename))
# EXPECT
- assert experiment.display_name == 'loaded_experiment'
+ assert experiment.display_name == "loaded_experiment"
assert isinstance(experiment._data, sc.DataArray)
- assert 'Q' in experiment._data.dims
- assert 'energy' in experiment._data.dims
- assert experiment._data.sizes['Q'] == 10
- assert experiment._data.sizes['energy'] == 11
+ assert "Q" in experiment._data.dims
+ assert "energy" in experiment._data.dims
+ assert experiment._data.sizes["Q"] == 10
+ assert experiment._data.sizes["energy"] == 11
assert sc.identical(
experiment._data.data,
- sc.array(dims=['Q', 'energy'], values=np.ones((10, 11))),
+ sc.array(dims=["Q", "energy"], values=np.ones((10, 11))),
)
def test_init_no_data(self):
"Test initialization with no data"
# WHEN
- experiment = Experiment(display_name='empty_experiment')
+ experiment = Experiment(display_name="empty_experiment")
# THEN EXPECT
- assert experiment.display_name == 'empty_experiment'
+ assert experiment.display_name == "empty_experiment"
assert experiment._data is None
def test_init_invalid_data(self):
@@ -86,34 +86,34 @@ def test_init_invalid_data(self):
def test_load_hdf5(self, tmp_path, experiment):
"Test loading data from an HDF5 file."
- 'First use scipp to save data to a file, '
- 'then load it using the method.'
+ "First use scipp to save data to a file, "
+ "then load it using the method."
# WHEN
# First create a file to load from
- filename = tmp_path / 'test.h5'
+ filename = tmp_path / "test.h5"
data_to_save = experiment.data
sc.io.save_hdf5(data_to_save, filename)
# THEN
- new_experiment = Experiment(display_name='new_experiment')
- new_experiment.load_hdf5(str(filename), display_name='loaded_data')
+ new_experiment = Experiment(display_name="new_experiment")
+ new_experiment.load_hdf5(str(filename), display_name="loaded_data")
loaded_data = new_experiment.data
# EXPECT
assert sc.identical(data_to_save, loaded_data)
- assert new_experiment.display_name == 'loaded_data'
+ assert new_experiment.display_name == "loaded_data"
def test_load_hdf5_invalid_name_raises(self, experiment):
"Test loading data from an HDF5 file,"
- 'giving the Experiment an invalid name'
+ "giving the Experiment an invalid name"
# WHEN / THEN EXPECT
with pytest.raises(TypeError):
- experiment.load_hdf5('some_file.h5', display_name=123)
+ experiment.load_hdf5("some_file.h5", display_name=123)
def test_load_hdf5_invalid_filename_raises(self, experiment):
"Test loading data from an HDF5 file with an invalid filename"
# WHEN / THEN EXPECT
- with pytest.raises(TypeError, match='must be a string'):
+ with pytest.raises(TypeError, match="must be a string"):
experiment.load_hdf5(123)
def test_load_hdf5_invalid_file_raises(self, experiment):
@@ -121,13 +121,13 @@ def test_load_hdf5_invalid_file_raises(self, experiment):
# WHEN / THEN EXPECT
with pytest.raises(OSError):
- experiment.load_hdf5('non_existent_file.h5')
+ experiment.load_hdf5("non_existent_file.h5")
def test_save_hdf5(self, tmp_path, experiment):
"Test saving data to an HDF5 file. Load the saved file"
- 'using scipp and compare to the original data.'
+ "using scipp and compare to the original data."
# WHEN THEN
- filename = tmp_path / 'saved_data.h5'
+ filename = tmp_path / "saved_data.h5"
experiment.save_hdf5(str(filename))
# EXPECT
@@ -144,25 +144,25 @@ def test_save_hdf5_default_filename(self, tmp_path, experiment, monkeypatch):
experiment.save_hdf5()
# EXPECT
- expected_filename = tmp_path / f'{experiment.unique_name}.h5'
+ expected_filename = tmp_path / f"{experiment.unique_name}.h5"
loaded_data = sc.io.load_hdf5(str(expected_filename))
original_data = experiment.data
assert sc.identical(original_data, loaded_data)
def test_save_hdf5_no_data_raises(self):
"Test saving data to an HDF5 file when no data is present"
- 'in the experiment'
+ "in the experiment"
# WHEN
experiment = Experiment()
# THEN EXPECT
with pytest.raises(ValueError):
- experiment.save_hdf5('should_fail.h5')
+ experiment.save_hdf5("should_fail.h5")
def test_save_hdf5_invalid_filename_raises(self, experiment):
"Test saving data to an HDF5 file with an invalid filename"
# WHEN / THEN EXPECT
- with pytest.raises(TypeError, match='must be a string'):
+ with pytest.raises(TypeError, match="must be a string"):
experiment.save_hdf5(123)
def test_remove_data(self, experiment):
@@ -174,11 +174,11 @@ def test_remove_data(self, experiment):
assert experiment._data is None
@pytest.mark.parametrize(
- 'new_Q_bins, new_energy_bins',
+ "new_Q_bins, new_energy_bins",
[
(
- sc.linspace('Q', 0.5, 1.5, num=7, unit='1/Angstrom'),
- sc.linspace('energy', -5, 5, num=8, unit='meV'),
+ sc.linspace("Q", 0.5, 1.5, num=7, unit="1/Angstrom"),
+ sc.linspace("energy", -5, 5, num=8, unit="meV"),
),
(
6,
@@ -189,23 +189,23 @@ def test_remove_data(self, experiment):
7.0,
),
(
- sc.linspace('Q', 0.5, 1.5, num=7, unit='1/Angstrom'),
+ sc.linspace("Q", 0.5, 1.5, num=7, unit="1/Angstrom"),
7,
),
],
- ids=['sc_bins', 'integers_bins', 'float_bins', 'mixed_bins'],
+ ids=["sc_bins", "integers_bins", "float_bins", "mixed_bins"],
)
def test_rebin(self, experiment, new_Q_bins, new_energy_bins):
"Test rebinning data in the experiment"
# WHEN
# THEN
- experiment.rebin({'Q': new_Q_bins, 'energy': new_energy_bins})
+ experiment.rebin({"Q": new_Q_bins, "energy": new_energy_bins})
# EXPECT
rebinned_data = experiment.binned_data
- assert rebinned_data.sizes['Q'] == 6
- assert rebinned_data.sizes['energy'] == 7
+ assert rebinned_data.sizes["Q"] == 6
+ assert rebinned_data.sizes["energy"] == 7
def test_rebin_no_data_raises(self):
"Test rebinning data when no data is present"
@@ -214,34 +214,34 @@ def test_rebin_no_data_raises(self):
# THEN EXPECT
with pytest.raises(ValueError):
- experiment.rebin({'Q': 6, 'energy': 7})
+ experiment.rebin({"Q": 6, "energy": 7})
def test_rebin_invalid_dimensions_raises(self, experiment):
"Test rebinning data with invalid dimensions"
# WHEN / THEN EXPECT
with pytest.raises(TypeError):
- experiment.rebin('invalid_dimensions')
+ experiment.rebin("invalid_dimensions")
def test_rebin_invalid_dimension_name_raises(self, experiment):
"Test rebinning data with invalid dimension name"
# WHEN / THEN EXPECT
- with pytest.raises(TypeError, match='Dimension keys must be strings'):
- experiment.rebin({123: 6, 'energy': 7})
+ with pytest.raises(TypeError, match="Dimension keys must be strings"):
+ experiment.rebin({123: 6, "energy": 7})
def test_rebin_dimension_not_in_data_raises(self, experiment):
"Test rebinning data with a dimension not in the data"
# WHEN / THEN EXPECT
with pytest.raises(KeyError, match="Dimension 'time' not a valid"):
- experiment.rebin({'time': 6, 'energy': 7})
+ experiment.rebin({"time": 6, "energy": 7})
def test_rebin_invalid_bin_values_raises(self, experiment):
"Test rebinning data with invalid bin values"
# WHEN / THEN EXPECT
with pytest.raises(
TypeError,
- match='Dimension values must be integers or',
+ match="Dimension values must be integers or",
):
- experiment.rebin({'Q': [0.5, 1.0, 1.5], 'energy': 7})
+ experiment.rebin({"Q": [0.5, 1.0, 1.5], "energy": 7})
##############
# test setters and getters
@@ -271,24 +271,6 @@ def test_Q_setter_raises(self, experiment):
with pytest.raises(AttributeError):
experiment.Q = experiment.Q
- def test_Q_getter_warns_no_data(self):
- "Test that getting Q data with no data raises Warning"
- # WHEN
- experiment = Experiment()
-
- # THEN EXPECT
- with pytest.warns(UserWarning, match='No data loaded'):
- _ = experiment.Q
-
- def test_energy_getter_warns_no_data(self):
- "Test that getting energy data with no data raises Warning"
- # WHEN
- experiment = Experiment()
-
- # THEN EXPECT
- with pytest.warns(UserWarning, match='No data loaded'):
- _ = experiment.energy
-
##############
# test plotting
##############
@@ -297,9 +279,9 @@ def test_plot_data_success(self, experiment):
"Test plotting data successfully when in notebook environment"
# WHEN
with (
- patch.object(Experiment, '_in_notebook', return_value=True),
- patch('plopp.plot') as mock_plot,
- patch('IPython.display.display') as mock_display,
+ patch(f"{Experiment.__module__}._in_notebook", return_value=True),
+ patch("plopp.plot") as mock_plot,
+ patch("IPython.display.display") as mock_display,
):
mock_fig = MagicMock()
mock_plot.return_value = mock_fig
@@ -311,7 +293,7 @@ def test_plot_data_success(self, experiment):
mock_plot.assert_called_once()
args, kwargs = mock_plot.call_args
assert sc.identical(args[0], experiment._data.transpose())
- assert kwargs['title'] == f'{experiment.display_name}'
+ assert kwargs["title"] == f"{experiment.display_name}"
mock_display.assert_called_once_with(mock_fig)
def test_plot_data_no_data_raises(self):
@@ -320,18 +302,18 @@ def test_plot_data_no_data_raises(self):
experiment = Experiment()
# THEN EXPECT
- with pytest.raises(ValueError, match='No data to plot'):
+ with pytest.raises(ValueError, match="No data to plot"):
experiment.plot_data()
def test_plot_data_not_in_notebook_raises(self, experiment):
"Test plotting data raises RuntimeError"
- 'when not in notebook environment'
+ "when not in notebook environment"
# WHEN
- with patch.object(Experiment, '_in_notebook', return_value=False):
+ with patch(f"{Experiment.__module__}._in_notebook", return_value=False):
# THEN EXPECT
with pytest.raises(
RuntimeError,
- match='plot_data\\(\\) can only be used in a Jupyter notebook environment',
+ match="plot_data\\(\\) can only be used in a Jupyter notebook environment",
):
experiment.plot_data()
@@ -339,62 +321,6 @@ def test_plot_data_not_in_notebook_raises(self, experiment):
# test private methods
##############
- def test_in_notebook_returns_true_for_jupyter(self, monkeypatch):
- """Should return True when IPython shell is
- ZMQInteractiveShell (Jupyter)."""
-
- # WHEN
- class ZMQInteractiveShell:
- __name__ = 'ZMQInteractiveShell'
-
- # THEN
- monkeypatch.setattr('IPython.get_ipython', lambda: ZMQInteractiveShell())
-
- # EXPECT
- assert Experiment._in_notebook() is True
-
- def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch):
- """Should return False when IPython shell is
- TerminalInteractiveShell."""
-
- # WHEN
- class TerminalInteractiveShell:
- __name__ = 'TerminalInteractiveShell'
-
- # THEN
-
- monkeypatch.setattr('IPython.get_ipython', lambda: TerminalInteractiveShell())
-
- # EXPECT
- assert Experiment._in_notebook() is False
-
- def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch):
- """Should return False when IPython shell type is
- unrecognized."""
-
- # WHEN
- class UnknownShell:
- __name__ = 'UnknownShell'
-
- # THEN
- monkeypatch.setattr('IPython.get_ipython', lambda: UnknownShell())
- # EXPECT
- assert Experiment._in_notebook() is False
-
- def test_in_notebook_returns_false_when_no_ipython(self, monkeypatch):
- """Should return False when IPython is not installed or
- available."""
-
- # WHEN
- def raise_import_error(*args, **kwargs):
- raise ImportError
-
- # THEN
- monkeypatch.setattr('builtins.__import__', raise_import_error)
-
- # EXPECT
- assert Experiment._in_notebook() is False
-
def test_validate_coordinates(self, experiment):
"Test that _validate_coordinates does not raise for valid data"
# WHEN / THEN EXPECT
@@ -402,40 +328,42 @@ def test_validate_coordinates(self, experiment):
def test_validate_coordinates_raises_missing_Q(self, experiment):
"Test that _validate_coordinates raises ValueError when Q coord"
- 'is missing'
+ "is missing"
# WHEN
invalid_data = experiment._data.copy()
- invalid_data.coords.pop('Q')
+ invalid_data.coords.pop("Q")
# THEN EXPECT
- with pytest.raises(ValueError, match='missing required coordinate'):
+ with pytest.raises(ValueError, match="missing required coordinate"):
experiment._validate_coordinates(invalid_data)
def test_validate_coordinates_raises_missing_energy(self, experiment):
"Test that _validate_coordinates raises ValueError when energy"
- 'coord is missing'
+ "coord is missing"
# WHEN
invalid_data = experiment._data.copy()
- invalid_data.coords.pop('energy')
+ invalid_data.coords.pop("energy")
# THEN EXPECT
- with pytest.raises(ValueError, match='missing required coordinate'):
+ with pytest.raises(ValueError, match="missing required coordinate"):
experiment._validate_coordinates(invalid_data)
def test_validate_coordinates_raises_not_DataArray(self):
"Test that _validate_coordinates raises TypeError when data is"
- 'not a Scipp DataArray'
+ "not a Scipp DataArray"
# WHEN THEN EXPECT
- with pytest.raises(TypeError, match='must be a'):
- Experiment()._validate_coordinates('not_a_data_array')
+ with pytest.raises(TypeError, match="must be a"):
+ Experiment()._validate_coordinates("not_a_data_array")
def test_convert_to_bin_centers(self, experiment):
"Test that _convert_to_bin_centers converts edges to centers"
# WHEN
- Q_edges = sc.linspace('Q', 0.0, 2.0, num=11, unit='1/Angstrom')
- energy_edges = sc.linspace('energy', -6, 6, num=13, unit='meV')
- values = sc.array(dims=['Q', 'energy'], values=np.ones((10, 12)))
- binned_data = sc.DataArray(data=values, coords={'Q': Q_edges, 'energy': energy_edges})
+ Q_edges = sc.linspace("Q", 0.0, 2.0, num=11, unit="1/Angstrom")
+ energy_edges = sc.linspace("energy", -6, 6, num=13, unit="meV")
+ values = sc.array(dims=["Q", "energy"], values=np.ones((10, 12)))
+ binned_data = sc.DataArray(
+ data=values, coords={"Q": Q_edges, "energy": energy_edges}
+ )
# THEN
experiment._data = binned_data # Set data to avoid warnings
@@ -445,8 +373,8 @@ def test_convert_to_bin_centers(self, experiment):
expected_Q = 0.5 * (Q_edges[:-1] + Q_edges[1:])
expected_energy = 0.5 * (energy_edges[:-1] + energy_edges[1:])
- assert sc.identical(converted_data.coords['Q'], expected_Q)
- assert sc.identical(converted_data.coords['energy'], expected_energy)
+ assert sc.identical(converted_data.coords["Q"], expected_Q)
+ assert sc.identical(converted_data.coords["energy"], expected_energy)
assert sc.identical(converted_data.data, binned_data.data)
##############
@@ -458,12 +386,15 @@ def test_repr(self, experiment):
repr_str = repr(experiment)
# THEN EXPECT
- assert repr_str == f'Experiment `{experiment.unique_name}` with data: {experiment._data}'
+ assert (
+ repr_str
+ == f"Experiment `{experiment.unique_name}` with data: {experiment._data}"
+ )
def test_copy_experiment(self, experiment):
"Test copying an Experiment object."
- 'The copied object should have the same attributes '
- 'but be a different object in memory.'
+ "The copied object should have the same attributes "
+ "but be a different object in memory."
# WHEN
copied_experiment = copy(experiment)
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
index 7476755b..0d0963c0 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
@@ -37,7 +37,6 @@ def test_init_default(self, brownian_diffusion_model):
'unit': 123,
'scale': 1.0,
'diffusion_coefficient': 1.0,
- 'diffusion_unit': 'm**2/s',
},
UnitError,
'Invalid unit',
@@ -47,7 +46,6 @@ def test_init_default(self, brownian_diffusion_model):
'unit': 123,
'scale': 'invalid',
'diffusion_coefficient': 1.0,
- 'diffusion_unit': 'm**2/s',
},
TypeError,
'scale must be a number',
@@ -57,50 +55,16 @@ def test_init_default(self, brownian_diffusion_model):
'unit': 123,
'scale': 1.0,
'diffusion_coefficient': 'invalid',
- 'diffusion_unit': 'm**2/s',
},
TypeError,
'diffusion_coefficient must be a number',
),
- (
- {
- 'unit': 123,
- 'scale': 1.0,
- 'diffusion_coefficient': 1.0,
- 'diffusion_unit': 123,
- },
- TypeError,
- 'diffusion_unit must be ',
- ),
],
)
def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message):
with pytest.raises(expected_exception, match=expected_message):
BrownianTranslationalDiffusion(display_name='BrownianTranslationalDiffusion', **kwargs)
- def test_diffusion_unit_value_error(self):
- # WHEN THEN EXPECT
- with pytest.raises(ValueError, match='diffusion_unit must be .'):
- BrownianTranslationalDiffusion(
- display_name='BrownianTranslationalDiffusion',
- unit='meV',
- scale=1.0,
- diffusion_coefficient=1.0,
- diffusion_unit='invalid_unit',
- )
-
- def test_scale_setter(self, brownian_diffusion_model):
- # WHEN
- brownian_diffusion_model.scale = 2.0
-
- # THEN EXPECT
- assert brownian_diffusion_model.scale.value == 2.0
-
- def test_scale_setter_raises(self, brownian_diffusion_model):
- # WHEN THEN EXPECT
- with pytest.raises(TypeError, match='scale must be a number.'):
- brownian_diffusion_model.scale = 'invalid' # Invalid type
-
def test_diffusion_coefficient_setter(self, brownian_diffusion_model):
# WHEN
brownian_diffusion_model.diffusion_coefficient = 3.0
@@ -136,20 +100,6 @@ def test_calculate_width(self, brownian_diffusion_model):
expected_widths = 1.0 * unit_conversion_factor.value * (Q_values**2)
np.testing.assert_allclose(widths, expected_widths, rtol=1e-5)
- def test_calculate_width_diffusion_unit_mev_angstrom2(self):
- # WHEN
- diffusion_model = BrownianTranslationalDiffusion(
- diffusion_coefficient=2.0, diffusion_unit='meV*ร
**2'
- )
- Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in ร
^-1
-
- # WHEN
- widths = diffusion_model.calculate_width(Q_values)
-
- # THEN EXPECT
- expected_widths = 2.0 * (Q_values**2)
- np.testing.assert_allclose(widths, expected_widths, rtol=1e-5)
-
def test_calculate_EISF(self, brownian_diffusion_model):
# WHEN
Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in ร
^-1
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py
index e7bca65a..b8eb0956 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py
@@ -23,3 +23,15 @@ def test_unit_setter_raises(self, diffusion_model):
match='Unit is read-only. Use convert_unit to change the unit between allowed types',
):
diffusion_model.unit = 'eV'
+
+ def test_scale_setter(self, diffusion_model):
+ # WHEN
+ diffusion_model.scale = 2.0
+
+ # THEN EXPECT
+ assert diffusion_model.scale.value == 2.0
+
+ def test_scale_setter_raises(self, diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='scale must be a number.'):
+ diffusion_model.scale = 'invalid' # Invalid type
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py
new file mode 100644
index 00000000..90a842d6
--- /dev/null
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py
@@ -0,0 +1,254 @@
+import numpy as np
+import pytest
+import scipp as sc
+from easyscience.variable import DescriptorNumber
+from scipp import UnitError
+from scipp.constants import hbar as scipp_hbar
+
+from easydynamics.sample_model.diffusion_model.jump_translational_diffusion import (
+ JumpTranslationalDiffusion,
+)
+
+hbar_1 = DescriptorNumber('hbar', 1.0)
+hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar)
+angstrom = DescriptorNumber('angstrom', 1e-10, unit='m')
+
+
+class TestJumpTranslationalDiffusion:
+ @pytest.fixture
+ def jump_diffusion_model(self):
+ return JumpTranslationalDiffusion()
+
+ def test_init_default(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ assert jump_diffusion_model.display_name == 'JumpTranslationalDiffusion'
+ assert jump_diffusion_model.unit == 'meV'
+ assert jump_diffusion_model.scale.value == 1.0
+ assert jump_diffusion_model.diffusion_coefficient.value == 1.0
+ assert jump_diffusion_model.relaxation_time.value == 1.0
+
+ @pytest.mark.parametrize(
+ 'kwargs,expected_exception, expected_message',
+ [
+ (
+ {
+ 'unit': 123,
+ 'scale': 1.0,
+ 'diffusion_coefficient': 1.0,
+ 'relaxation_time': 1.0,
+ },
+ UnitError,
+ 'Invalid unit',
+ ),
+ (
+ {
+ 'unit': 'meV',
+ 'scale': 'invalid',
+ 'diffusion_coefficient': 1.0,
+ 'relaxation_time': 1.0,
+ },
+ TypeError,
+ 'scale must be a number',
+ ),
+ (
+ {
+ 'unit': 'meV',
+ 'scale': 1.0,
+ 'diffusion_coefficient': 'invalid',
+ 'relaxation_time': 1.0,
+ },
+ TypeError,
+ 'diffusion_coefficient must be a number',
+ ),
+ (
+ {
+ 'unit': 'meV',
+ 'scale': 1.0,
+ 'diffusion_coefficient': 1.0,
+ 'relaxation_time': 'invalid',
+ },
+ TypeError,
+ 'relaxation_time must be a number',
+ ),
+ ],
+ )
+ def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message):
+ with pytest.raises(expected_exception, match=expected_message):
+ JumpTranslationalDiffusion(display_name='JumpTranslationalDiffusion', **kwargs)
+
+ def test_diffusion_coefficient_setter(self, jump_diffusion_model):
+ # WHEN
+ jump_diffusion_model.diffusion_coefficient = 3.0
+
+ # THEN EXPECT
+ assert jump_diffusion_model.diffusion_coefficient.value == 3.0
+
+ def test_diffusion_coefficient_setter_raises(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='diffusion_coefficient must be a number.'):
+ jump_diffusion_model.diffusion_coefficient = 'invalid' # Invalid type
+
+ def test_relaxation_time_setter(self, jump_diffusion_model):
+ # WHEN
+ jump_diffusion_model.relaxation_time = 2.5
+
+ # THEN EXPECT
+ assert jump_diffusion_model.relaxation_time.value == 2.5
+
+ def test_relaxation_time_setter_raises(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='relaxation_time must be a number.'):
+ jump_diffusion_model.relaxation_time = 'invalid' # Invalid type
+
+ def test_calculate_width_type_error(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='Q must be '):
+ jump_diffusion_model.calculate_width(Q='invalid') # Invalid type
+
+ def test_calculate_width(self, jump_diffusion_model):
+ "Test the calculation relying solely on a scipp implementation"
+ 'instead of our Parameters'
+ # WHEN
+ Q_values = sc.linspace('Q', 0.5, 1.5, num=6, unit='1/angstrom')
+ relaxation_time_sc = jump_diffusion_model.relaxation_time.value * sc.Unit(
+ jump_diffusion_model.relaxation_time.unit
+ )
+ diffusion_coefficient_sc = jump_diffusion_model.diffusion_coefficient.value * sc.Unit(
+ jump_diffusion_model.diffusion_coefficient.unit
+ )
+
+ # THEN
+ widths = jump_diffusion_model.calculate_width(Q_values)
+
+ denominator = diffusion_coefficient_sc * relaxation_time_sc * Q_values**2
+ denominator = denominator.to(unit='1')
+
+ # EXPECT
+ expected_widths = scipp_hbar * diffusion_coefficient_sc * (Q_values**2) / (1 + denominator)
+
+ expected_widths = expected_widths.to(unit=jump_diffusion_model.unit)
+
+ np.testing.assert_allclose(widths, expected_widths.values, rtol=1e-5)
+
+ def test_calculate_EISF(self, jump_diffusion_model):
+ # WHEN
+ Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in ร
^-1
+
+ # THEN
+ EISF = jump_diffusion_model.calculate_EISF(Q_values)
+
+ # EXPECT
+ expected_EISF = np.zeros_like(Q_values)
+ np.testing.assert_array_equal(EISF, expected_EISF)
+
+ def test_calculate_EISF_type_error(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='Q must be '):
+ jump_diffusion_model.calculate_EISF(Q='invalid') # Invalid type
+
+ def test_calculate_QISF(self, jump_diffusion_model):
+ # WHEN
+ Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in ร
^-1
+
+ # THEN
+ QISF = jump_diffusion_model.calculate_QISF(Q_values)
+
+ # EXPECT
+ expected_QISF = np.ones_like(Q_values)
+ np.testing.assert_array_equal(QISF, expected_QISF)
+
+ def test_calculate_QISF_type_error(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='Q must be '):
+ jump_diffusion_model.calculate_QISF(Q='invalid') # Invalid type
+
+ @pytest.mark.parametrize(
+ 'Q',
+ [
+ (0.5),
+ ([1.0, 2.0, 3.0]),
+ (np.array([1.0, 2.0, 3.0])),
+ ],
+ ids=[
+ 'python_scalar',
+ 'python_list',
+ 'numpy_array',
+ ],
+ )
+ def test_create_component_collections(self, jump_diffusion_model, Q):
+ # WHEN
+
+ # THEN
+ component_collections = jump_diffusion_model.create_component_collections(Q=Q)
+
+ # EXPECT
+ expected_widths = jump_diffusion_model.calculate_width(Q)
+ for model_index in range(len(component_collections)):
+ model = component_collections[model_index]
+ assert len(model.components) == 1
+ component = model.components[0]
+ assert component.width.unit == jump_diffusion_model.unit
+ assert np.isclose(component.width.value, expected_widths[model_index])
+ assert component.width.independent is False
+
+ def test_create_component_collections_component_name_must_be_string(
+ self, jump_diffusion_model
+ ):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='component_name must be a string.'):
+ jump_diffusion_model.create_component_collections(
+ Q=np.array([0.1, 0.2, 0.3]), component_display_name=123
+ )
+
+ def test_create_component_collections_Q_type_error(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='Q must be a '):
+ jump_diffusion_model.create_component_collections(Q='invalid') # Invalid type
+
+ def test_create_component_collections_Q_1dimensional_error(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match='Q must be a 1-dimensional array.'):
+ jump_diffusion_model.create_component_collections(
+ Q=np.array([[0.1, 0.2], [0.3, 0.4]])
+ ) # Invalid shape
+
+ def test_write_width_dependency_expression(self, jump_diffusion_model):
+ # WHEN THEN
+ expression = jump_diffusion_model._write_width_dependency_expression(0.5)
+
+ # EXPECT
+ expected_expression = (
+ 'hbar * D* 0.5 **2/(angstrom**2)/(1 + (D * t* 0.5 **2/(angstrom**2)))'
+ )
+ assert expression == expected_expression
+
+ def test_write_width_dependency_map_expression(self, jump_diffusion_model):
+ # WHEN THEN
+ expression_map = jump_diffusion_model._write_width_dependency_map_expression()
+
+ # EXPECT
+ expected_map = {
+ 'D': jump_diffusion_model.diffusion_coefficient,
+ 't': jump_diffusion_model.relaxation_time,
+ 'hbar': jump_diffusion_model._hbar,
+ 'angstrom': jump_diffusion_model._angstrom,
+ }
+
+ assert expression_map == expected_map
+
+ def test_write_width_dependency_expression_raises(self, jump_diffusion_model):
+ with pytest.raises(TypeError, match='Q must be a float'):
+ jump_diffusion_model._write_width_dependency_expression('invalid')
+
+ def test_write_area_dependency_expression_raises(self, jump_diffusion_model):
+ with pytest.raises(TypeError, match='QISF must be a float'):
+ jump_diffusion_model._write_area_dependency_expression('invalid')
+
+ def test_repr(self, jump_diffusion_model):
+ # WHEN THEN
+ repr_str = repr(jump_diffusion_model)
+
+ # EXPECT
+ assert 'JumpTranslationalDiffusion' in repr_str
+ assert 'diffusion_coefficient' in repr_str
+ assert 'scale=' in repr_str
diff --git a/tests/unit/easydynamics/sample_model/test_component_collection.py b/tests/unit/easydynamics/sample_model/test_component_collection.py
index 926adfa6..42a66f6a 100644
--- a/tests/unit/easydynamics/sample_model/test_component_collection.py
+++ b/tests/unit/easydynamics/sample_model/test_component_collection.py
@@ -216,13 +216,14 @@ def test_evaluate(self, component_collection):
) + component_collection.components[1].evaluate(x)
np.testing.assert_allclose(result, expected_result, rtol=1e-5)
- def test_evaluate_no_components_raises(self):
+ def test_evaluate_no_components_returns_zero(self):
# WHEN THEN
component_collection = ComponentCollection(display_name='EmptyModel')
x = np.linspace(-5, 5, 100)
# EXPECT
- with pytest.raises(ValueError, match='No components in the model to evaluate.'):
- component_collection.evaluate(x)
+ result = component_collection.evaluate(x)
+ assert np.all(result == 0.0)
+ assert result.shape == x.shape
def test_evaluate_component(self, component_collection):
# WHEN THEN
diff --git a/tests/unit/easydynamics/sample_model/test_instrument_model.py b/tests/unit/easydynamics/sample_model/test_instrument_model.py
new file mode 100644
index 00000000..00f036cd
--- /dev/null
+++ b/tests/unit/easydynamics/sample_model/test_instrument_model.py
@@ -0,0 +1,398 @@
+# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+
+from easydynamics.sample_model import Gaussian
+from easydynamics.sample_model import Polynomial
+from easydynamics.sample_model.background_model import BackgroundModel
+from easydynamics.sample_model.instrument_model import InstrumentModel
+from easydynamics.sample_model.resolution_model import ResolutionModel
+
+
+class TestInstrumentModel:
+ @pytest.fixture
+ def instrument_model(self):
+ Q = np.array([1.0, 2.0, 3.0])
+ component1 = Polynomial(coefficients=[1.0, 2.0])
+ background_model = BackgroundModel(components=component1, Q=Q)
+
+ component2 = Gaussian()
+ resolution_model = ResolutionModel(components=component2, Q=Q)
+
+ instrument_model = InstrumentModel(
+ display_name='TestInstrumentModel',
+ background_model=background_model,
+ resolution_model=resolution_model,
+ Q=Q,
+ )
+
+ return instrument_model
+
+ @pytest.fixture
+ def resolution_model(self):
+ component = Gaussian()
+ resolution_model = ResolutionModel(components=component)
+ return resolution_model
+
+ @pytest.fixture
+ def background_model(self):
+ component = Polynomial(coefficients=[1.0, 2.0])
+ background_model = BackgroundModel(components=component)
+ return background_model
+
+ def test_init(self, instrument_model):
+ # WHEN THEN
+ model = instrument_model
+
+ # EXPECT
+ assert model.display_name == 'TestInstrumentModel'
+ np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ assert isinstance(model.background_model, BackgroundModel)
+ assert isinstance(model.resolution_model, ResolutionModel)
+ np.testing.assert_array_equal(model.background_model.Q, np.array([1.0, 2.0, 3.0]))
+ np.testing.assert_array_equal(model.resolution_model.Q, np.array([1.0, 2.0, 3.0]))
+ np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+
+ def test_init_defaults(self):
+ # WHEN THEN
+ model = InstrumentModel()
+
+ # EXPECT
+ assert model.display_name == 'MyInstrumentModel'
+ assert isinstance(model.background_model, BackgroundModel)
+ assert isinstance(model.resolution_model, ResolutionModel)
+ assert model.Q is None
+
+ @pytest.mark.parametrize(
+ 'kwargs, expected_exception, expected_message',
+ [
+ (
+ {'resolution_model': 123},
+ TypeError,
+ 'resolution_model must be a ResolutionModel',
+ ),
+ (
+ {'background_model': 'not a model'},
+ TypeError,
+ 'background_model must be a BackgroundModel',
+ ),
+ (
+ {'energy_offset': 'abc'},
+ TypeError,
+ 'energy_offset must be a number',
+ ),
+ (
+ {'unit': 123},
+ TypeError,
+ 'unit must be',
+ ),
+ ],
+ ids=[
+ 'invalid resolution_model',
+ 'invalid background_model',
+ 'invalid energy_offset',
+ 'invalid unit',
+ ],
+ )
+ def test_instrument_model_init_invalid_inputs(
+ self, kwargs, expected_exception, expected_message
+ ):
+ with pytest.raises(expected_exception, match=expected_message):
+ InstrumentModel(**kwargs)
+
+ def test_resolution_model_setter_calls_update(self, instrument_model, resolution_model):
+ # WHEN
+ instrument_model._on_resolution_model_change = MagicMock()
+
+ # THEN
+ instrument_model.resolution_model = resolution_model
+
+ # EXPECT
+ assert instrument_model._resolution_model is resolution_model
+ instrument_model._on_resolution_model_change.assert_called_once()
+
+ def test_resolution_model_setter_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ TypeError,
+ match='resolution_model must be a ResolutionModel',
+ ):
+ instrument_model.resolution_model = 'invalid_model'
+
+ def test_background_model_setter_calls_update(self, instrument_model, background_model):
+ # WHEN
+ instrument_model._on_background_model_change = MagicMock()
+
+ # THEN
+ instrument_model.background_model = background_model
+
+ # EXPECT
+ assert instrument_model._background_model is background_model
+ instrument_model._on_background_model_change.assert_called_once()
+
+ def test_background_model_setter_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ TypeError,
+ match='background_model must be a BackgroundModel',
+ ):
+ instrument_model.background_model = 123
+
+ def test_Q_setter(self, instrument_model):
+ "Test that Q setter calls the appropriate methods."
+ # WHEN
+ new_Q = np.array([4.0, 5.0, 6.0])
+
+ instrument_model._on_Q_change = MagicMock()
+
+ # THEN EXPECT
+ with patch(
+ 'easydynamics.sample_model.instrument_model._validate_and_convert_Q',
+ return_value=new_Q,
+ ) as mock_validate:
+ instrument_model.Q = new_Q
+
+ np.testing.assert_array_equal(instrument_model.Q, new_Q)
+ mock_validate.assert_called_once_with(new_Q)
+ instrument_model._on_Q_change.assert_called_once()
+
+ def test_unit_setter_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ AttributeError,
+ match='Unit is read-only. Use convert_unit to change the unit between allowed types ',
+ ):
+ instrument_model.unit = 'meV'
+
+ def test_energy_offset_setter(self, instrument_model):
+ # WHEN
+ instrument_model._on_energy_offset_change = MagicMock()
+
+ # THEN
+ instrument_model.energy_offset = 1.0
+
+ # EXPECT
+ assert instrument_model.energy_offset.value == 1.0
+ instrument_model._on_energy_offset_change.assert_called_once()
+
+ def test_energy_offset_setter_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ TypeError,
+ match='energy_offset must be a number',
+ ):
+ instrument_model.energy_offset = 'invalid_offset'
+
+ def test_convert_unit_calls_all_children(self, instrument_model):
+ # WHEN
+ new_unit = 'eV'
+
+ # THEN
+ # Mock downstream convert_unit calls
+ instrument_model._background_model.convert_unit = MagicMock()
+ instrument_model._resolution_model.convert_unit = MagicMock()
+ instrument_model._energy_offset.convert_unit = MagicMock()
+ for offset in instrument_model._energy_offsets:
+ offset.convert_unit = MagicMock()
+
+ with patch(
+ 'easydynamics.sample_model.instrument_model._validate_unit',
+ return_value=new_unit,
+ ) as mock_validate:
+ instrument_model.convert_unit(new_unit)
+
+ # EXPECT
+ mock_validate.assert_called_once_with(new_unit)
+
+ instrument_model._background_model.convert_unit.assert_called_once_with(new_unit)
+ instrument_model._resolution_model.convert_unit.assert_called_once_with(new_unit)
+ instrument_model._energy_offset.convert_unit.assert_called_once_with(new_unit)
+
+ for offset in instrument_model._energy_offsets:
+ offset.convert_unit.assert_called_once_with(new_unit)
+
+ # final state
+ assert instrument_model.unit == new_unit
+
+ def test_convert_unit_None_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ ValueError,
+ match=' must be a valid unit',
+ ):
+ instrument_model.convert_unit(None)
+
+ def test_fix_resolution_parameters(self, instrument_model):
+ # WHEN
+ instrument_model.resolution_model.fix_all_parameters = MagicMock()
+
+ # THEN
+ instrument_model.fix_resolution_parameters()
+
+ # EXPECT
+ instrument_model.resolution_model.fix_all_parameters.assert_called_once()
+
+ def test_free_all_resolution_parameters(self, instrument_model):
+ # WHEN
+ instrument_model.resolution_model.free_all_parameters = MagicMock()
+
+ # THEN
+ instrument_model.free_resolution_parameters()
+
+ # EXPECT
+ instrument_model.resolution_model.free_all_parameters.assert_called_once()
+
+ def test_get_all_variables(self, instrument_model):
+ # WHEN
+ all_vars = instrument_model.get_all_variables()
+
+ # THEN
+ expected_var_names = {
+ 'energy_offset',
+ 'Polynomial_c0',
+ 'Polynomial_c1',
+ 'Gaussian area',
+ 'Gaussian center',
+ 'Gaussian width',
+ }
+
+ retrieved_var_names = {var.name for var in all_vars}
+
+ assert expected_var_names == retrieved_var_names
+ assert len(all_vars) == 18
+
+ def test_get_all_variables_no_Q(self, instrument_model):
+ # WHEN
+ instrument_model.Q = None
+
+ # THEN
+ all_vars = instrument_model.get_all_variables()
+
+ # EXPECT
+ assert all_vars == []
+
+ def test_get_all_variables_with_Q_index(self, instrument_model):
+ # WHEN
+ all_vars = instrument_model.get_all_variables(Q_index=1)
+
+ # THEN
+ expected_var_names = {
+ 'energy_offset',
+ 'Polynomial_c0',
+ 'Polynomial_c1',
+ 'Gaussian area',
+ 'Gaussian center',
+ 'Gaussian width',
+ }
+
+ retrieved_var_names = {var.name for var in all_vars}
+
+ assert expected_var_names == retrieved_var_names
+ assert len(all_vars) == 6
+
+ def test_get_all_variables_with_invalid_Q_index_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ IndexError,
+ match='Q_index 5 is out of bounds',
+ ):
+ instrument_model.get_all_variables(Q_index=5)
+
+ def test_get_all_variables_with_nonint_Q_index_raises(self, instrument_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ TypeError,
+ match='Q_index must be an int or None, got str',
+ ):
+ instrument_model.get_all_variables(Q_index='invalid_index')
+
+ def test_generate_energy_offsets_Q_none(self, instrument_model):
+ # WHEN
+ instrument_model._Q = None
+
+ # THEN
+ instrument_model._generate_energy_offsets()
+
+ # EXPECT
+ assert instrument_model._energy_offsets == []
+
+ def test_generate_energy_offsets(self, instrument_model):
+ # WHEN
+ instrument_model._Q = np.array([1.0, 2.0, 3.0, 4.0])
+
+ # THEN
+ instrument_model._generate_energy_offsets()
+
+ # EXPECT
+ assert len(instrument_model._energy_offsets) == 4
+ for offset in instrument_model._energy_offsets:
+ assert offset.name == 'energy_offset'
+ assert offset.unit == instrument_model.unit
+ assert offset.value == instrument_model.energy_offset.value
+
+ def test_on_Q_change(self, instrument_model):
+ # WHEN
+ instrument_model._generate_energy_offsets = MagicMock()
+ new_Q = np.array([1.0, 2.0, 3.0, 4.0])
+
+ # THEN
+ instrument_model._Q = new_Q
+ instrument_model._on_Q_change()
+
+ # EXPECT
+ instrument_model._generate_energy_offsets.assert_called_once()
+ instrument_model._background_model.Q = new_Q
+ instrument_model._resolution_model.Q = new_Q
+
+ def test_on_energy_offset_change(self, instrument_model):
+ # WHEN
+ new_offset = 2.0
+
+ # THEN
+ instrument_model._energy_offset.value = new_offset
+ instrument_model._on_energy_offset_change()
+
+ # EXPECT
+ for offset in instrument_model._energy_offsets:
+ assert offset.value == new_offset
+
+ def test_on_resolution_model_change(self, instrument_model, resolution_model):
+ # WHEN
+ new_resolution_model = resolution_model
+
+ # THEN
+ instrument_model._resolution_model = new_resolution_model
+ instrument_model._on_resolution_model_change()
+
+ # EXPECT
+ assert instrument_model._resolution_model is new_resolution_model
+
+ def test_on_background_model_change(self, instrument_model, background_model):
+ # WHEN
+ new_background_model = background_model
+
+ # THEN
+ instrument_model._background_model = new_background_model
+ instrument_model._on_background_model_change()
+
+ # EXPECT
+ assert instrument_model._background_model is new_background_model
+
+ def test_repr_contains_expected_fields(self, instrument_model):
+ # WHEN THEN
+ repr_str = repr(instrument_model)
+
+ # EXPECT
+ assert repr_str.startswith('InstrumentModel(')
+ assert f'unique_name={instrument_model.unique_name!r}' in repr_str
+ assert f'unit={instrument_model.unit}' in repr_str
+ assert 'Q_len=3' in repr_str
+ assert f'resolution_model={instrument_model._resolution_model!r}' in repr_str
+ assert f'background_model={instrument_model._background_model!r}' in repr_str
+ assert repr_str.endswith(')')
diff --git a/tests/unit/easydynamics/sample_model/test_model_base.py b/tests/unit/easydynamics/sample_model/test_model_base.py
index 31feb66a..0a5ec3f5 100644
--- a/tests/unit/easydynamics/sample_model/test_model_base.py
+++ b/tests/unit/easydynamics/sample_model/test_model_base.py
@@ -14,28 +14,28 @@
class TestModelBase:
@pytest.fixture
- def model_base(self, reset_global_object):
+ def model_base(self):
component1 = Gaussian(
- display_name='TestGaussian1',
+ display_name="TestGaussian1",
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ unit="meV",
)
component2 = Lorentzian(
- display_name='TestLorentzian1',
+ display_name="TestLorentzian1",
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ unit="meV",
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
component_collection.append_component(component2)
model_base = ModelBase(
- display_name='InitModel',
+ display_name="InitModel",
components=component_collection,
- unit='meV',
+ unit="meV",
Q=np.array([1.0, 2.0, 3.0]),
)
@@ -46,8 +46,8 @@ def test_init(self, model_base):
model = model_base
# EXPECT
- assert model.display_name == 'InitModel'
- assert model.unit == 'meV'
+ assert model.display_name == "InitModel"
+ assert model.unit == "meV"
assert len(model.components) == 2
np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
@@ -55,9 +55,9 @@ def test_init_raises_with_invalid_components(self):
# WHEN / THEN / EXPECT
with pytest.raises(
TypeError,
- match='Components must be ',
+ match="Components must be ",
):
- ModelBase(components='invalid_component')
+ ModelBase(components="invalid_component")
def test_evaluate_calls_all_component_collections(self, model_base):
# WHEN
@@ -88,7 +88,7 @@ def test_evaluate_no_component_collections_raises(self, model_base):
model_base._component_collections = []
# THEN / EXPECT
- with pytest.raises(ValueError, match='No components'):
+ with pytest.raises(ValueError, match="No components"):
model_base.evaluate(x)
def test_generate_component_collections_with_Q(self, model_base):
@@ -101,17 +101,39 @@ def test_generate_component_collections_with_Q(self, model_base):
assert isinstance(collection, ComponentCollection)
assert len(collection.components) == 2
assert isinstance(collection.components[0], Gaussian)
- assert collection.components[0].display_name == 'TestGaussian1'
+ assert collection.components[0].display_name == "TestGaussian1"
assert isinstance(collection.components[1], Lorentzian)
- assert collection.components[1].display_name == 'TestLorentzian1'
+ assert collection.components[1].display_name == "TestLorentzian1"
- def test_generate_component_collections_without_Q_warns(self, model_base):
+ def test_fix_free_all_parameters(self, model_base):
# WHEN
- model_base._Q = None
+ model_base.fix_all_parameters()
- # THEN / EXPECT
- with pytest.warns(UserWarning, match='Q is not set'):
- model_base._generate_component_collections()
+ # THEN
+ for par in model_base.get_all_variables():
+ assert par.fixed is True
+
+ # WHEN
+ model_base.free_all_parameters()
+
+ # THEN
+ for par in model_base.get_all_variables():
+ assert par.fixed is False
+
+ def test_fix_free_all_parameters(self, model_base):
+ # WHEN
+ model_base.fix_all_parameters()
+
+ # THEN
+ for par in model_base.get_all_variables():
+ assert par.fixed is True
+
+ # WHEN
+ model_base.free_all_parameters()
+
+ # THEN
+ for par in model_base.get_all_variables():
+ assert par.fixed is False
def test_get_all_variables(self, model_base):
# WHEN
@@ -119,12 +141,12 @@ def test_get_all_variables(self, model_base):
# THEN
expected_var_display_names = {
- 'TestGaussian1 area',
- 'TestGaussian1 center',
- 'TestGaussian1 width',
- 'TestLorentzian1 area',
- 'TestLorentzian1 center',
- 'TestLorentzian1 width',
+ "TestGaussian1 area",
+ "TestGaussian1 center",
+ "TestGaussian1 width",
+ "TestLorentzian1 area",
+ "TestLorentzian1 center",
+ "TestLorentzian1 width",
}
retrieved_var_display_names = {var.display_name for var in all_vars}
@@ -132,9 +154,44 @@ def test_get_all_variables(self, model_base):
assert expected_var_display_names == retrieved_var_display_names
assert len(all_vars) == 18
+ def test_get_all_variables_with_Q_index(self, model_base):
+ # WHEN
+ all_vars = model_base.get_all_variables(Q_index=1)
+
+ # THEN
+ expected_var_display_names = {
+ "TestGaussian1 area",
+ "TestGaussian1 center",
+ "TestGaussian1 width",
+ "TestLorentzian1 area",
+ "TestLorentzian1 center",
+ "TestLorentzian1 width",
+ }
+
+ retrieved_var_display_names = {var.display_name for var in all_vars}
+
+ assert expected_var_display_names == retrieved_var_display_names
+ assert len(all_vars) == 6
+
+ def test_get_all_variables_with_invalid_Q_index_raises(self, model_base):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ IndexError,
+ match="Q_index 5 is out of bounds for component collections of length 3",
+ ):
+ model_base.get_all_variables(Q_index=5)
+
+ def test_get_all_variables_with_nonint_Q_index_raises(self, model_base):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(
+ TypeError,
+ match="Q_index must be an int or None, got str",
+ ):
+ model_base.get_all_variables(Q_index="invalid_index")
+
def test_append_and_remove_and_clear_component(self, model_base):
# WHEN
- new_component = Gaussian(unique_name='NewGaussian')
+ new_component = Gaussian(unique_name="NewGaussian")
# THEN
model_base.append_component(new_component)
@@ -144,7 +201,7 @@ def test_append_and_remove_and_clear_component(self, model_base):
assert model_base.components[-1] is new_component
# THEN
- model_base.remove_component('NewGaussian')
+ model_base.remove_component("NewGaussian")
# EXPECT
assert len(model_base.components) == 2
@@ -173,38 +230,40 @@ def test_append_component_collection(self, model_base):
def test_append_component_invalid_type_raises(self, model_base):
# WHEN / THEN / EXPECT
- with pytest.raises(TypeError, match=' must be a ModelComponent or ComponentCollection'):
- model_base.append_component('invalid_component')
+ with pytest.raises(
+ TypeError, match=" must be a ModelComponent or ComponentCollection"
+ ):
+ model_base.append_component("invalid_component")
def test_unit_property(self, model_base):
# WHEN
unit = model_base.unit
# THEN / EXPECT
- assert unit == 'meV'
+ assert unit == "meV"
def test_unit_setter_raises(self, model_base):
# WHEN / THEN / EXPECT
- with pytest.raises(AttributeError, match='Use convert_unit to change '):
- model_base.unit = 'K'
+ with pytest.raises(AttributeError, match="Use convert_unit to change "):
+ model_base.unit = "K"
def test_convert_unit(self, model_base):
# WHEN
- model_base.convert_unit('eV')
+ model_base.convert_unit("eV")
# THEN / EXPECT
- assert model_base.unit == 'eV'
+ assert model_base.unit == "eV"
for component in model_base.components:
- assert component.unit == 'eV'
+ assert component.unit == "eV"
def test_convert_unit_invalid_raises(self, model_base):
# WHEN / THEN / EXPECT
with pytest.raises(Exception):
- model_base.convert_unit('invalid_unit')
+ model_base.convert_unit("invalid_unit")
def test_components_setter(self, model_base):
# WHEN
- new_component = Lorentzian(unique_name='NewLorentzian')
+ new_component = Lorentzian(unique_name="NewLorentzian")
model_base.components = new_component
# THEN / EXPECT
@@ -230,9 +289,9 @@ def test_components_setter_invalid_raises(self, model_base):
# WHEN / THEN / EXPECT
with pytest.raises(
TypeError,
- match='Components must be ',
+ match="Components must be ",
):
- model_base.components = 'invalid_component'
+ model_base.components = "invalid_component"
def test_Q_setter(self, model_base):
# WHEN
@@ -247,7 +306,7 @@ def test_repr(self, model_base):
repr_str = repr(model_base)
# THEN / EXPECT
- assert 'unique_name' in repr_str
- assert 'unit' in repr_str
- assert 'Q = ' in repr_str
- assert 'components = ' in repr_str
+ assert "unique_name" in repr_str
+ assert "unit" in repr_str
+ assert "Q = " in repr_str
+ assert "components = " in repr_str
diff --git a/tests/unit/easydynamics/sample_model/test_resolution_model.py b/tests/unit/easydynamics/sample_model/test_resolution_model.py
index 120cbf9b..d45eee19 100644
--- a/tests/unit/easydynamics/sample_model/test_resolution_model.py
+++ b/tests/unit/easydynamics/sample_model/test_resolution_model.py
@@ -89,7 +89,7 @@ def test_init_raises_with_invalid_components(self, invalid_component, expected_e
collection.append_component(invalid_component)
ResolutionModel(components=collection)
- def test_append_and_remove_and_clear_component(self, resolution_model, reset_global_object):
+ def test_append_and_remove_and_clear_component(self, resolution_model):
# WHEN
new_component = Gaussian(unique_name='NewGaussian')
@@ -136,9 +136,7 @@ def test_append_component_collection(self, resolution_model):
],
ids=['DeltaFunction', 'Polynomial'],
)
- def test_append_invalid_component_type_raises(
- self, resolution_model, invalid_component, reset_global_object
- ):
+ def test_append_invalid_component_type_raises(self, resolution_model, invalid_component):
# WHEN / THEN / EXPECT
# appending a single component
with pytest.raises(
diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py
index 8a383ee3..e5f7a9a7 100644
--- a/tests/unit/easydynamics/sample_model/test_sample_model.py
+++ b/tests/unit/easydynamics/sample_model/test_sample_model.py
@@ -22,6 +22,7 @@ class TestSampleModel:
def sample_model(self):
component1 = Gaussian(
display_name='TestGaussian1',
+ unique_name='TestGaussian1',
area=1.0,
center=0.0,
width=1.0,
@@ -29,6 +30,7 @@ def sample_model(self):
)
component2 = Lorentzian(
display_name='TestLorentzian1',
+ unique_name='TestLorentzian1',
area=2.0,
center=1.0,
width=0.5,
@@ -38,7 +40,9 @@ def sample_model(self):
component_collection.append_component(component1)
component_collection.append_component(component2)
- diffusion_model = BrownianTranslationalDiffusion(display_name='DiffusionModel')
+ diffusion_model = BrownianTranslationalDiffusion(
+ display_name='DiffusionModel', unique_name='DiffusionModel'
+ )
sample_model = SampleModel(
display_name='InitModel',
@@ -52,6 +56,7 @@ def sample_model(self):
return sample_model
def test_init(self, sample_model):
+
# WHEN THEN
model = sample_model
diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py
index 97a6c36c..76c967c8 100644
--- a/tests/unit/easydynamics/utils/test_utils.py
+++ b/tests/unit/easydynamics/utils/test_utils.py
@@ -5,13 +5,14 @@
import pytest
import scipp as sc
+from easydynamics.utils.utils import _in_notebook
from easydynamics.utils.utils import _validate_and_convert_Q
from easydynamics.utils.utils import _validate_unit
class TestValidateAndConvertQ:
@pytest.mark.parametrize(
- 'Q_input, expected',
+ "Q_input, expected",
[
(1.0, np.array([1.0])),
(2, np.array([2])),
@@ -29,7 +30,7 @@ def test_validate_and_convert_Q_numeric_and_array(self, Q_input, expected):
def test_validate_and_convert_Q_scipp_variable(self):
# WHEN
- Q = sc.array(dims=['Q'], values=[1.0, 2.0], unit='1/angstrom')
+ Q = sc.array(dims=["Q"], values=[1.0, 2.0], unit="1/angstrom")
# THEN
result = _validate_and_convert_Q(Q)
@@ -43,29 +44,29 @@ def test_validate_and_convert_Q_none(self):
assert _validate_and_convert_Q(None) is None
@pytest.mark.parametrize(
- 'Q_input',
+ "Q_input",
[
- 'invalid',
- {'a': 1},
+ "invalid",
+ {"a": 1},
(1, 2),
object(),
],
)
def test_validate_and_convert_Q_invalid_type(self, Q_input):
# WHEN THEN EXPECT
- with pytest.raises(TypeError, match='Q must be a number'):
+ with pytest.raises(TypeError, match="Q must be a number"):
_validate_and_convert_Q(Q_input)
def test_validate_and_convert_Q_ndarray_wrong_dim(self):
# WHEN THEN
Q = np.array([[1.0, 2.0]])
# EXPECT
- with pytest.raises(ValueError, match='Q must be a 1-dimensional array'):
+ with pytest.raises(ValueError, match="Q must be a 1-dimensional array"):
_validate_and_convert_Q(Q)
def test_validate_and_convert_Q_scipp_wrong_dims(self):
# WHEN THEN
- Q = sc.array(dims=['x'], values=[1.0, 2.0], unit='1/angstrom')
+ Q = sc.array(dims=["x"], values=[1.0, 2.0], unit="1/angstrom")
# EXPECT
with pytest.raises(ValueError, match="single dimension named 'Q'"):
@@ -77,12 +78,12 @@ def test_validate_and_convert_Q_scipp_wrong_dims(self):
class TestValidateUnit:
@pytest.mark.parametrize(
- 'unit_input',
+ "unit_input",
[
None,
- '1/angstrom',
- 'meV',
- sc.Unit('meV'),
+ "1/angstrom",
+ "meV",
+ sc.Unit("meV"),
],
)
def test_validate_unit_valid(self, unit_input):
@@ -94,13 +95,13 @@ def test_validate_unit_valid(self, unit_input):
assert isinstance(unit, sc.Unit)
def test_validate_unit_string_conversion(self):
- unit = _validate_unit('meV')
+ unit = _validate_unit("meV")
assert isinstance(unit, sc.Unit)
- assert unit == sc.Unit('meV')
+ assert unit == sc.Unit("meV")
@pytest.mark.parametrize(
- 'unit_input',
+ "unit_input",
[
123,
45.6,
@@ -110,5 +111,69 @@ def test_validate_unit_string_conversion(self):
],
)
def test_validate_unit_invalid_type(self, unit_input):
- with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'):
+ with pytest.raises(
+ TypeError, match="unit must be None, a string, or a scipp Unit"
+ ):
_validate_unit(unit_input)
+
+
+# -----------------------------
+
+
+class TestInNotebook:
+
+ def test_in_notebook_returns_true_for_jupyter(self, monkeypatch):
+ """Should return True when IPython shell is
+ ZMQInteractiveShell (Jupyter)."""
+
+ # WHEN
+ class ZMQInteractiveShell:
+ __name__ = "ZMQInteractiveShell"
+
+ # THEN
+ monkeypatch.setattr("IPython.get_ipython", lambda: ZMQInteractiveShell())
+
+ # EXPECT
+ assert _in_notebook() is True
+
+ def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch):
+ """Should return False when IPython shell is
+ TerminalInteractiveShell."""
+
+ # WHEN
+ class TerminalInteractiveShell:
+ __name__ = "TerminalInteractiveShell"
+
+ # THEN
+
+ monkeypatch.setattr("IPython.get_ipython", lambda: TerminalInteractiveShell())
+
+ # EXPECT
+ assert _in_notebook() is False
+
+ def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch):
+ """Should return False when IPython shell type is
+ unrecognized."""
+
+ # WHEN
+ class UnknownShell:
+ __name__ = "UnknownShell"
+
+ # THEN
+ monkeypatch.setattr("IPython.get_ipython", lambda: UnknownShell())
+ # EXPECT
+ assert _in_notebook() is False
+
+ def test_in_notebook_returns_false_when_no_ipython(self, monkeypatch):
+ """Should return False when IPython is not installed or
+ available."""
+
+ # WHEN
+ def raise_import_error(*args, **kwargs):
+ raise ImportError
+
+ # THEN
+ monkeypatch.setattr("builtins.__import__", raise_import_error)
+
+ # EXPECT
+ assert _in_notebook() is False
diff --git a/tools/update_github_labels.py b/tools/update_github_labels.py
new file mode 100644
index 00000000..a18043d0
--- /dev/null
+++ b/tools/update_github_labels.py
@@ -0,0 +1,254 @@
+"""
+Set/update GitHub labels for current or specified easyscience
+repository.
+
+Requires:
+ - gh CLI installed
+ - gh auth login completed
+
+Usage:
+ python update_github_labels.py
+ python update_github_labels.py --dry-run
+ python update_github_labels.py --repo easyscience/my-repo
+ python update_github_labels.py --repo easyscience/my-repo --dry-run
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import shlex
+import subprocess
+import sys
+from dataclasses import dataclass
+from typing import Iterable
+
+
+EASYSCIENCE_ORG = 'easyscience'
+
+
+# --- Label definitions -----------------------------------------------------------
+
+BASIC_GITHUB_LABELS = [
+ 'bug',
+ 'documentation',
+ 'duplicate',
+ 'enhancement',
+ 'good first issue',
+ 'help wanted',
+ 'invalid',
+ 'question',
+ 'wontfix',
+]
+
+NEW_BASIC_LABEL_NAMES = [
+ '[scope] bug',
+ '[scope] documentation',
+ '[maintainer] duplicate',
+ '[scope] enhancement',
+ '[maintainer] good first issue',
+ '[maintainer] help wanted',
+ '[maintainer] invalid',
+ '[maintainer] question',
+ '[maintainer] wontfix',
+]
+
+SCOPE_LABELS = [
+ ('bug', 'Bug report or fix (major.minor.PATCH)'),
+ ('documentation', 'Documentation only changes (major.minor.patch.POST)'),
+ ('enhancement', 'Adds/improves features (major.MINOR.patch)'),
+ ('maintenance', 'Code/tooling cleanup, no feature or bugfix (major.minor.PATCH)'),
+ ('significant', 'Breaking or major changes (MAJOR.minor.patch)'),
+ ('โ ๏ธ label needed', 'Automatically added to issues and PRs without a [scope] label'),
+]
+
+MAINTAINER_LABELS = [
+ ('duplicate', 'Already reported or submitted'),
+ ('good first issue', 'Good entry-level issue for newcomers'),
+ ('help wanted', 'Needs additional help to resolve or implement'),
+ ('invalid', 'Invalid, incorrect or outdated'),
+ ('question', 'Needs clarification, discussion, or more information'),
+ ('wontfix', 'Will not be fixed or continued'),
+]
+
+PRIORITY_LABELS = [
+ ('lowest', 'Very low urgency'),
+ ('low', 'Low importance'),
+ ('medium', 'Normal/default priority'),
+ ('high', 'Should be prioritized soon'),
+ ('highest', 'Urgent. Needs attention ASAP'),
+ ('โ ๏ธ label needed', 'Automatically added to issues without a [priority] label'),
+]
+
+BOT_LABEL = (
+ '[bot] pull request',
+ 'Automated release PR. Excluded from changelog/versioning',
+)
+
+COLORS = {
+ 'scope': 'd73a4a',
+ 'maintainer': '0e8a16',
+ 'priority': 'fbca04',
+ 'bot': '5319e7',
+}
+
+
+# --- Helpers --------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class CmdResult:
+ returncode: int
+ stdout: str
+ stderr: str
+
+
+def run_cmd(args: list[str], *, dry_run: bool, check: bool = True) -> CmdResult:
+ """Run a command (or print it in dry-run mode)."""
+ cmd_str = ' '.join(shlex.quote(a) for a in args)
+
+ if dry_run:
+ print(f'{cmd_str}')
+ return CmdResult(0, '', '')
+
+ proc = subprocess.run(
+ args,
+ text=True,
+ capture_output=True,
+ )
+ res = CmdResult(proc.returncode, proc.stdout.strip(), proc.stderr.strip())
+
+ if check and proc.returncode != 0:
+ raise RuntimeError(f'Command failed ({proc.returncode}): {cmd_str}\n{res.stderr}')
+
+ return res
+
+
+def get_current_repo_name_with_owner() -> str:
+ res = subprocess.run(
+ ['gh', 'repo', 'view', '--json', 'nameWithOwner'],
+ text=True,
+ capture_output=True,
+ check=True,
+ )
+ data = json.loads(res.stdout)
+ nwo = data.get('nameWithOwner')
+ if not nwo or '/' not in nwo:
+ raise RuntimeError('Could not determine current repository name')
+ return nwo
+
+
+def try_rename_label(repo: str, old: str, new: str, *, dry_run: bool) -> None:
+ try:
+ run_cmd(
+ ['gh', 'label', 'edit', old, '--name', new, '--repo', repo],
+ dry_run=dry_run,
+ )
+ print(f'Rename: {old!r} โ {new!r}')
+ except Exception:
+ print(f'Skip rename (label not found): {old!r}')
+
+
+def upsert_label(
+ repo: str,
+ name: str,
+ color: str,
+ description: str,
+ *,
+ dry_run: bool,
+) -> None:
+ run_cmd(
+ [
+ 'gh',
+ 'label',
+ 'create',
+ name,
+ '--color',
+ color,
+ '--description',
+ description,
+ '--force',
+ '--repo',
+ repo,
+ ],
+ dry_run=dry_run,
+ )
+ print(f'Upsert label: {name!r}')
+
+
+def upsert_group(
+ repo: str,
+ prefix: str,
+ color: str,
+ items: Iterable[tuple[str, str]],
+ *,
+ dry_run: bool,
+) -> None:
+ for short, desc in items:
+ upsert_label(
+ repo,
+ f'[{prefix}] {short}',
+ color,
+ desc,
+ dry_run=dry_run,
+ )
+
+
+# --- Main -----------------------------------------------------------------------
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(description='Sync GitHub labels for easyscience repos')
+ parser.add_argument(
+ '--repo',
+ help='Target repository in the form easyscience/',
+ )
+ parser.add_argument(
+ '--dry-run',
+ action='store_true',
+ help='Print actions without applying changes',
+ )
+ args = parser.parse_args()
+
+ if args.repo:
+ repo = args.repo
+ else:
+ repo = get_current_repo_name_with_owner()
+
+ org, _ = repo.split('/', 1)
+
+ if org.lower() != EASYSCIENCE_ORG:
+ print(
+ f"Refusing to run: repository {repo!r} is not under '{EASYSCIENCE_ORG}'.",
+ file=sys.stderr,
+ )
+ return 2
+
+ print(f'Target repository: {repo}')
+ if args.dry_run:
+ print('Running in DRY-RUN mode (no changes will be made)\n')
+
+ # 1) Rename basic labels
+ for old, new in zip(BASIC_GITHUB_LABELS, NEW_BASIC_LABEL_NAMES, strict=True):
+ try_rename_label(repo, old, new, dry_run=args.dry_run)
+
+ # 2) Scope / Maintainer / Priority groups
+ upsert_group(repo, 'scope', COLORS['scope'], SCOPE_LABELS, dry_run=args.dry_run)
+ upsert_group(repo, 'maintainer', COLORS['maintainer'], MAINTAINER_LABELS, dry_run=args.dry_run)
+ upsert_group(repo, 'priority', COLORS['priority'], PRIORITY_LABELS, dry_run=args.dry_run)
+
+ # 3) Bot label
+ upsert_label(
+ repo,
+ BOT_LABEL[0],
+ COLORS['bot'],
+ BOT_LABEL[1],
+ dry_run=args.dry_run,
+ )
+
+ print('\nDone.')
+ return 0
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())