Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ share/python-wheels/
# jupyter notebook
.ipynb_checkpoints
docs/notebooks/test*
.diffwofost-ml-models/

# Unit test / coverage reports
htmlcov/
Expand Down
4 changes: 4 additions & 0 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ hide:

::: diffwofost.physical_models.utils.EngineTestHelper

::: diffwofost.ml_models.io.save_model

::: diffwofost.ml_models.io.load_model

## **Other classes (for developers)**

::: diffwofost.physical_models.base.states_rates.TensorStatesTemplate
Expand Down
168 changes: 145 additions & 23 deletions docs/notebooks/hybrid_partitioning_wofost72.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
]
dependencies = [
"safetensors",
"torch",
"pcse",
]
Expand Down
4 changes: 4 additions & 0 deletions src/diffwofost/ml_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from diffwofost.ml_models.io import load_model
from diffwofost.ml_models.io import save_model

__all__ = ["save_model", "load_model"]
4 changes: 4 additions & 0 deletions src/diffwofost/ml_models/crop/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, hidden_size=8):
hidden_size (int): Width of the hidden layer. Defaults to 8.
"""
super().__init__()
self.hidden_size = hidden_size
self.init_kwargs = {"hidden_size": hidden_size}

self.network = torch.nn.Sequential(
torch.nn.Linear(1, hidden_size),
Expand Down Expand Up @@ -97,6 +99,8 @@ def __init__(self, hidden_size=32):
hidden_size (int): Width of the shared hidden layers. Defaults to 32.
"""
super().__init__()
self.hidden_size = hidden_size
self.init_kwargs = {"hidden_size": hidden_size}

