Skip to content

Conversation

@GardevoirX
Copy link
Contributor

@GardevoirX GardevoirX commented Nov 21, 2025

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Maintainer/Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?
  • GPU tests passed (maintainer comment: "cscs-ci run")?

📚 Documentation preview 📚: https://metatrain--938.org.readthedocs.build/en/938/

@pfebrer
Copy link
Contributor

pfebrer commented Nov 21, 2025

Why is this needed? This might make it look more daunting for contributors of new architectures, since it seems you can't no longer use raw torch, you have to understand what the metatensor torch is 😅

@GardevoirX
Copy link
Contributor Author

Why is this needed? This might make it look more daunting for contributors of new architectures, since it seems you can't no longer use raw torch, you have to understand what the metatensor torch is 😅

It's said that it can allow us moving from device and dtypes more easily, because the torch's module doesn't support moving things like tensormap when you call .to

@pfebrer
Copy link
Contributor

pfebrer commented Nov 23, 2025

Hmm I see, could we only apply the change to the modules that need it, which as far as I understand are the CompositionModel and a few others?

@pfebrer
Copy link
Contributor

pfebrer commented Nov 23, 2025

Or could this be solved by overwriting the .to method of CompositionModel ?

@ceriottm
Copy link
Contributor

my understanding is that this is helpful for modules that use tensormaps. so there's no point in converting models that are 100% torch, but those that have some metastuff would benefit from being converted.

@GardevoirX
Copy link
Contributor Author

my understanding is that this is helpful for modules that use tensormaps. so there's no point in converting models that are 100% torch, but those that have some metastuff would benefit from being converted.

I think a no-brain-choice of using only one Module would be better? It might be confusing to have two different Modules existing in the repo, and also when implementing new models, one shall think "if any part in my model involves moving TensorMap to a new device or type".

@pfebrer
Copy link
Contributor

pfebrer commented Nov 25, 2025

From the perspective of an outsider I think it is much better to use torch.nn.Module, and it is very unlikely that these people will find the need to have TensorMaps as buffers of their modules. It is also very unlikely that they care what the CompositionModel inherits from.

@GardevoirX
Copy link
Contributor Author

Okay if so I think I don't need to replace everything with metatensor.learn.nn.Module, given that people using TensorMap knows which Module to use

@pfebrer
Copy link
Contributor

pfebrer commented Nov 25, 2025

That would be my opinion yes, but maybe others have a different opinion, let's see🙂

@Luthaf
Copy link
Member

Luthaf commented Nov 26, 2025

So mts.learn.nn.Module is 100% compatible with torch.nn.Module, but as mentioned above improves compatibility with metatensor data, both in the .to function and making sure the data is included in the state_dict.

This should remove a bunch of workaround and xxx_buffers and function calls at the start of forward(); instead making everything work out of the box as one would expect, for the cost of changing what you inherit from.

I personally think that it is best not to have two nn.Module, so we should always use (and enforce with linting) mts.learn.nn.Module everywhere. I don't think most people understand what torch.nn.Module even does, so I think it is fine for us to say "use this magic class instead of this other one".

Or could this be solved by overwriting the .to method of CompositionModel ?

No, because this would only override the method in Python, not in TorchScript.

@pfebrer
Copy link
Contributor

pfebrer commented Nov 26, 2025

This should remove a bunch of workaround and xxx_buffers and function calls at the start of forward(); instead making everything work out of the box as one would expect, for the cost of changing what you inherit from.

The only place where I see this PR removing code is for the CompositionModel and Scaler. Will it allow to remove more code?

I don't think most people understand what torch.nn.Module even does, so I think it is fine for us to say "use this magic class instead of this other one".

This is not so clear to me to be honest. The people that will contribute to metatrain with new architectures for sure know what torch.nn.Module does

No, because this would only override the method in Python, not in TorchScript.

Why?

@Luthaf
Copy link
Member

Luthaf commented Nov 26, 2025

The only place where I see this PR removing code is for the CompositionModel and Scaler. Will it allow to remove more code?

Any place using TensorMap/TensorBlock/Labels inside the model should become simpler. Everything else will stay the same.

I don't think most people understand what torch.nn.Module even does, so I think it is fine for us to say "use this magic class instead of this other one".

This is not so clear to me to be honest. The people that will contribute to metatrain with new architectures for sure know what torch.nn.Module does

I mean most people don't know how to is implement, or how the state_dict is generated, or why self.a = some_torch_tensor is magic inside a torch.nn.Module (which is fine!). They just use it because it is what the documentation tells them to.

No, because this would only override the method in Python, not in TorchScript.

Why?

Because that's how TorchScript works? You can not override any of the default methods of torch.nn.Module. For mts.learn.nn.Module, implementing this required both overriding it on the Python side, overriding it again for the TorchScript-in-Python execution mode, and overriding it again in C++.

@pfebrer
Copy link
Contributor

pfebrer commented Nov 27, 2025

Any place using TensorMap/TensorBlock/Labels inside the model should become simpler. Everything else will stay the same.

So, some other place apart from CompositionModel and Scaler? haha

Because that's how TorchScript works? You can not override any of the default methods of torch.nn.Module. For mts.learn.nn.Module, implementing this required both overriding it on the Python side, overriding it again for the TorchScript-in-Python execution mode, and overriding it again in C++.

Ok, didn't know this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants