diff --git a/docs/notebooks/optimization_leaf_dynamics.ipynb b/docs/notebooks/optimization_leaf_dynamics.ipynb index df0cd1c..b54c9b8 100644 --- a/docs/notebooks/optimization_leaf_dynamics.ipynb +++ b/docs/notebooks/optimization_leaf_dynamics.ipynb @@ -50,10 +50,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "e4049fea-1d05-41f1-bf9d-f030ae83a324", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: diffwofost in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (0.2.0)\n", + "Requirement already satisfied: torch in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (2.9.0)\n", + "Requirement already satisfied: pcse in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (6.0.9)\n", + "Requirement already satisfied: SQLAlchemy<2.0,>=1.3.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.4.54)\n", + "Requirement already satisfied: PyYAML>=5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (6.0.3)\n", + "Requirement already satisfied: openpyxl>=3.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (3.1.5)\n", + "Requirement already satisfied: requests>=2.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.32.5)\n", + "Requirement already satisfied: pandas>=0.25 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.3.3)\n", + "Requirement already satisfied: traitlets-pcse==5.0.0.dev in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (5.0.0.dev0)\n", + "Requirement already satisfied: dotmap>=1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.3.30)\n", + "Requirement already satisfied: ipython_genutils in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (0.2.0)\n", + "Requirement already satisfied: six in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (1.17.0)\n", + "Requirement already satisfied: decorator in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (5.2.1)\n", + "Requirement already satisfied: filelock in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (4.15.0)\n", + "Requirement already satisfied: setuptools in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (80.9.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.14.0)\n", + "Requirement already satisfied: networkx>=2.5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5)\n", + "Requirement already satisfied: jinja2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.1.6)\n", + "Requirement already satisfied: fsspec>=0.8.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2025.9.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.3.3.83)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (10.3.9.90)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.7.3.90)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.5.8.93)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2.27.5)\n", + "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.3.20)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.13.1.3)\n", + "Requirement already satisfied: triton==3.5.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5.0)\n", + "Requirement already satisfied: et-xmlfile in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from openpyxl>=3.0.0->pcse->diffwofost) (2.0.0)\n", + "Requirement already satisfied: numpy>=1.26.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.3.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2025.10.5)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from SQLAlchemy<2.0,>=1.3.0->pcse->diffwofost) (3.2.4)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from sympy>=1.13.3->torch->diffwofost) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from jinja2->torch->diffwofost) (3.0.3)\n" + ] + } + ], "source": [ "# install diffwofost\n", "!pip install diffwofost" @@ -61,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "21731653-3976-4bb9-b83b-b11d78211700", "metadata": {}, "outputs": [], @@ -72,7 +127,7 @@ "import numpy\n", "import yaml\n", "from pathlib import Path\n", - "from diffwofost.physical_models.config import Configuration\n", + "from diffwofost.physical_models.config import Configuration, ComputeConfig\n", "from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics\n", "from diffwofost.physical_models.utils import EngineTestHelper\n", "from diffwofost.physical_models.utils import prepare_engine_input\n", @@ -81,19 +136,7 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "9c2ca761-22b0-4707-9f55-619f62617d12", - "metadata": {}, - "outputs": [], - "source": [ - "# --- run on CPU ------\n", - "from diffwofost.physical_models.config import ComputeConfig\n", - "ComputeConfig.set_device('cpu')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "82a1ef6b-336e-4902-8bd1-2a1ed2020f9d", "metadata": {}, "outputs": [], @@ -123,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "0233a048-e5a2-4249-887d-35a37284769c", "metadata": {}, "outputs": [ @@ -147,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "5a459489-bfcb-4ad6-9102-1b6be5edeb52", "metadata": {}, "outputs": [], @@ -158,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "9f3105fb-4fbe-4405-9fd4-e8255b4b119e", "metadata": {}, "outputs": [], @@ -171,7 +214,9 @@ "\n", "expected_results = test_data[\"ModelResults\"]\n", "expected_lai_twlv = torch.tensor(\n", - " [[float(item[\"LAI\"]), float(item[\"TWLV\"])] for item in expected_results], dtype=torch.float32\n", + " [[float(item[\"LAI\"]), float(item[\"TWLV\"])] for item in expected_results],\n", + " dtype=ComputeConfig.get_dtype(),\n", + " device=ComputeConfig.get_device(),\n", ").unsqueeze(0) # shape: [1, time_steps, 2]\n", "\n", "# ---- dont change this: in this config file we specified the diffrentiable version of leaf_dynamics ----\n", @@ -197,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "e4610238-de0d-42cf-9689-3c074eb2cc0e", "metadata": {}, "outputs": [], @@ -217,10 +262,17 @@ " init_norm = (init_value - low) / (high - low)\n", "\n", " # Parameter in raw logit space\n", - " self.raw = torch.nn.Parameter(torch.logit(torch.tensor(init_norm, dtype=torch.float32), eps=1e-6))\n", + " self.raw = torch.nn.Parameter(\n", + " torch.logit(\n", + " torch.tensor(\n", + " init_norm, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device()\n", + " ),\n", + " eps=1e-6,\n", + " )\n", + " )\n", "\n", " def forward(self):\n", - " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)\n" + " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)" ] }, { @@ -234,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "36dd6463-4812-41c0-b2bf-d4769df1136f", "metadata": {}, "outputs": [], @@ -282,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "3d34c3e8-a8d7-4bc9-94ed-bd2e0234e95c", "metadata": {}, "outputs": [], @@ -299,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "78d797f5-4ac4-4380-85f3-6622a7b0f7fb", "metadata": {}, "outputs": [ @@ -312,22 +364,22 @@ "Step 2, Loss 0.2117, TDWI 0.4727, SPAN 28.3303\n", "Step 3, Loss 0.1790, TDWI 0.4962, SPAN 29.5140\n", "Step 4, Loss 0.1502, TDWI 0.5190, SPAN 30.7290\n", - "Step 5, Loss 0.1148, TDWI 0.5351, SPAN 31.8314\n", - "Step 6, Loss 0.0873, TDWI 0.5459, SPAN 32.8227\n", - "Step 7, Loss 0.0633, TDWI 0.5525, SPAN 33.6916\n", - "Step 8, Loss 0.0380, TDWI 0.5559, SPAN 34.4434\n", - "Step 9, Loss 0.0164, TDWI 0.5564, SPAN 35.0537\n", - "Step 10, Loss 0.0016, TDWI 0.5543, SPAN 35.2340\n", - "Step 11, Loss 0.0050, TDWI 0.5500, SPAN 35.0621\n", - "Step 12, Loss 0.0018, TDWI 0.5436, SPAN 34.6755\n", - "Step 13, Loss 0.0109, TDWI 0.5356, SPAN 34.2452\n", - "Step 14, Loss 0.0230, TDWI 0.5261, SPAN 33.8167\n", - "Step 15, Loss 0.0341, TDWI 0.5152, SPAN 33.4017\n", - "Step 16, Loss 0.0472, TDWI 0.5031, SPAN 33.0102\n", - "Step 17, Loss 0.0559, TDWI 0.4950, SPAN 32.9639\n", - "Step 18, Loss 0.0571, TDWI 0.4903, SPAN 33.1766\n", - "Step 19, Loss 0.0519, TDWI 0.4888, SPAN 33.5989\n", - "Step 20, Loss 0.0400, TDWI 0.4899, SPAN 34.1872\n", + "Step 5, Loss 0.1148, TDWI 0.5351, SPAN 31.8327\n", + "Step 6, Loss 0.0873, TDWI 0.5459, SPAN 32.8254\n", + "Step 7, Loss 0.0633, TDWI 0.5525, SPAN 33.6958\n", + "Step 8, Loss 0.0375, TDWI 0.5559, SPAN 34.4362\n", + "Step 9, Loss 0.0164, TDWI 0.5564, SPAN 35.0038\n", + "Step 10, Loss 0.0001, TDWI 0.5543, SPAN 35.2185\n", + "Step 11, Loss 0.0048, TDWI 0.5500, SPAN 35.0892\n", + "Step 12, Loss 0.0019, TDWI 0.5436, SPAN 34.7440\n", + "Step 13, Loss 0.0092, TDWI 0.5356, SPAN 34.3535\n", + "Step 14, Loss 0.0175, TDWI 0.5260, SPAN 33.9622\n", + "Step 15, Loss 0.0296, TDWI 0.5151, SPAN 33.5733\n", + "Step 16, Loss 0.0409, TDWI 0.5030, SPAN 33.2059\n", + "Step 17, Loss 0.0512, TDWI 0.4949, SPAN 33.2068\n", + "Step 18, Loss 0.0512, TDWI 0.4903, SPAN 33.4762\n", + "Step 19, Loss 0.0422, TDWI 0.4888, SPAN 33.9525\n", + "Step 20, Loss 0.0299, TDWI 0.4900, SPAN 34.5985\n", "Early stopping at step 20\n" ] } @@ -369,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "c2d3a463-43a4-4b29-a71f-696c019343d3", "metadata": {}, "outputs": [ @@ -397,7 +449,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "dwof", "language": "python", "name": "python3" }, @@ -411,7 +463,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/notebooks/optimization_phenology.ipynb b/docs/notebooks/optimization_phenology.ipynb index 32551d4..46fca6b 100644 --- a/docs/notebooks/optimization_phenology.ipynb +++ b/docs/notebooks/optimization_phenology.ipynb @@ -45,10 +45,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "e4049fea-1d05-41f1-bf9d-f030ae83a324", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: diffwofost in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (0.2.0)\n", + "Requirement already satisfied: torch in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (2.9.0)\n", + "Requirement already satisfied: pcse in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (6.0.9)\n", + "Requirement already satisfied: SQLAlchemy<2.0,>=1.3.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.4.54)\n", + "Requirement already satisfied: PyYAML>=5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (6.0.3)\n", + "Requirement already satisfied: openpyxl>=3.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (3.1.5)\n", + "Requirement already satisfied: requests>=2.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.32.5)\n", + "Requirement already satisfied: pandas>=0.25 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.3.3)\n", + "Requirement already satisfied: traitlets-pcse==5.0.0.dev in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (5.0.0.dev0)\n", + "Requirement already satisfied: dotmap>=1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.3.30)\n", + "Requirement already satisfied: ipython_genutils in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (0.2.0)\n", + "Requirement already satisfied: six in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (1.17.0)\n", + "Requirement already satisfied: decorator in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (5.2.1)\n", + "Requirement already satisfied: filelock in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (4.15.0)\n", + "Requirement already satisfied: setuptools in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (80.9.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.14.0)\n", + "Requirement already satisfied: networkx>=2.5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5)\n", + "Requirement already satisfied: jinja2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.1.6)\n", + "Requirement already satisfied: fsspec>=0.8.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2025.9.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.3.3.83)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (10.3.9.90)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.7.3.90)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.5.8.93)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2.27.5)\n", + "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.3.20)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.13.1.3)\n", + "Requirement already satisfied: triton==3.5.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5.0)\n", + "Requirement already satisfied: et-xmlfile in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from openpyxl>=3.0.0->pcse->diffwofost) (2.0.0)\n", + "Requirement already satisfied: numpy>=1.26.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.3.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2025.10.5)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from SQLAlchemy<2.0,>=1.3.0->pcse->diffwofost) (3.2.4)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from sympy>=1.13.3->torch->diffwofost) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from jinja2->torch->diffwofost) (3.0.3)\n" + ] + } + ], "source": [ "# install diffwofost\n", "!pip install diffwofost" @@ -56,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "21731653-3976-4bb9-b83b-b11d78211700", "metadata": {}, "outputs": [], @@ -66,7 +121,7 @@ "import torch\n", "import numpy\n", "from pathlib import Path\n", - "from diffwofost.physical_models.config import Configuration\n", + "from diffwofost.physical_models.config import Configuration, ComputeConfig\n", "from diffwofost.physical_models.crop.phenology import DVS_Phenology\n", "from diffwofost.physical_models.utils import EngineTestHelper\n", "from diffwofost.physical_models.utils import prepare_engine_input\n", @@ -75,19 +130,7 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "8a06565c-139a-4039-92d5-bcfd7bcf8344", - "metadata": {}, - "outputs": [], - "source": [ - "# --- run on CPU ------\n", - "from diffwofost.physical_models.config import ComputeConfig\n", - "ComputeConfig.set_device('cpu')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "82a1ef6b-336e-4902-8bd1-2a1ed2020f9d", "metadata": {}, "outputs": [], @@ -118,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "0233a048-e5a2-4249-887d-35a37284769c", "metadata": {}, "outputs": [ @@ -142,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "5a459489-bfcb-4ad6-9102-1b6be5edeb52", "metadata": {}, "outputs": [], @@ -153,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "id": "a39f030b-ca6f-4535-8692-7883476ae7a4", "metadata": {}, "outputs": [], @@ -182,7 +225,10 @@ ")\n", "\n", "expected_results = test_data[\"ModelResults\"]\n", - "expected_dvs = torch.tensor([float(item[\"DVS\"]) for item in expected_results], dtype=torch.float32\n", + "expected_dvs = torch.tensor(\n", + " [float(item[\"DVS\"]) for item in expected_results],\n", + " dtype=ComputeConfig.get_dtype(),\n", + " device=ComputeConfig.get_device(),\n", ") # shape: [time_steps]\n", "\n", "# ---- don't change this: in this config file we specify the differentiable version of DVS_Phenology ----\n", @@ -208,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "id": "e4610238-de0d-42cf-9689-3c074eb2cc0e", "metadata": {}, "outputs": [], @@ -230,10 +276,17 @@ " init_norm = (init_value - low) / (high - low)\n", "\n", " # Parameter in raw logit space\n", - " self.raw = torch.nn.Parameter(torch.logit(torch.tensor(init_norm, dtype=torch.float32), eps=1e-6))\n", + " self.raw = torch.nn.Parameter(\n", + " torch.logit(\n", + " torch.tensor(\n", + " init_norm, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device()\n", + " ),\n", + " eps=1e-6,\n", + " )\n", + " )\n", "\n", " def forward(self):\n", - " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)\n" + " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)" ] }, { @@ -247,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "id": "36dd6463-4812-41c0-b2bf-d4769df1136f", "metadata": {}, "outputs": [], @@ -296,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "id": "2a0754ac-4cf1-4ed7-9059-af80484beb33", "metadata": {}, "outputs": [], @@ -312,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "id": "124a6077-64b7-4816-b42f-538e3f8e0538", "metadata": {}, "outputs": [ @@ -321,21 +374,21 @@ "output_type": "stream", "text": [ "Step 0: duration mismatch (260 vs 279).\n", - "Step 0, Loss 0.1490, TSUMEM 85.0787, TBASEM 1.8448, TSUM1 815.5215, TSUM2 815.5215,\n", + "Step 0, Loss 0.1490, TSUMEM 85.0787, TBASEM 1.8448, TSUM1 815.5214, TSUM2 815.5214,\n", "Step 1: duration mismatch (262 vs 279).\n", "Step 1, Loss 0.1348, TSUMEM 80.2344, TBASEM 1.6999, TSUM1 830.0543, TSUM2 830.0643,\n", "Step 2: duration mismatch (263 vs 279).\n", - "Step 2, Loss 0.1197, TSUMEM 77.2860, TBASEM 1.6076, TSUM1 843.6052, TSUM2 843.6012,\n", + "Step 2, Loss 0.1197, TSUMEM 77.2860, TBASEM 1.6076, TSUM1 843.6053, TSUM2 843.6012,\n", "Step 3: duration mismatch (264 vs 279).\n", - "Step 3, Loss 0.1147, TSUMEM 76.5338, TBASEM 1.5720, TSUM1 856.1688, TSUM2 856.1740,\n", + "Step 3, Loss 0.1147, TSUMEM 76.5338, TBASEM 1.5720, TSUM1 856.1687, TSUM2 856.1739,\n", "Step 4: duration mismatch (266 vs 279).\n", - "Step 4, Loss 0.1019, TSUMEM 77.1810, TBASEM 1.5731, TSUM1 867.7785, TSUM2 867.8158,\n", + "Step 4, Loss 0.1019, TSUMEM 77.1810, TBASEM 1.5731, TSUM1 867.7784, TSUM2 867.8158,\n", "Step 5: duration mismatch (267 vs 279).\n", - "Step 5, Loss 0.0881, TSUMEM 78.6763, TBASEM 1.5976, TSUM1 878.4762, TSUM2 878.5369,\n", + "Step 5, Loss 0.0881, TSUMEM 78.6763, TBASEM 1.5976, TSUM1 878.4762, TSUM2 878.5370,\n", "Step 6: duration mismatch (268 vs 279).\n", "Step 6, Loss 0.0830, TSUMEM 80.7683, TBASEM 1.6402, TSUM1 888.2892, TSUM2 888.3950,\n", "Step 7: duration mismatch (269 vs 279).\n", - "Step 7, Loss 0.0698, TSUMEM 82.9896, TBASEM 1.6870, TSUM1 897.2725, TSUM2 897.4227,\n", + "Step 7, Loss 0.0698, TSUMEM 82.9896, TBASEM 1.6870, TSUM1 897.2726, TSUM2 897.4227,\n", "Step 8: duration mismatch (270 vs 279).\n", "Step 8, Loss 0.0568, TSUMEM 84.5758, TBASEM 1.7161, TSUM1 905.4835, TSUM2 905.6589,\n", "Step 9: duration mismatch (271 vs 279).\n", @@ -343,41 +396,41 @@ "Step 10: duration mismatch (271 vs 279).\n", "Step 10, Loss 0.0480, TSUMEM 84.3238, TBASEM 1.6843, TSUM1 919.7631, TSUM2 920.0091,\n", "Step 11: duration mismatch (273 vs 279).\n", - "Step 11, Loss 0.0381, TSUMEM 83.4182, TBASEM 1.6478, TSUM1 925.9421, TSUM2 926.2325,\n", + "Step 11, Loss 0.0381, TSUMEM 83.4182, TBASEM 1.6478, TSUM1 925.9420, TSUM2 926.2325,\n", "Step 12: duration mismatch (273 vs 279).\n", "Step 12, Loss 0.0355, TSUMEM 82.3086, TBASEM 1.6063, TSUM1 931.5499, TSUM2 931.8865,\n", "Step 13: duration mismatch (273 vs 279).\n", - "Step 13, Loss 0.0324, TSUMEM 81.3026, TBASEM 1.5680, TSUM1 936.6345, TSUM2 937.0161,\n", + "Step 13, Loss 0.0324, TSUMEM 81.3026, TBASEM 1.5680, TSUM1 936.6345, TSUM2 937.0160,\n", "Step 14: duration mismatch (275 vs 279).\n", - "Step 14, Loss 0.0245, TSUMEM 80.8495, TBASEM 1.5439, TSUM1 941.2473, TSUM2 941.6774,\n", + "Step 14, Loss 0.0245, TSUMEM 80.8495, TBASEM 1.5439, TSUM1 941.2473, TSUM2 941.6775,\n", "Step 15: duration mismatch (275 vs 279).\n", "Step 15, Loss 0.0220, TSUMEM 81.1065, TBASEM 1.5381, TSUM1 945.4302, TSUM2 945.9092,\n", "Step 16: duration mismatch (275 vs 279).\n", - "Step 16, Loss 0.0197, TSUMEM 81.9637, TBASEM 1.5478, TSUM1 949.2226, TSUM2 949.7485,\n", + "Step 16, Loss 0.0197, TSUMEM 81.9637, TBASEM 1.5478, TSUM1 949.2227, TSUM2 949.7485,\n", "Step 17: duration mismatch (276 vs 279).\n", - "Step 17, Loss 0.0103, TSUMEM 83.1409, TBASEM 1.5657, TSUM1 952.6663, TSUM2 953.2308,\n", + "Step 17, Loss 0.0103, TSUMEM 83.1409, TBASEM 1.5657, TSUM1 952.6663, TSUM2 953.2309,\n", "Step 18: duration mismatch (277 vs 279).\n", - "Step 18, Loss 0.0093, TSUMEM 84.1272, TBASEM 1.5787, TSUM1 955.4659, TSUM2 956.3961,\n", + "Step 18, Loss 0.0093, TSUMEM 84.1272, TBASEM 1.5787, TSUM1 955.4660, TSUM2 956.3960,\n", "Step 19: duration mismatch (277 vs 279).\n", "Step 19, Loss 0.0093, TSUMEM 84.7385, TBASEM 1.5820, TSUM1 957.7150, TSUM2 959.2729,\n", "Step 20: duration mismatch (277 vs 279).\n", - "Step 20, Loss 0.0093, TSUMEM 85.0120, TBASEM 1.5765, TSUM1 959.5129, TSUM2 961.8885,\n", + "Step 20, Loss 0.0093, TSUMEM 85.0120, TBASEM 1.5765, TSUM1 959.5129, TSUM2 961.8886,\n", "Step 21: duration mismatch (277 vs 279).\n", "Step 21, Loss 0.0092, TSUMEM 84.9791, TBASEM 1.5633, TSUM1 960.9411, TSUM2 964.2680,\n", "Step 22: duration mismatch (277 vs 279).\n", "Step 22, Loss 0.0091, TSUMEM 84.6666, TBASEM 1.5432, TSUM1 962.0599, TSUM2 966.4341,\n", "Step 23: duration mismatch (278 vs 279).\n", - "Step 23, Loss 0.0090, TSUMEM 84.0982, TBASEM 1.5171, TSUM1 962.9180, TSUM2 968.4114,\n", - "Step 24, Loss 0.0078, TSUMEM 83.4926, TBASEM 1.4905, TSUM1 963.5505, TSUM2 970.0585,\n", - "Step 25, Loss 0.0082, TSUMEM 83.2271, TBASEM 1.4719, TSUM1 963.9872, TSUM2 971.4006,\n", + "Step 23, Loss 0.0090, TSUMEM 84.0982, TBASEM 1.5171, TSUM1 962.9181, TSUM2 968.4114,\n", + "Step 24, Loss 0.0078, TSUMEM 83.4926, TBASEM 1.4905, TSUM1 963.5505, TSUM2 970.0586,\n", + "Step 25, Loss 0.0082, TSUMEM 83.2271, TBASEM 1.4719, TSUM1 963.9873, TSUM2 971.4005,\n", "Step 26, Loss 0.0086, TSUMEM 83.4078, TBASEM 1.4639, TSUM1 964.2517, TSUM2 972.4788,\n", "Step 27, Loss 0.0090, TSUMEM 83.9896, TBASEM 1.4651, TSUM1 964.3623, TSUM2 973.3393,\n", "Step 28, Loss 0.0092, TSUMEM 84.8013, TBASEM 1.4715, TSUM1 964.3331, TSUM2 974.0173,\n", - "Step 29, Loss 0.0093, TSUMEM 85.4506, TBASEM 1.4742, TSUM1 964.1751, TSUM2 974.5405,\n", - "Step 30, Loss 0.0094, TSUMEM 85.9211, TBASEM 1.4726, TSUM1 963.8970, TSUM2 974.9305,\n", + "Step 29, Loss 0.0093, TSUMEM 85.4506, TBASEM 1.4742, TSUM1 964.1752, TSUM2 974.5405,\n", + "Step 30, Loss 0.0094, TSUMEM 85.9211, TBASEM 1.4726, TSUM1 963.8970, TSUM2 974.9304,\n", "Step 31, Loss 0.0094, TSUMEM 86.0486, TBASEM 1.4633, TSUM1 963.5046, TSUM2 975.2042,\n", "Step 32, Loss 0.0093, TSUMEM 85.8631, TBASEM 1.4471, TSUM1 963.0023, TSUM2 975.3755,\n", - "Step 33, Loss 0.0092, TSUMEM 85.6006, TBASEM 1.4293, TSUM1 962.3921, TSUM2 975.4550,\n", + "Step 33, Loss 0.0092, TSUMEM 85.6006, TBASEM 1.4293, TSUM1 962.3921, TSUM2 975.4549,\n", "Step 34, Loss 0.0090, TSUMEM 85.2661, TBASEM 1.4101, TSUM1 961.6753, TSUM2 975.4510,\n", "Early stopping at step 34\n", "duration (model 279 vs test 279).\n" @@ -434,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "id": "5b0a6b12-3ca9-4cf3-9dd5-ac2c11bf4fc6", "metadata": {}, "outputs": [ @@ -455,19 +508,11 @@ " f\"TSUM2 {crop_model_params_provider[\"TSUM2\"].item():.4f}\"\n", ")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1e06239-c037-4433-9da2-feb46b52a8e4", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "dwof", "language": "python", "name": "python3" }, @@ -481,7 +526,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/notebooks/optimization_root_dynamics.ipynb b/docs/notebooks/optimization_root_dynamics.ipynb index 0c22f7d..19da376 100644 --- a/docs/notebooks/optimization_root_dynamics.ipynb +++ b/docs/notebooks/optimization_root_dynamics.ipynb @@ -50,10 +50,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "9732ef36-f46f-4c5d-b789-ba0ff0eab0d8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: diffwofost in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (0.2.0)\n", + "Requirement already satisfied: torch in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (2.9.0)\n", + "Requirement already satisfied: pcse in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from diffwofost) (6.0.9)\n", + "Requirement already satisfied: SQLAlchemy<2.0,>=1.3.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.4.54)\n", + "Requirement already satisfied: PyYAML>=5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (6.0.3)\n", + "Requirement already satisfied: openpyxl>=3.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (3.1.5)\n", + "Requirement already satisfied: requests>=2.0.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.32.5)\n", + "Requirement already satisfied: pandas>=0.25 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (2.3.3)\n", + "Requirement already satisfied: traitlets-pcse==5.0.0.dev in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (5.0.0.dev0)\n", + "Requirement already satisfied: dotmap>=1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pcse->diffwofost) (1.3.30)\n", + "Requirement already satisfied: ipython_genutils in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (0.2.0)\n", + "Requirement already satisfied: six in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (1.17.0)\n", + "Requirement already satisfied: decorator in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from traitlets-pcse==5.0.0.dev->pcse->diffwofost) (5.2.1)\n", + "Requirement already satisfied: filelock in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (4.15.0)\n", + "Requirement already satisfied: setuptools in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (80.9.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.14.0)\n", + "Requirement already satisfied: networkx>=2.5.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5)\n", + "Requirement already satisfied: jinja2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.1.6)\n", + "Requirement already satisfied: fsspec>=0.8.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2025.9.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.3.3.83)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (10.3.9.90)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (11.7.3.90)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.5.8.93)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (2.27.5)\n", + "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.3.20)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.90)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (12.8.93)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (1.13.1.3)\n", + "Requirement already satisfied: triton==3.5.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from torch->diffwofost) (3.5.0)\n", + "Requirement already satisfied: et-xmlfile in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from openpyxl>=3.0.0->pcse->diffwofost) (2.0.0)\n", + "Requirement already satisfied: numpy>=1.26.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.3.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from pandas>=0.25->pcse->diffwofost) (2025.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from requests>=2.0.0->pcse->diffwofost) (2025.10.5)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from SQLAlchemy<2.0,>=1.3.0->pcse->diffwofost) (3.2.4)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from sympy>=1.13.3->torch->diffwofost) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/simone/.virtualenvs/dwof/lib/python3.12/site-packages (from jinja2->torch->diffwofost) (3.0.3)\n" + ] + } + ], "source": [ "# install diffwofost\n", "!pip install diffwofost" @@ -81,25 +136,13 @@ "import numpy\n", "import yaml\n", "from pathlib import Path\n", - "from diffwofost.physical_models.config import Configuration\n", + "from diffwofost.physical_models.config import Configuration, ComputeConfig\n", "from diffwofost.physical_models.crop.root_dynamics import WOFOST_Root_Dynamics\n", "from diffwofost.physical_models.utils import EngineTestHelper\n", "from diffwofost.physical_models.utils import prepare_engine_input\n", "from diffwofost.physical_models.utils import get_test_data" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "10c6361e-ad31-455f-aed1-b7a282cb793c", - "metadata": {}, - "outputs": [], - "source": [ - "# --- run on CPU ------\n", - "from diffwofost.physical_models.config import ComputeConfig\n", - "ComputeConfig.set_device('cpu')" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -186,7 +229,9 @@ "\n", "expected_results = test_data[\"ModelResults\"]\n", "expected_twrt = torch.tensor(\n", - " [float(item[\"TWRT\"]) for item in expected_results], dtype=torch.float32\n", + " [float(item[\"TWRT\"]) for item in expected_results],\n", + " dtype=ComputeConfig.get_dtype(),\n", + " device=ComputeConfig.get_device(),\n", ") # shape: [1, time_steps]\n", "\n", "# ---- dont change this: in this config file we specified the diffrentiable version of root_dynamics ----\n", @@ -231,10 +276,17 @@ " init_norm = (init_value - low) / (high - low)\n", "\n", " # Parameter in raw logit space\n", - " self.raw = torch.nn.Parameter(torch.logit(torch.tensor(init_norm, dtype=torch.float32), eps=1e-6))\n", + " self.raw = torch.nn.Parameter(\n", + " torch.logit(\n", + " torch.tensor(\n", + " init_norm, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device()\n", + " ),\n", + " eps=1e-6,\n", + " )\n", + " )\n", "\n", " def forward(self):\n", - " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)\n" + " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)" ] }, { @@ -317,15 +369,15 @@ "output_type": "stream", "text": [ "Step 0, Loss 0.00004644, TDWI 0.3214\n", - "Step 1, Loss 0.00004170, TDWI 0.3436\n", - "Step 2, Loss 0.00003679, TDWI 0.3665\n", - "Step 3, Loss 0.00003172, TDWI 0.3901\n", - "Step 4, Loss 0.00002650, TDWI 0.4143\n", + "Step 1, Loss 0.00004171, TDWI 0.3436\n", + "Step 2, Loss 0.00003680, TDWI 0.3665\n", + "Step 3, Loss 0.00003173, TDWI 0.3901\n", + "Step 4, Loss 0.00002651, TDWI 0.4143\n", "Step 5, Loss 0.00002116, TDWI 0.4389\n", - "Step 6, Loss 0.00001571, TDWI 0.4639\n", + "Step 6, Loss 0.00001572, TDWI 0.4639\n", "Step 7, Loss 0.00001019, TDWI 0.4891\n", "Step 8, Loss 0.00000461, TDWI 0.5144\n", - "Step 9, Loss 0.00000099, TDWI 0.5316\n", + "Step 9, Loss 0.00000098, TDWI 0.5316\n", "Step 10, Loss 0.00000479, TDWI 0.5424\n", "Early stopping at step 10\n" ] @@ -383,19 +435,11 @@ "# ---- validate the results using test data ---- \n", "print(f\"Actual TDWI {crop_model_params_provider[\"TDWI\"].item():.4f}\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6a511a4-f269-4af4-9f51-2dafa9ba38c0", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "dwof", "language": "python", "name": "python3" }, @@ -409,7 +453,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/src/diffwofost/physical_models/config.py b/src/diffwofost/physical_models/config.py index ed348bb..534b8f1 100644 --- a/src/diffwofost/physical_models/config.py +++ b/src/diffwofost/physical_models/config.py @@ -12,63 +12,55 @@ class ComputeConfig: """Central configuration for device and dtype settings. - This class provides a centralized way to control PyTorch device and dtype - settings across all simulation objects in diffWOFOST. Instead of setting - device and dtype individually for each class, use this central configuration - to apply settings globally. + This class acts as a factory for default configuration settings that are + captured by simulation objects upon initialization. This enables precise + control over where (device) and how (dtype) each model computation occurs, + allowing for multiple models with different configurations to coexist. + + **Key Concept: Configuration Capture** + + When a simulation object (e.g., `WOFOST_Leaf_Dynamics`) is initialized, it + queries `ComputeConfig` for the current device and dtype. The model *captures* + and stores these settings for its lifetime. Subsequent changes to + `ComputeConfig` will only affect *newly created* objects, leaving existing + ones unchanged. **Default Behavior:** - - **Device**: Automatically defaults to 'cuda' if available, otherwise 'cpu' - - **Dtype**: Defaults to torch.float64 + - **Device**: Defaults to torch.get_default_device() + - **Dtype**: Defaults to torch.get_default_dtype() **Basic Usage:** >>> from diffwofost.physical_models.config import ComputeConfig >>> import torch >>> - >>> # Set device to CPU - >>> ComputeConfig.set_device('cpu') - >>> - >>> # Or use a torch.device object - >>> ComputeConfig.set_device(torch.device('cuda')) - >>> - >>> # Set dtype to float32 + >>> # Configure defaults for new models + >>> ComputeConfig.set_device('cuda') >>> ComputeConfig.set_dtype(torch.float32) >>> - >>> # Get current settings - >>> device = ComputeConfig.get_device() # Returns: torch.device('cpu') - >>> dtype = ComputeConfig.get_dtype() # Returns: torch.float32 + >>> # Get current defaults + >>> device = ComputeConfig.get_device() + >>> dtype = ComputeConfig.get_dtype() - **Using with Simulation Objects:** + **Creating Models with Different Settings:** - All simulation objects (e.g., WOFOST_Leaf_Dynamics, WOFOST_Phenology) - automatically use the settings from ComputeConfig. No changes needed to - instantiation code: + Because models capture the configuration at initialization, you can create + instances with different settings in the same process: - >>> from diffwofost.physical_models.config import ComputeConfig >>> from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics >>> - >>> # Set global compute settings + >>> # Create a model on GPU (float32) >>> ComputeConfig.set_device('cuda') >>> ComputeConfig.set_dtype(torch.float32) + >>> model_gpu = WOFOST_Leaf_Dynamics(...) >>> - >>> # Instantiate objects - they automatically use global settings - >>> leaf_dynamics = WOFOST_Leaf_Dynamics() - - **Switching Between Devices:** - - Useful for switching between GPU training and CPU evaluation: - - >>> # Train on GPU - >>> ComputeConfig.set_device('cuda') - >>> ComputeConfig.set_dtype(torch.float32) - >>> # ... run training ... - >>> - >>> # Evaluate on CPU + >>> # Create a model on CPU (float64) >>> ComputeConfig.set_device('cpu') >>> ComputeConfig.set_dtype(torch.float64) - >>> # ... run evaluation ... + >>> model_cpu = WOFOST_Leaf_Dynamics(...) + >>> + >>> # model_gpu remains on cuda, model_cpu stays on cpu. **Resetting to Defaults:** @@ -83,9 +75,9 @@ class ComputeConfig: def _initialize_defaults(cls): """Initialize default device and dtype if not already set.""" if cls._device is None: - cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls._device = torch.get_default_device() if cls._dtype is None: - cls._dtype = torch.float64 + cls._dtype = torch.get_default_dtype() @classmethod def get_device(cls) -> torch.device: diff --git a/src/diffwofost/physical_models/crop/assimilation.py b/src/diffwofost/physical_models/crop/assimilation.py index 6b093fc..f829ce4 100644 --- a/src/diffwofost/physical_models/crop/assimilation.py +++ b/src/diffwofost/physical_models/crop/assimilation.py @@ -234,12 +234,12 @@ class WOFOST72_Assimilation(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): AMAXTB = AfgenTrait() @@ -264,6 +264,9 @@ def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider ) -> None: """Initialize the assimilation module.""" + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.kiosk = kiosk self.params = self.Parameters(parvalues) self.params_shape = _get_params_shape(self.params) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index a40b567..667cc1f 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -124,12 +124,12 @@ class WOFOST_Leaf_Dynamics(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): RGRLAI = Any() @@ -250,6 +250,11 @@ def initialize( """ self.START_DATE = day self.kiosk = kiosk + + # Get defaults from ComputeConfig if not already set + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + # TODO check if parvalues are already torch.nn.Parameters self.params = self.Parameters(parvalues) self.rates = self.RateVariables(kiosk) diff --git a/src/diffwofost/physical_models/crop/partitioning.py b/src/diffwofost/physical_models/crop/partitioning.py index 52fbbc6..07c6058 100644 --- a/src/diffwofost/physical_models/crop/partitioning.py +++ b/src/diffwofost/physical_models/crop/partitioning.py @@ -40,12 +40,12 @@ class _BaseDVSPartitioning(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): FRTB = AfgenTrait() @@ -127,6 +127,9 @@ def _compute_partitioning_from_tables(self, DVS): return FR, FL, FS, FO def _initialize_from_tables(self, kiosk, parvalues): + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.params = self.Parameters(parvalues) self.kiosk = kiosk self.params_shape = _get_params_shape(self.params) diff --git a/src/diffwofost/physical_models/crop/phenology.py b/src/diffwofost/physical_models/crop/phenology.py index bc30b64..627e18e 100644 --- a/src/diffwofost/physical_models/crop/phenology.py +++ b/src/diffwofost/physical_models/crop/phenology.py @@ -95,12 +95,12 @@ class Vernalisation(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): VERNSAT = Any() @@ -179,6 +179,9 @@ def initialize(self, day, kiosk, parvalues, dvs_shape=None): ISVERNALISED = False. """ + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.params = self.Parameters(parvalues) self.params_shape = _get_params_shape(self.params) @@ -427,12 +430,12 @@ class DVS_Phenology(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): TSUMEM = Any() @@ -563,6 +566,9 @@ def initialize(self, day, kiosk, parvalues): :param parvalues: `ParameterProvider` object providing parameters as key/value pairs """ + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.params = self.Parameters(parvalues) self.params_shape = _get_params_shape(self.params) diff --git a/src/diffwofost/physical_models/crop/root_dynamics.py b/src/diffwofost/physical_models/crop/root_dynamics.py index b38604f..036f7e8 100644 --- a/src/diffwofost/physical_models/crop/root_dynamics.py +++ b/src/diffwofost/physical_models/crop/root_dynamics.py @@ -121,12 +121,12 @@ class WOFOST_Root_Dynamics(SimulationObject): @property def device(self): """Get device from ComputeConfig.""" - return ComputeConfig.get_device() + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): """Get dtype from ComputeConfig.""" - return ComputeConfig.get_dtype() + return getattr(self, "_dtype", ComputeConfig.get_dtype()) class Parameters(ParamTemplate): RDI = Any() @@ -214,6 +214,9 @@ def initialize( all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. """ + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.kiosk = kiosk self.params = self.Parameters(parvalues) self.rates = self.RateVariables(kiosk, publish=["DRRT", "GRRT"]) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index f3229c7..66cab2d 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -25,6 +25,7 @@ from pcse.timer import Timer from pcse.traitlets import Enum from pcse.traitlets import TraitType +from .config import ComputeConfig from .config import Configuration from .engine import Engine @@ -96,8 +97,6 @@ def __init__( agromanagement, config, external_states=None, - device=None, - dtype=None, ): BaseEngine.__init__(self) @@ -109,12 +108,6 @@ def __init__( self.parameterprovider = parameterprovider - # Configure device and dtype on crop module class if it supports them - if hasattr(self.mconf.CROP, "device") and device is not None: - self.mconf.CROP.device = device - if hasattr(self.mconf.CROP, "dtype") and dtype is not None: - self.mconf.CROP.dtype = dtype - # Variable kiosk for registering and publishing variables self.kiosk = VariableKioskTestHelper(external_states) @@ -202,9 +195,15 @@ def __init__(self, yaml_weather, meteo_range_checks=True): def prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64, device="cpu" + test_data, crop_model_params, device=None, dtype=None, meteo_range_checks=True ): """Prepare the inputs for the engine from the YAML file.""" + # If not specified, use default dtype and device + if device is None: + device = ComputeConfig.get_device() + if dtype is None: + dtype = ComputeConfig.get_dtype() + agro_management_inputs = test_data["AgroManagement"] cropd = test_data["ModelParameters"] @@ -224,7 +223,10 @@ def prepare_engine_input( # convert external states to tensors tensor_external_states = [ - {k: v if k == "DAY" else torch.tensor(v, dtype=dtype) for k, v in item.items()} + { + k: v if k == "DAY" else torch.tensor(v, dtype=dtype, device=device) + for k, v in item.items() + } for item in external_states ] return ( diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 862c62e..9ceb78b 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -62,3 +62,9 @@ def device(request): # Reset to defaults after the test ComputeConfig.reset_to_defaults() + + +@pytest.fixture(autouse=True) +def configure_compute_config_dtype(): + """Ensure all tests run with float64 precision.""" + ComputeConfig.set_dtype(torch.float64) diff --git a/tests/physical_models/crop/test_assimilation.py b/tests/physical_models/crop/test_assimilation.py index 8eb16b7..cb5f74a 100644 --- a/tests/physical_models/crop/test_assimilation.py +++ b/tests/physical_models/crop/test_assimilation.py @@ -18,12 +18,12 @@ ) -def get_test_diff_assimilation_model(device: str = "cpu"): +def get_test_diff_assimilation_model(): test_data_url = f"{phy_data_folder}/test_assimilation_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["AMAXTB", "EFFTB", "KDIFTB", "TMPFTB", "TMNFTB"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( - prepare_engine_input(test_data, crop_model_params, device=device) + prepare_engine_input(test_data, crop_model_params) ) return DiffAssimilation( copy.deepcopy(crop_model_params_provider), @@ -31,7 +31,6 @@ def get_test_diff_assimilation_model(device: str = "cpu"): agro_management_inputs, assimilation_config, copy.deepcopy(external_states), - device=device, ) @@ -43,7 +42,6 @@ def __init__( agro_management_inputs, config, external_states, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -51,7 +49,6 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states - self.device = device def forward(self, params_dict): for name, value in params_dict.items(): @@ -63,7 +60,6 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -98,7 +94,6 @@ def test_assimilation_with_testengine(self, test_data_url, device): agro_management_inputs, assimilation_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -137,7 +132,6 @@ def test_assimilation_with_one_parameter_vector(self, param, device): agro_management_inputs, assimilation_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -187,7 +181,6 @@ def test_assimilation_with_different_parameter_values(self, param, delta, device agro_management_inputs, assimilation_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -226,7 +219,6 @@ def test_assimilation_with_multiple_parameter_vectors(self, device): agro_management_inputs, assimilation_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -269,7 +261,6 @@ def test_assimilation_with_multiple_parameter_arrays(self, device): agro_management_inputs, assimilation_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -310,7 +301,6 @@ def test_assimilation_with_incompatible_parameter_vectors(self): agro_management_inputs, assimilation_config, external_states, - device="cpu", ) def test_assimilation_with_incompatible_weather_parameter_vectors(self): @@ -337,7 +327,6 @@ def test_assimilation_with_incompatible_weather_parameter_vectors(self): agro_management_inputs, assimilation_config, external_states, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -433,7 +422,7 @@ class TestDiffAssimilationGradients: @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_no_gradients(self, param_name, output_name, config_type, device): - model = get_test_diff_assimilation_model(device=device) + model = get_test_diff_assimilation_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -446,7 +435,7 @@ def test_no_gradients(self, param_name, output_name, config_type, device): @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): - model = get_test_diff_assimilation_model(device=device) + model = get_test_diff_assimilation_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -471,7 +460,7 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) param_value = torch.tensor(value, dtype=torch.float64, device=device) def get_model_fn(): - return get_test_diff_assimilation_model(device=device) + return get_test_diff_assimilation_model() grad_num = calculate_numerical_grad(get_model_fn, param_name, param_value, output_name) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index dba524c..5fe67ca 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -18,12 +18,12 @@ ) -def get_test_diff_leaf_model(device: str = "cpu"): +def get_test_diff_leaf_model(): test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( - prepare_engine_input(test_data, crop_model_params, device=device) + prepare_engine_input(test_data, crop_model_params) ) return DiffLeafDynamics( copy.deepcopy(crop_model_params_provider), @@ -31,7 +31,6 @@ def get_test_diff_leaf_model(device: str = "cpu"): agro_management_inputs, leaf_dynamics_config, copy.deepcopy(external_states), - device=device, ) @@ -43,7 +42,6 @@ def __init__( agro_management_inputs, config, external_states, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -51,7 +49,6 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states - self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -64,7 +61,6 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -102,7 +98,6 @@ def test_leaf_dynamics_with_testengine(self, test_data_url, device): agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -136,9 +131,7 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=False, device=device - ) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) # Setting a vector (with one value) for the selected parameter if param == "TEMP": @@ -163,7 +156,6 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param, device): agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -174,7 +166,6 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param, device): agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -239,7 +230,6 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta, devic agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -289,7 +279,6 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self, device): agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -336,7 +325,6 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self, device): agro_management_inputs, leaf_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -384,7 +372,6 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): agro_management_inputs, leaf_dynamics_config, external_states, - device="cpu", ) def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): @@ -413,7 +400,6 @@ def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): agro_management_inputs, leaf_dynamics_config, external_states, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -470,7 +456,6 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): agro_management_inputs, leaf_dynamics_config, external_states, - device="cpu", ) engine.run_till_terminate() actual_results = engine.get_output() @@ -552,7 +537,7 @@ class TestDiffLeafDynamicsGradients: @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_no_gradients(self, param_name, output_name, config_type, device): """Test cases where parameters should not have gradients for specific outputs.""" - model = get_test_diff_leaf_model(device=device) + model = get_test_diff_leaf_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -568,7 +553,7 @@ def test_no_gradients(self, param_name, output_name, config_type, device): @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): """Test that forward and backward gradients match for parameter-output pairs.""" - model = get_test_diff_leaf_model(device=device) + model = get_test_diff_leaf_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -601,10 +586,10 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) # for parameter `SPAN` param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - lambda: get_test_diff_leaf_model(device=device), param_name, param, output_name + lambda: get_test_diff_leaf_model(), param_name, param, output_name ) - model = get_test_diff_leaf_model(device=device) + model = get_test_diff_leaf_model() output = model({param_name: param}) loss = output[output_name].sum() diff --git a/tests/physical_models/crop/test_partitioning.py b/tests/physical_models/crop/test_partitioning.py index ae5c6e4..e6d2502 100644 --- a/tests/physical_models/crop/test_partitioning.py +++ b/tests/physical_models/crop/test_partitioning.py @@ -16,7 +16,7 @@ partitioning_config = Configuration(CROP=DVS_Partitioning, OUTPUT_VARS=["FR", "FL", "FS", "FO"]) -def get_test_diff_partitioning(device: str = "cpu"): +def get_test_diff_partitioning(): """Build a small wrapper module for differentiable tests.""" test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -26,14 +26,13 @@ def get_test_diff_partitioning(device: str = "cpu"): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) return DiffPartitioning( copy.deepcopy(crop_model_params_provider), weather_data_provider, agro_management_inputs, partitioning_config, copy.deepcopy(external_states), - device=device, ) @@ -45,7 +44,6 @@ def __init__( agro_management_inputs, config, external_states, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -53,13 +51,10 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states - self.device = device def forward(self, params_dict: dict[str, torch.Tensor]): # pass new value of parameters to the model for name, value in params_dict.items(): - if isinstance(value, torch.Tensor) and value.device.type != self.device: - value = value.to(self.device) self.crop_model_params_provider.set_override(name, value, check=False) engine = EngineTestHelper( @@ -68,7 +63,6 @@ def forward(self, params_dict: dict[str, torch.Tensor]): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -99,7 +93,7 @@ def test_partitioning_with_testengine(self, test_data_url, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) engine = EngineTestHelper( crop_model_params_provider, @@ -107,7 +101,6 @@ def test_partitioning_with_testengine(self, test_data_url, device): agro_management_inputs, partitioning_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -141,7 +134,7 @@ def test_partitioning_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) # AfgenTrait parameters need to have shape (N, M) repeated = crop_model_params_provider[param].repeat(10, 1) @@ -153,7 +146,6 @@ def test_partitioning_with_one_parameter_vector(self, param, device): agro_management_inputs, partitioning_config, external_states, - device=device, ) engine.run_till_terminate() results = engine.get_output() @@ -172,7 +164,7 @@ def test_partitioning_with_different_parameter_values(self, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) # Setting vectors with multiple values for each table parameter for param in ("FRTB", "FLTB", "FSTB", "FOTB"): @@ -188,7 +180,6 @@ def test_partitioning_with_different_parameter_values(self, device): agro_management_inputs, partitioning_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -221,7 +212,6 @@ def test_partitioning_with_multiple_parameter_vectors(self): agro_management_inputs, partitioning_config, external_states, - device="cpu", ) engine.run_till_terminate() actual_results = engine.get_output() @@ -254,7 +244,6 @@ def test_partitioning_with_multiple_parameter_arrays(self): agro_management_inputs, partitioning_config, external_states, - device="cpu", ) engine.run_till_terminate() actual_results = engine.get_output() @@ -296,7 +285,6 @@ def test_partitioning_with_incompatible_parameter_vectors(self): agro_management_inputs, partitioning_config, external_states, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls[:1]) @@ -382,7 +370,7 @@ class TestDiffPartitioningGradients: @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_no_gradients(self, param_name, output_name, config_type, device): - model = get_test_diff_partitioning(device=device) + model = get_test_diff_partitioning() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -406,7 +394,7 @@ def test_no_gradients(self, param_name, output_name, config_type, device): @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): - model = get_test_diff_partitioning(device=device) + model = get_test_diff_partitioning() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -432,10 +420,10 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) value, _ = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - lambda: get_test_diff_partitioning(device=device), param_name, param.data, output_name + lambda: get_test_diff_partitioning(), param_name, param.data, output_name ) - model = get_test_diff_partitioning(device=device) + model = get_test_diff_partitioning() output = model({param_name: param}) loss = output[output_name].sum() diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py index 12cd78f..040bf07 100644 --- a/tests/physical_models/crop/test_phenology.py +++ b/tests/physical_models/crop/test_phenology.py @@ -38,7 +38,7 @@ def assert_reference_match(reference, model, expected_precision): assert torch.all(torch.abs(ref_t - model_t) < precision) -def get_test_diff_phenology_model(device: str = "cpu"): +def get_test_diff_phenology_model(): test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" test_data = get_test_data(test_data_url) # Phenology-related crop model parameters @@ -60,14 +60,13 @@ def get_test_diff_phenology_model(device: str = "cpu"): "VERNDVS", ] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( - prepare_engine_input(test_data, crop_model_params, device=device) + prepare_engine_input(test_data, crop_model_params) ) return DiffPhenologyDynamics( copy.deepcopy(crop_model_params_provider), weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) @@ -78,14 +77,12 @@ def __init__( weather_data_provider, agro_management_inputs, config, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider self.weather_data_provider = weather_data_provider self.agro_management_inputs = agro_management_inputs self.config = config - self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -97,7 +94,6 @@ def forward(self, params_dict): self.weather_data_provider, self.agro_management_inputs, self.config, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -140,14 +136,13 @@ def test_phenology_with_testengine(self, test_data_url, device): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) engine = EngineTestHelper( crop_model_params_provider, weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -209,9 +204,7 @@ def test_phenology_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=False, device=device - ) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) if param == "TEMP": if device == "cuda": @@ -234,7 +227,6 @@ def test_phenology_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() _ = engine.get_output() @@ -244,7 +236,6 @@ def test_phenology_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -297,7 +288,7 @@ def test_phenology_with_different_parameter_values(self, param, delta, device): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) test_value = crop_model_params_provider[param] if param == "DTSMTB": @@ -314,7 +305,6 @@ def test_phenology_with_different_parameter_values(self, param, delta, device): weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -372,7 +362,6 @@ def test_phenology_with_multiple_parameter_vectors(self, device): weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -438,7 +427,6 @@ def test_phenology_with_multiple_parameter_arrays(self, device): weather_data_provider, agro_management_inputs, phenology_config, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -492,7 +480,6 @@ def test_phenology_with_incompatible_parameter_vectors(self): weather_data_provider, agro_management_inputs, phenology_config, - device="cpu", ) def test_phenology_with_incompatible_weather_parameter_vectors(self): @@ -533,7 +520,6 @@ def test_phenology_with_incompatible_weather_parameter_vectors(self): weather_data_provider, agro_management_inputs, phenology_config, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -556,7 +542,7 @@ def test_wofost_pp_with_phenology(self, test_data_url, monkeypatch): "VERNDVS", ] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( - prepare_engine_input(test_data, crop_model_params, device="cpu") + prepare_engine_input(test_data, crop_model_params) ) expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] @@ -658,7 +644,7 @@ class TestDiffPhenologyDynamicsGradients: @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_no_gradients(self, param_name, output_name, config_type, device): - model = get_test_diff_phenology_model(device=device) + model = get_test_diff_phenology_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -674,7 +660,7 @@ def test_no_gradients(self, param_name, output_name, config_type, device): @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): - model = get_test_diff_phenology_model(device=device) + model = get_test_diff_phenology_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -693,12 +679,12 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) value, _ = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - lambda: get_test_diff_phenology_model(device=device), + lambda: get_test_diff_phenology_model(), param_name, param.data, output_name, ) - model = get_test_diff_phenology_model(device=device) + model = get_test_diff_phenology_model() output = model({param_name: param}) loss = output[output_name].sum() grads = torch.autograd.grad(loss, param, retain_graph=True)[0] diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 69b7c50..9e752c8 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -19,12 +19,12 @@ ) -def get_test_diff_root_model(device: str = "cpu"): +def get_test_diff_root_model(): test_data_url = f"{phy_data_folder}/test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["RDI", "RRI", "RDMCR", "RDMSOL", "TDWI", "IAIRDU"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( - prepare_engine_input(test_data, crop_model_params, device=device) + prepare_engine_input(test_data, crop_model_params) ) return DiffRootDynamics( copy.deepcopy(crop_model_params_provider), @@ -32,7 +32,6 @@ def get_test_diff_root_model(device: str = "cpu"): agro_management_inputs, root_dynamics_config, copy.deepcopy(external_states), - device=device, ) @@ -44,7 +43,6 @@ def __init__( agro_management_inputs, config, external_states, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -52,7 +50,6 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states - self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -65,7 +62,6 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -103,7 +99,6 @@ def test_root_dynamics_with_testengine(self, test_data_url, device): agro_management_inputs, root_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -151,7 +146,6 @@ def test_root_dynamics_with_one_parameter_vector(self, param, device): agro_management_inputs, root_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -212,7 +206,6 @@ def test_root_dynamics_with_different_parameter_values(self, param, delta, devic agro_management_inputs, root_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -261,7 +254,6 @@ def test_root_dynamics_with_multiple_parameter_vectors(self, device): agro_management_inputs, root_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -304,7 +296,6 @@ def test_root_dynamics_with_multiple_parameter_arrays(self, device): agro_management_inputs, root_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -352,7 +343,6 @@ def test_root_dynamics_with_incompatible_parameter_vectors(self, device): agro_management_inputs, root_dynamics_config, external_states, - device=device, ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -443,7 +433,7 @@ class TestDiffRootDynamicsGradients: @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_no_gradients(self, param_name, output_name, config_type, device): """Test cases where parameters should not have gradients for specific outputs.""" - model = get_test_diff_root_model(device=device) + model = get_test_diff_root_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -476,7 +466,7 @@ def test_no_gradients(self, param_name, output_name, config_type, device): @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): """Test that forward and backward gradients match for parameter-output pairs.""" - model = get_test_diff_root_model(device=device) + model = get_test_diff_root_model() value, dtype = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) @@ -506,10 +496,10 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) value, _ = self.param_configs[config_type][param_name] param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - lambda: get_test_diff_root_model(device=device), param_name, param.data, output_name + lambda: get_test_diff_root_model(), param_name, param.data, output_name ) - model = get_test_diff_root_model(device=device) + model = get_test_diff_root_model() output = model({param_name: param}) loss = output[output_name].sum() diff --git a/tests/physical_models/test_config.py b/tests/physical_models/test_config.py index 397d1b4..fc19f74 100644 --- a/tests/physical_models/test_config.py +++ b/tests/physical_models/test_config.py @@ -59,15 +59,15 @@ def test_output_variables_can_be_updated(self): class TestComputeConfig: - def test_default_device_is_cuda_or_cpu(self): + def test_default_device(self): ComputeConfig.reset_to_defaults() device = ComputeConfig.get_device() - assert device.type in ["cpu", "cuda"] + assert device == torch.get_default_device() - def test_default_dtype_is_float64(self): + def test_default_dtype(self): ComputeConfig.reset_to_defaults() dtype = ComputeConfig.get_dtype() - assert dtype == torch.float64 + assert dtype == torch.get_default_dtype() def test_set_device_with_string(self): ComputeConfig.set_device("cpu") @@ -92,5 +92,71 @@ def test_reset_to_defaults(self): device = ComputeConfig.get_device() dtype = ComputeConfig.get_dtype() - assert device.type in ["cpu", "cuda"] - assert dtype == torch.float64 + assert device == torch.get_default_device() + assert dtype == torch.get_default_dtype() + + def test_models_capture_config_at_initialization(self): + """Test that models capture the device/dtype at initialization time.""" + import datetime + from pcse.base.variablekiosk import VariableKiosk + from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics + + # Setup mocks + day = datetime.date(2000, 1, 1) + + class MockKiosk(VariableKiosk): + pass + + mock_kiosk = MockKiosk() + mock_kiosk.update( + { + "FL": 0.5, + "FR": 0.5, + "DVS": 0.5, + "SAI": 0.5, + "PAI": 0.5, + "ADMI": 0.5, + "RFTRA": 1.0, + "RF_FROST": 1.0, + } + ) + + mock_parvalues = { + "RGRLAI": torch.tensor(0.01), + "SPAN": torch.tensor(30.0), + "TBASE": torch.tensor(5.0), + "PERDL": torch.tensor(0.05), + "TDWI": torch.tensor(50.0), + "SLATB": [0.0, 20.0, 2.0, 20.0], + "KDIFTB": [0.0, 0.6, 2.0, 0.6], + } + + # 1. Config = float32 + ComputeConfig.set_dtype(torch.float32) + model1 = WOFOST_Leaf_Dynamics(day, mock_kiosk, mock_parvalues) + + # 2. Config = float64 + ComputeConfig.set_dtype(torch.float64) + mock_kiosk2 = MockKiosk() + mock_kiosk2.update( + { + "FL": 0.5, + "FR": 0.5, + "DVS": 0.5, + "SAI": 0.5, + "PAI": 0.5, + "ADMI": 0.5, + "RFTRA": 1.0, + "RF_FROST": 1.0, + } + ) + model2 = WOFOST_Leaf_Dynamics(day, mock_kiosk2, mock_parvalues) + + # 3. Assertions + assert model1.dtype == torch.float32, "Model 1 should retain float32" + assert model2.dtype == torch.float64, "Model 2 should use float64" + assert model1.states.LV[0].dtype == torch.float32, "Model 1 states should be float32" + assert model2.states.LV[0].dtype == torch.float64, "Model 2 states should be float64" + + # Cleanup + ComputeConfig.reset_to_defaults() diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index 805dc21..29dcf2d 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -10,6 +10,7 @@ from diffwofost.physical_models.utils import get_test_data from . import phy_data_folder +ComputeConfig.set_dtype(torch.float64) DTYPE = ComputeConfig.get_dtype() @@ -189,6 +190,9 @@ class TestAfgenTrait: def test_default_value(self): """Test that the default value is set correctly.""" + # Ensure default_value matches current config + AfgenTrait.default_value = Afgen([0, 0, 1, 1]) + trait = AfgenTrait() assert isinstance(trait.default_value, Afgen) @@ -277,7 +281,7 @@ def test_x_breakpoint_at_clamp(self): # Keep this example deterministic across environments. old_device = ComputeConfig.get_device() - old_dtype = ComputeConfig.get_dtype() + old_dtype = DTYPE ComputeConfig.set_device("cpu") ComputeConfig.set_dtype(torch.float64) try: