Skip to content

Use of nn.Module #82

@ronvree

Description

@ronvree

Just wanted to post this to clarify my comment about nn.Module this morning and why I think it could be very useful to integrate in the diffwofost submodules as well.

Here is some code that reflects how I currently envisioned the implementation structure. the code is self contained and should be runnable if copied in a colab notebook for example

I'm sure many if not all of these things are known to you but I figured it might help structure the discussion or clarify if I'm missing something.


"""
    Toy example of structuring diffwofost
"""

import torch
import torch.nn as nn

# Wofost simulation object
class SimulationObject:
  pass

# Some differentiable phenology module
class DiffPhenology(SimulationObject, nn.Module):
  def __init__(self, dtype=None, device=None) -> None:
    super().__init__()

    self._p = nn.Parameter(torch.tensor(1.0, dtype=dtype, device=device,))

  def forward(self, x):
    return self._p * x

# Some differentiable assimilation module
class DiffPartitioning(SimulationObject, nn.Module):
  def __init__(self, dtype=None, device=None) -> None:
    super().__init__()

    self._p = nn.Parameter(torch.tensor(1.0, dtype=dtype, device=device,))
    
  def forward(self, x):
    return self._p * x + torch.ones_like(x, device=x.device, dtype=x.dtype)

# Differentiable wofost that is composed of differentiable submodules and determines how they are related
class DiffWofost(nn.Module):
  def __init__(self, dtype=None, device=None) -> None:
    super().__init__()

    self._module_phenology = DiffPhenology(dtype=dtype, device=device)
    self._module_partitioning = DiffPartitioning(dtype=dtype, device=device)

  def forward(self, x):
    x = self._module_phenology(x)
    x = self._module_partitioning(x)
    return x

# some function to visualize parameter states
def param_summary(p: torch.Tensor) -> str:
  info = [
      f'value: {p.item() if sum(p.shape) <= 1 else "-"}',
      f'dtype: {p.dtype}',
      f'device: {p.device}',
      f'requires_grad: {p.requires_grad}',
      f'grad: {p.grad if sum(p.shape) <= 1 else "-"}',
      f'shape: {tuple(i for i in p.shape)}',
  ]
  return ' | '.join(info)

Using this structure has several benefits that are built-in in pytorch:

  • calling model.to to change tensor dtype or device recursively looks for "submodules" of the model that are also torch.nn.Module objects
  • calling model.parameters() will list all parameters including those in any submodules

the following code reflects this behavior


model = DiffWofost()

for p in model.parameters():
  print(param_summary(p))

if torch.cuda.is_available():
  model.to('cuda')

print()
for p in model.parameters():
  print(param_summary(p))

model.to(torch.float64)
print()
for p in model.parameters():
  print(param_summary(p))

  • Saving model weights become easier since a model.state_dict can be obtained that contains all model parameters and can easily be transferred to other sessions/machines

  • some commonly used model components (e.g. batchnorm) behave differently during training versus evaluation so pytorch modules have builtin methods model.train and model.eval s.t. model behavior can be set accordingly. calling model.train recursively calls this to its submodules as well, which could be very useful if some model component is replaced that should take this into account.

  • drop in replacement of model components becomes easier if all these things are handled by pytorch, e.g. in the following use-case:


class PhenologyNet(nn.Module):
  def __init__(self, dtype=None, device=None) -> None:
    super().__init__()

    self._p = nn.Parameter(torch.tensor(2.0, dtype=dtype, device=device,))

  def forward(self, x):
    return self._p * x

model = DiffWofost()

for p in model.parameters():
  print(param_summary(p))

del model._module_phenology  # Required to de-register old parameters
model._p = PhenologyNet()

print()
for p in model.parameters():
  print(param_summary(p))

The model component is updated in-place and model.parameters() adapts accordingly. (dtype and device are still the programmers responsibility)

What could cause problems is changing parameter values "in-place", for example trhough some auxiliary model:


class ParameterModel(nn.Module):
  def __init__(self, dtype=None, device=None) -> None:
    super().__init__()

    self._p = nn.Parameter(torch.tensor(2.0, dtype=dtype, device=device,))

  def forward(self, z):
    return self._p * z

model = DiffWofost()

x = torch.tensor(2.0)
y_true = torch.tensor(4.0)

loss = abs(model(x) - y_true)
loss.backward()

for p in model.parameters():
  print(param_summary(p))

model_p = ParameterModel()
z = torch.tensor(1.0) # Some contextual input to base parameters on

# This raises an error!
model._module_phenology._p.copy_(model_p(z))

print()
for p in model.parameters():
  print(param_summary(p))

So here is a model M with parameters P that uses some contextual variables z to "predict" a parameter that should be used by one of the diffwofost components during the forward pass. However, gradients should still flow w.r.t. P. The code above will attempt to do this but will raise an error. I think we can circumvent this but I'm not sure what the best approach is.

Curious to hear your thoughts!

Metadata

Metadata

Assignees

No one assigned

    Labels

    technicalThis needs attention!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions