diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 29bf767f9542..b346b6681b03 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -549,7 +549,23 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice); } -// TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy +InferLayoutOutput InferLayoutDynStridedSlice( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + CHECK(tensor_sinfo) << "Invalid Call"; + CHECK(!tensor_sinfo->IsUnknownNdim()) << "Layout inference only supports known dimensionality, " + << "but expression " << call << " has argument " + << call->args[0] << " of unknown dimensionality."; + int ndim = tensor_sinfo->ndim; + // Since begin/end/strides are dynamic tensors, we cannot transform + // them at compile time. Fall back to the initial layout. + LayoutDecision initial = LayoutDecision(InitialLayout(ndim)); + return InferLayoutOutput({initial}, {initial}, Attrs()); +} + TVM_REGISTER_OP("relax.dynamic_strided_slice") .set_num_inputs(4) .add_argument("x", "Tensor", "The source tensor to be sliced.") @@ -557,6 +573,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .add_argument("end", "Tensor", "Indices indicating end of the slice.") .add_argument("strides", "Tensor", "The stride values.") .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) + .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); } // namespace relax diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 42e1cff284b1..5ba0c4d86771 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -5231,5 +5231,57 @@ def main( verify(Input, Expected) +def test_conv2d_dynamic_strided_slice(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + begin: R.Tensor((4,), "int64"), + end: R.Tensor((4,), "int64"), + strides: R.Tensor((4,), "int64"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.dynamic_strided_slice(gv, begin, end, strides) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + begin: R.Tensor((4,), dtype="int64"), + end: R.Tensor((4,), dtype="int64"), + strides: R.Tensor((4,), dtype="int64"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2 = R.dynamic_strided_slice(lv2, begin, end, strides) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 4e90216f9bc0..2a23890d7f62 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -1064,5 +1064,58 @@ def tir_identity( tvm.ir.assert_structural_equal(Expected, After) +def test_dynamic_strided_slice(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + begin: R.Tensor((4,), "int64"), + end: R.Tensor((4,), "int64"), + strides: R.Tensor((4,), "int64"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv = R.dynamic_strided_slice(lv, begin, end, strides) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + begin: R.Tensor((4,), dtype="int64"), + end: R.Tensor((4,), dtype="int64"), + strides: R.Tensor((4,), dtype="int64"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16") + lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16") + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2, dtype="float16") + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, dtype="float32") + gv: R.Tensor(None, dtype="float32", ndim=4) = R.dynamic_strided_slice( + lv4, begin, end, strides + ) + R.output(gv) + return gv + + _assert_test(Input, Expected) + + if __name__ == "__main__": tvm.testing.main()