Skip to content

Conversation

@omegacoleman
Copy link

In recent versions of PyTorch, tensors initiated via self.tensor in the init function are assumed to be constants, which do not change during the runtime of the inference. The role for cache_k and cache_v should be non-persistent buffers, and be registered with self.register_buffer.

The above creates problems for, e.g., the ONNX exporter. I encountered this error to exporting layers[0] of llama3:

[rank0]: <class 'RuntimeError'>: Constant attention.cache_k is mutated in the forward method. Pls register it as buffer

This PR fixes the problem. It's also worth noticing that the llama3-multimodal already uses this approach.


The script I'm using

from models.llama3.generation import Llama3
import torch
import torch.onnx
import torch.nn.functional as F
import os
from models.datatypes import RawMessage

os.environ["MASTER_ADDR"] = 'localhost'
os.environ["MASTER_PORT"] = '12366'
os.environ["WORLD_SIZE"] = '1'
os.environ["RANK"] = '0'
torch.set_grad_enabled(False)

generator = Llama3.build(
    ckpt_dir=f"{os.environ['HOME']}/.llama/checkpoints/LLama-3.2-1B-Instruct",
    max_seq_len=512,
    max_batch_size=4
)

model = generator.model

example_prompt = RawMessage(role="user", content="Please tell me a joke.")
example_input = generator.formatter.encode_dialog_prompt([example_prompt])

tokens = torch.full((4, 512), generator.tokenizer.pad_id, dtype=torch.long)
tokens[0, : len(example_input.tokens)] = torch.tensor(example_input.tokens, dtype=torch.long)                                                                 
embeddings = model.tok_embeddings(tokens)

freqs_cis = model.freqs_cis[0:512]

torch.onnx.export(
    model.layers[0],
    (embeddings, 0, freqs_cis, None),
    "layer0.onnx",
    export_params=False,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamo=True
)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants