diff --git a/src/diffwofost/physical_models/base/__init__.py b/src/diffwofost/physical_models/base/__init__.py new file mode 100644 index 0000000..615e40c --- /dev/null +++ b/src/diffwofost/physical_models/base/__init__.py @@ -0,0 +1,5 @@ +from .states_rates import TensorParamTemplate +from .states_rates import TensorRatesTemplate +from .states_rates import TensorStatesTemplate + +__all__ = ["TensorParamTemplate", "TensorRatesTemplate", "TensorStatesTemplate"] diff --git a/src/diffwofost/physical_models/base/states_rates.py b/src/diffwofost/physical_models/base/states_rates.py new file mode 100644 index 0000000..e81e4ee --- /dev/null +++ b/src/diffwofost/physical_models/base/states_rates.py @@ -0,0 +1,109 @@ +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import StatesTemplate +from pcse.traitlets import HasTraits +from ..traitlets import Tensor +from ..utils import AfgenTrait + + +class TensorContainer(HasTraits): + def __init__(self, shape=None, do_not_broadcast=None, **variables): + """Container of tensor variables. + + It includes functionality to broadcast variables to a common shape. This common shape can + be inferred from the container's tensor and AFGEN variables, or it can be set as an input + argument. + + Args: + shape (tuple | torch.Size, optional): Shape to which the variables in the container + are broadcasted. If given, it should match the shape of all the input variables that + already have dimensions. Defaults to None. + do_not_broadcast (list, optional): Name of the variables that are not broadcasted + to the container shape. Defaults to None, which means that all variables are + broadcasted. + variables (dict): Collection of variables to initialize the container, as key-value + pairs. + """ + self._shape = () + self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast + HasTraits.__init__(self, **variables) + self._broadcast(shape) + + def _broadcast(self, shape=None): + # Identify which variables should be broadcasted. Also check that the input shape is + # compatible with the existing variable shapes + vars_to_broadcast = self._get_vars_to_broadcast() + vars_shape = self._get_vars_shape() + if shape and vars_shape and vars_shape != shape: + raise ValueError(f"Input shape {shape} does not match variable shape {vars_shape}") + shape = tuple(shape or vars_shape) + + # Broadcast all required variables to the identified shape. + for varname, var in vars_to_broadcast.items(): + try: + broadcasted = var.expand(shape) + except RuntimeError as error: + raise ValueError(f"Cannot broadcast {varname} to shape {shape}") from error + setattr(self, varname, broadcasted) + + # Finally, update the shape of the container + self.shape = shape + + def _get_vars_to_broadcast(self): + vars = {} + for varname, trait in self.traits().items(): + if varname not in self._do_not_broadcast: + if isinstance(trait, Tensor): + vars[varname] = getattr(self, varname) + return vars + + def _get_vars_shape(self): + shape = () + for varname, trait in self.traits().items(): + if varname not in self._do_not_broadcast: + if isinstance(trait, Tensor) or isinstance(trait, AfgenTrait): + var = getattr(self, varname) + if not var.shape or shape == var.shape: + continue + elif var.shape and not shape: + shape = tuple(var.shape) + else: + raise ValueError( + f"Incompatible shapes within variables: {shape} and {var.shape}" + ) + return shape + + @property + def shape(self): + """Base shape of the variables in the container.""" + return self._shape + + @shape.setter + def shape(self, shape): + if self.shape and self.shape != shape: + raise ValueError(f"Container shape already set to {self.shape}") + self._shape = shape + + +class TensorParamTemplate(TensorContainer, ParamTemplate): + def __init__(self, parvalues, shape=None, do_not_broadcast=None): + self._shape = () + self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast + ParamTemplate.__init__(self, parvalues=parvalues) + self._broadcast(shape) + + +class TensorStatesTemplate(TensorContainer, StatesTemplate): + def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None, **kwargs): + self._shape = () + self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast + StatesTemplate.__init__(self, kiosk=kiosk, publish=publish, **kwargs) + self._broadcast(shape) + + +class TensorRatesTemplate(TensorContainer, RatesTemplate): + def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None): + self._shape = () + self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast + RatesTemplate.__init__(self, kiosk=kiosk, publish=publish) + self._broadcast(shape) diff --git a/src/diffwofost/physical_models/crop/assimilation.py b/src/diffwofost/physical_models/crop/assimilation.py index 6b093fc..7faa4bd 100644 --- a/src/diffwofost/physical_models/crop/assimilation.py +++ b/src/diffwofost/physical_models/crop/assimilation.py @@ -3,21 +3,20 @@ import datetime from collections import deque import torch -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.decorators import prepare_rates from pcse.decorators import prepare_states -from pcse.traitlets import Any from pcse.util import astro +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv -from diffwofost.physical_models.utils import _get_params_shape def _as_python_float(x) -> float: @@ -42,6 +41,8 @@ def totass7( COSLD: torch.Tensor, *, epsilon: torch.Tensor, + dtype: torch.Size | tuple, + device: str, ) -> torch.Tensor: """Calculates daily total gross CO2 assimilation. @@ -69,9 +70,9 @@ def totass7( COSLD R4 Amplitude of sine of solar height - I DTGA R4 Daily total gross assimilation kg CO2/ha/d O """ - xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=DAYL.dtype, device=DAYL.device) - wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=DAYL.dtype, device=DAYL.device) - pi = torch.tensor(torch.pi, dtype=DAYL.dtype, device=DAYL.device) + xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=dtype, device=device) + wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=dtype, device=device) + pi = torch.tensor(torch.pi, dtype=dtype, device=device) # Only compute where it can be non-zero. mask = (AMAX > 0) & (LAI > 0) & (DAYL > 0) @@ -229,8 +230,6 @@ class WOFOST72_Assimilation(SimulationObject): | PGASS | AMAXTB, EFFTB, KDIFTB, TMPFTB, TMNFTB | """ # noqa: E501 - params_shape = None - @property def device(self): """Get device from ComputeConfig.""" @@ -241,33 +240,27 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): + class Parameters(TensorParamTemplate): AMAXTB = AfgenTrait() EFFTB = AfgenTrait() KDIFTB = AfgenTrait() TMPFTB = AfgenTrait() TMNFTB = AfgenTrait() - def __init__(self, parvalues): - super().__init__(parvalues) - - class RateVariables(RatesTemplate): - PGASS = Any() - - def __init__(self, kiosk, publish=None): - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.PGASS = torch.tensor(0.0, dtype=dtype, device=device) - super().__init__(kiosk, publish=publish) + class RateVariables(TensorRatesTemplate): + PGASS = Tensor(0.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | torch.Size | None = None, ) -> None: """Initialize the assimilation module.""" self.kiosk = kiosk - self.params = self.Parameters(parvalues) - self.params_shape = _get_params_shape(self.params) - self.rates = self.RateVariables(kiosk, publish=["PGASS"]) + self.params = self.Parameters(parvalues, shape=shape) + self.rates = self.RateVariables(kiosk, publish=["PGASS"], shape=shape) # 7-day running average buffer for TMIN (stored as tensors). self._tmn_window = deque(maxlen=7) @@ -285,16 +278,16 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None _exist_required_external_variables(k) # External states - dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) - lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device) + dvs = _broadcast_to(k["DVS"], self.params.shape, dtype=self.dtype, device=self.device) + lai = _broadcast_to(k["LAI"], self.params.shape, dtype=self.dtype, device=self.device) # Weather drivers - irrad = _get_drv(drv.IRRAD, self.params_shape, dtype=self.dtype, device=self.device) - dtemp = _get_drv(drv.DTEMP, self.params_shape, dtype=self.dtype, device=self.device) - tmin = _get_drv(drv.TMIN, self.params_shape, dtype=self.dtype, device=self.device) + irrad = _get_drv(drv.IRRAD, self.params.shape, dtype=self.dtype, device=self.device) + dtemp = _get_drv(drv.DTEMP, self.params.shape, dtype=self.dtype, device=self.device) + tmin = _get_drv(drv.TMIN, self.params.shape, dtype=self.dtype, device=self.device) # Assimilation is zero before crop emergence (DVS < 0) - dvs_mask = (dvs >= 0).to(dtype=self.dtype) + dvs_mask = dvs >= 0 # 7-day running average of TMIN self._tmn_window.appendleft(tmin * dvs_mask) self._tmn_window_mask.appendleft(dvs_mask) @@ -307,11 +300,11 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None irrad_for_astro = _as_python_float(drv.IRRAD) dayl, _daylp, sinld, cosld, difpp, _atmtr, dsinbe, _angot = astro(day, lat, irrad_for_astro) - dayl_t = _broadcast_to(dayl, self.params_shape, dtype=self.dtype, device=self.device) - sinld_t = _broadcast_to(sinld, self.params_shape, dtype=self.dtype, device=self.device) - cosld_t = _broadcast_to(cosld, self.params_shape, dtype=self.dtype, device=self.device) - difpp_t = _broadcast_to(difpp, self.params_shape, dtype=self.dtype, device=self.device) - dsinbe_t = _broadcast_to(dsinbe, self.params_shape, dtype=self.dtype, device=self.device) + dayl_t = _broadcast_to(dayl, self.params.shape, dtype=self.dtype, device=self.device) + sinld_t = _broadcast_to(sinld, self.params.shape, dtype=self.dtype, device=self.device) + cosld_t = _broadcast_to(cosld, self.params.shape, dtype=self.dtype, device=self.device) + difpp_t = _broadcast_to(difpp, self.params.shape, dtype=self.dtype, device=self.device) + dsinbe_t = _broadcast_to(dsinbe, self.params.shape, dtype=self.dtype, device=self.device) # Parameter tables amax = p.AMAXTB(dvs) @@ -331,6 +324,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None sinld_t, cosld_t, epsilon=self._epsilon, + dtype=self.dtype, + device=self.device, ) # Correction for low minimum temperature potential diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index a40b567..8e64750 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -2,21 +2,19 @@ import datetime import torch -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.decorators import prepare_rates from pcse.decorators import prepare_states -from pcse.traitlets import Any +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait -from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv -from diffwofost.physical_models.utils import _get_params_shape class WOFOST_Leaf_Dynamics(SimulationObject): @@ -119,7 +117,6 @@ class WOFOST_Leaf_Dynamics(SimulationObject): # on the leaf classes during the time integration: leaf area, age, and biomass. START_DATE = None # Start date of the simulation MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) - params_shape = None # Shape of the parameters tensors @property def device(self): @@ -131,111 +128,47 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - RGRLAI = Any() - SPAN = Any() - TBASE = Any() - PERDL = Any() - TDWI = Any() + class Parameters(TensorParamTemplate): + RGRLAI = Tensor(-99.0) + SPAN = Tensor(-99.0) + TBASE = Tensor(-99.0) + PERDL = Tensor(-99.0) + TDWI = Tensor(-99.0) SLATB = AfgenTrait() KDIFTB = AfgenTrait() - def __init__(self, parvalues): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.RGRLAI = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.SPAN = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.TBASE = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.PERDL = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] - - # Call parent init - super().__init__(parvalues) - - class StateVariables(StatesTemplate): - LV = Any() - SLA = Any() - LVAGE = Any() - LAIEM = Any() - LASUM = Any() - LAIEXP = Any() - LAIMAX = Any() - LAI = Any() - WLV = Any() - DWLV = Any() - TWLV = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - if "LV" not in kwargs: - self.LV = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "SLA" not in kwargs: - self.SLA = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "LVAGE" not in kwargs: - self.LVAGE = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "LAIEM" not in kwargs: - self.LAIEM = torch.tensor(-99.0, dtype=dtype, device=device) - if "LASUM" not in kwargs: - self.LASUM = torch.tensor(-99.0, dtype=dtype, device=device) - if "LAIEXP" not in kwargs: - self.LAIEXP = torch.tensor(-99.0, dtype=dtype, device=device) - if "LAIMAX" not in kwargs: - self.LAIMAX = torch.tensor(-99.0, dtype=dtype, device=device) - if "LAI" not in kwargs: - self.LAI = torch.tensor(-99.0, dtype=dtype, device=device) - if "WLV" not in kwargs: - self.WLV = torch.tensor(-99.0, dtype=dtype, device=device) - if "DWLV" not in kwargs: - self.DWLV = torch.tensor(-99.0, dtype=dtype, device=device) - if "TWLV" not in kwargs: - self.TWLV = torch.tensor(-99.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish, **kwargs) - - class RateVariables(RatesTemplate): - GRLV = Any() - DSLV1 = Any() - DSLV2 = Any() - DSLV3 = Any() - DSLV = Any() - DALV = Any() - DRLV = Any() - SLAT = Any() - FYSAGE = Any() - GLAIEX = Any() - GLASOL = Any() - - def __init__(self, kiosk): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.GRLV = torch.tensor(0.0, dtype=dtype, device=device) - self.DSLV1 = torch.tensor(0.0, dtype=dtype, device=device) - self.DSLV2 = torch.tensor(0.0, dtype=dtype, device=device) - self.DSLV3 = torch.tensor(0.0, dtype=dtype, device=device) - self.DSLV = torch.tensor(0.0, dtype=dtype, device=device) - self.DALV = torch.tensor(0.0, dtype=dtype, device=device) - self.DRLV = torch.tensor(0.0, dtype=dtype, device=device) - self.SLAT = torch.tensor(0.0, dtype=dtype, device=device) - self.FYSAGE = torch.tensor(0.0, dtype=dtype, device=device) - self.GLAIEX = torch.tensor(0.0, dtype=dtype, device=device) - self.GLASOL = torch.tensor(0.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk) + class StateVariables(TensorStatesTemplate): + LV = Tensor(-99.0) + SLA = Tensor(-99.0) + LVAGE = Tensor(-99.0) + LAIEM = Tensor(-99.0) + LASUM = Tensor(-99.0) + LAIEXP = Tensor(-99.0) + LAIMAX = Tensor(-99.0) + LAI = Tensor(-99.0) + WLV = Tensor(-99.0) + DWLV = Tensor(-99.0) + TWLV = Tensor(-99.0) + + class RateVariables(TensorRatesTemplate): + GRLV = Tensor(0.0) + DSLV1 = Tensor(0.0) + DSLV2 = Tensor(0.0) + DSLV3 = Tensor(0.0) + DSLV = Tensor(0.0) + DALV = Tensor(0.0) + DRLV = Tensor(0.0) + SLAT = Tensor(0.0) + FYSAGE = Tensor(0.0) + GLAIEX = Tensor(0.0) + GLASOL = Tensor(0.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | torch.Size | None = None, ) -> None: """Initialize the WOFOST_Leaf_Dynamics simulation object. @@ -247,12 +180,13 @@ def initialize( parvalues (ParameterProvider): A dictionary-like container holding all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ self.START_DATE = day self.kiosk = kiosk # TODO check if parvalues are already torch.nn.Parameters - self.params = self.Parameters(parvalues) - self.rates = self.RateVariables(kiosk) + self.params = self.Parameters(parvalues, shape=shape) + self.rates = self.RateVariables(kiosk, shape=shape) # Create scalar constants once at the beginning to avoid recreating them self._zero = torch.tensor(0.0, dtype=self.dtype, device=self.device) @@ -263,44 +197,33 @@ def initialize( # CALCULATE INITIAL STATE VARIABLES # check for required external variables _exist_required_external_variables(self.kiosk) - # TODO check if external variables are already torch tensors - - # Get kiosk values and ensure they are on the correct device - FL = torch.as_tensor(self.kiosk["FL"], dtype=self.dtype, device=self.device) - FR = torch.as_tensor(self.kiosk["FR"], dtype=self.dtype, device=self.device) - DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) params = self.params - self.params_shape = _get_params_shape(params) # Initial leaf biomass - TDWI = _broadcast_to(params.TDWI, self.params_shape, dtype=self.dtype, device=self.device) - WLV = (TDWI * (1 - FR)) * FL - DWLV = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + WLV = (params.TDWI * (1 - self.kiosk["FR"])) * self.kiosk["FL"] + DWLV = 0.0 TWLV = WLV + DWLV # Initialize leaf classes (SLA, age and weight) - SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device) - LVAGE = torch.zeros( - (*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device - ) - LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device) - SLA[..., 0] = params.SLATB(DVS).to(dtype=self.dtype, device=self.device) - LV[..., 0] = WLV + SLA = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) + LVAGE = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) + LV = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) + SLA[0, ...] = params.SLATB(self.kiosk["DVS"]) + LV[0, ...] = WLV # Initial values for leaf area - LAIEM = LV[..., 0] * SLA[..., 0] + LAIEM = LV[0, ...] * SLA[0, ...] LASUM = LAIEM LAIEXP = LAIEM LAIMAX = LAIEM - SAI = torch.as_tensor(self.kiosk["SAI"], dtype=self.dtype, device=self.device) - PAI = torch.as_tensor(self.kiosk["PAI"], dtype=self.dtype, device=self.device) - LAI = LASUM + SAI + PAI + LAI = LASUM + self.kiosk["SAI"] + self.kiosk["PAI"] # Initialize StateVariables object self.states = self.StateVariables( kiosk, publish=["LAI", "TWLV", "WLV"], + do_not_broadcast=["SLA", "LVAGE", "LV"], LV=LV, SLA=SLA, LVAGE=LVAGE, @@ -312,6 +235,7 @@ def initialize( WLV=WLV, DWLV=DWLV, TWLV=TWLV, + shape=shape, ) def _calc_LAI(self): @@ -338,24 +262,17 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) - DVS = torch.as_tensor(k["DVS"], dtype=self.dtype, device=self.device) - dvs_mask = (DVS >= 0).to(dtype=self.dtype).to(device=self.device) + dvs_mask = k["DVS"] >= self._zero # Growth rate leaves # weight of new leaves r.GRLV = dvs_mask * k.ADMI * k.FL # death of leaves due to water/oxygen stress - RFTRA = _broadcast_to(k.RFTRA, self.params_shape, dtype=self.dtype, device=self.device) - PERDL = _broadcast_to(p.PERDL, self.params_shape, dtype=self.dtype, device=self.device) - r.DSLV1 = dvs_mask * s.WLV * (1.0 - RFTRA) * PERDL + r.DSLV1 = dvs_mask * s.WLV * (1.0 - k.RFTRA) * p.PERDL # death due to self shading cause by high LAI - DVS = _broadcast_to( - self.kiosk["DVS"], self.params_shape, dtype=self.dtype, device=self.device - ) - KDIFTB = p.KDIFTB.to(device=self.device, dtype=self.dtype) - LAICR = 3.2 / KDIFTB(DVS) + LAICR = 3.2 / p.KDIFTB(k["DVS"]) r.DSLV2 = dvs_mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03) # Death of leaves due to frost damage as determined by @@ -363,7 +280,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: if "RF_FROST" in self.kiosk: r.DSLV3 = s.WLV * k.RF_FROST else: - r.DSLV3 = torch.zeros_like(s.WLV, dtype=self.dtype) + r.DSLV3 = torch.zeros_like(s.WLV) r.DSLV3 = dvs_mask * r.DSLV3 @@ -376,9 +293,6 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # in DALV. # Note that the actual leaf death is imposed on the array LV during the # state integration step. - tSPAN = _broadcast_to( - p.SPAN, s.LVAGE.shape, dtype=self.dtype, device=self.device - ) # Broadcast to same shape # Using a sigmoid here instead of a conditional statement on the value of # SPAN because the latter would not allow for the gradient to be tracked. @@ -389,11 +303,11 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: if p.SPAN.requires_grad: # soft mask using sigmoid soft_mask = torch.sigmoid( - (s.LVAGE - tSPAN - self._sigmoid_epsilon) / self._sigmoid_sharpness - ).to(dtype=self.dtype) + (s.LVAGE - p.SPAN - self._sigmoid_epsilon) / self._sigmoid_sharpness + ) # originial hard mask - hard_mask = (s.LVAGE > tSPAN).to(dtype=self.dtype) + hard_mask = s.LVAGE > p.SPAN # STE method. Here detach is used to stop the gradient flow. This # way, during backpropagation, the gradient is computed only through @@ -401,44 +315,42 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # is used. span_mask = hard_mask.detach() + soft_mask - soft_mask.detach() else: - span_mask = (s.LVAGE > tSPAN).to(dtype=self.dtype) + span_mask = s.LVAGE > p.SPAN - r.DALV = torch.sum(span_mask * s.LV, dim=-1) + r.DALV = torch.sum(span_mask * s.LV, dim=0) r.DALV = dvs_mask * r.DALV # Total death rate leaves r.DRLV = torch.maximum(r.DSLV, r.DALV) # Get the temperature from the drv - TEMP = _get_drv(drv.TEMP, self.params_shape, self.dtype, self.device) + TEMP = _get_drv(drv.TEMP, p.shape, self.dtype, self.device) # physiologic ageing of leaves per time step - TBASE = _broadcast_to(p.TBASE, self.params_shape, dtype=self.dtype, device=self.device) - FYSAGE = (TEMP - TBASE) / (35.0 - TBASE) + FYSAGE = (TEMP - p.TBASE) / (35.0 - p.TBASE) r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0) # specific leaf area of leaves per time step - SLATB = p.SLATB.to(device=self.device, dtype=self.dtype) - r.SLAT = dvs_mask * SLATB(DVS) + r.SLAT = dvs_mask * p.SLATB(k["DVS"]) # leaf area not to exceed exponential growth curve is_lai_exp = s.LAIEXP < 6.0 - DTEFF = torch.clamp(TEMP - TBASE, 0.0) + DTEFF = torch.clamp(TEMP - p.TBASE, 0.0) # NOTE: conditional statements do not allow for the gradient to be # tracked through the condition. Thus, the gradient with respect to # parameters that contribute to `is_lai_exp` (e.g. RGRLAI and TBASE) # are expected to be incorrect. - RGRLAI = _broadcast_to(p.RGRLAI, self.params_shape, dtype=self.dtype, device=self.device) + r.GLAIEX = torch.where( - dvs_mask.bool(), - torch.where(is_lai_exp, s.LAIEXP * RGRLAI * DTEFF, r.GLAIEX), + dvs_mask, + torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX), self._zero, ) # source-limited increase in leaf area r.GLASOL = torch.where( - dvs_mask.bool(), + dvs_mask, torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL), self._zero, ) @@ -448,7 +360,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # adjustment of specific leaf area of youngest leaf class r.SLAT = torch.where( - dvs_mask.bool(), + dvs_mask, torch.where( is_lai_exp & (r.GRLV > self._epsilon), GLA / (r.GRLV + self._epsilon), r.SLAT ), @@ -471,22 +383,19 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: tLV = states.LV.clone() tSLA = states.SLA.clone() tLVAGE = states.LVAGE.clone() - tDRLV = _broadcast_to(rates.DRLV, tLV.shape, dtype=self.dtype, device=self.device) # Leaf death is imposed on leaves from the oldest ones. # Calculate the cumulative sum of weights after leaf death, and # find out which leaf classes are dead (negative weights) - weight_cumsum = tLV.cumsum(dim=-1) - tDRLV + weight_cumsum = tLV.cumsum(dim=0) - rates.DRLV is_alive = weight_cumsum >= 0 # Adjust value of oldest leaf class, i.e. the first non-zero # weight along the time axis (the last dimension). # Cast argument to int because torch.argmax requires it to be numeric - idx_oldest = torch.argmax(is_alive.type(torch.int), dim=-1, keepdim=True).to( - device=self.device - ) - new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=-1) - tLV = torch.scatter(tLV, dim=-1, index=idx_oldest, src=new_biomass) + idx_oldest = torch.argmax(is_alive.type(torch.int), dim=0, keepdim=True) + new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=0) + tLV = torch.scatter(tLV, dim=0, index=idx_oldest, src=new_biomass) # Integration of physiological age # Zero out all dead leaf classes @@ -494,18 +403,18 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: # tracked through the condition. Thus, the gradient with respect to # parameters that contribute to `is_alive` are expected to be incorrect. tLV = torch.where(is_alive, tLV, 0.0) - tLVAGE = tLVAGE + rates.FYSAGE.unsqueeze(-1) + tLVAGE = tLVAGE + rates.FYSAGE tLVAGE = torch.where(is_alive, tLVAGE, 0.0) tSLA = torch.where(is_alive, tSLA, 0.0) # --------- leave growth --------- idx = int((day - self.START_DATE).days / delt) - tLV[..., idx] = rates.GRLV - tSLA[..., idx] = rates.SLAT - tLVAGE[..., idx] = 0.0 + tLV[idx, ...] = rates.GRLV + tSLA[idx, ...] = rates.SLAT + tLVAGE[idx, ...] = 0.0 # calculation of new leaf area - states.LASUM = torch.sum(tLV * tSLA, dim=-1) + states.LASUM = torch.sum(tLV * tSLA, dim=0) states.LAI = self._calc_LAI() states.LAIMAX = torch.maximum(states.LAI, states.LAIMAX) @@ -513,7 +422,7 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: states.LAIEXP = states.LAIEXP + rates.GLAIEX # Update leaf biomass states - states.WLV = torch.sum(tLV, dim=-1) + states.WLV = torch.sum(tLV, dim=0) states.DWLV = states.DWLV + rates.DRLV states.TWLV = states.WLV + states.DWLV diff --git a/src/diffwofost/physical_models/crop/partitioning.py b/src/diffwofost/physical_models/crop/partitioning.py index 52fbbc6..6fecd51 100644 --- a/src/diffwofost/physical_models/crop/partitioning.py +++ b/src/diffwofost/physical_models/crop/partitioning.py @@ -2,16 +2,15 @@ from warnings import warn import torch from pcse import exceptions as exc -from pcse.base import ParamTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.decorators import prepare_states -from pcse.traitlets import Any from pcse.traitlets import Instance +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to -from diffwofost.physical_models.utils import _get_params_shape # Template for namedtuple containing partitioning factors @@ -35,8 +34,6 @@ class _BaseDVSPartitioning(SimulationObject): the public partitioning classes. """ - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" @@ -47,37 +44,19 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): + class Parameters(TensorParamTemplate): FRTB = AfgenTrait() FLTB = AfgenTrait() FSTB = AfgenTrait() FOTB = AfgenTrait() - def __init__(self, parvalues): - super().__init__(parvalues) - - class StateVariables(StatesTemplate): - FR = Any() - FL = Any() - FS = Any() - FO = Any() + class StateVariables(TensorStatesTemplate): + FR = Tensor(-99.0) + FL = Tensor(-99.0) + FS = Tensor(-99.0) + FO = Tensor(-99.0) PF = Instance(PartioningFactors) - def __init__(self, kiosk, publish=None, **kwargs): - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - if "FR" not in kwargs: - kwargs["FR"] = torch.tensor(-99.0, dtype=dtype, device=device) - if "FL" not in kwargs: - kwargs["FL"] = torch.tensor(-99.0, dtype=dtype, device=device) - if "FS" not in kwargs: - kwargs["FS"] = torch.tensor(-99.0, dtype=dtype, device=device) - if "FO" not in kwargs: - kwargs["FO"] = torch.tensor(-99.0, dtype=dtype, device=device) - - super().__init__(kiosk, publish=publish, **kwargs) - def _handle_partitioning_error(self, msg: str) -> None: """Hook for error handling (warn vs raise).""" warn(msg) @@ -104,13 +83,6 @@ def _check_partitioning(self): self.logger.error(msg) self._handle_partitioning_error(msg) - def _broadcast_partitioning(self, FR, FL, FS, FO): - FR = _broadcast_to(FR, self.params_shape, dtype=self.dtype, device=self.device) - FL = _broadcast_to(FL, self.params_shape, dtype=self.dtype, device=self.device) - FS = _broadcast_to(FS, self.params_shape, dtype=self.dtype, device=self.device) - FO = _broadcast_to(FO, self.params_shape, dtype=self.dtype, device=self.device) - return FR, FL, FS, FO - def _set_partitioning_states(self, FR, FL, FS, FO): self.states.FR = FR self.states.FL = FL @@ -126,15 +98,11 @@ def _compute_partitioning_from_tables(self, DVS): FO = p.FOTB(DVS) return FR, FL, FS, FO - def _initialize_from_tables(self, kiosk, parvalues): - self.params = self.Parameters(parvalues) + def _initialize_from_tables(self, kiosk, parvalues, shape=None): + self.params = self.Parameters(parvalues, shape=shape) self.kiosk = kiosk - self.params_shape = _get_params_shape(self.params) - - DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) + DVS = _broadcast_to(self.kiosk["DVS"], self.params.shape) FR, FL, FS, FO = self._compute_partitioning_from_tables(DVS) - FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) - self.states = self.StateVariables( kiosk, publish=["FR", "FL", "FS", "FO"], @@ -143,13 +111,13 @@ def _initialize_from_tables(self, kiosk, parvalues): FS=FS, FO=FO, PF=PartioningFactors(FR, FL, FS, FO), + shape=shape, ) self._check_partitioning() def _update_from_tables(self): - DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) + DVS = _broadcast_to(self.kiosk["DVS"], self.params.shape) FR, FL, FS, FO = self._compute_partitioning_from_tables(DVS) - FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) self._set_partitioning_states(FR, FL, FS, FO) self._check_partitioning() @@ -220,7 +188,7 @@ class DVS_Partitioning(_BaseDVSPartitioning): stems and storage organs on a given day do not add up to 1. """ - def initialize(self, day, kiosk, parvalues): + def initialize(self, day, kiosk, parvalues, shape=None): """Initialize the DVS_Partitioning simulation object. Args: @@ -228,8 +196,9 @@ def initialize(self, day, kiosk, parvalues): kiosk (VariableKiosk): Variable kiosk of this PCSE instance. parvalues (ParameterProvider): Object providing parameters as key/value pairs. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ - self._initialize_from_tables(kiosk, parvalues) + self._initialize_from_tables(kiosk, parvalues, shape=shape) @prepare_states def integrate(self, day, delt=1.0): @@ -317,7 +286,7 @@ class DVS_Partitioning_N(_BaseDVSPartitioning): def _handle_partitioning_error(self, msg: str) -> None: raise exc.PartitioningError(msg) - def initialize(self, day, kiosk, parameters): + def initialize(self, day, kiosk, parameters, shape=None): """Initialize the DVS_Partitioning_N simulation object. Args: @@ -325,8 +294,9 @@ def initialize(self, day, kiosk, parameters): kiosk (VariableKiosk): Variable kiosk of this PCSE instance. parameters (ParameterProvider): Dictionary with WOFOST cropdata key/value pairs. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ - self._initialize_from_tables(kiosk, parameters) + self._initialize_from_tables(kiosk, parameters, shape=shape) def _calculate_stressed_fr(self, DVS: torch.Tensor, RFTRA: torch.Tensor) -> torch.Tensor: """Computes the FR partitioning fraction under water/oxygen stress.""" @@ -336,15 +306,12 @@ def _calculate_stressed_fr(self, DVS: torch.Tensor, RFTRA: torch.Tensor) -> torc @prepare_states def integrate(self, day, delt=1.0): """Update partitioning factors based on DVS and water/oxygen stress.""" - DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) - RFTRA = torch.as_tensor(self.kiosk["RFTRA"], dtype=self.dtype, device=self.device) - + DVS = _broadcast_to(self.kiosk["DVS"], self.params.shape) + RFTRA = _broadcast_to(self.kiosk["RFTRA"], self.params.shape) FR = self._calculate_stressed_fr(DVS, RFTRA) FL = self.params.FLTB(DVS) FS = self.params.FSTB(DVS) FO = self.params.FOTB(DVS) - - FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) self._set_partitioning_states(FR, FL, FS, FO) self._check_partitioning() diff --git a/src/diffwofost/physical_models/crop/phenology.py b/src/diffwofost/physical_models/crop/phenology.py index bc30b64..76d2fe8 100644 --- a/src/diffwofost/physical_models/crop/phenology.py +++ b/src/diffwofost/physical_models/crop/phenology.py @@ -10,21 +10,20 @@ import torch from pcse import exceptions as exc from pcse import signals -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.decorators import prepare_rates from pcse.decorators import prepare_states -from pcse.traitlets import Any from pcse.traitlets import Enum from pcse.traitlets import Instance from pcse.util import daylength +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv -from diffwofost.physical_models.utils import _get_params_shape from diffwofost.physical_models.utils import _restore_state from diffwofost.physical_models.utils import _snapshot_state @@ -90,8 +89,6 @@ class Vernalisation(SimulationObject): | | for vernalisation reached) | | | """ - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" @@ -102,64 +99,22 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - VERNSAT = Any() - VERNBASE = Any() + class Parameters(TensorParamTemplate): + VERNSAT = Tensor(-99.0) + VERNBASE = Tensor(-99.0) VERNRTB = AfgenTrait() - VERNDVS = Any() - - def __init__(self, parvalues): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values using the ComputeConfig dtype and device - self.VERNSAT = torch.tensor(-99.0, dtype=dtype, device=device) - self.VERNBASE = torch.tensor(-99.0, dtype=dtype, device=device) - self.VERNDVS = torch.tensor(-99.0, dtype=dtype, device=device) - self.VERNRTB = self.VERNRTB.to(device=device, dtype=dtype) - - # Call parent init - super().__init__(parvalues) - - class RateVariables(RatesTemplate): - VERNR = Any() - VERNFAC = Any() - - def __init__(self, kiosk, publish=None): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values using the ComputeConfig dtype and device - self.VERNR = torch.tensor(0.0, dtype=dtype, device=device) - self.VERNFAC = torch.tensor(0.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish) - - class StateVariables(StatesTemplate): - VERN = Any() - DOV = Any() - ISVERNALISED = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values using the ComputeConfig dtype and device if not in kwargs - if "VERN" not in kwargs: - self.VERN = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOV" not in kwargs: - self.DOV = torch.tensor(-99.0, dtype=dtype, device=device) - if "ISVERNALISED" not in kwargs: - self.ISVERNALISED = torch.tensor(False, dtype=torch.bool, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish, **kwargs) - - def initialize(self, day, kiosk, parvalues, dvs_shape=None): + VERNDVS = Tensor(-99.0) + + class RateVariables(TensorRatesTemplate): + VERNR = Tensor(0.0) + VERNFAC = Tensor(0.0) + + class StateVariables(TensorStatesTemplate): + VERN = Tensor(-99.0) + DOV = Tensor(-99.0) + ISVERNALISED = Tensor(False, dtype=bool) + + def initialize(self, day, kiosk, parvalues, shape=None): """Initialize the Vernalisation sub-module. Args: @@ -167,7 +122,7 @@ def initialize(self, day, kiosk, parvalues, dvs_shape=None): kiosk: Shared PCSE kiosk for inter-module variable exchange. parvalues: ParameterProvider/dict containing VERNSAT, VERNBASE, VERNRTB and VERNDVS. - dvs_shape (torch.Size, optional): Shape of the DVS_phenology parameters + shape (tuple | torch.Size | None): Target shape for the state and rate variables. Side Effects: - Instantiates params, rates and states containers. @@ -179,58 +134,32 @@ def initialize(self, day, kiosk, parvalues, dvs_shape=None): ISVERNALISED = False. """ - self.params = self.Parameters(parvalues) - self.params_shape = _get_params_shape(self.params) + self.params = self.Parameters(parvalues, shape=shape) # Small epsilon tensor reused in multiple safe divisions. self._epsilon = torch.tensor(1e-8, dtype=self.dtype, device=self.device) - if dvs_shape is not None: - if self.params_shape == (): - self.params_shape = dvs_shape - elif self.params_shape != dvs_shape: - raise ValueError( - f"Vernalisation params shape {self.params_shape}" - + " incompatible with dvs_shape {dvs_shape}" - ) # Common constant tensors (same shape/dtype/device as this module). - self._ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) - self._zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + self._ones = torch.ones(self.params.shape, dtype=self.dtype, device=self.device) + self._zeros = torch.zeros(self.params.shape, dtype=self.dtype, device=self.device) + # Explicitly initialize rates - self.rates = self.RateVariables(kiosk, publish=["VERNFAC"]) - self.rates.VERNR = _broadcast_to( - self.rates.VERNR, self.params_shape, dtype=self.dtype, device=self.device - ) - self.rates.VERNFAC = _broadcast_to( - self.rates.VERNFAC, self.params_shape, dtype=self.dtype, device=self.device - ) - self.kiosk = kiosk + self.rates = self.RateVariables(kiosk, publish=["VERNFAC"], shape=shape) - # Explicitly broadcast all parameters to params_shape - self.params.VERNSAT = _broadcast_to( - self.params.VERNSAT, self.params_shape, dtype=self.dtype, device=self.device - ) - self.params.VERNBASE = _broadcast_to( - self.params.VERNBASE, self.params_shape, dtype=self.dtype, device=self.device - ) - self.params.VERNDVS = _broadcast_to( - self.params.VERNDVS, self.params_shape, dtype=self.dtype, device=self.device - ) - self.params.VERNRTB = self.params.VERNRTB.to(device=self.device, dtype=self.dtype) + self.kiosk = kiosk # Define initial states self.states = self.StateVariables( kiosk, - VERN=torch.zeros(self.params_shape, dtype=self.dtype, device=self.device), - DOV=torch.full( - self.params_shape, -1.0, dtype=self.dtype, device=self.device - ), # -1 indicates not yet fulfilled - ISVERNALISED=torch.zeros(self.params_shape, dtype=torch.bool, device=self.device), + VERN=0.0, + DOV=-1.0, # -1 indicates not yet fulfilled + ISVERNALISED=False, publish=["ISVERNALISED"], + shape=shape, ) # Per-element force flag (False for all elements initially) self._force_vernalisation = torch.zeros( - self.params_shape, dtype=torch.bool, device=self.device + self.params.shape, dtype=torch.bool, device=self.device ) @prepare_rates @@ -253,7 +182,7 @@ def calc_rates(self, day, drv): VERNBASE = params.VERNBASE DVS = self.kiosk["DVS"] - TEMP = _get_drv(drv.TEMP, self.params_shape, self.dtype, self.device) + TEMP = _get_drv(drv.TEMP, self.params.shape, self.dtype, self.device) # Operate elementwise only on elements not yet vernalised not_vernalised = ~self.states.ISVERNALISED @@ -320,7 +249,7 @@ def integrate(self, day, delt=1.0): states.DOV = torch.where( newly_reached_and_no_dov, torch.full( - self.params_shape, day.toordinal(), dtype=self.dtype, device=self.device + self.params.shape, day.toordinal(), dtype=self.dtype, device=self.device ), states.DOV, ) @@ -422,8 +351,6 @@ class DVS_Phenology(SimulationObject): # Placeholder for start/stop types and vernalisation module vernalisation = Instance(Vernalisation) - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" @@ -434,202 +361,79 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - TSUMEM = Any() - TBASEM = Any() - TEFFMX = Any() - TSUM1 = Any() - TSUM2 = Any() - IDSL = Any() - DLO = Any() - DLC = Any() - DVSI = Any() - DVSEND = Any() + class Parameters(TensorParamTemplate): + TSUMEM = Tensor(-99.0) + TBASEM = Tensor(-99.0) + TEFFMX = Tensor(-99.0) + TSUM1 = Tensor(-99.0) + TSUM2 = Tensor(-99.0) + IDSL = Tensor(-99.0) + DLO = Tensor(-99.0) + DLC = Tensor(-99.0) + DVSI = Tensor(-99.0) + DVSEND = Tensor(-99.0) DTSMTB = AfgenTrait() CROP_START_TYPE = Enum(["sowing", "emergence"]) CROP_END_TYPE = Enum(["maturity", "harvest", "earliest"]) - def __init__(self, parvalues): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values using the ComputeConfig dtype and device - self.TSUMEM = torch.tensor(-99.0, dtype=dtype, device=device) - self.TBASEM = torch.tensor(-99.0, dtype=dtype, device=device) - self.TEFFMX = torch.tensor(-99.0, dtype=dtype, device=device) - self.TSUM1 = torch.tensor(-99.0, dtype=dtype, device=device) - self.TSUM2 = torch.tensor(-99.0, dtype=dtype, device=device) - self.IDSL = torch.tensor(-99.0, dtype=dtype, device=device) - self.DLO = torch.tensor(-99.0, dtype=dtype, device=device) - self.DLC = torch.tensor(-99.0, dtype=dtype, device=device) - self.DVSI = torch.tensor(-99.0, dtype=dtype, device=device) - self.DVSEND = torch.tensor(-99.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(parvalues) - - class RateVariables(RatesTemplate): - DTSUME = Any() - DTSUM = Any() - DVR = Any() - - def __init__(self, kiosk, publish=None): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.DTSUME = torch.tensor(0.0, dtype=dtype, device=device) - self.DTSUM = torch.tensor(0.0, dtype=dtype, device=device) - self.DVR = torch.tensor(0.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish) - - class StateVariables(StatesTemplate): - DVS = Any() - TSUM = Any() - TSUME = Any() - DOS = Any() - DOE = Any() - DOA = Any() - DOM = Any() - DOH = Any() - STAGE = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - if "DVS" not in kwargs: - self.DVS = torch.tensor(-99.0, dtype=dtype, device=device) - if "TSUM" not in kwargs: - self.TSUM = torch.tensor(-99.0, dtype=dtype, device=device) - if "TSUME" not in kwargs: - self.TSUME = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOS" not in kwargs: - self.DOS = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOE" not in kwargs: - self.DOE = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOA" not in kwargs: - self.DOA = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOM" not in kwargs: - self.DOM = torch.tensor(-99.0, dtype=dtype, device=device) - if "DOH" not in kwargs: - self.DOH = torch.tensor(-99.0, dtype=dtype, device=device) - if "STAGE" not in kwargs: - self.STAGE = torch.tensor(-99, dtype=torch.long, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish, **kwargs) - - def _cast_and_broadcast_params(self): - """Cast and broadcast all parameters to params_shape with correct dtype/device. - - This ensures all parameters have consistent shape, dtype, and device. - Necessary if Vernalisation changes the params_shape during initialization. - """ - p = self.params - # Broadcast numeric parameters to the final params_shape and ensure dtype/device. - for name in ( - "TSUMEM", - "TBASEM", - "TEFFMX", - "TSUM1", - "TSUM2", - "IDSL", - "DLO", - "DLC", - "DVSI", - "DVSEND", - ): - setattr( - p, - name, - _broadcast_to(getattr(p, name), self.params_shape, self.dtype, self.device), - ) - - # Move AFGEN table buffers, if present. - if hasattr(p, "DTSMTB") and hasattr(p.DTSMTB, "to"): - p.DTSMTB.to(device=self.device, dtype=self.dtype) - - def initialize(self, day, kiosk, parvalues): + class RateVariables(TensorRatesTemplate): + DTSUME = Tensor(0.0) + DTSUM = Tensor(0.0) + DVR = Tensor(0.0) + + class StateVariables(TensorStatesTemplate): + DVS = Tensor(-99.0) + TSUM = Tensor(-99.0) + TSUME = Tensor(-99.0) + DOS = Tensor(-99.0) + DOE = Tensor(-99.0) + DOA = Tensor(-99.0) + DOM = Tensor(-99.0) + DOH = Tensor(-99.0) + STAGE = Tensor(-99.0) + + def initialize(self, day, kiosk, parvalues, shape=None): """:param day: start date of the simulation :param kiosk: variable kiosk of this PCSE instance :param parvalues: `ParameterProvider` object providing parameters as key/value pairs """ - self.params = self.Parameters(parvalues) - self.params_shape = _get_params_shape(self.params) + self.params = self.Parameters(parvalues, shape=shape) # Initialize vernalisation for IDSL>=2 - # It has to be done in advance to get the correct params_shape - IDSL = _broadcast_to( - self.params.IDSL, self.params_shape, dtype=self.dtype, device=self.device - ) - self.params.IDSL = IDSL - if torch.any(IDSL >= 2): - if self.params_shape != (): - self.vernalisation = Vernalisation( - day, kiosk, parvalues, dvs_shape=self.params_shape - ) - else: - self.vernalisation = Vernalisation(day, kiosk, parvalues) - if self.vernalisation.params_shape != self.params_shape: - self.params_shape = self.vernalisation.params_shape + if torch.any(self.params.IDSL >= 2): + self.vernalisation = Vernalisation(day, kiosk, parvalues, shape=shape) else: self.vernalisation = None - # After Vernalisation initialization the final params_shape may have changed. - self._cast_and_broadcast_params() - # Create scalar constants once at the beginning to avoid recreating them - self._ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) - self._zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + self._ones = torch.ones(self.params.shape, dtype=self.dtype, device=self.device) + self._zeros = torch.zeros(self.params.shape, dtype=self.dtype, device=self.device) self._epsilon = torch.tensor(1e-8, dtype=self.dtype, device=self.device) # Initialize rates and kiosk - self.rates = self.RateVariables(kiosk) + self.rates = self.RateVariables(kiosk, shape=shape) self.kiosk = kiosk self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish) # Define initial states DVS, DOS, DOE, STAGE = self._get_initial_stage(day) - DVS = _broadcast_to(DVS, self.params_shape, dtype=self.dtype, device=self.device) - - # Initialize all date tensors with -1 (not yet occurred) - DOS = _broadcast_to(DOS, self.params_shape, dtype=self.dtype, device=self.device) - DOE = _broadcast_to(DOE, self.params_shape, dtype=self.dtype, device=self.device) - DOA = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) - DOM = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) - DOH = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) - STAGE = _broadcast_to(STAGE, self.params_shape, dtype=self.dtype, device=self.device) - - # Also ensure TSUM and TSUME are properly shaped - TSUM = torch.zeros( - self.params_shape, dtype=self.dtype, device=self.device, requires_grad=True - ) - TSUME = torch.zeros( - self.params_shape, dtype=self.dtype, device=self.device, requires_grad=True - ) self.states = self.StateVariables( kiosk, publish="DVS", - TSUM=TSUM, - TSUME=TSUME, + TSUM=0.0, + TSUME=0.0, DVS=DVS, DOS=DOS, DOE=DOE, - DOA=DOA, - DOM=DOM, - DOH=DOH, + DOA=-1.0, # not yet occurred + DOM=-1.0, # not yet occurred + DOH=-1.0, # not yet occurred STAGE=STAGE, + shape=shape, ) def _get_initial_stage(self, day): @@ -647,26 +451,24 @@ def _get_initial_stage(self, day): STAGE (Tensor): Integer stage code (0=emerging, 1=vegetative). """ p = self.params - day_ordinal = torch.tensor(day.toordinal(), dtype=self.dtype, device=self.device) + day_ordinal = day.toordinal() # Define initial stage type (emergence/sowing) and fill the # respective day of sowing/emergence (DOS/DOE) if p.CROP_START_TYPE == "emergence": - STAGE = torch.tensor(1, dtype=torch.long, device=self.device) # 1 = vegetative + STAGE = 1 # 1 = vegetative DOE = day_ordinal - DOS = torch.tensor(-1.0, dtype=self.dtype, device=self.device) # Not applicable + DOS = -1.0 # Not applicable DVS = p.DVSI - if not isinstance(DVS, torch.Tensor): - DVS = torch.tensor(DVS, dtype=self.dtype, device=self.device) # send signal to indicate crop emergence self._send_signal(signals.crop_emerged) elif p.CROP_START_TYPE == "sowing": - STAGE = torch.tensor(0, dtype=torch.long, device=self.device) # 0 = emerging + STAGE = 0 # 0 = emerging DOS = day_ordinal - DOE = torch.tensor(-1.0, dtype=self.dtype, device=self.device) # Not yet occurred - DVS = torch.tensor(-0.1, dtype=self.dtype, device=self.device) + DOE = -1.0 # Not yet occurred + DVS = -0.1 else: msg = f"Unknown start type: {p.CROP_START_TYPE}" @@ -701,11 +503,10 @@ def calc_rates(self, day, drv): p = self.params r = self.rates s = self.states - shape = self.params_shape # Day length sensitivity DAYLP = daylength(day, drv.LAT) - DAYLP_t = _broadcast_to(DAYLP, shape, dtype=self.dtype, device=self.device) + DAYLP_t = _broadcast_to(DAYLP, p.shape, dtype=self.dtype, device=self.device) # Compute DVRED conditionally based on IDSL >= 1 safe_den = p.DLO - p.DLC safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) @@ -714,7 +515,7 @@ def calc_rates(self, day, drv): # Vernalisation factor - always compute if module exists VERNFAC = self._ones - if hasattr(self, "vernalisation") and self.vernalisation is not None: + if self.vernalisation is not None: # Always call calc_rates (it handles stage internally now) self.vernalisation.calc_rates(day, drv) # Apply vernalisation only where IDSL >= 2 AND in vegetative stage @@ -725,7 +526,7 @@ def calc_rates(self, day, drv): self._ones, ) - TEMP = _get_drv(drv.TEMP, shape, self.dtype, self.device) + TEMP = _get_drv(drv.TEMP, p.shape, self.dtype, self.device) # Initialize all rate variables r.DTSUME = self._zeros @@ -804,7 +605,6 @@ def integrate(self, day, delt=1.0): p = self.params r = self.rates s = self.states - shape = self.params_shape # Integrate vernalisation module if self.vernalisation: @@ -845,11 +645,11 @@ def integrate(self, day, delt=1.0): is_emerging = s.STAGE == 0 should_emerge = is_emerging & (s.DVS >= 0.0) s.STAGE = torch.where( - should_emerge, torch.ones(shape, dtype=torch.long, device=self.device), s.STAGE + should_emerge, torch.ones(p.shape, dtype=torch.long, device=self.device), s.STAGE ) s.DOE = torch.where( should_emerge, - torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + torch.full(p.shape, day_ordinal, dtype=self.dtype, device=self.device), s.DOE, ) s.DVS = torch.where(should_emerge, torch.clamp(s.DVS, max=0.0), s.DVS) @@ -862,11 +662,11 @@ def integrate(self, day, delt=1.0): is_vegetative = s.STAGE == 1 should_flower = is_vegetative & (s.DVS >= 1.0) s.STAGE = torch.where( - should_flower, torch.full(shape, 2, dtype=torch.long, device=self.device), s.STAGE + should_flower, torch.full(p.shape, 2, dtype=torch.long, device=self.device), s.STAGE ) s.DOA = torch.where( should_flower, - torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + torch.full(p.shape, day_ordinal, dtype=self.dtype, device=self.device), s.DOA, ) s.DVS = torch.where(should_flower, torch.clamp(s.DVS, max=1.0), s.DVS) @@ -875,11 +675,11 @@ def integrate(self, day, delt=1.0): is_reproductive = s.STAGE == 2 should_mature = is_reproductive & (s.DVS >= p.DVSEND) s.STAGE = torch.where( - should_mature, torch.full(shape, 3, dtype=torch.long, device=self.device), s.STAGE + should_mature, torch.full(p.shape, 3, dtype=torch.long, device=self.device), s.STAGE ) s.DOM = torch.where( should_mature, - torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + torch.full(p.shape, day_ordinal, dtype=self.dtype, device=self.device), s.DOM, ) s.DVS = torch.where(should_mature, torch.minimum(s.DVS, p.DVSEND), s.DVS) @@ -916,7 +716,7 @@ def _on_CROP_FINISH(self, day, finish_type=None): if finish_type in ["harvest", "earliest"]: day_ordinal = torch.tensor(day.toordinal(), dtype=self.dtype, device=self.device) self._for_finalize["DOH"] = torch.full( - self.params_shape, day_ordinal, dtype=self.dtype, device=self.device + self.params.shape, day_ordinal, dtype=self.dtype, device=self.device ) def get_variable(self, varname): diff --git a/src/diffwofost/physical_models/crop/root_dynamics.py b/src/diffwofost/physical_models/crop/root_dynamics.py index b38604f..536852e 100644 --- a/src/diffwofost/physical_models/crop/root_dynamics.py +++ b/src/diffwofost/physical_models/crop/root_dynamics.py @@ -1,19 +1,17 @@ import datetime import torch -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.decorators import prepare_rates from pcse.decorators import prepare_states -from pcse.traitlets import Any +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait -from diffwofost.physical_models.utils import _broadcast_to -from diffwofost.physical_models.utils import _get_params_shape class WOFOST_Root_Dynamics(SimulationObject): @@ -116,8 +114,6 @@ class WOFOST_Root_Dynamics(SimulationObject): better and more biophysical approach to root development in WOFOST. """ # noqa: E501 - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" @@ -128,80 +124,34 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - RDI = Any() - RRI = Any() - RDMCR = Any() - RDMSOL = Any() - TDWI = Any() - IAIRDU = Any() + class Parameters(TensorParamTemplate): + RDI = Tensor(-99.0) + RRI = Tensor(-99.0) + RDMCR = Tensor(-99.0) + RDMSOL = Tensor(-99.0) + TDWI = Tensor(-99.0) + IAIRDU = Tensor(-99.0) RDRRTB = AfgenTrait() - def __init__(self, parvalues): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.RDI = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.RRI = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.RDMCR = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.RDMSOL = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.IAIRDU = [torch.tensor(-99.0, dtype=dtype, device=device)] - - # Call parent init - super().__init__(parvalues) - - class RateVariables(RatesTemplate): - RR = Any() - GRRT = Any() - DRRT = Any() - GWRT = Any() - - def __init__(self, kiosk, publish=None): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.RR = torch.tensor(0.0, dtype=dtype, device=device) - self.GRRT = torch.tensor(0.0, dtype=dtype, device=device) - self.DRRT = torch.tensor(0.0, dtype=dtype, device=device) - self.GWRT = torch.tensor(0.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish) - - class StateVariables(StatesTemplate): - RD = Any() - RDM = Any() - WRT = Any() - DWRT = Any() - TWRT = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - if "RD" not in kwargs: - self.RD = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "RDM" not in kwargs: - self.RDM = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "WRT" not in kwargs: - self.WRT = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "DWRT" not in kwargs: - self.DWRT = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "TWRT" not in kwargs: - self.TWRT = [torch.tensor(-99.0, dtype=dtype, device=device)] - - # Call parent init - super().__init__(kiosk, publish=publish, **kwargs) + class RateVariables(TensorRatesTemplate): + RR = Tensor(0.0) + GRRT = Tensor(0.0) + DRRT = Tensor(0.0) + GWRT = Tensor(0.0) + + class StateVariables(TensorStatesTemplate): + RD = Tensor(-99.0) + RDM = Tensor(-99.0) + WRT = Tensor(-99.0) + DWRT = Tensor(-99.0) + TWRT = Tensor(-99.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | torch.Size | None = None, ) -> None: """Initialize the model. @@ -213,34 +163,33 @@ def initialize( parvalues (ParameterProvider): A dictionary-like container holding all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ self.kiosk = kiosk - self.params = self.Parameters(parvalues) - self.rates = self.RateVariables(kiosk, publish=["DRRT", "GRRT"]) + self.params = self.Parameters(parvalues, shape=shape) + self.rates = self.RateVariables(kiosk, publish=["DRRT", "GRRT"], shape=shape) # INITIAL STATES params = self.params - self.params_shape = _get_params_shape(params) - shape = self.params_shape # Initial root depth states - RDI = _broadcast_to(params.RDI, shape, dtype=self.dtype, device=self.device) - RDMCR = _broadcast_to(params.RDMCR, shape, dtype=self.dtype, device=self.device) - RDMSOL = _broadcast_to(params.RDMSOL, shape, dtype=self.dtype, device=self.device) - - rdmax = torch.maximum(RDI, torch.minimum(RDMCR, RDMSOL)) - RDM = rdmax - RD = RDI + RDM = torch.maximum(params.RDI, torch.minimum(params.RDMCR, params.RDMSOL)) + RD = params.RDI # Initial root biomass states - TDWI = _broadcast_to(params.TDWI, shape, dtype=self.dtype, device=self.device) - FR = _broadcast_to(self.kiosk["FR"], shape, dtype=self.dtype, device=self.device) - WRT = TDWI * FR - DWRT = torch.zeros(shape, dtype=self.dtype, device=self.device) + WRT = params.TDWI * self.kiosk["FR"] + DWRT = 0.0 TWRT = WRT + DWRT self.states = self.StateVariables( - kiosk, publish=["RD", "WRT", "TWRT"], RD=RD, RDM=RDM, WRT=WRT, DWRT=DWRT, TWRT=TWRT + kiosk, + publish=["RD", "WRT", "TWRT"], + RD=RD, + RDM=RDM, + WRT=WRT, + DWRT=DWRT, + TWRT=TWRT, + shape=shape, ) @prepare_rates @@ -260,25 +209,19 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask. # Make a mask (0 if DVS < 0, 1 if DVS >= 0) - DVS = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) - dvs_mask = (DVS >= 0).to(dtype=self.dtype) + dvs_mask = k["DVS"] >= 0 # Increase in root biomass - FR = _broadcast_to(k["FR"], self.params_shape, dtype=self.dtype, device=self.device) - DMI = _broadcast_to(k["DMI"], self.params_shape, dtype=self.dtype, device=self.device) - RDRRTB = p.RDRRTB.to(device=self.device, dtype=self.dtype) - - r.GRRT = dvs_mask * FR * DMI - r.DRRT = dvs_mask * s.WRT * RDRRTB(DVS) + r.GRRT = dvs_mask * k["FR"] * k["DMI"] + r.DRRT = dvs_mask * s.WRT * p.RDRRTB(k["DVS"]) r.GWRT = r.GRRT - r.DRRT # Increase in root depth - RRI = _broadcast_to(p.RRI, self.params_shape, dtype=self.dtype, device=self.device) - r.RR = dvs_mask * torch.minimum((s.RDM - s.RD), RRI) + r.RR = dvs_mask * torch.minimum((s.RDM - s.RD), p.RRI) # Do not let the roots growth if partioning to the roots # (variable FR) is zero. - mask = (FR > 0.0).to(dtype=self.dtype) + mask = k["FR"] > 0.0 r.RR = r.RR * mask * dvs_mask @prepare_states diff --git a/src/diffwofost/physical_models/engine.py b/src/diffwofost/physical_models/engine.py index 4b6920e..d570f0e 100644 --- a/src/diffwofost/physical_models/engine.py +++ b/src/diffwofost/physical_models/engine.py @@ -1,4 +1,5 @@ from pathlib import Path +import torch from pcse import signals from pcse.base import BaseEngine from pcse.base.variablekiosk import VariableKiosk @@ -27,6 +28,7 @@ def __init__( self.mconf = config self.parameterprovider = parameterprovider + self._shape = _get_params_shape(self.parameterprovider) # Variable kiosk for registering and publishing variables self.kiosk = VariableKiosk() @@ -65,3 +67,37 @@ def __init__( # Calculate initial rates self.calc_rates(self.day, self.drv) + + def _on_CROP_START( + self, day, crop_name=None, variety_name=None, crop_start_type=None, crop_end_type=None + ): + """Starts the crop.""" + self.logger.debug(f"Received signal 'CROP_START' on day {day}") + + if self.crop is not None: + raise RuntimeError( + "A CROP_START signal was received while self.cropsimulation still holds a valid " + "cropsimulation object. It looks like you forgot to send a CROP_FINISH signal with " + "option crop_delete=True" + ) + + self.parameterprovider.set_active_crop( + crop_name, variety_name, crop_start_type, crop_end_type + ) + self.crop = self.mconf.CROP(day, self.kiosk, self.parameterprovider, shape=self._shape) + + +def _get_params_shape(parameterprovider): + shape = () + for paramname in parameterprovider._unique_parameters: + param = parameterprovider[paramname] + if isinstance(param, torch.Tensor): + # We need to drop the last dimension from the Afgen table parameters + param_shape = param.shape[:-1] if paramname.endswith("TB") else param.shape + if not param_shape or shape == param_shape: + continue + elif param_shape and not shape: + shape = tuple(param_shape) + else: + raise ValueError("Non-matching shapes found in parameter provider!") + return shape diff --git a/src/diffwofost/physical_models/traitlets.py b/src/diffwofost/physical_models/traitlets.py new file mode 100644 index 0000000..efda289 --- /dev/null +++ b/src/diffwofost/physical_models/traitlets.py @@ -0,0 +1,45 @@ +import torch +from traitlets_pcse import TraitType +from traitlets_pcse import Undefined +from .config import ComputeConfig + + +class Tensor(TraitType): + info_text = "an object that could be casted into a tensor" + + def __init__( + self, + default_value=Undefined, + allow_none=False, + read_only=None, + help=None, + config=None, + dtype=None, + **kwargs, + ): + super().__init__( + default_value=default_value, + allow_none=allow_none, + read_only=read_only, + help=help, + config=config, + **kwargs, + ) + self.dtype = dtype + + def validate(self, obj, value): + """Validate input object, recasting it into a tensor if possible.""" + device = ComputeConfig.get_device() + dtype = ComputeConfig.get_dtype() if self.dtype is None else self.dtype + if isinstance(value, torch.Tensor): + casted = value.to(dtype=dtype, device=device) + return casted + try: + # Try casting value into a tensor, raise validation error if it fails + return torch.tensor(value, dtype=dtype, device=device) + except: # noqa: E722 + self.error(obj, value) + + def from_string(self, s): + """Casting tensor from string is not supported for now.""" + raise NotImplementedError diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index f3229c7..2ce9ed8 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -23,10 +23,10 @@ from pcse.engine import BaseEngine from pcse.settings import settings from pcse.timer import Timer -from pcse.traitlets import Enum from pcse.traitlets import TraitType from .config import Configuration from .engine import Engine +from .engine import _get_params_shape logging.disable(logging.CRITICAL) @@ -108,6 +108,7 @@ def __init__( self.mconf = config self.parameterprovider = parameterprovider + self._shape = _get_params_shape(self.parameterprovider) # Configure device and dtype on crop module class if it supports them if hasattr(self.mconf.CROP, "device") and device is not None: @@ -545,33 +546,6 @@ def validate(self, obj, value): self.error(obj, value) -def _get_params_shape(params): - """Get the parameters shape. - - Parameters can have arbitrary number of dimensions, but all parameters that are not zero- - dimensional should have the same shape. - - This check if fundamental for vectorized operations in the physical models. - """ - shape = () - for parname in params.trait_names(): - # Skip special traitlets attributes - if parname.startswith("trait"): - continue - param = getattr(params, parname) - # Skip Enum and str parameters - if isinstance(param, Enum) or isinstance(param, str): - continue - # Parameters that are not zero dimensional should all have the same shape - if param.shape and not shape: - shape = param.shape - elif param.shape: - assert param.shape == shape, ( - "All parameters should have the same shape (or have no dimensions)" - ) - return shape - - def _get_drv(drv_var, expected_shape, dtype, device=None): """Check that the driving variables have the expected shape and fetch them. @@ -610,33 +584,23 @@ def _get_drv(drv_var, expected_shape, dtype, device=None): ) -def _broadcast_to(x, shape, dtype, device=None): +def _broadcast_to(x, shape, dtype=None, device=None): """Create a view of tensor X with the given shape. Args: x: The tensor or value to broadcast shape: The target shape - dtype: dtype for the tensor + dtype: Optional dtype for the tensor device: Optional device for the tensor """ - # If x is not a tensor, convert it - if not isinstance(x, torch.Tensor): - x = torch.tensor(x, dtype=dtype) - # Ensure correct dtype and device - if dtype is not None: - x = x.to(dtype=dtype) + # Make sure x is a tensor + x = torch.as_tensor(x, dtype=dtype) if device is not None: x = x.to(device=device) # If already the correct shape, return as-is if x.shape == shape: return x - if x.dim() == 0: - # For 0-d tensors, we simply broadcast to the given shape - return torch.broadcast_to(x, shape) - # The given shape should match x in all but the last axis, which represents - # the dimension along which the time integration is carried out. - # We first append an axis to x, then expand to the given shape - return x.unsqueeze(-1).expand(shape) + return torch.broadcast_to(x, shape) def _snapshot_state(obj): diff --git a/tests/physical_models/base/test_states_rates.py b/tests/physical_models/base/test_states_rates.py new file mode 100644 index 0000000..dae7bbd --- /dev/null +++ b/tests/physical_models/base/test_states_rates.py @@ -0,0 +1,125 @@ +import pytest +import torch +from pcse.base import VariableKiosk +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate +from diffwofost.physical_models.traitlets import Tensor +from diffwofost.physical_models.utils import AfgenTrait + + +class TestTensorParamTemplate: + class Params(TensorParamTemplate): + A = Tensor(0) + B = Tensor(0, dtype=int) + + class ParamsWithAfgen(TensorParamTemplate): + A = Tensor(0) + B = AfgenTrait() + + def test_template_automatically_cast_to_tensor_with_correct_type(self): + p = self.Params(dict(A=1, B=[1, 1, 1])) + assert isinstance(p.A, torch.Tensor) + assert p.A.dtype == torch.float64 + assert isinstance(p.B, torch.Tensor) + assert p.B.dtype == torch.int64 + + def test_template_can_infer_shape_from_parameters(self): + p = self.Params(dict(A=1, B=[1, 1, 1])) + assert all(param.shape == (3,) for param in (p.A, p.B)) + assert p.shape == (3,) + + def test_template_can_apply_shape_from_argument(self): + shape = (3,) + p = self.Params(dict(A=1, B=1), shape=shape) + assert all(param.shape == shape for param in (p.A, p.B)) + assert p.shape == shape + + def test_template_checks_consistency_of_parameter_and_input_shapes(self): + # Here the input shape is consistent with the parameters + shape = (3,) + p = self.Params(dict(A=1, B=[1, 1, 1]), shape=shape) + assert all(param.shape == shape for param in (p.A, p.B)) + assert p.shape == shape + + # Here it is not + with pytest.raises(ValueError): + self.Params(dict(A=1, B=[1, 1, 1]), shape=(5,)) + + def test_template_allows_to_skip_broadcasting_of_variables(self): + shape = (5,) + p = self.Params(dict(A=1, B=[1, 1, 1]), shape=shape, do_not_broadcast=["B"]) + assert p.A.shape == shape + assert p.B.shape == (3,) + + def test_template_recognizes_shape_of_afgen_tables(self): + p = self.ParamsWithAfgen(dict(A=1, B=[0, 0, 1, 1])) + assert p.shape == () + p = self.ParamsWithAfgen(dict(A=1, B=[[0, 0, 1, 1], [0, 0, 2, 2]])) + assert p.shape == (2,) + + +class TestTensorRatesTemplate: + class Rates(TensorRatesTemplate): + A = Tensor(0) + B = Tensor(0, dtype=int) + + def test_template_automatically_cast_to_tensor_with_correct_type(self): + r = self.Rates(kiosk=VariableKiosk()) + assert isinstance(r.A, torch.Tensor) + assert r.A.dtype == torch.float64 + assert isinstance(r.B, torch.Tensor) + assert r.B.dtype == torch.int64 + + def test_template_can_apply_shape_from_argument(self): + shape = (3,) + r = self.Rates(kiosk=VariableKiosk(), shape=shape) + assert all(rate.shape == shape for rate in (r.A, r.B)) + assert r.shape == shape + + def test_template_allows_to_skip_broadcasting_of_variables(self): + shape = (5,) + r = self.Rates(kiosk=VariableKiosk(), shape=shape, do_not_broadcast=["B"]) + assert r.A.shape == shape + assert r.B.shape == () + + def test_template_allows_to_publish_in_kiosk(self): + k = VariableKiosk() + r = self.Rates(kiosk=k, publish=["A"]) + assert "A" in k + assert k.A == 0.0 + r.A = torch.tensor(1.0) + assert k.A == 1.0 + + +class TestTensorStatesTemplate: + class States(TensorStatesTemplate): + A = Tensor(0) + B = Tensor(0, dtype=int) + + def test_template_automatically_cast_to_tensor_with_correct_type(self): + s = self.States(kiosk=VariableKiosk(), A=1, B=1) + assert isinstance(s.A, torch.Tensor) + assert s.A.dtype == torch.float64 + assert isinstance(s.B, torch.Tensor) + assert s.B.dtype == torch.int64 + + def test_template_can_apply_shape_from_argument(self): + shape = (3,) + s = self.States(kiosk=VariableKiosk(), shape=shape, A=1, B=1) + assert all(state.shape == shape for state in (s.A, s.B)) + assert s.shape == shape + + def test_template_allows_to_skip_broadcasting_of_variables(self): + shape = (5,) + s = self.States(kiosk=VariableKiosk(), shape=shape, do_not_broadcast=["B"], A=1, B=1) + assert s.A.shape == shape + assert s.B.shape == () + + def test_template_allows_to_publish_in_kiosk(self): + k = VariableKiosk() + s = self.States(kiosk=k, publish=["A"], A=0, B=0) + assert "A" in k + assert k.A == 0.0 + s.A = torch.tensor(1.0) + assert k.A == 1.0 diff --git a/tests/physical_models/crop/test_assimilation.py b/tests/physical_models/crop/test_assimilation.py index 8eb16b7..54ac98a 100644 --- a/tests/physical_models/crop/test_assimilation.py +++ b/tests/physical_models/crop/test_assimilation.py @@ -303,7 +303,7 @@ def test_assimilation_with_incompatible_parameter_vectors(self): "EFFTB", crop_model_params_provider["EFFTB"].repeat(5, 1), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index dba524c..0d101c8 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -377,7 +377,7 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): "SPAN", crop_model_params_provider["SPAN"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, diff --git a/tests/physical_models/crop/test_partitioning.py b/tests/physical_models/crop/test_partitioning.py index ae5c6e4..db21718 100644 --- a/tests/physical_models/crop/test_partitioning.py +++ b/tests/physical_models/crop/test_partitioning.py @@ -289,7 +289,7 @@ def test_partitioning_with_incompatible_parameter_vectors(self): crop_model_params_provider.set_override("FRTB", [[0.0, 0.3, 2.0, 0.1]] * 4, check=False) crop_model_params_provider.set_override("FLTB", [[0.0, 0.3, 2.0, 0.1]] * 2, check=False) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py index 12cd78f..cd3690f 100644 --- a/tests/physical_models/crop/test_phenology.py +++ b/tests/physical_models/crop/test_phenology.py @@ -486,7 +486,7 @@ def test_phenology_with_incompatible_parameter_vectors(self): "TSUM2", crop_model_params_provider["TSUM2"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 69b7c50..b75b907 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -345,7 +345,7 @@ def test_root_dynamics_with_incompatible_parameter_vectors(self, device): "RRI", crop_model_params_provider["RRI"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, diff --git a/tests/physical_models/test_engine.py b/tests/physical_models/test_engine.py index 680f860..b23c9f4 100644 --- a/tests/physical_models/test_engine.py +++ b/tests/physical_models/test_engine.py @@ -1,8 +1,12 @@ +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.phenology import DVS_Phenology from diffwofost.physical_models.engine import Engine from diffwofost.physical_models.utils import get_test_data from diffwofost.physical_models.utils import prepare_engine_input from . import phy_data_folder +config = Configuration(CROP=DVS_Phenology) + class TestEngine: def test_engine_can_be_instantiated_from_default_pcse_config(self): @@ -31,6 +35,6 @@ def test_engine_can_be_instantiated_from_default_pcse_config(self): parameterprovider=crop_model_params_provider, weatherdataprovider=weather_data_provider, agromanagement=agro_management_inputs, - config="Wofost72_Pheno.conf", + config=config, ) assert isinstance(engine, Engine)