Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/metatrain/deprecated/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# Move dataset info to CPU so that it can be saved
self.dataset_info = self.dataset_info.to(device="cpu")

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This funciton moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

interaction_ranges = [self.hypers["num_gnn_layers"] * self.hypers["cutoff"]]
for additive_model in self.additive_models:
if hasattr(additive_model, "cutoff_radius"):
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/deprecated/nanopet/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from metatensor.torch.learn.nn import Module


class AttentionBlock(torch.nn.Module):
class AttentionBlock(Module):
"""
A single transformer attention block. We are not using the
MultiHeadAttention module from torch.nn because we need to apply a
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/deprecated/nanopet/modules/encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict

import torch
from metatensor.torch.learn.nn import Module


class Encoder(torch.nn.Module):
class Encoder(Module):
"""
An encoder of edges. It generates a fixed-size representation of the
interatomic vector, the chemical element of the center and the chemical
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/deprecated/nanopet/modules/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from metatensor.torch.learn.nn import Module


class FeedForwardBlock(torch.nn.Module):
class FeedForwardBlock(Module):
"""A single transformer feed forward block."""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions src/metatrain/deprecated/nanopet/modules/transformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from metatensor.torch.learn.nn import Module

from .attention import AttentionBlock
from .feedforward import FeedForwardBlock


class TransformerLayer(torch.nn.Module):
class TransformerLayer(Module):
"""A single transformer layer."""

