diff --git a/docs/api_reference.md b/docs/api_reference.md index ed8303e..6381e93 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,6 +16,8 @@ hide: ::: diffwofost.physical_models.crop.partitioning.DVS_Partitioning +::: diffwofost.physical_models.crop.respiration.WOFOST_Maintenance_Respiration + ## **Utility (under development)** ::: diffwofost.physical_models.config.Configuration diff --git a/src/diffwofost/physical_models/crop/respiration.py b/src/diffwofost/physical_models/crop/respiration.py new file mode 100644 index 0000000..77f7fe1 --- /dev/null +++ b/src/diffwofost/physical_models/crop/respiration.py @@ -0,0 +1,160 @@ +"""Maintenance respiration for the WOFOST crop model.""" + +import datetime +import torch +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv +from diffwofost.physical_models.utils import _get_params_shape + + +class WOFOST_Maintenance_Respiration(SimulationObject): + """Maintenance respiration in WOFOST. + + WOFOST calculates the maintenance respiration as proportional to the dry + weights of the plant organs to be maintained, where each plant organ can be + assigned a different maintenance coefficient. Multiplying organ weight + with the maintenance coeffients yields the relative maintenance respiration + (`RMRES`) which is than corrected for senescence (parameter `RFSETB`). Finally, + the actual maintenance respiration rate is calculated using the daily mean + temperature, assuming a relative increase for each 10 degrees increase + in temperature as defined by `Q10`. + + **Simulation parameters** (provide in cropdata dictionary) + + | Name | Description | Type | Unit | + |--------|---------------------------------------------------------- |------|------------------| + | Q10 | Relative increase in maintenance respiration rate with | SCr | - | + | | each 10 degrees increase in temperature | | - | + | RMR | Relative maintenance respiration rate for roots | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RMS | Relative maintenance respiration rate for stems | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RML | Relative maintenance respiration rate for leaves | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RMO | Relative maintenance respiration rate for storage organs | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RFSETB | Reduction factor for senescence | TCr | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|--------------------------------------------|----|-------------------| + | PMRES | Potential maintenance respiration rate | N | kg CH₂O ha⁻¹ d⁻¹ | + + **Signals send or handled** + + None + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|-------------------------------------|--------------------------------|-----------| + | DVS | Crop development stage | DVS_Phenology | - | + | WRT | Dry weight of living roots | WOFOST_Root_Dynamics | kg ha⁻¹ | + | WST | Dry weight of living stems | WOFOST_Stem_Dynamics | kg ha⁻¹ | + | WLV | Dry weight of living leaves | WOFOST_Leaf_Dynamics | kg ha⁻¹ | + | WSO | Dry weight of living storage organs | WOFOST_Storage_Organ_Dynamics | kg ha⁻¹ | + + **Outputs** + + | Name | Description | Pbl | Unit | + |-------|--------------------------------------------|----|---------------------| + | PMRES | Potential maintenance respiration rate | N | kg CH₂O ha⁻¹ d⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|------------------------------------------| + | PMRES | Q10, RMR, RML, RMS, RMO, RFSETB | + """ + + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + + class Parameters(ParamTemplate): + Q10 = Any() + RMR = Any() + RML = Any() + RMS = Any() + RMO = Any() + RFSETB = AfgenTrait() + + class RateVariables(RatesTemplate): + PMRES = Any() + + def __init__(self, kiosk, publish=None): + self.PMRES = torch.tensor( + 0.0, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device() + ) + super().__init__(kiosk, publish=publish) + + def initialize(self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider): + """Initialize the maintenance respiration module. + + Args: + day: Start date of the simulation + kiosk: Variable kiosk of this PCSE instance + parvalues: ParameterProvider object providing parameters as key/value pairs + """ + self.params = self.Parameters(parvalues) + self.rates = self.RateVariables(kiosk, publish=["PMRES"]) + self.kiosk = kiosk + self.params_shape = _get_params_shape(self.params) + + @prepare_rates + def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): + """Calculate maintenance respiration rates. + + Args: + day: Current date + drv: Weather data for the current day + """ + p = self.params + kk = self.kiosk + r = self.rates + + Q10 = _broadcast_to(p.Q10, self.params_shape, dtype=self.dtype, device=self.device) + RMR = _broadcast_to(p.RMR, self.params_shape, dtype=self.dtype, device=self.device) + RML = _broadcast_to(p.RML, self.params_shape, dtype=self.dtype, device=self.device) + RMS = _broadcast_to(p.RMS, self.params_shape, dtype=self.dtype, device=self.device) + RMO = _broadcast_to(p.RMO, self.params_shape, dtype=self.dtype, device=self.device) + + WRT = _broadcast_to(kk["WRT"], self.params_shape, dtype=self.dtype, device=self.device) + WLV = _broadcast_to(kk["WLV"], self.params_shape, dtype=self.dtype, device=self.device) + WST = _broadcast_to(kk["WST"], self.params_shape, dtype=self.dtype, device=self.device) + WSO = _broadcast_to(kk["WSO"], self.params_shape, dtype=self.dtype, device=self.device) + DVS = _broadcast_to(kk["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + + TEMP = _get_drv(drv.TEMP, self.params_shape, dtype=self.dtype, device=self.device) + + RMRES = RMR * WRT + RML * WLV + RMS * WST + RMO * WSO + RMRES = RMRES * p.RFSETB(DVS) + TEFF = Q10 ** ((TEMP - 25.0) / 10.0) + PMRES = RMRES * TEFF + + # No maintenance respiration before emergence (DVS < 0). + r.PMRES = torch.where(DVS < 0, torch.zeros_like(PMRES), PMRES) + + def __call__(self, day: datetime.date, drv: WeatherDataContainer): + """Calculate and return maintenance respiration (PMRES).""" + self.calc_rates(day, drv) + return self.rates.PMRES + + def integrate(self, day: datetime.date, delt: float = 1.0): + """No state variables to integrate for this module.""" + return diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 862c62e..9f830b6 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -14,6 +14,7 @@ "phenology", "partitioning", "assimilation", + "respiration", ] FILE_NAMES = [ f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45) diff --git a/tests/physical_models/crop/test_respiration.py b/tests/physical_models/crop/test_respiration.py new file mode 100644 index 0000000..4345929 --- /dev/null +++ b/tests/physical_models/crop/test_respiration.py @@ -0,0 +1,466 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.respiration import WOFOST_Maintenance_Respiration +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +respiration_config = Configuration( + CROP=WOFOST_Maintenance_Respiration, + OUTPUT_VARS=["PMRES"], +) + + +def get_test_diff_respiration_model(device: str = "cpu"): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + return DiffRespiration( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + respiration_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffRespiration(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict): + for name, value in params_dict.items(): + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + + return {"PMRES": torch.stack([item["PMRES"] for item in results])} + + +class TestRespiration: + respiration_data_urls = [ + f"{phy_data_folder}/test_respiration_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" + for i in range(1, 45) # there are 44 test files + ] + + @pytest.mark.parametrize("test_data_url", respiration_data_urls) + def test_respiration_with_testengine(self, test_data_url, device): + """EngineTestHelper (not Engine) allows forcing `external_states` from YAML.""" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var]) < precision + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize("param", ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB", "TEMP"]) + def test_respiration_with_one_parameter_vector(self, param, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + if param == "TEMP": + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + with pytest.raises(ValueError): + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + _ = engine.get_output() + return + + if param == "RFSETB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + all(abs(reference[var] - model_cpu[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param,delta", + [ + ("Q10", 0.2), + ("RMR", 0.002), + ("RML", 0.002), + ("RMS", 0.002), + ("RMO", 0.002), + ], + ) + def test_respiration_with_different_parameter_values(self, param, delta, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + test_value = crop_model_params_provider[param] + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var][-1]) < precision + for var, precision in expected_precision.items() + ) + + def test_respiration_with_multiple_parameter_vectors(self, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + for param in ("Q10", "RMR", "RML", "RMS", "RMO"): + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + crop_model_params_provider.set_override( + "RFSETB", crop_model_params_provider["RFSETB"].repeat(10, 1), check=False + ) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + def test_respiration_with_multiple_parameter_arrays(self, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + for param in ("Q10", "RMR", "RML", "RMS", "RMO"): + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + crop_model_params_provider.set_override( + "RFSETB", crop_model_params_provider["RFSETB"].repeat(30, 5, 1), check=False + ) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all(model[var].shape == (30, 5) for var in expected_precision.keys()) + + def test_respiration_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + + crop_model_params_provider.set_override( + "RMR", crop_model_params_provider["RMR"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "RML", crop_model_params_provider["RML"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device="cpu", + ) + + def test_respiration_with_incompatible_weather_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "RMR", crop_model_params_provider["RMR"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_leaf_dynamics(self, test_data_url): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.MaintenanceRespiration", WOFOST_Maintenance_Respiration): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +class TestDiffRespirationGradients: + """Parametrized tests for gradient calculations in maintenance respiration.""" + + param_configs = { + "single": { + "Q10": (2.0, torch.float64), + "RMR": (0.015, torch.float64), + "RML": (0.03, torch.float64), + "RMS": (0.02, torch.float64), + "RMO": (0.01, torch.float64), + }, + "tensor": { + "Q10": ([1.5, 2.0, 2.5], torch.float64), + "RMR": ([0.01, 0.015, 0.02], torch.float64), + "RML": ([0.02, 0.03, 0.04], torch.float64), + "RMS": ([0.01, 0.02, 0.03], torch.float64), + "RMO": ([0.005, 0.01, 0.02], torch.float64), + }, + } + + @pytest.mark.parametrize("param_name", ["Q10", "RMR", "RML", "RMS", "RMO"]) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, config_type, device): + model = get_test_diff_respiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output["PMRES"].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None + loss.backward() + grad_backward = param.grad + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) + + @pytest.mark.parametrize("param_name", ["Q10", "RMR", "RML", "RMS", "RMO"]) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, config_type, device): + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + numerical_grad = calculate_numerical_grad( + lambda: get_test_diff_respiration_model(device=device), + param_name, + param, + "PMRES", + ) + + model = get_test_diff_respiration_model(device=device) + output = model({param_name: param}) + loss = output["PMRES"].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with" + + f"respect to output 'PMRES' is zero: {grads.data}", + UserWarning, + )