Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffwofost/physical_models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .states_rates import TensorParamTemplate
from .states_rates import TensorRatesTemplate
from .states_rates import TensorStatesTemplate

__all__ = ["TensorParamTemplate", "TensorRatesTemplate", "TensorStatesTemplate"]
109 changes: 109 additions & 0 deletions src/diffwofost/physical_models/base/states_rates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pcse.base import ParamTemplate
from pcse.base import RatesTemplate
from pcse.base import StatesTemplate
from pcse.traitlets import HasTraits
from ..traitlets import Tensor
from ..utils import AfgenTrait


class TensorContainer(HasTraits):
def __init__(self, shape=None, do_not_broadcast=None, **variables):
"""Container of tensor variables.
It includes functionality to broadcast variables to a common shape. This common shape can
be inferred from the container's tensor and AFGEN variables, or it can be set as an input
argument.
Args:
shape (tuple | torch.Size, optional): Shape to which the variables in the container
are broadcasted. If given, it should match the shape of all the input variables that
already have dimensions. Defaults to None.
do_not_broadcast (list, optional): Name of the variables that are not broadcasted
to the container shape. Defaults to None, which means that all variables are
broadcasted.
variables (dict): Collection of variables to initialize the container, as key-value
pairs.
"""
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
HasTraits.__init__(self, **variables)
self._broadcast(shape)

def _broadcast(self, shape=None):
# Identify which variables should be broadcasted. Also check that the input shape is
# compatible with the existing variable shapes
vars_to_broadcast = self._get_vars_to_broadcast()
vars_shape = self._get_vars_shape()
if shape and vars_shape and vars_shape != shape:
raise ValueError(f"Input shape {shape} does not match variable shape {vars_shape}")
shape = tuple(shape or vars_shape)

# Broadcast all required variables to the identified shape.
for varname, var in vars_to_broadcast.items():
try:
broadcasted = var.expand(shape)
except RuntimeError as error:
raise ValueError(f"Cannot broadcast {varname} to shape {shape}") from error
setattr(self, varname, broadcasted)

# Finally, update the shape of the container
self.shape = shape

def _get_vars_to_broadcast(self):
vars = {}
for varname, trait in self.traits().items():
if varname not in self._do_not_broadcast:
if isinstance(trait, Tensor):
vars[varname] = getattr(self, varname)
return vars

def _get_vars_shape(self):
shape = ()
for varname, trait in self.traits().items():
if varname not in self._do_not_broadcast:
if isinstance(trait, Tensor) or isinstance(trait, AfgenTrait):
var = getattr(self, varname)
if not var.shape or shape == var.shape:
continue
elif var.shape and not shape:
shape = tuple(var.shape)
else:
raise ValueError(
f"Incompatible shapes within variables: {shape} and {var.shape}"
)
return shape

@property
def shape(self):
"""Base shape of the variables in the container."""
return self._shape

@shape.setter
def shape(self, shape):
if self.shape and self.shape != shape:
raise ValueError(f"Container shape already set to {self.shape}")
self._shape = shape


class TensorParamTemplate(TensorContainer, ParamTemplate):
def __init__(self, parvalues, shape=None, do_not_broadcast=None):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
ParamTemplate.__init__(self, parvalues=parvalues)
self._broadcast(shape)


class TensorStatesTemplate(TensorContainer, StatesTemplate):
def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None, **kwargs):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
StatesTemplate.__init__(self, kiosk=kiosk, publish=publish, **kwargs)
self._broadcast(shape)


class TensorRatesTemplate(TensorContainer, RatesTemplate):
def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
RatesTemplate.__init__(self, kiosk=kiosk, publish=publish)
self._broadcast(shape)
67 changes: 31 additions & 36 deletions src/diffwofost/physical_models/crop/assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import datetime
from collections import deque
import torch
from pcse.base import ParamTemplate
from pcse.base import RatesTemplate
from pcse.base import SimulationObject
from pcse.base.parameter_providers import ParameterProvider
from pcse.base.variablekiosk import VariableKiosk
from pcse.base.weather import WeatherDataContainer
from pcse.decorators import prepare_rates
from pcse.decorators import prepare_states
from pcse.traitlets import Any
from pcse.util import astro
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.base import TensorRatesTemplate
from diffwofost.physical_models.config import ComputeConfig
from diffwofost.physical_models.traitlets import Tensor
from diffwofost.physical_models.utils import AfgenTrait
from diffwofost.physical_models.utils import _broadcast_to
from diffwofost.physical_models.utils import _get_drv
from diffwofost.physical_models.utils import _get_params_shape


def _as_python_float(x) -> float:
Expand All @@ -42,6 +41,8 @@ def totass7(
COSLD: torch.Tensor,
*,
epsilon: torch.Tensor,
dtype: torch.Size | tuple,
device: str,
) -> torch.Tensor:
"""Calculates daily total gross CO2 assimilation.

Expand Down Expand Up @@ -69,9 +70,9 @@ def totass7(
COSLD R4 Amplitude of sine of solar height - I
DTGA R4 Daily total gross assimilation kg CO2/ha/d O
"""
xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=DAYL.dtype, device=DAYL.device)
wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=DAYL.dtype, device=DAYL.device)
pi = torch.tensor(torch.pi, dtype=DAYL.dtype, device=DAYL.device)
xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=dtype, device=device)
wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=dtype, device=device)
pi = torch.tensor(torch.pi, dtype=dtype, device=device)