head_hidden_size = max(hidden_size // 2, 8)
self.trunk = torch.nn.Sequential(
Expand Down
155 changes: 155 additions & 0 deletions src/diffwofost/ml_models/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import hashlib
import importlib
import json
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import load_file
from safetensors.torch import save_file
from diffwofost.physical_models.config import ComputeConfig


def _default_model_filename(model):
"""Build a stable default filename from the model structure.

The filename depends on the model class name and serialized constructor
kwargs, so repeated saves of the same model structure reuse the same path
unless a custom name is provided explicitly.

Args:
model (torch.nn.Module): Model instance to name.

Returns:
str: Default safetensors filename for this model structure.
"""
init_kwargs = json.dumps(dict(getattr(model, "init_kwargs", {})), sort_keys=True)
structure_digest = hashlib.sha256(init_kwargs.encode("utf-8")).hexdigest()[:12]
return f"{model.__class__.__name__.lower()}-{structure_digest}.safetensors"


def _load_model_metadata(path):
"""Load metadata stored in a safetensors file.

Args:
path (str | Path): Path to the safetensors file.

Returns:
dict: Metadata dictionary stored alongside the tensors.

Raises:
ValueError: If the file does not contain metadata.
"""
with safe_open(str(path), framework="pt", device="cpu") as handle:
metadata = handle.metadata()
if metadata is None:
raise ValueError(f"No metadata found in safetensors file: {path}")
return metadata


def _build_safetensors_metadata(model):
"""Build the metadata needed to reconstruct a saved model.

Args:
model (torch.nn.Module): Model instance to describe.

Returns:
dict: Metadata with module, class, and constructor kwargs.
"""
return {
"diffwofost.model_module": model.__class__.__module__,
"diffwofost.model_class": model.__class__.__qualname__,
"diffwofost.init_kwargs": json.dumps(
dict(getattr(model, "init_kwargs", {})),
sort_keys=True,
),
}


def save_model(model, path=None, filename=None, directory=None):
"""Persist a torch model with safetensors and constructor metadata.

If no explicit path is provided, the model is saved under a stable default
location in a hidden repository-local directory. The default filename depends on
the model class name and stored constructor kwargs so repeated saves of the
same model structure reuse the same file.

Args:
model (torch.nn.Module): Model instance to persist.
path (str | Path | None): Full target path. When provided, `filename`
and `directory` must be omitted.
filename (str | None): Optional custom filename used with `directory`.
directory (str | Path | None): Optional custom directory used with
`filename` or the default filename.

Returns:
Path: Path of the saved safetensors file.

Raises:
ValueError: If `path` is combined with `filename` or `directory`.
"""
if path is not None and (filename is not None or directory is not None):
raise ValueError("Pass either path or filename/directory, not both.")

if path is None:
target_directory = (
Path.cwd().resolve() / ".diffwofost-ml-models" if directory is None else Path(directory)
)
target_filename = _default_model_filename(model) if filename is None else filename
path = Path(target_directory) / target_filename

path = Path(path).expanduser().resolve()
path.parent.mkdir(parents=True, exist_ok=True)
tensors = {
name: tensor.detach().cpu().contiguous() for name, tensor in model.state_dict().items()
}
save_file(tensors, str(path), metadata=_build_safetensors_metadata(model))
return path


def load_model(path, model_class=None, device=None, dtype=None):
"""Load a diffWOFOST model from a safetensors file.

The model class is discovered from metadata by default. A caller can also
provide `model_class` explicitly to validate that the stored class matches
the expected one.

Args:
path (str | Path): Path to the saved safetensors file.
model_class (type[torch.nn.Module] | None): Expected model class. When
omitted, the class is resolved from the stored metadata.
device (str | torch.device | None): Target device for the restored
model. Defaults to the active `ComputeConfig` device.
dtype (torch.dtype | None): Target dtype for the restored model.
Defaults to the active `ComputeConfig` dtype.

Returns:
torch.nn.Module: Restored model instance with loaded parameters.

Raises:
ValueError: If the stored class does not match the provided
`model_class`.
"""
path = Path(path).expanduser().resolve()
metadata = _load_model_metadata(path)
stored_module_name = metadata.get("diffwofost.model_module")
stored_class_name = metadata.get("diffwofost.model_class")

if model_class is None:
module = importlib.import_module(stored_module_name)
model_class = getattr(module, stored_class_name)
elif (
stored_module_name != model_class.__module__
or stored_class_name != model_class.__qualname__
):
raise ValueError(
f"Safetensors file {path} stores {stored_module_name}.{stored_class_name}, "
f"not {model_class.__module__}.{model_class.__qualname__}."
)

init_kwargs = json.loads(metadata["diffwofost.init_kwargs"])
model = model_class(**init_kwargs)
state_dict = load_file(str(path), device="cpu")
model.load_state_dict(state_dict)
target_device = ComputeConfig.get_device() if device is None else device
target_dtype = ComputeConfig.get_dtype() if dtype is None else dtype
model.to(device=target_device, dtype=target_dtype)
return model
69 changes: 69 additions & 0 deletions tests/ml_models/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pathlib import Path
import torch
from diffwofost.ml_models import load_model
from diffwofost.ml_models import save_model
from diffwofost.ml_models.crop.partitioning import PartitioningMLP
from diffwofost.ml_models.crop.partitioning import PartitioningNN
from diffwofost.physical_models.config import ComputeConfig


def _fill_parameters(model, scale):
for index, parameter in enumerate(model.parameters(), start=1):
parameter.data.fill_(index / scale)


def _assert_same_partition_outputs(left, right, dvs):
left_pf = left(dvs)
right_pf = right(dvs)

assert torch.allclose(left_pf.FR, right_pf.FR)
assert torch.allclose(left_pf.FL, right_pf.FL)
assert torch.allclose(left_pf.FS, right_pf.FS)
assert torch.allclose(left_pf.FO, right_pf.FO)


def test_partitioning_nn_round_trips_through_safetensors(tmp_path):
model = PartitioningNN(hidden_size=16)
_fill_parameters(model, scale=10.0)

path = save_model(model, tmp_path / "partitioning_nn.safetensors")
restored = load_model(path, model_class=PartitioningNN)
dvs = torch.tensor([0.0, 0.5, 1.0], dtype=ComputeConfig.get_dtype())

assert restored.hidden_size == 16
_assert_same_partition_outputs(model, restored, dvs)


def test_partitioning_mlp_uses_stable_default_save_path():
model = PartitioningMLP(hidden_size=12)
_fill_parameters(model, scale=20.0)

first_path = save_model(model)
second_path = save_model(model)
restored = load_model(first_path)
dvs = torch.tensor([0.2, 1.4], dtype=ComputeConfig.get_dtype())

assert isinstance(restored, PartitioningMLP)
assert restored.hidden_size == 12
assert first_path == second_path
assert first_path.parent == Path(__file__).resolve().parents[2] / ".diffwofost-ml-models"
assert first_path.suffix == ".safetensors"
_assert_same_partition_outputs(model, restored, dvs)


def test_partitioning_mlp_can_override_default_save_name(tmp_path):
model = PartitioningMLP(hidden_size=12)
_fill_parameters(model, scale=20.0)

path = save_model(
model,
directory=tmp_path,
filename="partitioning_mlp_custom.safetensors",
)
restored = load_model(path)
dvs = torch.tensor([0.2, 1.4], dtype=ComputeConfig.get_dtype())

assert isinstance(restored, PartitioningMLP)
assert restored.hidden_size == 12
assert path == tmp_path / "partitioning_mlp_custom.safetensors"
_assert_same_partition_outputs(model, restored, dvs)
Loading