def __init__(
Expand Down Expand Up @@ -40,7 +41,7 @@ def forward(
return output


class Transformer(torch.nn.Module):
class Transformer(Module):
"""A transformer model."""

def __init__(
Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/deprecated/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,10 @@ def train(

# Extract additive models and scaler and move them to CPU/float64 so they
# can be used in the collate function
model.additive_models[0].weights_to(device="cpu", dtype=torch.float64)
additive_models = copy.deepcopy(
model.additive_models.to(dtype=torch.float64, device="cpu")
)
model.additive_models.to(device)
model.additive_models[0].weights_to(device=device, dtype=torch.float64)
model.scaler.scales_to(device="cpu", dtype=torch.float64)
scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu"))
model.scaler.to(device)
Expand Down
4 changes: 0 additions & 4 deletions src/metatrain/experimental/flashmd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,10 +1175,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# float64
self.to(dtype)

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This function moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

interaction_ranges = [self.num_gnn_layers * self.cutoff]
for additive_model in self.additive_models:
if hasattr(additive_model, "cutoff_radius"):
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/experimental/flashmd/modules/additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.learn.nn import Module
from metatomic.torch import ModelOutput, NeighborListOptions, System
from pydantic import TypeAdapter
from typing_extensions import TypedDict
Expand All @@ -14,7 +15,7 @@ class PositionAdditiveHypers(TypedDict):
also_momenta: bool


class PositionAdditive(torch.nn.Module):
class PositionAdditive(Module):
"""
A simple additive model that adds the positions of the system to any outputs that
is either "positions" or one of its variants.
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/experimental/flashmd/modules/encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from metatensor.torch.learn.nn import Module


class NodeEncoder(torch.nn.Module):
class NodeEncoder(Module):
"""
An encoder of edges. It generates a fixed-size representation of the
interatomic vector, the chemical element of the center and the chemical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def test_torchscript_save_load(tmpdir):
)
model = FlashMD(MODEL_HYPERS, dataset_info)
model.to(torch.float64)
model.additive_models[0].weights_to(device="cpu", dtype=torch.float64)
model.scaler.scales_to(device="cpu", dtype=torch.float64)

with tmpdir.as_cwd():
Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/experimental/flashmd/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ def train(

# Extract additive models and scaler and move them to CPU/float64 so they
# can be used in the collate function
model.additive_models[0].weights_to(device="cpu", dtype=torch.float64)
additive_models = copy.deepcopy(
model.additive_models.to(dtype=torch.float64, device="cpu")
)
model.additive_models.to(device)
model.additive_models[0].weights_to(device=device, dtype=torch.float64)
model.scaler.scales_to(device="cpu", dtype=torch.float64)
scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu"))
model.scaler.to(device)
Expand Down
14 changes: 4 additions & 10 deletions src/metatrain/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.learn.nn import Module
from metatomic.torch import (
AtomisticModel,
ModelCapabilities,
Expand Down Expand Up @@ -284,10 +285,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
interaction_ranges.append(additive_model.cutoff_radius)
interaction_range = max(interaction_ranges)

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This funciton moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=sorted(self.dataset_info.atomic_types),
Expand Down Expand Up @@ -399,7 +396,7 @@ def predict(
return KTM @ self._weights


class AggregateKernel(torch.nn.Module):
class AggregateKernel(Module):
"""
A kernel that aggregates values in a kernel over :param aggregate_names: using
the sum as aggregate function
Expand Down Expand Up @@ -457,7 +454,7 @@ def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap:
return mts.pow(mts.dot(tensor1, tensor2), self._degree)


class TorchAggregateKernel(torch.nn.Module):
class TorchAggregateKernel(Module):
"""
A kernel that aggregates values in a kernel over :param aggregate_names: using
the sum as aggregate function
Expand Down Expand Up @@ -796,7 +793,7 @@ def export_torch_script_model(self) -> "TorchSubsetofRegressors":
)


class TorchSubsetofRegressors(torch.nn.Module):
class TorchSubsetofRegressors(Module):
def __init__(
self,
weights: TensorMap,
Expand All @@ -821,9 +818,6 @@ def forward(self, T: TensorMap) -> TensorMap:
:return:
TensorMap with the predictions
"""
# move weights and X_pseudo to the same device as T
self._weights = self._weights.to(T.device)
self._X_pseudo = self._X_pseudo.to(T.device)

k_tm = self._kernel(T, self._X_pseudo, are_pseudo_points=(False, True))
return mts.dot(k_tm, self._weights)
10 changes: 0 additions & 10 deletions src/metatrain/llpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,16 +658,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# float64
self.to(dtype)

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This function moves them:
try:
self.model.additive_models[0]._move_weights_to_device_and_dtype(
torch.device("cpu"), torch.float64
)
except Exception:
# no weights to move
pass

metadata = merge_metadata(
merge_metadata(self.__default_metadata__, metadata),
self.model.export().metadata(),
Expand Down
4 changes: 0 additions & 4 deletions src/metatrain/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,10 +1159,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# float64
self.to(dtype)

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This function moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

interaction_ranges = [self.num_gnn_layers * self.cutoff]
for additive_model in self.additive_models:
if hasattr(additive_model, "cutoff_radius"):
Expand Down
7 changes: 4 additions & 3 deletions src/metatrain/pet/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn.functional as F
from metatensor.torch.learn.nn import Module
from torch import nn

from .utilities import DummyModule
Expand Down Expand Up @@ -105,7 +106,7 @@ def forward(
return x


class TransformerLayer(torch.nn.Module):
class TransformerLayer(Module):
"""
Single layer of a Transformer.

Expand Down Expand Up @@ -247,7 +248,7 @@ def forward(
return node_embeddings, edge_embeddings


class Transformer(torch.nn.Module):
class Transformer(Module):
"""
Transformer implementation.

Expand Down Expand Up @@ -338,7 +339,7 @@ def forward(
return node_embeddings, edge_embeddings


class CartesianTransformer(torch.nn.Module):
class CartesianTransformer(Module):
"""
Cartesian Transformer implementation for handling 3D coordinates.

Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/pet/modules/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from metatensor.torch.learn.nn import Module


def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor:
Expand All @@ -20,7 +21,7 @@ def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float) -> torch.Tensor:
return f


class DummyModule(torch.nn.Module):
class DummyModule(Module):
"""Dummy torch module to make torchscript happy.
This model should never be run"""

Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,10 @@ def train(

# Extract additive models and scaler and move them to CPU/float64 so they
# can be used in the collate function
model.additive_models[0].weights_to(device="cpu", dtype=torch.float64)
additive_models = copy.deepcopy(
model.additive_models.to(dtype=torch.float64, device="cpu")
)
model.additive_models.to(device)
model.additive_models[0].weights_to(device=device, dtype=torch.float64)
model.scaler.scales_to(device="cpu", dtype=torch.float64)
scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu"))
model.scaler.to(device)
Expand Down
10 changes: 3 additions & 7 deletions src/metatrain/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.learn.nn import Linear as LinearMap
from metatensor.torch.learn.nn import ModuleMap
from metatensor.torch.learn.nn import Module, ModuleMap
from metatensor.torch.operations._add import _add_block_block
from metatomic.torch import (
AtomisticModel,
Expand All @@ -32,7 +32,7 @@
from .spherical import TensorBasis


class Identity(torch.nn.Module):
class Identity(Module):
def __init__(self) -> None:
super().__init__()

Expand All @@ -48,7 +48,7 @@ def __init__(self, atomic_types: List[int], hypers: dict) -> None:
# Build a neural network for each species
nns_per_species = []
for _ in atomic_types:
module_list: List[torch.nn.Module] = []
module_list: List[Module] = []
for _ in range(hypers["num_hidden_layers"]):
if len(module_list) == 0:
module_list.append(
Expand Down Expand Up @@ -720,10 +720,6 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# float64
self.to(dtype)

# Additionally, the composition model contains some `TensorMap`s that cannot
# be registered correctly with Pytorch. This funciton moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

interaction_ranges = [self.hypers["soap"]["cutoff"]["radius"]]
for additive_model in self.additive_models:
if hasattr(additive_model, "cutoff_radius"):
Expand Down
18 changes: 8 additions & 10 deletions src/metatrain/soap_bpnn/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import wigners
from metatensor.torch import Labels, TensorMap
from metatensor.torch.learn.nn import Linear as LinearMap
from metatensor.torch.learn.nn import Module
from spex.metatensor import SphericalExpansion

from .documentation import SOAPConfig


class VectorBasis(torch.nn.Module):
class VectorBasis(Module):
"""
This module creates a basis of 3 vectors for each atomic environment.

Expand Down Expand Up @@ -103,9 +104,6 @@ def forward(
"""
device = interatomic_vectors.device

if self.neighbor_species_labels.device != device:
self.neighbor_species_labels = self.neighbor_species_labels.to(device)

spherical_expansion = self.soap_calculator(
interatomic_vectors,
centers,
Expand Down Expand Up @@ -141,7 +139,7 @@ def forward(
return basis_vectors_as_tensor # [n_atoms, 3(yzx), 3]


class TensorBasis(torch.nn.Module):
class TensorBasis(Module):
"""
Creates a basis of spherical tensors for each atomic environment. Internally, it
uses one (for proper tensors) or two (for pseudotensors) VectorBasis objects to
Expand Down Expand Up @@ -286,8 +284,8 @@ def forward(
device = interatomic_vectors.device
dtype = interatomic_vectors.dtype
for k, v in self.cgs.items():
if v.device != device or v.dtype != dtype:
self.cgs[k] = v.to(device, dtype)
if v.dtype != dtype:
self.cgs[k] = v.to(dtype=dtype)

if selected_atoms is None:
num_atoms = len(atom_index_in_structure)
Expand Down Expand Up @@ -589,7 +587,7 @@ def _complex_clebsch_gordan_matrix(l1: int, l2: int, L: int) -> np.ndarray:
return wigners.clebsch_gordan_array(l1, l2, L)


class FakeVectorBasis(torch.nn.Module):
class FakeVectorBasis(Module):
# fake class to make torchscript work

def forward(
Expand All @@ -605,7 +603,7 @@ def forward(
return torch.tensor(0)


class FakeSphericalExpansion(torch.nn.Module):
class FakeSphericalExpansion(Module):
# Dummy class to make torchscript work
def forward(
self,
Expand All @@ -621,7 +619,7 @@ def forward(
)


class FakeLinearMap(torch.nn.Module):
class FakeLinearMap(Module):
# fake class to make torchscript work

def forward(
Expand Down
Loading
Loading