# Only compute where it can be non-zero.
mask = (AMAX > 0) & (LAI > 0) & (DAYL > 0)
Expand Down Expand Up @@ -229,8 +230,6 @@ class WOFOST72_Assimilation(SimulationObject):
| PGASS | AMAXTB, EFFTB, KDIFTB, TMPFTB, TMNFTB |
""" # noqa: E501

params_shape = None

@property
def device(self):
"""Get device from ComputeConfig."""
Expand All @@ -241,33 +240,27 @@ def dtype(self):
"""Get dtype from ComputeConfig."""
return ComputeConfig.get_dtype()

class Parameters(ParamTemplate):
class Parameters(TensorParamTemplate):
AMAXTB = AfgenTrait()
EFFTB = AfgenTrait()
KDIFTB = AfgenTrait()
TMPFTB = AfgenTrait()
TMNFTB = AfgenTrait()

def __init__(self, parvalues):
super().__init__(parvalues)

class RateVariables(RatesTemplate):
PGASS = Any()

def __init__(self, kiosk, publish=None):
dtype = ComputeConfig.get_dtype()
device = ComputeConfig.get_device()
self.PGASS = torch.tensor(0.0, dtype=dtype, device=device)
super().__init__(kiosk, publish=publish)
class RateVariables(TensorRatesTemplate):
PGASS = Tensor(0.0)

def initialize(
self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider
self,
day: datetime.date,
kiosk: VariableKiosk,
parvalues: ParameterProvider,
shape: tuple | torch.Size | None = None,
) -> None:
"""Initialize the assimilation module."""
self.kiosk = kiosk
self.params = self.Parameters(parvalues)
self.params_shape = _get_params_shape(self.params)
self.rates = self.RateVariables(kiosk, publish=["PGASS"])
self.params = self.Parameters(parvalues, shape=shape)
self.rates = self.RateVariables(kiosk, publish=["PGASS"], shape=shape)

# 7-day running average buffer for TMIN (stored as tensors).
self._tmn_window = deque(maxlen=7)
Expand All @@ -285,16 +278,16 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
_exist_required_external_variables(k)

# External states
dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device)
lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device)
dvs = _broadcast_to(k["DVS"], self.params.shape, dtype=self.dtype, device=self.device)
lai = _broadcast_to(k["LAI"], self.params.shape, dtype=self.dtype, device=self.device)

# Weather drivers
irrad = _get_drv(drv.IRRAD, self.params_shape, dtype=self.dtype, device=self.device)
dtemp = _get_drv(drv.DTEMP, self.params_shape, dtype=self.dtype, device=self.device)
tmin = _get_drv(drv.TMIN, self.params_shape, dtype=self.dtype, device=self.device)
irrad = _get_drv(drv.IRRAD, self.params.shape, dtype=self.dtype, device=self.device)
dtemp = _get_drv(drv.DTEMP, self.params.shape, dtype=self.dtype, device=self.device)
tmin = _get_drv(drv.TMIN, self.params.shape, dtype=self.dtype, device=self.device)

# Assimilation is zero before crop emergence (DVS < 0)
dvs_mask = (dvs >= 0).to(dtype=self.dtype)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this needed? Why cannot we leave it as a tensor with dtype bool?

dvs_mask = dvs >= 0
# 7-day running average of TMIN
self._tmn_window.appendleft(tmin * dvs_mask)
self._tmn_window_mask.appendleft(dvs_mask)
Expand All @@ -307,11 +300,11 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
irrad_for_astro = _as_python_float(drv.IRRAD)
dayl, _daylp, sinld, cosld, difpp, _atmtr, dsinbe, _angot = astro(day, lat, irrad_for_astro)

dayl_t = _broadcast_to(dayl, self.params_shape, dtype=self.dtype, device=self.device)
sinld_t = _broadcast_to(sinld, self.params_shape, dtype=self.dtype, device=self.device)
cosld_t = _broadcast_to(cosld, self.params_shape, dtype=self.dtype, device=self.device)
difpp_t = _broadcast_to(difpp, self.params_shape, dtype=self.dtype, device=self.device)
dsinbe_t = _broadcast_to(dsinbe, self.params_shape, dtype=self.dtype, device=self.device)
dayl_t = _broadcast_to(dayl, self.params.shape, dtype=self.dtype, device=self.device)
sinld_t = _broadcast_to(sinld, self.params.shape, dtype=self.dtype, device=self.device)
cosld_t = _broadcast_to(cosld, self.params.shape, dtype=self.dtype, device=self.device)
difpp_t = _broadcast_to(difpp, self.params.shape, dtype=self.dtype, device=self.device)
dsinbe_t = _broadcast_to(dsinbe, self.params.shape, dtype=self.dtype, device=self.device)

# Parameter tables
amax = p.AMAXTB(dvs)
Expand All @@ -331,6 +324,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
sinld_t,
cosld_t,
epsilon=self._epsilon,
dtype=self.dtype,
device=self.device,
)

# Correction for low minimum temperature potential
Expand Down
Loading