diff --git a/test/test_ops.py b/test/test_ops.py index 88124f7ba17..dd184ddde2d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -228,12 +228,12 @@ def func(z): ): gradcheck(func, (x,)) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) + def test_autocast(self, device, x_dtype, rois_dtype): + with torch.amp.autocast(device): + self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) def _helper_boxes_shape(self, func): # test boxes as Tensor[N, 5] @@ -490,32 +490,18 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois aligned=aligned, ) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) - @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) - @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half, torch.bfloat16)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half, torch.bfloat16)) @pytest.mark.opcheck_only_one() - def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): - self.test_forward( - torch.device("cuda"), - contiguous=False, - deterministic=deterministic, - aligned=aligned, - x_dtype=x_dtype, - rois_dtype=rois_dtype, - ) - - @pytest.mark.skip(reason="1/5000 flaky failure") - @pytest.mark.parametrize("aligned", (True, False)) - @pytest.mark.parametrize("deterministic", (True, False)) - @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16)) - @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16)) - def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype): - with torch.cpu.amp.autocast(): + def test_autocast(self, device, aligned, deterministic, x_dtype, rois_dtype): + if device == "cpu" and x_dtype is torch.bfloat16: + pytest.skip("1/5000 flaky failure") + with torch.amp.autocast(device): self.test_forward( - torch.device("cpu"), + torch.device(device), contiguous=False, deterministic=deterministic, aligned=aligned, @@ -856,14 +842,14 @@ def test_nms_gpu(self, iou, device, dtype=torch.float64): @pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.opcheck_only_one() def test_autocast(self, iou, dtype): - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16)) def test_autocast_cpu(self, iou, dtype): boxes, scores = self._create_tensors_with_iou(1000, iou) - with torch.cpu.amp.autocast(): + with torch.amp.autocast("cpu"): keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou) keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou) torch.testing.assert_close(keep_ref_float, keep_dtype) @@ -1188,13 +1174,13 @@ def test_compare_cpu_cuda_grads(self, contiguous): res_grads = init_weight.grad.to("cpu") torch.testing.assert_close(true_cpu_grads, res_grads) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.opcheck_only_one() - def test_autocast(self, batch_sz, dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) + def test_autocast(self, device, batch_sz, dtype): + with torch.amp.autocast(device): + self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype) def test_forward_scriptability(self): # Non-regression test for https://github.com/pytorch/vision/issues/4078 diff --git a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp index 4f082fa0006..fb7b953cd2d 100644 --- a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp @@ -9,6 +9,7 @@ namespace ops { namespace { +template at::Tensor deform_conv2d_autocast( const at::Tensor& input, const at::Tensor& weight, @@ -24,13 +25,13 @@ at::Tensor deform_conv2d_autocast( int64_t groups, int64_t offset_groups, bool use_mask) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); return deform_conv2d( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, weight), - at::autocast::cached_cast(at::kFloat, offset), - at::autocast::cached_cast(at::kFloat, mask), - at::autocast::cached_cast(at::kFloat, bias), + at::autocast::cached_cast(at::kFloat, input, device_type), + at::autocast::cached_cast(at::kFloat, weight, device_type), + at::autocast::cached_cast(at::kFloat, offset, device_type), + at::autocast::cached_cast(at::kFloat, mask, device_type), + at::autocast::cached_cast(at::kFloat, bias, device_type), stride_h, stride_w, pad_h, @@ -48,7 +49,25 @@ at::Tensor deform_conv2d_autocast( TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_autocast)); + TORCH_FN((deform_conv2d_autocast< + c10::DispatchKey::Autocast, + c10::DeviceType::CUDA>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN((deform_conv2d_autocast< + c10::DispatchKey::AutocastCPU, + c10::DeviceType::CPU>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN((deform_conv2d_autocast< + c10::DispatchKey::AutocastXPU, + c10::DeviceType::XPU>))); } } // namespace ops diff --git a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp index bce987b0f71..d6cd0c471d1 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp @@ -9,6 +9,7 @@ namespace ops { namespace { +template std::tuple ps_roi_align_autocast( const at::Tensor& input, const at::Tensor& rois, @@ -16,10 +17,10 @@ std::tuple ps_roi_align_autocast( int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); auto result = ps_roi_align( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), + at::autocast::cached_cast(at::kFloat, input, device_type), + at::autocast::cached_cast(at::kFloat, rois, device_type), spatial_scale, pooled_height, pooled_width, @@ -35,7 +36,25 @@ std::tuple ps_roi_align_autocast( TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_autocast)); + TORCH_FN((ps_roi_align_autocast< + c10::DispatchKey::Autocast, + c10::DeviceType::CUDA>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN((ps_roi_align_autocast< + c10::DispatchKey::AutocastCPU, + c10::DeviceType::CPU>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN((ps_roi_align_autocast< + c10::DispatchKey::AutocastXPU, + c10::DeviceType::XPU>))); } } // namespace ops diff --git a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp index 3cf1e7f80d7..a623c42312b 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp @@ -9,16 +9,17 @@ namespace ops { namespace { +template std::tuple ps_roi_pool_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); auto result = ps_roi_pool( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), + at::autocast::cached_cast(at::kFloat, input, device_type), + at::autocast::cached_cast(at::kFloat, rois, device_type), spatial_scale, pooled_height, pooled_width); @@ -33,7 +34,25 @@ std::tuple ps_roi_pool_autocast( TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_autocast)); + TORCH_FN((ps_roi_pool_autocast< + c10::DispatchKey::Autocast, + c10::DeviceType::CUDA>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN((ps_roi_pool_autocast< + c10::DispatchKey::AutocastCPU, + c10::DeviceType::CPU>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN((ps_roi_pool_autocast< + c10::DispatchKey::AutocastXPU, + c10::DeviceType::XPU>))); } } // namespace ops diff --git a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp index 3aaa038a9b4..936ce1dc5f5 100644 --- a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp @@ -9,16 +9,17 @@ namespace ops { namespace { +template std::tuple roi_pool_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); auto result = roi_pool( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), + at::autocast::cached_cast(at::kFloat, input, device_type), + at::autocast::cached_cast(at::kFloat, rois, device_type), spatial_scale, pooled_height, pooled_width); @@ -33,7 +34,25 @@ std::tuple roi_pool_autocast( TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_autocast)); + TORCH_FN((roi_pool_autocast< + c10::DispatchKey::Autocast, + c10::DeviceType::CUDA>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN((roi_pool_autocast< + c10::DispatchKey::AutocastCPU, + c10::DeviceType::CPU>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN((roi_pool_autocast< + c10::DispatchKey::AutocastXPU, + c10::DeviceType::XPU>))); } } // namespace ops