Skip to content

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Dec 24, 2025

TransformerEngine requires that we pass device="meta" (and importantly not device=torch.device("meta")) to layer constructors to initialize parameters on the meta device.

This makes sure we pass the right device to the layer constructor and adds tests to ensure the parameters are actually being placed on the right devices.

For TransformerEngine layers, we want to be moving parameters from the meta device to cuda with

for module in model.modules():
    if hasattr(module, "reset_parameters"):
        module.reset_parameters()

while for HF layers (e.g., nn.Embedding), we want to be doing

model.to_empty("cuda")
model.apply(model._init_weights)

to ensure that we pick up the config.initializer_range initialization correctly.

The issue is that we can't do to_empty("cuda") or _init_weights on TE layers, nor can we do reset_parameters() on the HF layers without the preceeding to_empty, and this doesn't use the HF config when creating initial values.

Closes BIO-1

@pstjohn pstjohn force-pushed the pstjohn/meta-device-fix branch 5 times, most recently from 8ca75dc to 919f984 Compare January 5, 2026 16:29
pstjohn added 16 commits January 5, 2026 21:49
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
…asses to use a consistent initialization function. Update tests to ensure weight initialization matches expected values with adjusted tolerances for standard deviation.

Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
@pstjohn pstjohn force-pushed the pstjohn/meta-device-fix branch from adc9525 to 48f573d Compare January 6, 2026 19:16
Signed-off-by: Peter St. John <[email protected]>
@pstjohn pstjohn added this pull request to the merge queue Jan 6, 2026
Merged via the queue into NVIDIA:main with commit aeed0de Jan 6, 2026
19 checks passed
@pstjohn pstjohn deleted the pstjohn/meta-device-fix branch January 6, 2026 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants