Skip to content

Commit b2f7190

Browse files
authored
torch.log(): cast int arguments to float32 (#2017)
Testing: https://gitlab.com/coremltools1/coremltools/-/pipelines/1043280060
1 parent 4dc2ba8 commit b2f7190

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5124,7 +5124,10 @@ def reciprocal(context, node):
51245124
@register_torch_op
51255125
def log(context, node):
51265126
inputs = _get_inputs(context, node, expected=1)
5127-
context.add(mb.log(x=inputs[0], name=node.name))
5127+
x = inputs[0]
5128+
if types.is_int(x.dtype):
5129+
x = mb.cast(x=x, dtype="fp32")
5130+
context.add(mb.log(x=x, name=node.name))
51285131

51295132

51305133
@register_torch_op(torch_alias=["round"])

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5546,6 +5546,33 @@ def test_elementwise_numerically_stable(
55465546
rand_range=(20, 100),
55475547
)
55485548

5549+
@pytest.mark.parametrize(
5550+
"compute_unit, backend, dtype",
5551+
itertools.product(
5552+
compute_units,
5553+
backends,
5554+
[np.int32, np.float32],
5555+
),
5556+
)
5557+
def test_log_dtype(
5558+
self, compute_unit, backend, dtype
5559+
):
5560+
SHAPE = (2, 3)
5561+
5562+
input_data = np.random.randint(1, 100, SHAPE).astype(dtype)
5563+
input_data = torch.from_numpy(input_data)
5564+
model = ModuleWrapper(torch.log)
5565+
converter_input_type = [TensorType(shape=SHAPE, dtype=dtype)]
5566+
5567+
self.run_compare_torch(
5568+
input_data,
5569+
model,
5570+
backend=backend,
5571+
compute_unit=compute_unit,
5572+
input_as_shape=False,
5573+
converter_input_type=converter_input_type
5574+
)
5575+
55495576

55505577
class TestAtan2(TorchBaseTest):
55515578
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)