diff --git a/docs/api_reference.md b/docs/api_reference.md index 4eff37d..7575569 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,6 +16,8 @@ hide: ::: diffwofost.physical_models.crop.root_dynamics.WOFOST_Root_Dynamics +::: diffwofost.physical_models.crop.storage_organ_dynamics.WOFOST_Storage_Organ_Dynamics + ## **Utility (under development)** diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py new file mode 100644 index 0000000..4ef523e --- /dev/null +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -0,0 +1,233 @@ +import datetime +import torch +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base import StatesTemplate +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_params_shape + + +class WOFOST_Storage_Organ_Dynamics(SimulationObject): + """Implementation of storage organ dynamics. + + Storage organs are the most simple component of the plant in WOFOST and + consist of a static pool of biomass. Growth of the storage organs is the + result of assimilate partitioning. Death of storage organs is not + implemented and the corresponding rate variable (DRSO) is always set to + zero. + + Pods are green elements of the plant canopy and can as such contribute + to the total photosynthetic active area. This is expressed as the Pod + Area Index which is obtained by multiplying pod biomass with a fixed + Specific Pod Area (SPA). + + **Simulation parameters** + + | Name | Description | Type | Unit | + |------|===============================================|========|=============| + | TDWI | Initial total crop dry weight | SCr | kg ha⁻¹ | + | SPA | Specific Pod Area | SCr | ha kg⁻¹ | + + **State variables** + + | Name | Description | Pbl | Unit | + |------|==================================================|======|=============| + | PAI | Pod Area Index | Y | - | + | WSO | Weight of living storage organs | Y | kg ha⁻¹ | + | DWSO | Weight of dead storage organs | N | kg ha⁻¹ | + | TWSO | Total weight of storage organs | Y | kg ha⁻¹ | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |------|==================================================|======|=============| + | GRSO | Growth rate storage organs | N | kg ha⁻¹ d⁻¹ | + | DRSO | Death rate storage organs | N | kg ha⁻¹ d⁻¹ | + | GWSO | Net change in storage organ biomass | N | kg ha⁻¹ d⁻¹ | + + **Signals send or handled** + + None + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|====================================|=====================|=============| + | ADMI | Above-ground dry matter increase | CropSimulation | kg ha⁻¹ d⁻¹ | + | FO | Fraction biomass to storage organs | DVS_Partitioning | - | + | FR | Fraction biomass to roots | DVS_Partitioning | - | + + **Outputs:** + + | Name | Description | Provided by | Unit | + |------|------------------------------|-------------|--------------| + | PAI | Pod Area Index | Y | - | + | TWSO | Total weight storage organs | Y | kg ha⁻¹ | + | WSO | Weight living storage organs | Y | kg ha⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|----------------------------| + | PAI | SPA | + | TWSO | TDWI | + | WSO | TDWI | + """ + + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + + class Parameters(ParamTemplate): + SPA = Any() + TDWI = Any() + + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.SPA = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(parvalues) + + class StateVariables(StatesTemplate): + WSO = Any() # Weight living storage organs + DWSO = Any() # Weight dead storage organs + TWSO = Any() # Total weight storage organs + PAI = Any() # Pod Area Index + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + if "WSO" not in kwargs: + self.WSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "DWSO" not in kwargs: + self.DWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "TWSO" not in kwargs: + self.TWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "PAI" not in kwargs: + self.PAI = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) + + class RateVariables(RatesTemplate): + GRSO = Any() + DRSO = Any() + GWSO = Any() + + def __init__(self, kiosk, publish=None): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.GRSO = torch.tensor(0.0, dtype=dtype, device=device) + self.DRSO = torch.tensor(0.0, dtype=dtype, device=device) + self.GWSO = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + """Initialize the storage organ dynamics model. + + :param day: start date of the simulation + :param kiosk: variable kiosk of this PCSE instance + :param parvalues: `ParameterProvider` object providing parameters as + key/value pairs + """ + self.kiosk = kiosk + self.params = self.Parameters(parvalues) + self.rates = self.RateVariables(kiosk, publish=["GRSO"]) + + # INITIAL STATES + params = self.params + self.params_shape = _get_params_shape(params) + shape = self.params_shape + + # Initial storage organ biomass + TDWI = _broadcast_to(params.TDWI, shape, dtype=self.dtype, device=self.device) + SPA = _broadcast_to(params.SPA, shape, dtype=self.dtype, device=self.device) + FO = _broadcast_to(self.kiosk["FO"], shape, dtype=self.dtype, device=self.device) + FR = _broadcast_to(self.kiosk["FR"], shape, dtype=self.dtype, device=self.device) + + WSO = (TDWI * (1 - FR)) * FO + DWSO = torch.zeros(shape, dtype=self.dtype, device=self.device) + TWSO = WSO + DWSO + # Initial Pod Area Index + PAI = WSO * SPA + + self.states = self.StateVariables( + kiosk, publish=["TWSO", "WSO", "PAI"], WSO=WSO, DWSO=DWSO, TWSO=TWSO, PAI=PAI + ) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None) -> None: + """Calculate the rates of change of the state variables. + + Args: + day (datetime.date, optional): The current date of the simulation. + drv (WeatherDataContainer, optional): A dictionary-like container holding + weather data elements as key/value. + """ + rates = self.rates + k = self.kiosk + + FO = _broadcast_to(k["FO"], self.params_shape, dtype=self.dtype, device=self.device) + ADMI = _broadcast_to(k["ADMI"], self.params_shape, dtype=self.dtype, device=self.device) + REALLOC_SO = _broadcast_to( + k.get("REALLOC_SO", 0.0), self.params_shape, dtype=self.dtype, device=self.device + ) + + # Growth/death rate organs + rates.GRSO = ADMI * FO + rates.DRSO = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + rates.GWSO = rates.GRSO - rates.DRSO + REALLOC_SO + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Integrate the state variables. + + Args: + day (datetime.date, optional): The current date of the simulation. + delt (float, optional): The time step for integration. Defaults to 1.0. + """ + params = self.params + rates = self.rates + states = self.states + + SPA = _broadcast_to(params.SPA, self.params_shape, dtype=self.dtype, device=self.device) + + # Stem biomass (living, dead, total) + states.WSO = states.WSO + rates.GWSO + states.DWSO = states.DWSO + rates.DRSO + states.TWSO = states.WSO + states.DWSO + + # Calculate Pod Area Index (SAI) + states.PAI = states.WSO * SPA diff --git a/tests/physical_models/crop/test_storage_organ_dynamics.py b/tests/physical_models/crop/test_storage_organ_dynamics.py new file mode 100644 index 0000000..4a20334 --- /dev/null +++ b/tests/physical_models/crop/test_storage_organ_dynamics.py @@ -0,0 +1,572 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.storage_organ_dynamics import WOFOST_Storage_Organ_Dynamics +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +storage_dynamics_config = Configuration( + CROP=WOFOST_Storage_Organ_Dynamics, + OUTPUT_VARS=["PAI", "TWSO", "WSO", "DWSO"], +) + +# [!] Notice that the storage organ module does not have dedicated test data. +# This means that we can only test the execution of the module, +# but not the correctness of its results (except when used within Wofost72_PP). + + +def _prepare_common_storage_inputs(test_data_url, device, meteo_range_checks=True): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["TDWI", "SPA"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=meteo_range_checks, device=device + ) + + # Patch missing states + for state in external_states: + if "FO" not in state: + state["FO"] = 0.5 + if "FR" not in state: + state["FR"] = 0.5 + if "ADMI" not in state: + state["ADMI"] = 100.0 + # DVS is unused in storage organ dynamics but good to have if something changes + if "DVS" not in state: + state["DVS"] = 0.0 + + # Patch missing parameters + if "SPA" not in crop_model_params_provider: + crop_model_params_provider.set_override( + "SPA", + torch.tensor(0.01, dtype=torch.float64, device=device), + check=False, + ) + if "TDWI" not in crop_model_params_provider: + crop_model_params_provider.set_override( + "TDWI", torch.tensor(20.0, dtype=torch.float64, device=device), check=False + ) + + return ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) + + +def get_test_diff_storage_model(device: str = "cpu"): + # [!] The storage organ module does not have dedicated test data. + # We reuse the partitioning test data as they contain relevant parameters and states. + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + + ( + _, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + return DiffStorageDynamics( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffStorageDynamics(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict): + # pass new value of parameters to the model + for name, value in params_dict.items(): + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + + return { + var: torch.stack([item[var] for item in results]) + for var in ["PAI", "TWSO", "WSO", "DWSO"] + } + + +class TestStorageOrganDynamics: + # [!] The storage module does not have dedicated test data. + # We reuse the partitioning test data as they contain relevant parameters and states. + storage_dynamics_data_urls = [ + f"{phy_data_folder}/test_partitioning_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" + for i in range(1, 45) # there are 44 test files + ] + + @pytest.mark.parametrize("test_data_url", storage_dynamics_data_urls) + def test_storage_dynamics_with_testengine(self, test_data_url, device): + """EngineTestHelper and not Engine because it allows to specify `external_states`.""" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + @pytest.mark.parametrize("param", ["TDWI", "SPA", "TEMP"]) + def test_storage_dynamics_with_one_parameter_vector(self, param, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + + # Setting a vector (with one value) for the selected parameter + if param == "TEMP": + # Vectorize weather variable + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + else: + # Broadcast all parameters to match the batch size of 10 + for p_name in ["TDWI", "SPA"]: + if p_name in crop_model_params_provider: + p_val = crop_model_params_provider[p_name] + if p_val.dim() == 0: # scalar + crop_model_params_provider.set_override( + p_name, p_val.repeat(10), check=False + ) + elif p_val.dim() == 2 and p_val.shape[0] == 1: # table (1, M) -> (10, M) + crop_model_params_provider.set_override( + p_name, p_val.repeat(10, 1), check=False + ) + + if param == "TEMP": + # Vectorize weather variable + # We expect the model to handle scalar parameters with vectorized weather + # via implicit broadcasting or explicit checks passing. + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + else: + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + @pytest.mark.parametrize( + "param,delta", + [ + ("TDWI", 0.1), + ("SPA", 0.0001), + ], + ) + def test_storage_dynamics_with_different_parameter_values(self, param, delta, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + # Setting a vector with multiple values for the selected parameter + test_value = crop_model_params_provider[param] + + param_vec = torch.tensor( + [test_value - delta, test_value + delta, test_value], + device=device, + dtype=torch.float64, + ) + target_batch_size = 3 + crop_model_params_provider.set_override(param, param_vec, check=False) + + # Broadcast all other params + for p_name in ["TDWI", "SPA"]: + if p_name == param: + continue + if p_name not in crop_model_params_provider: + continue + + p_val = crop_model_params_provider[p_name] + if p_val.dim() == 0: + crop_model_params_provider.set_override( + p_name, p_val.repeat(target_batch_size), check=False + ) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_multiple_parameter_vectors(self, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + # Setting a vector (with one value) for the TDWI and SPA parameters + for param in ("TDWI", "SPA"): + if param == "SPA" and crop_model_params_provider[param].dim() == 2: + # In case SPA is treated as table somehow, though here it is scalar + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_multiple_parameter_arrays(self, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + + # Setting an array with arbitrary shape (and one value) + for param in ("TDWI", "SPA"): + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_incompatible_parameter_vectors(self): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + _, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device="cpu") + + # Setting a vector (with one value) for the TDWI and SPA parameters, + # but with different lengths + crop_model_params_provider.set_override( + "TDWI", crop_model_params_provider["TDWI"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "SPA", crop_model_params_provider["SPA"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_storage_dynamics(self, test_data_url): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["TDWI", "SPA"] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Storage_Organ_Dynamics", WOFOST_Storage_Organ_Dynamics): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +class TestDiffStorageDynamicsGradients: + """Parametrized tests for gradient calculations in storage organ dynamics.""" + + # Define parameters and outputs + param_names = ["TDWI", "SPA"] + output_names = ["PAI", "TWSO", "WSO"] + + # Define parameter configurations (value, dtype) + param_configs = { + "single": { + "TDWI": (0.2, torch.float64), + "SPA": (0.01, torch.float64), + }, + "tensor": { + "TDWI": ([0.1, 0.2, 0.3], torch.float64), + "SPA": ([0.01, 0.02, 0.03], torch.float64), + }, + } + + # Define which parameter-output pairs should have gradients + # Format: {param_name: [list of outputs that should have gradients]} + gradient_mapping = { + "TDWI": ["PAI", "TWSO", "WSO", "DWSO"], + "SPA": ["PAI"], + } + + # Generate all combinations + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type, device): + """Test cases where parameters should not have gradients for specific outputs.""" + model = get_test_diff_storage_model(device=device) + + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + model.crop_model_params_provider.set_override( + p_name, torch.tensor(p_val, dtype=p_dtype, device=device), check=False + ) + + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + if not loss.requires_grad: + return + + try: + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + except RuntimeError as e: + if "does not require grad" in str(e): + return + raise e + + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + """Test that forward and backward gradients match for parameter-output pairs.""" + model = get_test_diff_storage_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + + overrides = {param_name: param} + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + overrides[p_name] = torch.tensor(p_val, dtype=p_dtype, device=device) + + output = model(overrides) + loss = output[output_name].sum() + + # this is ∂loss/∂param + # this is called forward gradient here because it is calculated without backpropagation. + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None # clear any existing gradient + loss.backward() + + # this is ∂loss/∂param calculated using backpropagation + grad_backward = param.grad + + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type, device): + """Test that analytical gradients match numerical gradients.""" + value, _ = self.param_configs[config_type][param_name] + + # we pass `param` and not `param.data` because we want `requires_grad=True` + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + def model_factory(): + m = get_test_diff_storage_model(device=device) + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + m.crop_model_params_provider.set_override( + p_name, torch.tensor(p_val, dtype=p_dtype, device=device), check=False + ) + return m + + numerical_grad = calculate_numerical_grad(model_factory, param_name, param, output_name) + + model = model_factory() + output = model({param_name: param}) + loss = output[output_name].sum() + + # this is ∂loss/∂param, for comparison with numerical gradient + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + # Warn if gradient is zero (but this shouldn't happen for gradient_params) + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with respect to output " + + f"'{output_name}' is zero: {grads.data}", + UserWarning, + )