Skip to content

Add tensor parameter containers#75

Draft
fnattino wants to merge 9 commits intomainfrom
tensor-param-template
Draft

Add tensor parameter containers#75
fnattino wants to merge 9 commits intomainfrom
tensor-param-template

Conversation

@fnattino
Copy link
Collaborator

@fnattino fnattino commented Jan 15, 2026

relates #25

@fnattino
Copy link
Collaborator Author

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:
    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

@SCiarella
Copy link
Collaborator

Thanks @fnattino, this looks fantastic 🚀

I really like the template to automatically broadcast to the correct shape and device at the beginning, because right now we are doing it quite a lot of times in the integration loops.

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

@SarahAlidoost
Copy link
Collaborator

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:

    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

It is Awesome! 🥇 Thanks. I like how things get simpler and cleaner. Just one comment about naming, see above.

@fnattino
Copy link
Collaborator Author

fnattino commented Jan 19, 2026

Thank you @SCiarella and @SarahAlidoost for the useful feedback!

@SCiarella :

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

Indeed, I think it's a good idea to also add similar containers for states and rates, so all variables are initialized with the correct shape and device!

@SarahAlidoost:

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

My idea was to use Tensor in order to define variables that are expected to be tensors, in a similar fashion in which pcse has pcse.traitlets.Float or pcse.traitlets.Bool for floats and booleans, respectively. Right now, all the variables expeced to be tensors were marked as generic Any. Variables that are defined as Tensor are automatically checked to be of torch.Tensor type or casted into such type via the validate method, so for instance:

import torch
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.traitlets import Tensor

class Parameters(TensorParamTemplate):
    A = Tensor(0.)
    B = Tensor(0, dtype=int)

# Parameters A and B are casted into tensors
params = Parameters(dict(A=0., B=0))

params.A
# tensor(0., dtype=torch.float64)

params.B
# tensor(0)

@sonarqubecloud
Copy link

sonarqubecloud bot commented Feb 4, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants