Skip to content

Commit 1f85977

Browse files
committed
Minor fixes
1 parent 7ab4c5f commit 1f85977

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

onnxscript/rewriter/ort_fusions/fuse_xformers_test.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44

55
import unittest
66

7-
from parameterized import parameterized
8-
97
import onnxscript.optimizer
8+
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
109
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
1110
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
12-
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
1311

14-
class TestTransformerFusion(unittest.TestCase):
1512

16-
def test_transformer_fusion(self):
13+
class TestFuseXformers(unittest.TestCase):
14+
def test_fuse_xformers(self):
1715
test = smollm_test_1()
1816
model = test.get_onnx_model()
1917
onnxscript.optimizer.optimize(model)

onnxscript/rewriter/pattern.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,7 @@ def copy_value(value: ir.Value | None) -> ir.Value | None:
15661566
return None
15671567
if value not in value_map:
15681568
const_value = value.const_value
1569-
if isinstance(const_value, ir.Tensor):
1569+
if isinstance(const_value, (ir.Tensor, ir.TensorProtoTensor)):
15701570
# create a Constant node to represent the value
15711571
value_attr = ir.AttrTensor("value", const_value)
15721572
const_node = ir.Node("", "Constant", [], [value_attr])

0 commit comments

Comments
 (0)