From de9f9e074ab01c1b0f11c339aa0193826bef956f Mon Sep 17 00:00:00 2001 From: abeeha123 Date: Thu, 30 Oct 2025 15:12:36 +0500 Subject: [PATCH 1/4] Implement a runtime fix for the Hugging Face export that addresses the KeyError (position_ids, token_type_ids) in the Relax frontend. --- position_id_fix.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 position_id_fix.py diff --git a/position_id_fix.py b/position_id_fix.py new file mode 100644 index 000000000000..b0744568da7f --- /dev/null +++ b/position_id_fix.py @@ -0,0 +1,78 @@ +# sol-script-fixed.py +import torch +import torch.nn as nn +from transformers import AutoModel +from torch.export import export as torch_export +from tvm.relax.frontend.torch import from_exported_program + +class StateDictWrapper(dict): + """Wrap exported state_dict and inject extra keys (non-persistent buffers).""" + def __init__(self, base_dict, extra): + super().__init__(base_dict) + self.extra = extra + + def __getitem__(self, key): + if key in self.extra: + return self.extra[key] + return super().__getitem__(key) + + def get(self, key, default=None): + if key in self.extra: + return self.extra[key] + return super().get(key, default) + +class M(nn.Module): + def __init__(self): + super().__init__() + self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased") + self.cls = nn.Linear(self.bert.config.hidden_size, 2) + + def forward(self, x, mask=None): + out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :] + return self.cls(out) + +def main(): + torch.manual_seed(0) + m = M().eval() + + x = torch.randint(0, 30522, (2, 16)) + mask = torch.ones_like(x) + + ep = torch_export(m, (x, mask)) + print("\n torch.export completed successfully\n") + + # --- Build extra buffers dict --- + extra = {} + for buf_name in m.bert.embeddings._non_persistent_buffers_set: + tensor = m.bert.embeddings._buffers.get(buf_name) + if tensor is not None: + extra[f"bert.embeddings.{buf_name}"] = tensor + print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape {tensor.shape}") + + # Wrap exported state_dict + sd_wrapped = StateDictWrapper(ep.state_dict, extra) + + # EP wrapper to override state_dict access + class EPWrapper: + def __init__(self, ep, sd_wrapped): + self.__dict__["_ep"] = ep + self.__dict__["_sd"] = sd_wrapped + + def __getattr__(self, name): + if name == "state_dict": + return self._sd + return getattr(self._ep, name) + + ep_wrapped = EPWrapper(ep, sd_wrapped) + + # Import to TVM + try: + mod = from_exported_program(ep_wrapped) + print("\n TVM import succeeded — all non-persistent buffers injected!\n") + except Exception as e: + print("\n TVM import failed with exception:") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() From 0f3d0436c589e72afd8f4b60c566d87a40c5106d Mon Sep 17 00:00:00 2001 From: abeeha123 Date: Mon, 3 Nov 2025 19:08:48 +0500 Subject: [PATCH 2/4] Refactor: applied Gemini suggestions and improved readability - Used collections.ChainMap for buffer injection - Added property-based EPWrapper - Removed hardcoded vocab size - Moved imports to comply with PEP 8 --- position_id_patch.py | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 position_id_patch.py diff --git a/position_id_patch.py b/position_id_patch.py new file mode 100644 index 000000000000..b0744568da7f --- /dev/null +++ b/position_id_patch.py @@ -0,0 +1,78 @@ +# sol-script-fixed.py +import torch +import torch.nn as nn +from transformers import AutoModel +from torch.export import export as torch_export +from tvm.relax.frontend.torch import from_exported_program + +class StateDictWrapper(dict): + """Wrap exported state_dict and inject extra keys (non-persistent buffers).""" + def __init__(self, base_dict, extra): + super().__init__(base_dict) + self.extra = extra + + def __getitem__(self, key): + if key in self.extra: + return self.extra[key] + return super().__getitem__(key) + + def get(self, key, default=None): + if key in self.extra: + return self.extra[key] + return super().get(key, default) + +class M(nn.Module): + def __init__(self): + super().__init__() + self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased") + self.cls = nn.Linear(self.bert.config.hidden_size, 2) + + def forward(self, x, mask=None): + out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :] + return self.cls(out) + +def main(): + torch.manual_seed(0) + m = M().eval() + + x = torch.randint(0, 30522, (2, 16)) + mask = torch.ones_like(x) + + ep = torch_export(m, (x, mask)) + print("\n torch.export completed successfully\n") + + # --- Build extra buffers dict --- + extra = {} + for buf_name in m.bert.embeddings._non_persistent_buffers_set: + tensor = m.bert.embeddings._buffers.get(buf_name) + if tensor is not None: + extra[f"bert.embeddings.{buf_name}"] = tensor + print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape {tensor.shape}") + + # Wrap exported state_dict + sd_wrapped = StateDictWrapper(ep.state_dict, extra) + + # EP wrapper to override state_dict access + class EPWrapper: + def __init__(self, ep, sd_wrapped): + self.__dict__["_ep"] = ep + self.__dict__["_sd"] = sd_wrapped + + def __getattr__(self, name): + if name == "state_dict": + return self._sd + return getattr(self._ep, name) + + ep_wrapped = EPWrapper(ep, sd_wrapped) + + # Import to TVM + try: + mod = from_exported_program(ep_wrapped) + print("\n TVM import succeeded — all non-persistent buffers injected!\n") + except Exception as e: + print("\n TVM import failed with exception:") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() From e114e9caff48f6272812badb133eb765614eb811 Mon Sep 17 00:00:00 2001 From: abeeha123 Date: Tue, 4 Nov 2025 11:36:12 +0500 Subject: [PATCH 3/4] Format position_id_patch.py to pass lint --- position_id_patch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/position_id_patch.py b/position_id_patch.py index b0744568da7f..980f80a6681f 100644 --- a/position_id_patch.py +++ b/position_id_patch.py @@ -5,8 +5,10 @@ from torch.export import export as torch_export from tvm.relax.frontend.torch import from_exported_program + class StateDictWrapper(dict): """Wrap exported state_dict and inject extra keys (non-persistent buffers).""" + def __init__(self, base_dict, extra): super().__init__(base_dict) self.extra = extra @@ -21,6 +23,7 @@ def get(self, key, default=None): return self.extra[key] return super().get(key, default) + class M(nn.Module): def __init__(self): super().__init__() @@ -31,6 +34,7 @@ def forward(self, x, mask=None): out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :] return self.cls(out) + def main(): torch.manual_seed(0) m = M().eval() @@ -72,7 +76,9 @@ def __getattr__(self, name): except Exception as e: print("\n TVM import failed with exception:") import traceback + traceback.print_exc() + if __name__ == "__main__": main() From 943c0a70c9a46a2a3cd792b4a54a5230629f17e1 Mon Sep 17 00:00:00 2001 From: abeeha123 Date: Tue, 4 Nov 2025 15:24:46 +0500 Subject: [PATCH 4/4] Format position_id files to pass CI Black 22.12.0 --- .../relax/frontend/torch/position_id_fix.py | 27 +++++++++++++++++-- .../relax/frontend/torch/position_id_patch.py | 21 +++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) rename position_id_fix.py => python/tvm/relax/frontend/torch/position_id_fix.py (75%) rename position_id_patch.py => python/tvm/relax/frontend/torch/position_id_patch.py (75%) diff --git a/position_id_fix.py b/python/tvm/relax/frontend/torch/position_id_fix.py similarity index 75% rename from position_id_fix.py rename to python/tvm/relax/frontend/torch/position_id_fix.py index b0744568da7f..d9ee93ab72cc 100644 --- a/position_id_fix.py +++ b/python/tvm/relax/frontend/torch/position_id_fix.py @@ -1,12 +1,31 @@ -# sol-script-fixed.py +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import torch import torch.nn as nn -from transformers import AutoModel from torch.export import export as torch_export +from transformers import AutoModel + from tvm.relax.frontend.torch import from_exported_program + class StateDictWrapper(dict): """Wrap exported state_dict and inject extra keys (non-persistent buffers).""" + def __init__(self, base_dict, extra): super().__init__(base_dict) self.extra = extra @@ -21,6 +40,7 @@ def get(self, key, default=None): return self.extra[key] return super().get(key, default) + class M(nn.Module): def __init__(self): super().__init__() @@ -31,6 +51,7 @@ def forward(self, x, mask=None): out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :] return self.cls(out) + def main(): torch.manual_seed(0) m = M().eval() @@ -72,7 +93,9 @@ def __getattr__(self, name): except Exception as e: print("\n TVM import failed with exception:") import traceback + traceback.print_exc() + if __name__ == "__main__": main() diff --git a/position_id_patch.py b/python/tvm/relax/frontend/torch/position_id_patch.py similarity index 75% rename from position_id_patch.py rename to python/tvm/relax/frontend/torch/position_id_patch.py index 980f80a6681f..d9ee93ab72cc 100644 --- a/position_id_patch.py +++ b/python/tvm/relax/frontend/torch/position_id_patch.py @@ -1,8 +1,25 @@ -# sol-script-fixed.py +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import torch import torch.nn as nn -from transformers import AutoModel from torch.export import export as torch_export +from transformers import AutoModel + from tvm.relax.frontend.torch import from_exported_program