Skip to content

Commit 600f0d8

Browse files
committed
added .detach() to connst value, added new unit test
1 parent 51caf17 commit 600f0d8

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,32 @@ def forward(
926926
past_kv_len += 1
927927

928928

929+
@staticmethod
930+
def test_immediate_return_getattr_model():
931+
class ImmediateReturnGetAttrModel(torch.nn.Module):
932+
def __init__(self):
933+
super().__init__()
934+
self.register_buffer("my_constant_output", torch.tensor([1.0, 2.0, 3.0, 4.0]))
935+
self.register_buffer("my_constant_output2", torch.tensor([5.0, 6.0, 7.0, 8.0]))
936+
937+
def forward(self, x):
938+
# x is a dummy input, not used
939+
return self.my_constant_output, self.my_constant_output2
940+
941+
model = ImmediateReturnGetAttrModel()
942+
model.eval()
943+
dummy_input = torch.zeros(1) # Dummy input for tracing
944+
traced_model = torch.jit.trace(model, example_inputs=(dummy_input,))
945+
mlmodel = ct.convert(
946+
traced_model,
947+
inputs=[ct.TensorType(shape=(1,))],
948+
convert_to='mlprogram'
949+
)
950+
outputs = mlmodel.predict({"x": np.zeros(1)})
951+
assert "my_constant_output" in outputs
952+
assert "my_constant_output2" in outputs
953+
954+
929955
###############################################################################
930956
# Note: Stress tests for PyTorch input / output types
931957
###############################################################################

coremltools/converters/mil/frontend/torch/torchir_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
251251
outputs=node.outputs,
252252
kind="constant",
253253
name="internal_immediate_output_attr",
254-
attr={"value": node.parent.params[node.name]}
254+
attr={"value": node.parent.params[node.name].detach()}
255255
)
256256
)
257257
else:

0 commit comments

Comments
 (0)