From 61ec568bab2d7eedf033d490a0c9533c725bedfb Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sun, 22 Dec 2024 22:55:12 +0000 Subject: [PATCH 01/23] init minimal zoom backend --- aten/src/ATen/AccumulateType.cpp | 28 +- aten/src/ATen/AccumulateType.h | 31 +- aten/src/ATen/CMakeLists.txt | 49 +- aten/src/ATen/Context.cpp | 16 + aten/src/ATen/Context.h | 14 +- aten/src/ATen/EmptyTensor.cpp | 1 + aten/src/ATen/TensorIndexing.cpp | 3 +- aten/src/ATen/autocast_mode.cpp | 38 + aten/src/ATen/autocast_mode.h | 19 + aten/src/ATen/detail/ZoomHooksInterface.cpp | 48 + aten/src/ATen/detail/ZoomHooksInterface.h | 143 + aten/src/ATen/native/Copy.cpp | 6 +- aten/src/ATen/native/TensorCompare.cpp | 5 +- aten/src/ATen/native/zoom/AmpKernels.cu | 252 ++ aten/src/ATen/native/zoom/CompareEQKernel.cu | 50 + aten/src/ATen/native/zoom/CompareKernels.cu | 103 + aten/src/ATen/native/zoom/Copy.cu | 393 +++ aten/src/ATen/native/zoom/Copy.h | 11 + aten/src/ATen/native/zoom/Equal.cpp | 49 + aten/src/ATen/native/zoom/FillKernel.cu | 30 + aten/src/ATen/native/zoom/MiscUtils.h | 32 + aten/src/ATen/native/zoom/Nonzero.cu | 130 + aten/src/ATen/native/zoom/Resize.cpp | 69 + aten/src/ATen/native/zoom/Resize.h | 61 + aten/src/ATen/native/zoom/TensorCompare.cpp | 23 + aten/src/ATen/native/zoom/TensorCompare.cu | 133 + aten/src/ATen/native/zoom/TensorFactories.cu | 396 +++ aten/src/ATen/native/zoom/TensorShape.cu | 833 +++++ aten/src/ATen/native/zoom/TensorShapeZoom.cpp | 37 + .../ATen/native/zoom/TensorTransformations.cu | 154 + aten/src/ATen/zoom/ATenZoomGeneral.h | 8 + aten/src/ATen/zoom/ApplyGridUtils.cuh | 47 + aten/src/ATen/zoom/AsmUtils.cuh | 85 + aten/src/ATen/zoom/Atomic.cuh | 457 +++ aten/src/ATen/zoom/CachingHostAllocator.cpp | 266 ++ aten/src/ATen/zoom/CachingHostAllocator.h | 39 + aten/src/ATen/zoom/DeviceUtils.cuh | 75 + aten/src/ATen/zoom/EmptyTensor.cpp | 71 + aten/src/ATen/zoom/EmptyTensor.h | 14 + aten/src/ATen/zoom/HIPConfig.h | 9 + aten/src/ATen/zoom/HIPGraph.cpp | 317 ++ aten/src/ATen/zoom/HIPGraph.h | 96 + aten/src/ATen/zoom/HIPGraphsUtils.hpp | 41 + aten/src/ATen/zoom/HIPUtils.h | 20 + aten/src/ATen/zoom/NumericLimits.cuh | 121 + aten/src/ATen/zoom/PeerToPeerAccess.cpp | 59 + aten/src/ATen/zoom/PeerToPeerAccess.h | 12 + aten/src/ATen/zoom/PhiloxHIPState.h | 5 + aten/src/ATen/zoom/PhiloxUtils.hpp | 4 + aten/src/ATen/zoom/PinnedMemoryAllocator.cpp | 32 + aten/src/ATen/zoom/PinnedMemoryAllocator.h | 11 + aten/src/ATen/zoom/ScanUtils.cuh | 72 + aten/src/ATen/zoom/ThrustAllocator.h | 23 + aten/src/ATen/zoom/ZoomApplyUtils.cuh | 537 +++ aten/src/ATen/zoom/ZoomContext.cpp | 69 + aten/src/ATen/zoom/ZoomContext.h | 9 + aten/src/ATen/zoom/ZoomContextLight.h | 85 + aten/src/ATen/zoom/ZoomDataType.h | 97 + aten/src/ATen/zoom/ZoomDevice.h | 17 + aten/src/ATen/zoom/ZoomEvent.h | 213 ++ aten/src/ATen/zoom/ZoomGeneratorImpl.cpp | 512 +++ aten/src/ATen/zoom/ZoomGeneratorImpl.h | 181 + aten/src/ATen/zoom/cub-RadixSortKeys.cu | 59 + aten/src/ATen/zoom/cub-RadixSortPairs.cu | 86 + aten/src/ATen/zoom/cub.cu | 51 + aten/src/ATen/zoom/cub.cuh | 284 ++ aten/src/ATen/zoom/cub.h | 88 + aten/src/ATen/zoom/cub_definitions.cuh | 27 + .../ATen/zoom/detail/DeviceThreadHandles.h | 151 + aten/src/ATen/zoom/detail/IndexUtils.cu | 75 + aten/src/ATen/zoom/detail/IndexUtils.cuh | 36 + aten/src/ATen/zoom/detail/KernelUtils.h | 37 + .../ATen/zoom/detail/PhiloxHIPStateRaw.hpp | 43 + aten/src/ATen/zoom/detail/TensorInfo.cuh | 116 + aten/src/ATen/zoom/detail/UnpackRaw.hpp | 28 + aten/src/ATen/zoom/detail/ZoomHooks.cpp | 273 ++ aten/src/ATen/zoom/detail/ZoomHooks.h | 36 + aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.cpp | 13 + aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.h | 85 + aten/src/ATen/zoom/jit/HIPJitLoops.cuh | 292 ++ aten/src/ATen/zoom/jit/HIPLoops.cuh | 333 ++ aten/src/ATen/zoom/jit/IntegerDivider.cuh | 126 + aten/src/ATen/zoom/jit/JitLoops.cuh | 182 + aten/src/ATen/zoom/jit/Loops.cuh | 325 ++ aten/src/ATen/zoom/jit/MemoryAccess.cuh | 395 +++ aten/src/ATen/zoom/jit/OffsetCalculator.cuh | 115 + aten/src/ATen/zoom/jit/jit_utils.cpp | 1752 ++++++++++ aten/src/ATen/zoom/jit/jit_utils.h | 230 ++ aten/src/ATen/zoom/jit/llvm_jit_strings.cpp | 1444 ++++++++ aten/src/ATen/zoom/jit/llvm_jit_strings.h | 14 + aten/src/ATen/zoom/jit/macros.h | 4 + aten/src/ATen/zoom/jit/thread_constants.h | 16 + build_variables.bzl | 15 + c10/CMakeLists.txt | 4 + c10/core/Allocator.cpp | 15 + c10/core/Allocator.h | 14 + c10/macros/Export.h | 2 + c10/macros/Macros.h | 2 +- c10/util/generic_math.h | 6 +- c10/zoom/CMakeLists.txt | 60 + c10/zoom/HIPGraphsC10Utils.h | 77 + c10/zoom/HIPMathCompat.h | 152 + c10/zoom/ZoomAllocatorConfig.cpp | 350 ++ c10/zoom/ZoomAllocatorConfig.h | 128 + c10/zoom/ZoomCachingAllocator.cpp | 3104 +++++++++++++++++ c10/zoom/ZoomCachingAllocator.h | 480 +++ c10/zoom/ZoomDeviceAssertionHost.cpp | 344 ++ c10/zoom/ZoomDeviceAssertionHost.h | 164 + c10/zoom/ZoomException.cpp | 88 + c10/zoom/ZoomException.h | 185 + c10/zoom/ZoomFunctions.cpp | 294 ++ c10/zoom/ZoomFunctions.h | 112 + c10/zoom/ZoomGuard.h | 301 ++ c10/zoom/ZoomMacros.h | 41 + c10/zoom/ZoomMallocAsyncAllocator.cpp | 899 +++++ c10/zoom/ZoomMiscFunctions.cpp | 23 + c10/zoom/ZoomMiscFunctions.h | 8 + c10/zoom/ZoomStream.cpp | 375 ++ c10/zoom/ZoomStream.h | 221 ++ c10/zoom/impl/ZoomGuardImpl.cpp | 7 + c10/zoom/impl/ZoomGuardImpl.h | 249 ++ caffe2/CMakeLists.txt | 126 + cmake/Caffe2Config.cmake.in | 4 + cmake/Codegen.cmake | 3 + cmake/Dependencies.cmake | 14 +- cmake/External/aotriton.cmake | 5 +- cmake/Summary.cmake | 2 + cmake/public/LoadHIP.cmake | 4 + torch/CMakeLists.txt | 16 + torch/__init__.py | 1 + torch/_decomp/decompositions.py | 8 +- torch/csrc/Module.cpp | 36 +- .../autograd/python_variable_indexing.cpp | 3 +- torch/csrc/tensor/python_tensor.cpp | 14 + torch/csrc/zoom/Event.cpp | 250 ++ torch/csrc/zoom/Event.h | 18 + torch/csrc/zoom/Graph.cpp | 91 + torch/csrc/zoom/Module.cpp | 1533 ++++++++ torch/csrc/zoom/Module.h | 11 + torch/csrc/zoom/Stream.cpp | 216 ++ torch/csrc/zoom/Stream.h | 20 + torch/csrc/zoom/THCP.h | 10 + torch/csrc/zoom/Tensor.cpp | 15 + torch/csrc/zoom/ZoomPluggableAllocator.cpp | 373 ++ torch/csrc/zoom/ZoomPluggableAllocator.h | 147 + torch/csrc/zoom/comm.cpp | 508 +++ torch/csrc/zoom/comm.h | 52 + torch/csrc/zoom/device_set.h | 11 + torch/csrc/zoom/memory_snapshot.cpp | 376 ++ torch/csrc/zoom/memory_snapshot.h | 27 + torch/csrc/zoom/python_comm.cpp | 109 + torch/csrc/zoom/python_comm.h | 7 + torch/csrc/zoom/shared/hiprt.cpp | 76 + torch/csrc/zoom/utils.cpp | 41 + torch/csrc/zoom/utils.h | 4 + torch/nn/functional.py | 4 +- torch/testing/_internal/common_device_type.py | 29 + torch/testing/_internal/common_utils.py | 28 + torch/testing/_internal/opinfo/core.py | 4 +- torch/utils/cpp_extension.py | 24 +- torch/zoom/__init__.py | 577 +++ torch/zoom/_memory_viz.py | 627 ++++ torch/zoom/_utils.py | 38 + torch/zoom/graphs.py | 479 +++ torch/zoom/memory.py | 910 +++++ torch/zoom/random.py | 179 + torch/zoom/streams.py | 241 ++ 167 files changed, 28424 insertions(+), 44 deletions(-) create mode 100644 aten/src/ATen/detail/ZoomHooksInterface.cpp create mode 100644 aten/src/ATen/detail/ZoomHooksInterface.h create mode 100644 aten/src/ATen/native/zoom/AmpKernels.cu create mode 100644 aten/src/ATen/native/zoom/CompareEQKernel.cu create mode 100644 aten/src/ATen/native/zoom/CompareKernels.cu create mode 100644 aten/src/ATen/native/zoom/Copy.cu create mode 100644 aten/src/ATen/native/zoom/Copy.h create mode 100644 aten/src/ATen/native/zoom/Equal.cpp create mode 100644 aten/src/ATen/native/zoom/FillKernel.cu create mode 100644 aten/src/ATen/native/zoom/MiscUtils.h create mode 100644 aten/src/ATen/native/zoom/Nonzero.cu create mode 100644 aten/src/ATen/native/zoom/Resize.cpp create mode 100644 aten/src/ATen/native/zoom/Resize.h create mode 100644 aten/src/ATen/native/zoom/TensorCompare.cpp create mode 100644 aten/src/ATen/native/zoom/TensorCompare.cu create mode 100644 aten/src/ATen/native/zoom/TensorFactories.cu create mode 100644 aten/src/ATen/native/zoom/TensorShape.cu create mode 100644 aten/src/ATen/native/zoom/TensorShapeZoom.cpp create mode 100644 aten/src/ATen/native/zoom/TensorTransformations.cu create mode 100644 aten/src/ATen/zoom/ATenZoomGeneral.h create mode 100644 aten/src/ATen/zoom/ApplyGridUtils.cuh create mode 100644 aten/src/ATen/zoom/AsmUtils.cuh create mode 100644 aten/src/ATen/zoom/Atomic.cuh create mode 100644 aten/src/ATen/zoom/CachingHostAllocator.cpp create mode 100644 aten/src/ATen/zoom/CachingHostAllocator.h create mode 100644 aten/src/ATen/zoom/DeviceUtils.cuh create mode 100644 aten/src/ATen/zoom/EmptyTensor.cpp create mode 100644 aten/src/ATen/zoom/EmptyTensor.h create mode 100644 aten/src/ATen/zoom/HIPConfig.h create mode 100644 aten/src/ATen/zoom/HIPGraph.cpp create mode 100644 aten/src/ATen/zoom/HIPGraph.h create mode 100644 aten/src/ATen/zoom/HIPGraphsUtils.hpp create mode 100644 aten/src/ATen/zoom/HIPUtils.h create mode 100644 aten/src/ATen/zoom/NumericLimits.cuh create mode 100644 aten/src/ATen/zoom/PeerToPeerAccess.cpp create mode 100644 aten/src/ATen/zoom/PeerToPeerAccess.h create mode 100644 aten/src/ATen/zoom/PhiloxHIPState.h create mode 100644 aten/src/ATen/zoom/PhiloxUtils.hpp create mode 100644 aten/src/ATen/zoom/PinnedMemoryAllocator.cpp create mode 100644 aten/src/ATen/zoom/PinnedMemoryAllocator.h create mode 100644 aten/src/ATen/zoom/ScanUtils.cuh create mode 100644 aten/src/ATen/zoom/ThrustAllocator.h create mode 100644 aten/src/ATen/zoom/ZoomApplyUtils.cuh create mode 100644 aten/src/ATen/zoom/ZoomContext.cpp create mode 100644 aten/src/ATen/zoom/ZoomContext.h create mode 100644 aten/src/ATen/zoom/ZoomContextLight.h create mode 100644 aten/src/ATen/zoom/ZoomDataType.h create mode 100644 aten/src/ATen/zoom/ZoomDevice.h create mode 100644 aten/src/ATen/zoom/ZoomEvent.h create mode 100644 aten/src/ATen/zoom/ZoomGeneratorImpl.cpp create mode 100644 aten/src/ATen/zoom/ZoomGeneratorImpl.h create mode 100644 aten/src/ATen/zoom/cub-RadixSortKeys.cu create mode 100644 aten/src/ATen/zoom/cub-RadixSortPairs.cu create mode 100644 aten/src/ATen/zoom/cub.cu create mode 100644 aten/src/ATen/zoom/cub.cuh create mode 100644 aten/src/ATen/zoom/cub.h create mode 100644 aten/src/ATen/zoom/cub_definitions.cuh create mode 100644 aten/src/ATen/zoom/detail/DeviceThreadHandles.h create mode 100644 aten/src/ATen/zoom/detail/IndexUtils.cu create mode 100644 aten/src/ATen/zoom/detail/IndexUtils.cuh create mode 100644 aten/src/ATen/zoom/detail/KernelUtils.h create mode 100644 aten/src/ATen/zoom/detail/PhiloxHIPStateRaw.hpp create mode 100644 aten/src/ATen/zoom/detail/TensorInfo.cuh create mode 100644 aten/src/ATen/zoom/detail/UnpackRaw.hpp create mode 100644 aten/src/ATen/zoom/detail/ZoomHooks.cpp create mode 100644 aten/src/ATen/zoom/detail/ZoomHooks.h create mode 100644 aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.cpp create mode 100644 aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.h create mode 100644 aten/src/ATen/zoom/jit/HIPJitLoops.cuh create mode 100644 aten/src/ATen/zoom/jit/HIPLoops.cuh create mode 100644 aten/src/ATen/zoom/jit/IntegerDivider.cuh create mode 100644 aten/src/ATen/zoom/jit/JitLoops.cuh create mode 100644 aten/src/ATen/zoom/jit/Loops.cuh create mode 100644 aten/src/ATen/zoom/jit/MemoryAccess.cuh create mode 100644 aten/src/ATen/zoom/jit/OffsetCalculator.cuh create mode 100644 aten/src/ATen/zoom/jit/jit_utils.cpp create mode 100644 aten/src/ATen/zoom/jit/jit_utils.h create mode 100644 aten/src/ATen/zoom/jit/llvm_jit_strings.cpp create mode 100644 aten/src/ATen/zoom/jit/llvm_jit_strings.h create mode 100644 aten/src/ATen/zoom/jit/macros.h create mode 100644 aten/src/ATen/zoom/jit/thread_constants.h create mode 100644 c10/zoom/CMakeLists.txt create mode 100644 c10/zoom/HIPGraphsC10Utils.h create mode 100644 c10/zoom/HIPMathCompat.h create mode 100644 c10/zoom/ZoomAllocatorConfig.cpp create mode 100644 c10/zoom/ZoomAllocatorConfig.h create mode 100644 c10/zoom/ZoomCachingAllocator.cpp create mode 100644 c10/zoom/ZoomCachingAllocator.h create mode 100644 c10/zoom/ZoomDeviceAssertionHost.cpp create mode 100644 c10/zoom/ZoomDeviceAssertionHost.h create mode 100644 c10/zoom/ZoomException.cpp create mode 100644 c10/zoom/ZoomException.h create mode 100644 c10/zoom/ZoomFunctions.cpp create mode 100644 c10/zoom/ZoomFunctions.h create mode 100644 c10/zoom/ZoomGuard.h create mode 100644 c10/zoom/ZoomMacros.h create mode 100644 c10/zoom/ZoomMallocAsyncAllocator.cpp create mode 100644 c10/zoom/ZoomMiscFunctions.cpp create mode 100644 c10/zoom/ZoomMiscFunctions.h create mode 100644 c10/zoom/ZoomStream.cpp create mode 100644 c10/zoom/ZoomStream.h create mode 100644 c10/zoom/impl/ZoomGuardImpl.cpp create mode 100644 c10/zoom/impl/ZoomGuardImpl.h create mode 100644 torch/csrc/zoom/Event.cpp create mode 100644 torch/csrc/zoom/Event.h create mode 100644 torch/csrc/zoom/Graph.cpp create mode 100644 torch/csrc/zoom/Module.cpp create mode 100644 torch/csrc/zoom/Module.h create mode 100644 torch/csrc/zoom/Stream.cpp create mode 100644 torch/csrc/zoom/Stream.h create mode 100644 torch/csrc/zoom/THCP.h create mode 100644 torch/csrc/zoom/Tensor.cpp create mode 100644 torch/csrc/zoom/ZoomPluggableAllocator.cpp create mode 100644 torch/csrc/zoom/ZoomPluggableAllocator.h create mode 100644 torch/csrc/zoom/comm.cpp create mode 100644 torch/csrc/zoom/comm.h create mode 100644 torch/csrc/zoom/device_set.h create mode 100644 torch/csrc/zoom/memory_snapshot.cpp create mode 100644 torch/csrc/zoom/memory_snapshot.h create mode 100644 torch/csrc/zoom/python_comm.cpp create mode 100644 torch/csrc/zoom/python_comm.h create mode 100644 torch/csrc/zoom/shared/hiprt.cpp create mode 100644 torch/csrc/zoom/utils.cpp create mode 100644 torch/csrc/zoom/utils.h create mode 100644 torch/zoom/__init__.py create mode 100644 torch/zoom/_memory_viz.py create mode 100644 torch/zoom/_utils.py create mode 100644 torch/zoom/graphs.py create mode 100644 torch/zoom/memory.py create mode 100644 torch/zoom/random.py create mode 100644 torch/zoom/streams.py diff --git a/aten/src/ATen/AccumulateType.cpp b/aten/src/ATen/AccumulateType.cpp index c4623cc08629c7..55952a6c8ff919 100644 --- a/aten/src/ATen/AccumulateType.cpp +++ b/aten/src/ATen/AccumulateType.cpp @@ -2,17 +2,20 @@ namespace at { +// TODO(Arham): exchange keys c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { switch (type) { -#define DEFINE_CASE(scalar_t, TypeNum) \ - case ScalarType::TypeNum: \ - switch (device) { \ - case DeviceType::CUDA: \ - return CppTypeToScalarType>::value; \ - case DeviceType::MPS: \ - return CppTypeToScalarType>::value; \ - default: \ - return CppTypeToScalarType>::value; \ +#define DEFINE_CASE(scalar_t, TypeNum) \ + case ScalarType::TypeNum: \ + switch (device) { \ + case DeviceType::CUDA: \ + return CppTypeToScalarType>::value; \ + case DeviceType::PrivateUse1: \ + return CppTypeToScalarType>::value; \ + case DeviceType::MPS: \ + return CppTypeToScalarType>::value; \ + default: \ + return CppTypeToScalarType>::value; \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(DEFINE_CASE) @@ -23,7 +26,12 @@ c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { } c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) { - return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU); + #ifndef USE_ZOOM + return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU); + #else + // TODO(Arham): exchange keys + return is_cuda ? toAccumulateType(type, c10::DeviceType::PrivateUse1) : toAccumulateType(type, c10::DeviceType::CPU); + #endif } } diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index 0275ef099b03d7..1cdd2423c050a0 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -67,7 +67,12 @@ struct AccumulateType { template struct AccumulateType { - using type = typename AccumulateTypeDevice::type; + #ifndef USE_ZOOM + using type = typename AccumulateTypeDevice::type; + #else + // TODO(Arham): exchange keys + using type = typename AccumulateTypeDevice::type; + #endif }; template @@ -83,6 +88,8 @@ using acc_type = typename AccumulateType::type; }; #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) +// TODO(Arham): exchange keys +#define ZOOM_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::PrivateUse1) #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) MPS_ACC_TYPE(BFloat16, float); @@ -126,6 +133,28 @@ CUDA_ACC_TYPE(c10::complex, c10::complex); CUDA_ACC_TYPE(c10::complex, c10::complex); CUDA_ACC_TYPE(c10::complex, c10::complex); +#if defined(__HIPCC__) +ZOOM_ACC_TYPE(half, float); +#endif +ZOOM_ACC_TYPE(BFloat16, float); +ZOOM_ACC_TYPE(Half, float); +ZOOM_ACC_TYPE(Float8_e5m2, float); +ZOOM_ACC_TYPE(Float8_e4m3fn, float); +ZOOM_ACC_TYPE(Float8_e5m2fnuz, float); +ZOOM_ACC_TYPE(Float8_e4m3fnuz, float); +ZOOM_ACC_TYPE(float, float); +ZOOM_ACC_TYPE(double, double); +ZOOM_ACC_TYPE(int8_t, int64_t); +ZOOM_ACC_TYPE(uint8_t, int64_t); +ZOOM_ACC_TYPE(char, int64_t); +ZOOM_ACC_TYPE(int16_t, int64_t); +ZOOM_ACC_TYPE(int32_t, int64_t); +ZOOM_ACC_TYPE(int64_t, int64_t); +ZOOM_ACC_TYPE(bool, bool); +ZOOM_ACC_TYPE(c10::complex, c10::complex); +ZOOM_ACC_TYPE(c10::complex, c10::complex); +ZOOM_ACC_TYPE(c10::complex, c10::complex); + CPU_ACC_TYPE(BFloat16, float); CPU_ACC_TYPE(Half, float); CPU_ACC_TYPE(Float8_e5m2, float); diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9ec458fda45e4c..1cd471cee47bc0 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -82,6 +82,12 @@ file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") file(GLOB miopen_h "miopen/*.h") file(GLOB miopen_cpp "miopen/*.cpp") +file(GLOB zoom_h "zoom/*.h" "zoom/detail/*.h" "zoom/*.cuh" "zoom/detail/*.cuh" "zoom/tunable/*.cuh" "zoom/tunable/*.h" "zoom/jit/*.cuh" "zoom/jit/*.h") +file(GLOB zoom_cpp "zoom/*.cpp" "zoom/detail/*.cpp" "zoom/tunable/*.cpp" "zoom/jit/*.cpp") +file(GLOB zoom_hip "zoom/*.cu" "zoom/detail/*.cu" "zoom/impl/*.cu" "zoom/tunable/*.cu") +file(GLOB zoom_hiprtc_stub_h "zoom/hiprtc_stub/*.h") +file(GLOB zoom_hiprtc_stub_cpp "zoom/hiprtc_stub/*.cpp") + file(GLOB mkl_cpp "mkl/*.cpp") file(GLOB mkldnn_cpp "mkldnn/*.cpp") @@ -166,6 +172,13 @@ file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") +file(GLOB native_zoom_hip "native/zoom/*.cu") +file(GLOB native_zoom_hip_h "native/zoom/*.cuh") +file(GLOB native_zoom_cpp "native/zoom/*.cpp") +file(GLOB native_zoom_linalg_cpp "native/zoom/linalg/*.cpp") +file(GLOB native_sparse_zoom_hip "native/sparse/zoom/*.cu") +file(GLOB native_sparse_zoom_cpp "native/sparse/zoom/*.cpp") + # flash_attention sources file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") @@ -342,6 +355,26 @@ if(USE_ROCM) ) endif() +if(USE_ZOOM) + list(APPEND ATen_ZOOM_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/zoom) + list(APPEND ATen_ZOOM_SRCS + ${ATen_ZOOM_SRCS} + ${zoom_hip} + ${native_zoom_hip} + ${native_zoom_hip_h} + ${native_sparse_zoom_hip} + ) + list(APPEND all_zoom_cpp + ${native_sparse_zoom_cpp} + ${zoom_cpp} + ${native_zoom_cpp} + ${native_zoom_linalg_cpp} + ${zoom_generated_sources} + ${ATen_ZOOM_SRCS} + ${all_zoom_cpp} + ) +endif() + if(USE_XPU) list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu) list(APPEND ATen_XPU_SRCS ${xpu_cpp}) @@ -546,6 +579,7 @@ endif() # Include CPU paths for CUDA/HIP as well list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE}) +list(APPEND ATen_ZOOM_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND ATen_VULKAN_INCLUDE ${ATen_CPU_INCLUDE}) # We have two libraries: libATen_cpu.so and libATen_cuda.so, @@ -576,6 +610,12 @@ if(USE_ROCM) # list(APPEND ATen_HIP_DEPENDENCY_LIBS ATEN_CUDA_FILES_GEN_LIB) endif() +if(USE_ZOOM) + set(ATen_ZOOM_SRCS ${all_zoom_cpp}) + set(ATen_HIPRTC_STUB_SRCS ${zoom_hiprtc_stub_cpp}) + # list(APPEND ATen_ZOOM_DEPENDENCY_LIBS ATEN_ZOOM_FILES_GEN_LIB) +endif() + set(ATEN_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/${AT_INSTALL_INCLUDE_DIR}") configure_file(ATenConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" @@ -583,7 +623,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${zoom_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) @@ -611,7 +651,7 @@ foreach(HEADER ${INSTALL_HEADERS}) endforeach() # TODO: Install hip_generated_headers when we have it -foreach(HEADER ${generated_headers} ${cuda_generated_headers}) +foreach(HEADER ${generated_headers} ${cuda_generated_headers} ${zoom_generated_headers}) # NB: Assumed to be flat install(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen) endforeach() @@ -652,7 +692,10 @@ set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_ZOOM_SRCS ${ATen_ZOOM_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) @@ -671,12 +714,14 @@ set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) +set(ATen_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE) set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(ATen_ATTENTION_KERNEL_SRCS ${ATen_ATTENTION_KERNEL_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 7fd191ef3f38c3..1136b05b265491 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -153,6 +153,7 @@ static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" } bool Context::checkCuBLASConfigDeterministic() { bool cublas_config_deterministic = true; + #ifndef USE_ZOOM // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config // is set to deterministic setting if (hasCUDART() && (versionCUDART() >= 10020)) { @@ -163,6 +164,10 @@ bool Context::checkCuBLASConfigDeterministic() { ); } return cublas_config_deterministic; + #else + // Zoom uses hipBLAS with the rocBLAS backend - this is only deterministic if atomics are disabled + return checkHIPBlasDeterministic(); + #endif } void Context::alertCuBLASConfigNotDeterministic() const { @@ -171,6 +176,7 @@ void Context::alertCuBLASConfigNotDeterministic() const { return; } + #ifndef USE_ZOOM auto msg = c10::str( "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ", "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ", @@ -180,6 +186,16 @@ void Context::alertCuBLASConfigNotDeterministic() const { cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ", "https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility" ); + #else + auto msg = c10::str( + "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ", + "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ", + "it uses hipBLAS and you have atomic operations enabled. To enable deterministic behavior in this ", + "case, you must set an environment variable before running your PyTorch application: ", + "ROCBLAS_DEFAULT_ATOMICS_MODE = 0. For more information, go to ", + "https://github.com/ROCm/rocBLAS/blob/develop/docs/how-to/what-is-rocblas.rst#bitwise-reproducibility" + ); + #endif if (deterministicAlgorithmsWarnOnly()) { TORCH_WARN(msg); diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a922bcd5922fc8..f241e91be6f731 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -126,6 +127,9 @@ class TORCH_API Context { static bool hasCuBLASLt() { return detail::getCUDAHooks().hasCuBLASLt(); } + static bool checkHIPBlasDeterministic() { + return detail::getZoomHooks().checkHIPBlasDeterministic(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -163,14 +167,18 @@ class TORCH_API Context { } void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { - if (isPrivateUse1HooksRegistered()) { - at::GetPrivateUse1HooksInterface()->initPrivateUse1(); - } + // if (isPrivateUse1HooksRegistered()) { + // at::GetPrivateUse1HooksInterface()->initPrivateUse1(); + // } + detail::getZoomHooks().initPrivateUse1(); }); } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } + static const at::zoom::HIPRTC& getHIPRTC() { + return detail::getZoomHooks().hiprtc(); + } static bool setFlushDenormal(bool on); diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 1eb5c070b547c9..8b5cd8e8123920 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -21,6 +21,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { } else if (at::globalContext().hasXPU()) { return at::detail::getXPUHooks().getPinnedMemoryAllocator(); } else if(at::isPrivateUse1HooksRegistered()) { + // TODO(Arham): exchange keys return at::GetPrivateUse1HooksInterface()->getPinnedMemoryAllocator(); } else { TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.") diff --git a/aten/src/ATen/TensorIndexing.cpp b/aten/src/ATen/TensorIndexing.cpp index bd50282b46ec6a..128298522d48f2 100644 --- a/aten/src/ATen/TensorIndexing.cpp +++ b/aten/src/ATen/TensorIndexing.cpp @@ -50,9 +50,10 @@ static inline void set_item(const Tensor& self, ArrayRef indices, c at::Device self_device = self.device(); // TODO: This qint special case looks very suspicious... + // TODO(Arham): exchange keys if (isQIntType(self.scalar_type())) { value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU)); - } else if (self_device.is_cuda()) { + } else if (self_device.is_cuda() || self_device.is_privateuseone()) { value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU)); } else { value = at::indexing::scalarToTensor(v, self.options(), self_device); diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2d01bdeca500b0..8219fafb037b98 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -202,6 +202,44 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +// TODO(Arham): exchange keys +TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) { + // lower_precision_fp +#define _KERNEL_ZOOM_LOW_PRECISION_FP(...) \ + KERNEL_ZOOM(__VA_ARGS__, lower_precision_fp) + + AT_FORALL_LOWER_PRECISION_FP(_KERNEL_ZOOM_LOW_PRECISION_FP) + + // fp32 +#define _KERNEL_ZOOM_FP32(...) KERNEL_ZOOM(__VA_ARGS__, fp32) + + AT_FORALL_FP32(_KERNEL_ZOOM_FP32) + + // fp32_set_opt_dtype +#define _KERNEL_ZOOM_FP32_SET_OPT_DTYPE(...) \ + KERNEL_ZOOM(__VA_ARGS__, fp32_set_opt_dtype) + + AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_ZOOM_FP32_SET_OPT_DTYPE) + + // fp32_append_dtype + // The fp32_append_dtype wrapper overrides implicit promotion behavior. + // norm does not implicitly promote, but be aware when adding new ops to this policy. + AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM) + + // promote +#define _KERNEL_ZOOM_PROMOTE(...) KERNEL_ZOOM(__VA_ARGS__, promote) + + AT_FORALL_PROMOTE(_KERNEL_ZOOM_PROMOTE) + + m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), + TORCH_FN((&at::autocast::binary_cross_entropy_banned))); +} + TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index c36030db5b0489..2f897715d03b60 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -708,6 +708,25 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) +// KERNEL_ZOOM/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastZOOM +// TODO(Arham): exchange keys +#define KERNEL_ZOOM(...) KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::PrivateUse1, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__) diff --git a/aten/src/ATen/detail/ZoomHooksInterface.cpp b/aten/src/ATen/detail/ZoomHooksInterface.cpp new file mode 100644 index 00000000000000..f23de3c899c165 --- /dev/null +++ b/aten/src/ATen/detail/ZoomHooksInterface.cpp @@ -0,0 +1,48 @@ +#include + +#include + +#include + +namespace at { +namespace detail { + +// NB: We purposely leak the CUDA hooks object. This is because under some +// situations, we may need to reference the CUDA hooks while running destructors +// of objects which were constructed *prior* to the first invocation of +// getZoomHooks. The example which precipitated this change was the fused +// kernel cache in the JIT. The kernel cache is a global variable which caches +// both CPU and CUDA kernels; CUDA kernels must interact with CUDA hooks on +// destruction. Because the kernel cache handles CPU kernels too, it can be +// constructed before we initialize CUDA; if it contains CUDA kernels at program +// destruction time, you will destruct the CUDA kernels after CUDA hooks has +// been unloaded. In principle, we could have also fixed the kernel cache store +// CUDA kernels in a separate global variable, but this solution is much +// simpler. +// +// CUDAHooks doesn't actually contain any data, so leaking it is very benign; +// you're probably losing only a word (the vptr in the allocated object.) +static ZoomHooksInterface* zoom_hooks = nullptr; + +// init and register extension hooks +void initZoomHooks() { + static c10::once_flag once; + c10::call_once(once, [] { + zoom_hooks = PrivateUse1HooksRegistry()->Create("ZoomHooks", ZoomHooksArgs{}).release(); + if (!zoom_hooks) { + zoom_hooks = new ZoomHooksInterface(); + } + RegisterPrivateUse1HooksInterface(zoom_hooks); + }); +} + +const ZoomHooksInterface& getZoomHooks() { + initZoomHooks(); + return *zoom_hooks; +} + +} // namespace detail + +C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, ZoomHooksInterface, ZoomHooksArgs) + +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/detail/ZoomHooksInterface.h b/aten/src/ATen/detail/ZoomHooksInterface.h new file mode 100644 index 00000000000000..0e971a17e5a9c9 --- /dev/null +++ b/aten/src/ATen/detail/ZoomHooksInterface.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include +#include + +#include + +// Forward-declares at::Generator and at::zoom::NVRTC +namespace at { +struct Generator; +namespace zoom { +struct HIPRTC; +} // namespace zoom +} // namespace at + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// #ifdef _MSC_VER +// constexpr const char* ZOOM_HELP = +// "PyTorch splits its backend into two shared libraries: a CPU library " +// "and a CUDA library; this error has occurred because you are trying " +// "to use some CUDA functionality, but the CUDA library has not been " +// "loaded by the dynamic linker for some reason. The CUDA library MUST " +// "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " +// "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " +// "in your link arguments; many dynamic linkers will delete dynamic library " +// "dependencies if you don't depend on any of their symbols. You can check " +// "if this has occurred by using link on your binary to see if there is a " +// "dependency on *_cuda.dll library."; +// #else +constexpr const char* ZOOM_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a ZOOM library; this error has occurred because you are trying " + "to use some ZOOM functionality, but the ZOOM library has not been " + "loaded by the dynamic linker for some reason. The ZOOM library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the ZOOM library! " + "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " + "dynamic linkers will delete dynamic library dependencies if you don't " + "depend on any of their symbols. You can check if this has occurred by " + "using ldd on your binary to see if there is a dependency on *_cuda.so " + "library."; +// #endif + +// The ZoomHooksInterface is an omnibus interface for any ZOOM functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of ZOOM code). How do I +// decide if a function should live in this class? There are two tests: +// +// 1. Does the *implementation* of this function require linking against +// ZOOM libraries? +// +// 2. Is this function *called* from non-ZOOM ATen code? +// +// (2) should filter out many ostensible use-cases, since many times a ZOOM +// function provided by ATen is only really ever used by actual ZOOM code. +// +// TODO: Consider putting the stub definitions in another class, so that one +// never forgets to implement each virtual function in the real implementation +// in ZOOMHooks. This probably doesn't buy us much though. +struct TORCH_API ZoomHooksInterface : PrivateUse1HooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~ZoomHooksInterface() override = default; + + // Initialize THCState and, transitively, the ZOOM state + virtual void initZoom() const { + TORCH_CHECK(false, "Cannot initialize ZOOM without torch_zoom library. ", ZOOM_HELP); + } + + virtual void initPrivateUse1() const override { + initZoom(); + } + + virtual const Generator& getDefaultZoomGenerator(C10_UNUSED DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Cannot get default ZOOM generator without torch_zoom library. ", ZOOM_HELP); + } + + virtual const Generator& getDefaultGenerator(DeviceIndex device_index) override { return getDefaultZoomGenerator(device_index); }; + + virtual Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on ZOOM without torch_zoom library. ", ZOOM_HELP); + } + + virtual bool isPinnedPtr(const void* /*data*/) const { + return false; + } + + virtual bool hasROCM() const { + return false; + } + + virtual bool checkHIPBlasDeterministic() const { + TORCH_CHECK(false, "Cannot call checkHIPBlasDeterministic without torch_zoom library", ZOOM_HELP); + } + + virtual const at::zoom::HIPRTC& hiprtc() const { + TORCH_CHECK(false, "HIPRTC requires Zoom. ", ZOOM_HELP); + } + + virtual bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without torch_zoom library. ", ZOOM_HELP); + } + + virtual DeviceIndex current_device() const { + return -1; + } + + virtual Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires ZOOM. ", ZOOM_HELP); + } + + virtual Allocator* getZoomDeviceAllocator() const { + TORCH_CHECK(false, "ZoomDeviceAllocator requires ZOOM. ", ZOOM_HELP); + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed ZOOM version without torch_zoom library. ", ZOOM_HELP); + } + + virtual int getNumGPUs() const { + return 0; + } + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize ZOOM device without torch_zoom library. ", ZOOM_HELP); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API ZoomHooksArgs {}; + +TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, ZoomHooksInterface, ZoomHooksArgs); +#define REGISTER_PRIVATEUSE1_HOOKS(clsname) \ + C10_REGISTER_CLASS(PrivateUse1HooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API void initZoomHooks(); +TORCH_API const ZoomHooksInterface& getZoomHooks(); +} // namespace detail +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index c5f81e98906dd4..416a607d5c2622 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -130,7 +130,8 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { // (e.g. XLA) may be supported by overriding copy_ and _copy_from. bool is_supported_device(Device device) { DeviceType device_type = device.type(); - return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS; + // TODO(Arham): exchange keys + return device_type == kPrivateUse1 || device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS; } } // namespace @@ -288,6 +289,9 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) } else if (iter.device_type(1) == kMPS) { device_type = kMPS; } + else if (iter.device_type(1) == kPrivateUse1) { + device_type = kPrivateUse1; + } // TODO: if we need to, we can also enable this path for quantized tensor if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) { diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 974ad302ca0c86..72336656842368 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -585,8 +585,9 @@ std::tuple mode(const Tensor& self, int64_t dim, bool keepdim) { std::tuple mode_out(const Tensor& self, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { - TORCH_CHECK(self.device().is_cpu() || self.is_cuda(), - "mode only supports CPU AND CUDA device type, got: ", self.device().type()); + // TODO(Arham): exchange keys + TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_privateuseone(), + "mode only supports CPU, CUDA, and Zoom device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "mode only supports strided layout, got: ", self.layout()); TORCH_CHECK(self.device() == values.device(), diff --git a/aten/src/ATen/native/zoom/AmpKernels.cu b/aten/src/ATen/native/zoom/AmpKernels.cu new file mode 100644 index 00000000000000..14fa799fd6d283 --- /dev/null +++ b/aten/src/ATen/native/zoom/AmpKernels.cu @@ -0,0 +1,252 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace { +// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e, +// to ensure the Cuda math library's isfinite is actually what gets called in +// _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda. +// +// isfinite_ensure_cuda_math is defined outside at::native because: +// - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated +// Tensor at::native::isfinite(const Tensor&), resulting in an error: +// "no suitable constructor exists to convert from "float" to "at::Tensor"" +// - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path +// to ensure that its version of a particular function is invoked. It only shows bare (not-namespaced) +// calls to its routines inside kernel or device functions. +// - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows. +// +// isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs +// suggest and allows disambiguated usage in the lambda within the at::native region. +// GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that +// "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed." +static __host__ __device__ __forceinline__ int isfinite_ensure_zoom_math(float val) { + return isfinite(val); +} +} + +namespace at::native { + +namespace { +// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_zoom_. +// Handles individual tensors that are acceptable to unscale but not MTA-safe. +void _amp_non_finite_check_and_unscale_zoom_(Tensor& scaled_grad, + Tensor& found_inf, + const Tensor& inv_scale) +{ + // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_zoom_, so no input checks. + + // It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case. + const OptionalDeviceGuard device_guard(device_of(scaled_grad)); + + // Acts on scaled_grad in place. + auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + iter.dtype(), + "_amp_non_finite_check_and_unscale_zoom", + [&iter, &found_inf, &inv_scale] { + auto* found_inf_ptr = found_inf.mutable_data_ptr(); + auto* inv_scale_ptr = inv_scale.const_data_ptr(); + + using opmath_t = at::opmath_type; + + gpu_kernel(iter, + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { + auto val = static_cast(val_in); + if (!isfinite_ensure_zoom_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); +} +} // anonymous namespace + + +// Multiplies each tensor in scaled_grads by inv_scale in-place. +// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0. +// Uses multi tensor apply (MTA) to process all MTA-safe tensors. +// +// Args: +// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs. +// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans. +// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. +// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied. +void _amp_foreach_non_finite_check_and_unscale_zoom_(TensorList scaled_grads, + Tensor& found_inf, + const Tensor& inv_scale) +{ + if (scaled_grads.size() == 0) { + return; + } + + TORCH_CHECK(inv_scale.is_privateuseone(), "inv_scale must be a Zoom tensor."); + TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); + TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + + // Ensures client code (GradScaler) filtered scaled_grads by dtype. + check_foreach_api_restrictions(scaled_grads); + + std::vector> tensor_lists; + + // is_non_overlapping_and_dense() is not available in Python. + // GradScaler can't filter for it. We need to filter here. + if (can_use_fast_route(scaled_grads)) { + // Hopefully common case. + // can_use_fast_route is true, which confirms: + // - all scaled_grads are strided + // - all scaled_grads are non overlapping and dense + // - all scaled_grads are on the same device + // - all scaled_grads are of the same dtype + TORCH_CHECK(scaled_grads[0].is_privateuseone(), "scaled_grads must be Zoom tensors."); + // Sets up MTA launch to use scaled_grads as-is. + tensor_lists.emplace_back(scaled_grads.vec()); + } else { + // Hopefully uncommon case. + // can_use_fast_route is an all-or-nothing check. In this path it was false, + // so any of the above confirmations could have gone wrong. + // We filter MTA-safe tensors into an MTA-able list. + // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel. + // If a tensor is unacceptable, we throw an error to blame GradScaler. + tensor_lists.resize(1); + tensor_lists[0].reserve(scaled_grads.size()); + auto expected_device = scaled_grads[0].device(); + const auto expected_dtype = scaled_grads[0].scalar_type(); + for (const Tensor& t : scaled_grads) { + // Ensures GradScaler filtered scaled_grads by device. + TORCH_CHECK(t.is_privateuseone(), "one of scaled_grads was not a Zoom tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { + // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. + _amp_non_finite_check_and_unscale_zoom_(const_cast(t), + found_inf, + inv_scale); + } else { + tensor_lists[0].push_back(t); + } + } + if (tensor_lists[0].size() == 0) { + return; + } + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + tensor_lists[0][0].scalar_type(), + "_amp_foreach_non_finite_check_and_unscale_zoom", + [&tensor_lists, &found_inf, &inv_scale] { + auto* found_inf_ptr = found_inf.mutable_data_ptr(); + auto* inv_scale_ptr = inv_scale.const_data_ptr(); + + using opmath_t = at::opmath_type; + + // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { + // There is a slight asymmetry here with the TensorIterator kernel above. + // MTA Functors ensure val comes in as opmath_t rather than scalar_t. + if (!isfinite_ensure_zoom_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); +} + + +// amp_update_scale_zoom_kernel is launched with a single thread to compute the new scale. +// The scale factor is maintained and updated on the GPU to avoid synchronization. +__global__ void amp_update_scale_zoom_kernel(float* current_scale, + int* growth_tracker, + const float* found_inf, + double growth_factor, + double backoff_factor, + int growth_interval) +{ + if (*found_inf) { + *current_scale = (*current_scale)*backoff_factor; + *growth_tracker = 0; + } else { + // Entering this branch means we just carried out a successful step, + // so growth_tracker is incremented before comparing to growth_interval. + auto successful = (*growth_tracker) + 1; + if (successful == growth_interval) { + auto new_scale = static_cast((*current_scale)*growth_factor); + // Do not grow the scale past fp32 bounds to inf. + if (isfinite_ensure_zoom_math(new_scale)) { + *current_scale = new_scale; + } + *growth_tracker = 0; + } else { + *growth_tracker = successful; + } + } +} + + +// _amp_update_scale_zoom asynchronously updates the scale tensor in place. +// +// Args: +// current_scale: A one-element zoom float tensor containing the scale value. +// growth_tracker: A one-element torch.zoom.IntTensor containing the number of recent consecutive unskipped steps. +// found_inf: A one-element zoom float tensor. If > 0, indicates that infs/nans were found by the relevant +// prior _amp_non_finite_check_and_unscale_zoom call, and 0 if no infs/nans were found. +// growth_factor: Multiplier if no infs/NaNs were found (typically slightly > 1). +// backoff_factor: Multiplier if infs/NaNs were found (typically 0.5). +// growth_interval: Number of consecutive unskipped steps that must occur for current_scale to be multiplied by +// growth_factor. +// +// Returns: +// current_scale +Tensor& _amp_update_scale_zoom_(Tensor& current_scale, + Tensor& growth_tracker, + const Tensor& found_inf, + double growth_factor, + double backoff_factor, + int64_t growth_interval) +{ + TORCH_CHECK(growth_tracker.is_privateuseone(), "growth_tracker must be a Zoom tensor."); + TORCH_CHECK(current_scale.is_privateuseone(), "current_scale must be a Zoom tensor."); + TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); + TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor."); + TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor."); + TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + + amp_update_scale_zoom_kernel<<<1, 1, 0, c10::zoom::getCurrentZoomStream()>>>( + current_scale.mutable_data_ptr(), + growth_tracker.mutable_data_ptr(), + found_inf.const_data_ptr(), + growth_factor, + backoff_factor, + growth_interval); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + return current_scale; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CompareEQKernel.cu b/aten/src/ATen/native/zoom/CompareEQKernel.cu new file mode 100644 index 00000000000000..b8869c0dc86b31 --- /dev/null +++ b/aten/src/ATen/native/zoom/CompareEQKernel.cu @@ -0,0 +1,50 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { namespace { + +enum class EqOpType {EQ, NE}; + +template +struct CompareEqFunctor{ + CompareEqFunctor(EqOpType op): op_(op) {} + const EqOpType op_; + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + if (op_ == EqOpType::EQ) { + return a == b; + } else { //NE + return a != b; + } + + } + }; +} + +C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) { + AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_zoom", AT_WRAP([&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, CompareEqFunctor(op)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +void eq_kernel_zoom(TensorIteratorBase& iter) { + compare_eq_ne_kernel(iter, EqOpType::EQ); +} + +void ne_kernel_zoom(TensorIteratorBase& iter) { + compare_eq_ne_kernel(iter, EqOpType::NE); +} + +REGISTER_PRIVATEUSE1_DISPATCH(eq_stub, &eq_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(ne_stub, &ne_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/CompareKernels.cu b/aten/src/ATen/native/zoom/CompareKernels.cu new file mode 100644 index 00000000000000..21da608a35fc94 --- /dev/null +++ b/aten/src/ATen/native/zoom/CompareKernels.cu @@ -0,0 +1,103 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { namespace { + +enum class OpType {GE, GT, LE, LT}; + +template +struct CompareFunctor{ + constexpr CompareFunctor(OpType op): op_(op) {}; + OpType op_; + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + if (op_ == OpType::GE) { + return a >= b; + } else if (op_ == OpType::GT) { + return a > b; + } else if (op_ == OpType::LE) { + return a <= b; + } else { //LT + return a < b; + } + } +}; + +// Reflects the comparison operator, so reflect(op)(a, b) == op(b, a) +OpType reflect(OpType x) { + switch (x) { + case OpType::GE: return OpType::LE; + case OpType::GT: return OpType::LT; + case OpType::LE: return OpType::GE; + case OpType::LT: return OpType::GT; + } + TORCH_INTERNAL_ASSERT(false, "Invalid OpType"); +} + +} // namespace (anonymous) + +template +void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) { + CompareFunctor f(op); + gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool { + return f(lhs, rhs); + }); +} + +template +void compare_kernel_impl(TensorIteratorBase &iter, OpType op) { + // If either input is a cpu scalar, perform the equivalent comparison + // where the scalar is on the right hand side. This saves us from + // generating two otherwise identical kernels with mirrored + // arguments. + if (iter.is_cpu_scalar(1)) { + const scalar_t lhs = iter.scalar_value(1); + iter.remove_operand(1); + const DeviceGuard device_guard(iter.device(1)); + compare_scalar_kernel(iter, reflect(op), lhs); + } else if (iter.is_cpu_scalar(2)) { + const scalar_t rhs = iter.scalar_value(2); + iter.remove_operand(2); + compare_scalar_kernel(iter, op, rhs); + } else { + CompareFunctor f(op); + gpu_kernel(iter, f); + } +} + +C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_zoom", [&]() { + compare_kernel_impl(iter, op); + }); +} + + +void ge_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GE); +} + +void gt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GT); +} + +void le_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LE); +} + +void lt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LT); +} + +REGISTER_PRIVATEUSE1_DISPATCH(ge_stub, &ge_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(gt_stub, >_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(le_stub, &le_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(lt_stub, <_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Copy.cu b/aten/src/ATen/native/zoom/Copy.cu new file mode 100644 index 00000000000000..3415806851f9fd --- /dev/null +++ b/aten/src/ATen/native/zoom/Copy.cu @@ -0,0 +1,393 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include + +namespace at::native { + +void neg_kernel_zoom(TensorIteratorBase &iter); +void conj_kernel_zoom(TensorIteratorBase &iter); + +void float8_copy_kernel_zoom(TensorIteratorBase &iter) { + ScalarType dtype = iter.dtype(0); + ScalarType other_dtype = iter.dtype(1); + if (dtype == kFloat8_e4m3fn) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e4m3fn(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e4m3fn(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e4m3fn(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; }); + break; + } + } else if (dtype == kFloat8_e5m2) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; }); + break; + } + } else if (dtype == kFloat8_e4m3fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e4m3fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e4m3fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e4m3fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; }); + break; + } + } else if (dtype == kFloat8_e5m2fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e5m2fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e5m2fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e5m2fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; }); + break; + } + } else { + TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); + } +} + +// TODO: We probably can use the opaque type trick to avoid creating duplicate +// kernels for equivalent bit lengths +void direct_copy_kernel_zoom(TensorIteratorBase &iter) { + ScalarType dtype = iter.dtype(0); + if (isQIntType(dtype)) { + AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }); + } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) { + float8_copy_kernel_zoom(iter); + } else if (isBitsType(dtype)) { + TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " + "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); + AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }); + } else { + AT_DISPATCH_V2( + dtype, "copy_", AT_WRAP([&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + } +} + +void neg_conj_kernel_zoom(TensorIteratorBase &iter) { + AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_zoom", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); }); + }); +} + +using namespace at::zoom; + +// device-to-device copy, does type conversion +void copy_device_to_device(TensorIterator& iter, + bool non_blocking, + bool p2p_enabled) { + int64_t numel = iter.numel(); + + // We can memcpy the memory if both tensors have the same type AND both + // tensors are contiguous after dimension coalescing and reordering. + bool same_type = iter.dtype(0) == iter.dtype(1); + bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj(); + bool same_neg = iter.tensor(0).is_neg() == iter.tensor(1).is_neg(); + bool memcpy_eligible = same_type && same_conj && same_neg && iter.is_contiguous(); + + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + c10::zoom::ZoomGuard device_guard(src_device); + + // We always perform the copy on the source device, using the current stream + // on the source device, and we fully synchronize on both src and dst's + // current streams for completion of the copy. We have to explicitly do this + // for non-contig copies. This mimics the behavior of cross-device + // hipMemcpyAsync on the default stream. + c10::zoom::ZoomStream copy_stream = c10::zoom::getCurrentZoomStream(src_device.index()); + if (src_device != dst_device) { + // This is a cross-device copy on the src current stream and dst current + // stream. We perform a two-way barrier between both devices' streams + // before the copy. This ensures that any write-after-write and + // write-after-read dependencies on the destination side are handled, so + // that no one is operating on the dst memory when we perform the copy. + // src waits on dst barrier (src already waits on src) + ZoomEvent dst_ready; + device_guard.set_device(dst_device); + dst_ready.record(c10::zoom::getCurrentZoomStream(dst_device.index())); + + device_guard.set_device(src_device); + dst_ready.block(copy_stream); + } + + if (memcpy_eligible) { + void *dst = iter.data_ptr(0); + void *src = iter.data_ptr(1); + size_t size = numel * iter.element_size(0); + if (src != dst || src_device != dst_device) { + // Due to bizarre cuda driver intricacies, copies of + // hipMallocAsynced memory between devices that aren't + // peer-to-peer-capable need "hipMemcpyPeerAsync". + // So we let the allocator implement the correct call + // (either hipMemcpyAsync or hipMemcpyPeerAsync) + C10_ZOOM_CHECK(c10::zoom::ZoomCachingAllocator::memcpyAsync( + dst, dst_device.index(), + src, src_device.index(), + size, copy_stream, p2p_enabled)); + } + } else { + if (same_neg) { + if (!same_conj) { + conj_kernel_zoom(iter); + } else { + direct_copy_kernel_zoom(iter); + } + } else { + if (!same_conj) { + neg_conj_kernel_zoom(iter); + } else { + neg_kernel_zoom(iter); + } + } + } + + if (src_device != dst_device) { + // dst waits on src barrier (dst already waits on dst). We cannot + // operate on dst's copy until the copy is complete. + + // Still on src_device, record stream event + ZoomEvent src_ready; + src_ready.record(copy_stream); + + device_guard.set_device(dst_device); + src_ready.block(c10::zoom::getCurrentZoomStream(dst_device.index())); + } + + C10_ZOOM_CHECK(hipGetLastError()); +} + +static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) { + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + if (dst_device == src_device) { + // We never require temporaries for copies on the same GPU. + TORCH_INTERNAL_ASSERT(dst_device.is_privateuseone() && src_device.is_privateuseone()); + return false; + } + + bool same_dtype = iter.dtype(0) == iter.dtype(1); + if (same_dtype && iter.is_contiguous()) { + // Contiguous same-dtype copies can always use hipMemcpyAsync + return false; + } else if (dst_device.is_privateuseone() && src_device.is_privateuseone()) { + // Copies between GPUs can use the copy kernel if P2P is supported + return !p2p_enabled; + } else { + // The remaining cases require temporaries. For example, this includes + // non-contiguous copies between CPU and GPU. + return true; + } +} + +static bool maybe_enable_p2p_access(Device dst_device, Device src_device) { + if (dst_device.is_cpu() || src_device.is_cpu()) { + return false; + } + return at::zoom::get_p2p_access(src_device.index(), dst_device.index()); +} + +static void copy_kernel_zoom(TensorIterator& iter, bool non_blocking) { + TORCH_CHECK(iter.ntensors() == 2); + + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + // Enable p2p access between devices. (No-op if it involves the CPU) + bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device); + + if (copy_requires_temporaries(iter, p2p_enabled)) { + // NB: this involves recursive calls to copy. Be careful that those copies + // don't require temporaries or you will cause an infinite recursion! + auto& dst = iter.tensor(0); + Tensor dst_contig; + Tensor src_contig; + + // If non_blocking is true - type conversions are performed on the GPU + // For blocking transfers conversions are performed on CPU to avoid allocating + // extra GPU memory + // for GPU-GPU transfers conversions are performed on the source device + auto conversion_device = non_blocking ? DeviceType::PrivateUse1 : kCPU; + if (iter.device_type(1) == conversion_device) { + dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous(); + } else { + bool same_type = iter.dtype(0) == iter.dtype(1); + dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + src_contig = iter.tensor(1).expand_as(dst).contiguous(); + } + + // propagate the correct conjugate bit + dst_contig._set_conj(dst.is_conj()); + src_contig._set_conj(iter.tensor(1).is_conj()); + + dst_contig._set_neg(dst.is_neg()); + src_contig._set_neg(iter.tensor(1).is_neg()); + + // perform a same-dtype copy on contiguous tensors + TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes())); + TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type()); + dst_contig.copy_(src_contig, non_blocking); + + // if necessary, copy back into dst + if (!dst_contig.is_same(dst)) { + TORCH_INTERNAL_ASSERT(dst_contig.device() == dst.device()); + dst.copy_(dst_contig, non_blocking); + } + return; + } + + // Copy on GPU (or between GPUs) + if (dst_device.is_privateuseone() && src_device.is_privateuseone()) { + copy_device_to_device(iter, non_blocking, p2p_enabled); + return; + } + + // Copy between CPU and GPU + c10::zoom::OptionalZoomGuard device_guard; + hipMemcpyKind kind; + if (dst_device.is_privateuseone() && src_device.is_cpu()) { + device_guard.set_device(dst_device); + kind = hipMemcpyHostToDevice; + } else if (dst_device.is_cpu() && src_device.is_privateuseone()) { + device_guard.set_device(src_device); + kind = hipMemcpyDeviceToHost; + } else { + TORCH_INTERNAL_ASSERT(false, "unsupported devices in GPU copy_()"); + } + + void* dst = iter.data_ptr(0); + void* src = iter.data_ptr(1); + int64_t nbytes = iter.numel() * iter.element_size(0); + c10::zoom::ZoomStream stream = c10::zoom::getCurrentZoomStream(); + + if (non_blocking) { + C10_ZOOM_CHECK(hipMemcpyAsync(dst, src, nbytes, kind, stream)); + // we use both the storage context and the tensor data pointer as the key + // for the caching host allocator. This allows us to better attribute the + // events to the original tensor allocation correctly. The cases we seek to + // handle are: + + // 1: a user can pass a pinned memory tensor with an alternative + // context, for example if allocating memory directly from the pinned memory + // allocator and constructing a tensor with torch::from_blob. + + // 2: a user can pass a tensor with a different base pointer to the original + // allocation (via slicing). + const auto& dst_tensor = iter.tensor(0); + const auto& src_tensor = iter.tensor(1); + const auto& host_tensor = (dst_device == kCPU ? dst_tensor : src_tensor); + auto* ptr = (dst_device == kCPU ? dst : src); + auto* ctx = host_tensor.storage().data_ptr().get_context(); + // TODO: warn on the return value. + CachingHostAllocator_recordEvent(ptr, ctx, stream); + + } else { + c10::zoom::memcpy_and_sync(dst, src, nbytes, kind, stream); + } + + if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { + iter.tensor(0).conj_physical_(); + } + if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) { + iter.tensor(0).neg_(); + } +} + + REGISTER_PRIVATEUSE1_DISPATCH(copy_stub, ©_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Copy.h b/aten/src/ATen/native/zoom/Copy.h new file mode 100644 index 00000000000000..d7a7243b36dfdf --- /dev/null +++ b/aten/src/ATen/native/zoom/Copy.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + + namespace native { + + void direct_copy_kernel_zoom(TensorIteratorBase &iter); + + } +} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Equal.cpp b/aten/src/ATen/native/zoom/Equal.cpp new file mode 100644 index 00000000000000..00f6acf51d0b66 --- /dev/null +++ b/aten/src/ATen/native/zoom/Equal.cpp @@ -0,0 +1,49 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#endif + +namespace at::native { + +bool zoom_equal(const Tensor& self, const Tensor &src) { + if (!at::namedinference::are_names_equal( + self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) { + return false; + } + at::NoNamesGuard guard; + TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on " + "different devices. Got: ", self.device(), " and ", src.device()); + if (self.sizes() != src.sizes()) { + return false; + } + if (self.numel() == 0) { + return true; + } + + // This is the same optimization done in the cpu_equal. Since the flags like neg/conj should be already handled outside the + // cuda_equal, it should be safe to have the following fast path by + // ensuring the storage and strides exactly the same. + if (self.is_alias_of(src) + && self.storage_offset() == src.storage_offset() + && self.dtype() == src.dtype() + && self.is_contiguous() == src.is_contiguous() + && self.strides().equals(src.strides()) + // Extra checks to ensure the safety in case cuda_equal is directly called in C++. + && self.layout() == src.layout() + && self.is_neg() == src.is_neg() + && self.is_conj() == src.is_conj()) { + return true; + } + + return at::eq(self, src).all().item().to(); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/FillKernel.cu b/aten/src/ATen/native/zoom/FillKernel.cu new file mode 100644 index 00000000000000..24c0a00c54726b --- /dev/null +++ b/aten/src/ATen/native/zoom/FillKernel.cu @@ -0,0 +1,30 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct FillFunctor { + FillFunctor(scalar_t v): value(v) {} + __device__ __forceinline__ scalar_t operator() () const { + return value; + } + private: + scalar_t value; +}; + +void fill_kernel_zoom(TensorIterator& iter, const Scalar& value) { + AT_DISPATCH_V2(iter.dtype(), "fill_zoom", AT_WRAP([&]() { + gpu_kernel(iter, FillFunctor(value.to())); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +REGISTER_PRIVATEUSE1_DISPATCH(fill_stub, &fill_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/MiscUtils.h b/aten/src/ATen/native/zoom/MiscUtils.h new file mode 100644 index 00000000000000..257c488bd7e98e --- /dev/null +++ b/aten/src/ATen/native/zoom/MiscUtils.h @@ -0,0 +1,32 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include +#include +#include + +namespace at { +namespace native { + +static inline int zoom_int_cast(int64_t value, const char* varname) { + auto result = static_cast(value); + TORCH_CHECK(static_cast(result) == value, + "zoom_int_cast: The value of ", varname, "(", (long long)value, + ") is too large to fit into a int (", sizeof(int), " bytes)"); + return result; +} + +// Creates an array of size elements of type T, backed by pinned memory +// wrapped in a Storage +template +static inline Storage pin_memory(int64_t size) { + auto* allocator = zoom::getPinnedMemoryAllocator(); + int64_t adjusted_size = size * sizeof(T); + return Storage( + Storage::use_byte_size_t(), + adjusted_size, + allocator, + /*resizable=*/false); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/zoom/Nonzero.cu b/aten/src/ATen/native/zoom/Nonzero.cu new file mode 100644 index 00000000000000..d735795bcc1720 --- /dev/null +++ b/aten/src/ATen/native/zoom/Nonzero.cu @@ -0,0 +1,130 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include //for MAX_DIMS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + + +namespace at::native { + +namespace{ +template +struct NonZeroOp +{ + __host__ __device__ __forceinline__ bool operator()(const T& a) const { + return (a!=T(0)); + } +}; + +//TODO: actually support int64_t index_t +template +struct TensorDims { + index_t sizes[MAX_DIMS]; +}; + +template +__global__ void write_indices( + int64_t* inp, + TensorDims dims, + int ndim, + index_t n) { + auto index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + index_t div = 1; + int64_t idx_flat = inp[index]; +#pragma unroll + for (int dim = MAX_DIMS; dim >= 0; dim--) { + if (dim > ndim - 1) + continue; + auto dim_size = dims.sizes[dim]; + inp[index + dim * n] = (idx_flat / div) % dim_size; + div *= dim_size; + } + } +} + +} //anonymous namespace + +template +void nonzero_zoom_out_impl(const Tensor& self, Tensor& out){ + Tensor self_ = self.contiguous(); + int N = self_.numel(); + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); +// compute number of nonzero elements + size_t temp_storage_bytes=0; + auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); + auto num_nonzeros = allocator.allocate(sizeof(int)); + hipcub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); + hipcub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); + auto temp_storage = allocator.allocate(temp_storage_bytes); + hipcub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); + int num_nonzeros_h; + c10::zoom::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), hipMemcpyDeviceToHost, stream); + //expected output size is num_nonzeros x ndim + //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) + //we are able to directly use passed output with this size and strides, and we can also (per contract) + //resize passed output with incorrect sizes anyway we want. + //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. + bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); + at::Tensor out_temp = need_to_copy ? + Tensor(at::detail::empty_zoom({self.dim(), num_nonzeros_h}, out.options())) : + out.resize_({self.dim(), num_nonzeros_h}); + //Scalars are expected to produce output of size (1,0), so we can't write to it + if (self.dim() > 0) { + hipcub::CountingInputIterator counting_itr(0); + temp_storage_bytes = 0; + hipcub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, + out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); + temp_storage = allocator.allocate(temp_storage_bytes); + hipcub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr, + out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); + if (num_nonzeros_h > 0 && self.dim() > 1){ + TensorDims dims; + for (int i=0; i>>(out_temp.mutable_data_ptr(), + dims, self.dim(), num_nonzeros_h); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } + if (need_to_copy) { + out.copy_(out_temp.t()); + } else { + //transpose out so it is correct size + Tensor out_ = out_temp.t(); + out.set_(out_); + } +} + +Tensor& nonzero_out_zoom(const Tensor& self, Tensor& out){ + TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ + See https://github.com/pytorch/pytorch/issues/51871"); + TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); + TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", + out.device(), " and self on ", self.device()); + TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "nonzero_zoom", + [&] {nonzero_zoom_out_impl(self, out);}); + return out; +} + +Tensor nonzero_zoom(const Tensor& self){ + Tensor out = at::detail::empty_zoom({0}, self.options().dtype(kLong)); + return at::native::nonzero_out_zoom(self, out); +} +} //namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Resize.cpp b/aten/src/ATen/native/zoom/Resize.cpp new file mode 100644 index 00000000000000..da9a11971c86f3 --- /dev/null +++ b/aten/src/ATen/native/zoom/Resize.cpp @@ -0,0 +1,69 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void resize_bytes_zoom(StorageImpl* storage, size_t size_bytes) { + TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable"); + auto allocator = storage->allocator(); + TORCH_CHECK(allocator != nullptr, "Trying to resize storage without an allocator"); + + c10::Device device = storage->device(); + + if (size_bytes == 0) { + storage->set_data_ptr_noswap(at::DataPtr(nullptr, device)); + storage->set_nbytes(0); + return; + } + + c10::zoom::ZoomGuard guard(device.index()); + at::DataPtr data = allocator->allocate(size_bytes); + if (storage->data_ptr()) { + at::globalContext().lazyInitPrivateUse1(); + + C10_ZOOM_CHECK( + hipMemcpyAsync( + data.get(), + storage->data(), + std::min(storage->nbytes(), size_bytes), + hipMemcpyDeviceToDevice, + c10::zoom::getCurrentZoomStream())); + } + + // Destructively overwrite data_ptr + storage->set_data_ptr_noswap(std::move(data)); + storage->set_nbytes(size_bytes); +} + +const Tensor& resize_zoom_( + const Tensor& self, + IntArrayRef size, + std::optional optional_memory_format) { + if (self.has_names()) { + return resize_named_tensor_(self, size, optional_memory_format); + } + auto* self_ = self.unsafeGetTensorImpl(); + int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0; + resize_impl_zoom_(self_, size, /*strides=*/c10::nullopt); + if (optional_memory_format.has_value()) { + auto memory_format = + optional_memory_format.value(); + TORCH_CHECK( + memory_format != MemoryFormat::Preserve, + "Unsupported memory format", + memory_format); + self_->empty_tensor_restride(memory_format); + } + // See Note [Enabling Deterministic Operations] + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) { + at::native::fill_resize_deterministic_(self, old_storage_nbytes); + } + return self; +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Resize.h b/aten/src/ATen/native/zoom/Resize.h new file mode 100644 index 00000000000000..01c71e3fe861ab --- /dev/null +++ b/aten/src/ATen/native/zoom/Resize.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include + +namespace at { namespace native { + +TORCH_ZOOM_API void resize_bytes_zoom(StorageImpl* storage, size_t size_bytes); + +static inline void maybe_resize_storage_zoom(TensorImpl* self, size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage &storage = self->unsafe_storage(); + TORCH_CHECK(storage, "Tensor: invalid null storage"); + if (new_size_bytes > storage.nbytes()) { + resize_bytes_zoom(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +inline TensorImpl* resize_impl_zoom_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride, + bool device_guard = true) { + if (self->sizes() == size && (!stride || self->strides() == stride)) { + return self; + } + + // NB: We don't need to hold the device guard when calling from TH + c10::zoom::OptionalZoomGuard guard; + if (device_guard) { + guard.set_index(self->storage().device().index()); + } + + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->storage_offset(); + size_t storage_size = 1; + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + maybe_resize_storage_zoom(self, storage_size); + + return self; +} + +}} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorCompare.cpp b/aten/src/ATen/native/zoom/TensorCompare.cpp new file mode 100644 index 00000000000000..21847fa0b41229 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorCompare.cpp @@ -0,0 +1,23 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +namespace at::native { + +namespace { + +// Composite op implementation for simplicity. This materializes the cross product of elements and test elements, +// so it is not very memory efficient, but it is fast on CUDA. +void isin_default_kernel_gpu( + const Tensor& elements, const Tensor& test_elements, bool invert, const Tensor& out) { + std::vector bc_shape(elements.dim(), 1); + bc_shape.push_back(-1); + out.copy_(invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1) + : elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1)); +} + +} // anonymous namespace + +REGISTER_PRIVATEUSE1_DISPATCH(isin_default_stub, &isin_default_kernel_gpu); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/TensorCompare.cu b/aten/src/ATen/native/zoom/TensorCompare.cu new file mode 100644 index 00000000000000..e92d058c9b7222 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorCompare.cu @@ -0,0 +1,133 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + + +namespace at::native { + +namespace { + +void where_kernel_impl(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_zoom", [&] { + gpu_kernel( + iter, + [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { + return cond_val ? self_val : other_val; + }); + }); +} + +void isposinf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits::infinity(); } + ); + }); +} + +void isneginf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits::infinity(); } + ); + }); +} + +void clamp_kernel_impl(TensorIteratorBase& iter) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_zoom", [&] { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (at::_isnan(v)) { + return v; + } if (at::_isnan(lower)) { + return lower; + } if (at::_isnan(upper)) { + return upper; + } else { + return ::min(::max(v, lower), upper); + } + }); + }); +} + +void inline launch_clamp_scalar(TensorIteratorBase& iter, Scalar lim0, Scalar lim1, at::native::detail::ClampLimits minmax){ + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_zoom", [&] { + using opmath_t = at::opmath_type; + auto lim0_val = lim0.to(); + auto lim1_val = lim1.to(); + + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(static_cast(v))) { + return v; + } else if (minmax==at::native::detail::ClampLimits::Min){ + return ::max(static_cast(v), lim0_val); + } else if (minmax==at::native::detail::ClampLimits::Max){ + return ::min(static_cast(v), lim0_val); + } else { + return ::min(::max(static_cast(v), lim0_val), lim1_val); + } + }); + }); +} + + +void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min, const Scalar& max) { + launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax); +} + +void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min) { + launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min); +} + +void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max) { + launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); +} + +} // anonymous namespace + + +REGISTER_PRIVATEUSE1_DISPATCH(where_kernel, &where_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isposinf_stub, &isposinf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isneginf_stub, &isneginf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_stub, &clamp_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); + +template +__global__ void _assert_async_zoom_kernel(const scalar_t* input) { + ZOOM_KERNEL_ASSERT(input[0] != 0); +} + +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} + +void _assert_async_zoom(const Tensor& self_tensor) { + const TensorBase &self = get_tensor_base(self_tensor); + auto n = self.numel(); + TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); + TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); + auto stream = c10::zoom::getCurrentZoomStream(); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_zoom", [&] { + _assert_async_zoom_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +// TODO (tmanlaibaatar) Ignore assert msg for now +void _assert_async_msg_zoom(const Tensor& self_tensor, c10::string_view assert_msg) { + _assert_async_zoom(self_tensor); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorFactories.cu b/aten/src/ATen/native/zoom/TensorFactories.cu new file mode 100644 index 00000000000000..7cf9b0d7ec2417 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorFactories.cu @@ -0,0 +1,396 @@ + +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include + +namespace at::native { + +Tensor& eye_out_zoom(int64_t n, Tensor& result) { + // the default value of `m` equals to `n` + return at::native::eye_out_zoom(n, n, result); +} + +Tensor& eye_out_zoom(int64_t n, int64_t m, Tensor& result) { + TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); + + result.resize_({n, m}); + result.zero_(); + + int64_t sz = std::min(n, m); + int64_t stride = result.stride(0) + result.stride(1); + + Tensor diag = result.as_strided({sz}, {stride}); + diag.fill_(1); + return result; +} + +Tensor empty_zoom(IntArrayRef size, std::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + Tensor result = at::detail::zoom_empty_memory_format(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); + // See Note [Enabling Deterministic Operations] + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) { + fill_empty_deterministic_(result); + } + return result; +} + +Tensor _efficientzerotensor_zoom(IntArrayRef size, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory) { + auto device_ = device_or_default(device); + if (!device_.has_index()) { + device_.set_index(c10::zoom::current_device()); + } + auto allocator = at::native::ZeroTensorAllocator(device_); + auto dtype_ = dtype_or_default(dtype); + auto zero_ks = at::DispatchKeySet(c10::DispatchKey::PrivateUse1) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor); + auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt); + return out; +} + + +Tensor empty_strided_zoom(IntArrayRef size, IntArrayRef stride, std::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + Tensor result = at::detail::zoom_empty_strided(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); + // See Note [Enabling Deterministic Operations] + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) { + fill_empty_deterministic_(result); + } + return result; +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +namespace { +// To find the max integer that does not exceed the root of an int64_t variable, +// we could use a loop to test one bit at a time, which takes up to 31 +// iterations. This would give the accurate result, but is relatively slow and +// is an overkill for most cases where double's precision suffice. +// +// If we directly use sqrt to calculate the root, the conversion from int64_t +// to double would lose 11 bits precision. +// +// The following solution uses sqrt directly for most cases, and would only +// special handle it if there is indeed precision loss. +__device__ +inline int64_t resolve_root_int( + int64_t b, int64_t cX4, int64_t x, int32_t sign) { + int64_t bXb_cX4 = b*b - cX4; + // potential precision loss could occur here when casting int64_t (63 bits + // precision) to double (52 bits precision) + double sr = ::sqrt((double)bXb_cX4); + int64_t res = ::__double2ll_rd((-b + sign * sr)/2); + + // have to cast double to int64_t, otherwise it would only compare up to the + // precision of a double variable, ignoring the precision loss + if (bXb_cX4 != (int64_t) (sr * sr)) { + // handle precision loss by using binary search + int64_t llsr = ::__double2ll_rd(sr); + // Use the following math to reduce search space. + // Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss + // let d = abs(bXb_cX4 - llsr * llsr), then we have: + // z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d) + // z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d) + // Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)). + // And the true value of row would also be with in range, + // [res - sqrt(d), res + sqrt(d) + 1) + // as the denominator would only reduce the precision penalty. + int64_t diff = + ::__double2ll_ru(::sqrt(::fabs((double)(bXb_cX4 - llsr * llsr)))); + // l never exceeds (could equal to) the target row index + auto l = res > diff ? res - diff : 0; + // r is always larger than the target row index + auto r = res + diff + 1; + + // binary search for the correct answer + x <<= 1; // the loop always compares with 2x, so do it once here + while (l + 1 < r) { + auto m = (l + r) >> 1; + // for tril: + // b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2 + // for triu: + // b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2 + if (sign * (b + m) * m > x) { + r = m; + } else { + l = m; + } + } + res = l; + } + + return res; +} + +// f: the number of elements in the first row of the trapezoid. +// x: the index of the target coordinates ordered by row and then column. +// +// View the tril as a top trapezoid stacked on a bottom rectangle. Assume x +// corresponds to the coordinate (row, col) in the trapezoid, where the row and +// the col both start from 0, then we have: +// +// (f + f + row - 1) * row / 2 <= x [1] +// (f + f + row) * (row + 1) / 2 > x [2] +// +// Therefore, row is the maximum integer satisfying the following inequality: +// +// (row + 2f - 1)row <= 2x +// row^2 + (2f-1)row - 2x <= 0. [3] +// +// Based on inequality [3], we have the following coefficients for formula of +// root: +// a = 1 +// b = 2f - 1 +// c = -2x +// There are two roots, and we should use the largest integer that does not +// exceed the root on the right. Intuitively, it is because: +// i) the valid solution range of row is between two roots, as it is <= 0; +// ii) as we count in more rows, the total # of elements should always +// increase, hence so does the left-hand side row^2 + (2f-1)row - 2x. +// Therefore, the valid range of row lies in between the nadir point and +// the larger root on the right. +// Full proof can be derived from inequality [2]. So, we calculate the result +// coordinate as: +// +// row = floor((-b + sqrt(b^2 - 4c)) / 2) +// col = x - (f + f + row - 1) * row / 2 +__device__ +inline void get_coordinate_in_tril_trapezoid( + int64_t f, int64_t x, int64_t & row, int64_t & col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = f - 1; + auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x; + row = resolve_root_int(b, cX4, x, 1); + col = x - ((f + row - 1) * row >> 1); +} + +// f: the number of elements in the first row of the bottom trapezoid. +// x: the index of the target coordinates ordered by row and then column. +// +// View the triu as a top rectangle stacked on a bottom trapezoid, where the +// trapezoid is upside down. Assume x corresponds to the coordinate (row, col) +// in the bottom trapezoid, where the row and the col start from 0, then we +// have: +// +// (f + f - row + 1) * row / 2 <= x [1] +// (f + f - row) * (row + 1) / 2 > x [2] +// +// Therefore, row is the maximum integer satisfying the following inequality: +// +// (-row + 2f + 1)row <= 2x +// row^2 - (2f+1)row + 2x >= 0. [3] +// +// Based on inequality [3], we have the following coefficients for formula of +// root: +// a = 1 +// b = -1 - 2f +// c = 2x +// There are two roots, and we should use the largest integer that does not +// exceed the root on the left. Intuitively, it is because: +// i) the valid solution range of row is outside of the two roots, as it is < +// > 0; +// ii) as we count in more rows, the total # of elements should always +// increase, hence so does the left-hand side row^2 - (2f+1)row + 2x. +// Therefore, the valid range of row lies to the left of the smaller root +// on the left. +// Full proof can be derived from inequality [2]. So, we calculate the result +// coordinate as: +// +// row = floor((-b - sqrt(b^2 - 4c)) / 2) +// col = x - (f + f - row + 1) * row / 2 +__device__ +inline void get_coordinate_in_triu_trapezoid( + int64_t f, int64_t x, int64_t & row, int64_t & col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = -1 - f; + auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x; + row = resolve_root_int(b, cX4, x, -1); + col = x - ((f - row + 1) * row >> 1) + row; +} + +} // namespace + +template +__global__ +C10_LAUNCH_BOUNDS_1(512) +void tril_indices_kernel(scalar_t * tensor, + int64_t row_offset, + int64_t m_first_row, + int64_t col, + int64_t trapezoid_size, + int64_t tril_size) { + int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_index < tril_size) { + int64_t r, c; + if (linear_index < trapezoid_size) { + // the coordinate is within the top trapezoid + get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c); + } else { + // the coordinate falls in the bottom rectangle + auto surplus = linear_index - trapezoid_size; + // add the height of trapezoid: m_last_row (col) - m_first_row + 1 + r = surplus / col + col - m_first_row + 1; + c = surplus % col; + } + r += row_offset; + + tensor[linear_index] = r; + tensor[linear_index + tril_size] = c; + } +} + +// Some Large test cases for the fallback binary search path is disabled by +// default to speed up CI tests and to avoid OOM error. When modifying the +// implementation, please enable them in test/test_cuda.py and make sure they +// pass on your local server. +Tensor tril_indices_zoom( + int64_t row, int64_t col, int64_t offset, std::optional dtype_opt, + std::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); + + auto tril_size = get_tril_size(row, col, offset); + auto tensor = empty_zoom({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); + + if (tril_size > 0) { + auto m_first_row = offset > 0 ? + std::min(col, 1 + offset) : // upper bounded by col + row + offset > 0; // either 0 or 1 + auto trapezoid_row_offset = std::max(0, -offset); + auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1; + int64_t rectangle_size = 0; + if (rectangle_row_offset < row) { + rectangle_size = (row - rectangle_row_offset) * col; + } + + dim3 dim_block = zoom::getApplyBlock(); + dim3 dim_grid; + // using tril_size instead of tensor.numel(), as each thread takes care of + // two elements in the tensor. + TORCH_CHECK( + zoom::getApplyGrid(tril_size, dim_grid, tensor.get_device()), + "unable to get dim grid"); + + AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "tril_indices_zoom", [&] { + hipLaunchKernelGGL(( tril_indices_kernel), + dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), + tensor.mutable_data_ptr(), + trapezoid_row_offset, + m_first_row, + col, + tril_size - rectangle_size, + tril_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } + + return tensor; +} + +template +__global__ +void triu_indices_kernel(scalar_t * tensor, + int64_t col_offset, + int64_t m_first_row, + int64_t col, + int64_t rectangle_size, + int64_t triu_size) { + int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_index < triu_size) { + int64_t r, c; + if (linear_index < rectangle_size) { + // the coordinate is within the top rectangle + r = linear_index / col; + c = linear_index % col; + } else { + // the coordinate falls in the bottom trapezoid + get_coordinate_in_triu_trapezoid( + m_first_row, linear_index - rectangle_size, r, c); + r += rectangle_size / col; + } + + c += col_offset; + tensor[linear_index] = r; + tensor[linear_index + triu_size] = c; + } +} + +// Some Large test cases for the fallback binary search path is disabled by +// default to speed up CI tests and to avoid OOM error. When modifying the +// implementation, please enable them in test/test_cuda.py and make sure they +// pass on your local server. +Tensor triu_indices_zoom( + int64_t row, int64_t col, int64_t offset, std::optional dtype_opt, + std::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); + + auto triu_size = row * col - get_tril_size(row, col, offset - 1); + auto tensor = empty_zoom({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); + + if (triu_size > 0) { + // # of triu elements in the first row + auto m_first_row = offset > 0 ? + std::max(col - offset, 0) : // upper bounded by col + col; + + // size of the top rectangle + int64_t rectangle_size = 0; + if (offset < 0) { + rectangle_size = std::min(row, -offset) * col; + } + + dim3 dim_block = zoom::getApplyBlock(); + dim3 dim_grid; + + // using triu_size instead of tensor.numel(), as each thread takes care of + // two elements in the tensor. + TORCH_CHECK( + zoom::getApplyGrid(triu_size, dim_grid, tensor.get_device()), + "unable to get dim grid"); + + AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "triu_indices_zoom", [&] { + hipLaunchKernelGGL(( triu_indices_kernel), + dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), + tensor.mutable_data_ptr(), + std::max(0, offset), + m_first_row, + col, + rectangle_size, + triu_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } + + return tensor; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/TensorShape.cu b/aten/src/ATen/native/zoom/TensorShape.cu new file mode 100644 index 00000000000000..5fad25d8a76179 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorShape.cu @@ -0,0 +1,833 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { + +namespace detail { + +// NOTE [CUDA fast path for split_with_sizes_copy.out] +// split_with_sizes_copy.out for contiguous operands has the following +// properties: +// - Each src split consists of multiple chunks that are separated by a fixed +// stride. The number of chunks and the strides are the same across all src +// splits. +// - Each dst split is the concatenation of the chunks in its corresponding src +// splits. +// - The sizes of chunks vary across splits. +// - A (src, dst) chunk pair is not guaranteed to have the +// same alignment. +// +// The following strategies are employed to optimize for this workload: +// - The entire workload is fused into a single kernel to maximize I/O +// throughput and minimize wave quantization. +// - To account for both small and large chunk sizes, a "jagged grid" is used. +// Each chunk is processed by one or more blocks depending on its size. +// - Within each chunk, the region in which writes can be vectorized is +// identified. Within this region, writes are always vectorized and reads are +// oppurtunistically vectorized. +static constexpr int64_t BLOCK_SIZE = 128; +static constexpr int64_t BYTES_PER_THREAD = 16; +static constexpr int64_t BYTES_PER_BLOCK = BYTES_PER_THREAD * BLOCK_SIZE; + +static __host__ __device__ inline int64_t div_up(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +template +__device__ inline void stream_load128(uint4& val, const T* addr) { + uint64_t low, high; + low = reinterpret_cast(addr)[0]; + high = reinterpret_cast(addr)[1]; + reinterpret_cast(&val)[0] = low; + reinterpret_cast(&val)[1] = high; +} + +template +__device__ inline void stream_store128(T* addr, const uint4& val) { + uint64_t low, high; + low = reinterpret_cast(&val)[0]; + high = reinterpret_cast(&val)[1]; + reinterpret_cast(addr)[0] = low; + reinterpret_cast(addr)[1] = high; +} + +template +static __device__ inline bool is_aligned(const void* addr) { + return reinterpret_cast(addr) % sizeof(T) == 0; +} + +template +static __device__ inline void load128(uint4& val, const char* addr) { + for (size_t i = 0; i < detail::BYTES_PER_THREAD / sizeof(T); ++i) { + reinterpret_cast(&val)[i] = reinterpret_cast(addr)[i]; + } +} + +template <> +__device__ inline void load128(uint4& val, const char* addr) { + stream_load128(val, addr); +} + +static __device__ inline void load128(uint4& val, const char* addr) { + if (is_aligned(addr)) { + load128(val, addr); + } else if (is_aligned(addr)) { + load128(val, addr); + } else if (is_aligned(addr)) { + load128(val, addr); + } else { + load128(val, addr); + } +} + +static __device__ __inline__ void get_aligned_region( + char* ptr, + const int64_t chunk_size, + const int64_t alignment, + int64_t& align_off, + int64_t& aligned_size) { + const int64_t ptr_val = reinterpret_cast(ptr); + align_off = detail::div_up(ptr_val, alignment) * alignment - ptr_val; + aligned_size = (chunk_size - align_off) / alignment * alignment; +} + +static __device__ __inline__ void copy_chunk( + char* dst, + const char* src, + int64_t chunk_size, + int64_t thread_idx, + int64_t num_threads) { + if (chunk_size < num_threads) { + if (thread_idx < chunk_size) { + dst[thread_idx] = src[thread_idx]; + } + return; + } + + // Identify the region in which writes are guaranteed to be 128-bit aligned + int64_t align_off, aligned_size; + get_aligned_region( + dst, chunk_size, detail::BYTES_PER_THREAD, align_off, aligned_size); + + for (int64_t off = align_off + thread_idx * detail::BYTES_PER_THREAD; + off < align_off + aligned_size; + off += num_threads * detail::BYTES_PER_THREAD) { + uint4 val; + // Oppurtunistically vectorize reads + load128(val, &src[off]); + stream_store128(&dst[off], val); + } + + // Handle unaligned regions + if (thread_idx < align_off && thread_idx < chunk_size) { + dst[thread_idx] = src[thread_idx]; + } + if (align_off + aligned_size + thread_idx < chunk_size) { + dst[align_off + aligned_size + thread_idx] = + src[align_off + aligned_size + thread_idx]; + } +} + +static __global__ void split_with_sizes_copy_out_contiguous_no_cast_kernel( + char** dst_base_addrs, + char** src_base_addrs, + int64_t* split_chunk_sizes, + int64_t* block_idx_to_split_idx, + int64_t* blocks_cumsums, + int64_t src_stride, + int64_t num_chunks) { + const int64_t split_idx = block_idx_to_split_idx[blockIdx.x]; + const int64_t split_blocks = + blocks_cumsums[split_idx + 1] - blocks_cumsums[split_idx]; + const int64_t split_threads = split_blocks * blockDim.x; + const int64_t split_thread_idx = + (blockIdx.x - blocks_cumsums[split_idx]) * blockDim.x + threadIdx.x; + const int64_t split_chunk_size = split_chunk_sizes[split_idx]; + + char* dst_base_addr = dst_base_addrs[split_idx]; + char* src_base_addr = src_base_addrs[split_idx]; + + for (int64_t i = blockIdx.y; i < num_chunks; i += gridDim.y) { + copy_chunk( + dst_base_addr + i * split_chunk_size, + src_base_addr + i * src_stride, + split_chunk_size, + split_thread_idx, + split_threads); + } +} + +// Calculate the base addr for each split. +static inline std::vector get_split_base_addrs( + const at::Tensor& tensor, + at::IntArrayRef split_sizes, + int64_t dim) { + const auto* data_ptr = static_cast(tensor.const_data_ptr()); + const auto strides = tensor.strides(); + const auto element_sz = tensor.element_size(); + int64_t off = 0; + std::vector split_base_addrs; + split_base_addrs.reserve(split_sizes.size()); + for (const auto& split_size : split_sizes) { + split_base_addrs.push_back(reinterpret_cast(data_ptr + off)); + off += split_size * strides[dim] * element_sz; + } + return split_base_addrs; +} + +static inline std::vector get_dst_addrs(at::TensorList out) { + std::vector addrs; + addrs.reserve(out.size()); + for (const auto& tensor : out) { + addrs.push_back(reinterpret_cast(tensor.data_ptr())); + } + return addrs; +} + +// Calculate the chunk size for each split in bytes. +static inline std::vector get_split_chunk_sizes( + const at::Tensor& tensor, + at::IntArrayRef split_sizes, + int64_t dim) { + const auto stride = tensor.stride(dim); + const auto element_sz = tensor.element_size(); + std::vector split_chunk_sizes; + split_chunk_sizes.reserve(split_sizes.size()); + for (const auto& split_size : split_sizes) { + split_chunk_sizes.push_back(split_size * stride * element_sz); + } + return split_chunk_sizes; +} + +// Calculate the chunk stride in bytes. This is the same for all splits. +static inline int64_t get_chunk_stride(const at::Tensor& tensor, int64_t dim) { + int64_t stride = 1; + for (int64_t d = dim; d < tensor.dim(); ++d) { + stride *= tensor.sizes()[d]; + } + return stride * tensor.element_size(); +} + +// Calculate the number of chunks. This is the same for all splits. +static inline int64_t get_num_chunks(const at::Tensor& tensor, int64_t dim) { + int64_t num_chunks = tensor.numel(); + for (int64_t d = dim; d < tensor.dim(); ++d) { + num_chunks /= tensor.sizes()[d]; + } + return num_chunks; +} + +// Pack multiple std::vector into a single zoom tensor. +std::pair> pack_vecs( + std::vector*> vecs, + const at::Device& device) { + int64_t numel = 0; + for (const auto* vec : vecs) { + numel += vec->size(); + } + + auto packed = at::empty( + {numel}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + size_t offset = 0; + for (const auto* vec : vecs) { + memcpy( + packed.data_ptr() + offset, + vec->data(), + sizeof(int64_t) * vec->size()); + offset += vec->size(); + } + packed = packed.to(device, /*non_blocking=*/true); + + std::vector ptrs; + ptrs.reserve(vecs.size()); + offset = 0; + for (const auto* vec : vecs) { + ptrs.push_back(packed.data_ptr() + offset); + offset += vec->size(); + } + return std::make_pair(std::move(packed), std::move(ptrs)); +} + +static inline std::vector get_chunk_cat_out_sizes( + IntArrayRef input_tensor_sizes, + int64_t dim, + int64_t num_chunks, + int64_t chunk_size, + int64_t out_element_size) { + std::vector view_sizes = std::vector( + input_tensor_sizes.begin(), input_tensor_sizes.begin() + dim); + view_sizes.insert( + view_sizes.end(), {num_chunks, chunk_size / out_element_size}); + return view_sizes; +} + +// Copy `max_chunk_size` bytes from `src` to `dst` by `num_threads`, and pad +// zero when `src` size (i.e., actual_chunk_size) is less than `max_chunk_size`. +// Assume elements of src and dst have the same data type. +template +__device__ __inline__ void copy_chunk_with_pad( + dst_t* dst_ptr, + src_t* src_ptr, + int64_t max_chunk_size, + int64_t actual_chunk_size, + int64_t thread_idx, + int64_t num_threads) { + // Supports type cast + if (!std::is_same_v) { + const int64_t max_num_elems = max_chunk_size / sizeof(dst_t); + const int64_t actual_num_elems = actual_chunk_size / sizeof(src_t); + int64_t elem_index = thread_idx; + while (elem_index < actual_num_elems) { + dst_ptr[elem_index] = + static_cast_with_inter_type::apply(src_ptr[elem_index]); + elem_index += num_threads; + } + while (elem_index < max_num_elems) { + dst_ptr[elem_index] = static_cast_with_inter_type::apply(0); + elem_index += num_threads; + } + return; + } + char* dst = reinterpret_cast(dst_ptr); + char* src = reinterpret_cast(src_ptr); + // Fast path when the number of threads is larger than the number of bytes to + // be copied (i.e., max_chunk_size). In this case, each thread only copies 1 + // byte. For 0 <= thread_idx < actual_chunk_size, the thread copies data from + // `src`. For actual_chunk_size <= thread_idx < max_chunk_size, the thread set + // the val=0 for padding. + if (max_chunk_size < num_threads) { + char val = static_cast(0); + if (thread_idx < actual_chunk_size) { + val = src[thread_idx]; + } + if (thread_idx < max_chunk_size) { + dst[thread_idx] = val; + } + return; + } + // Split dst array into three parts: + // [dst, dst+align_off), [dst+align_off, dst+align_end), [dst+align_end, + // dst+max_chunk_size) The second part is aligned with BYTES_PER_THREAD(=16 + // bytes) to enable `stream_store128`. + int64_t align_off, aligned_size; + get_aligned_region( + dst, actual_chunk_size, BYTES_PER_THREAD, align_off, aligned_size); + int64_t align_end = align_off + aligned_size; + for (int64_t i = align_off + thread_idx * BYTES_PER_THREAD; i < align_end; + i += num_threads * BYTES_PER_THREAD) { + uint4 val; + if (is_aligned(src + i)) { + stream_load128(val, src + i); + } else { + for (size_t j = 0; j < BYTES_PER_THREAD; ++j) { + reinterpret_cast(&val)[j] = src[i + j]; + } + } + stream_store128(&dst[i], val); + } + // Copy data for the first part of dst array [dst, dst+align_off). + // Check `thread_idx +static __global__ void chunk_cat_zoom_kernel( + src_t** src, + dst_t* dst, + int64_t* block_idx_to_tensor_idx, + int64_t* tensor_idx_to_start_tensor_bytes, + int64_t* start_block_idx_per_tensor_chunk, + int64_t* actual_tensor_sizes, + int64_t* pad_tensor_chunk_sizes, + int64_t* num_blocks_per_tensor_chunk, + int64_t slice_size, + int64_t chunk_size, + int64_t dst_to_src_ratio) { + const int64_t slice_idx = blockIdx.z; + const int64_t chunk_idx = blockIdx.y; + const int64_t tensor_idx = block_idx_to_tensor_idx[blockIdx.x]; + const int64_t tile_idx = + blockIdx.x - start_block_idx_per_tensor_chunk[tensor_idx]; + // Number of threads for the `tensor_idx`-th tensor chunk. + const int64_t num_threads = + num_blocks_per_tensor_chunk[tensor_idx] * BLOCK_SIZE; + const int64_t thread_idx = tile_idx * BLOCK_SIZE + threadIdx.x; + char* src_addr = reinterpret_cast(src)[tensor_idx] + + slice_idx * actual_tensor_sizes[tensor_idx] + + chunk_idx * pad_tensor_chunk_sizes[tensor_idx] / dst_to_src_ratio; + char* dst_addr = reinterpret_cast(dst) + slice_idx * slice_size + + chunk_idx * chunk_size + tensor_idx_to_start_tensor_bytes[tensor_idx]; + // Compute the actual number of bytes to copy from src. + const int64_t actual_copy_size = ::min( + pad_tensor_chunk_sizes[tensor_idx] / dst_to_src_ratio, + ::max( + (int64_t)0, + actual_tensor_sizes[tensor_idx] - + chunk_idx * pad_tensor_chunk_sizes[tensor_idx] / + dst_to_src_ratio)); + copy_chunk_with_pad( + reinterpret_cast(dst_addr), + reinterpret_cast(src_addr), + pad_tensor_chunk_sizes[tensor_idx], + actual_copy_size, + thread_idx, + num_threads); +} + +bool all_contiguous(TensorList tensors) { + bool contiguous = true; + for (const auto& t : tensors) { + contiguous &= t.is_non_overlapping_and_dense(); + } + return contiguous; +} + +// Get leading dimensions before `dim`-th dimension. +static inline int64_t get_leading_dim(at::IntArrayRef sizes, int64_t dim) { + int64_t leading_dim = 1; + if (dim > 0) { + leading_dim = c10::multiply_integers(sizes.slice(0, dim)); + } + return leading_dim; +} + +// Get trailing dimensions after `dim`-th dimension and padded size along +// `dim`-th dimension. +static inline std::pair get_pad_size( + at::IntArrayRef sizes, + int64_t dim, + int64_t num_chunks) { + int64_t trailing_numel = 1; + if (sizes.size() > (uint64_t)dim + 1) { + trailing_numel = + c10::multiply_integers(sizes.slice(dim + 1, sizes.size() - dim - 1)); + } + int64_t pad_size_along_dim = + detail::div_up(sizes[dim], num_chunks) * num_chunks; + return std::make_pair(pad_size_along_dim, trailing_numel); +} + +// Get the padded chunk size. +static inline int64_t get_chunk_size( + TensorList tensors, + int64_t dim, + int64_t num_chunks, + int64_t elem_size) { + auto num_tensors = tensors.size(); + int64_t chunk_size = 0; + for (const auto i : c10::irange(num_tensors)) { + auto [pad_size_along_dim, trailing_numel] = + get_pad_size(tensors[i].sizes(), dim, num_chunks); + const int64_t pad_tensor_chunk_size = + pad_size_along_dim * trailing_numel * elem_size / num_chunks; + chunk_size += pad_tensor_chunk_size; + } + return chunk_size; +} + +// Get metadata for chunk_cat. +std::tuple< + int64_t, + int64_t, + int64_t, + int64_t, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector> +get_chunk_cat_metadata( + TensorList tensors, + int64_t dim, + int64_t num_chunks, + int64_t dst_elem_size, + int64_t src_elem_size) { + TORCH_CHECK( + dst_elem_size % src_elem_size == 0, + "get_chunk_cat_metadata error: only support dst_elem_size % src_elem_size == 0"); + auto num_tensors = tensors.size(); + int64_t leading_dim = get_leading_dim(tensors[0].sizes(), dim); + std::vector pad_tensor_chunk_sizes; + std::vector num_blocks_per_tensor_chunk; + std::vector start_block_idx_per_tensor_chunk{0}; + std::vector actual_tensor_sizes; + std::vector tensor_idx_to_start_tensor_bytes{0}; + std::vector srcs; + pad_tensor_chunk_sizes.reserve(num_tensors); + num_blocks_per_tensor_chunk.reserve(num_tensors); + start_block_idx_per_tensor_chunk.reserve(num_tensors + 1); + actual_tensor_sizes.reserve(num_tensors); + tensor_idx_to_start_tensor_bytes.reserve(num_tensors + 1); + srcs.reserve(num_tensors); + // block_idx_to_tensor_idx cannot be reserved since the number of blocks is + // data dependent + std::vector block_idx_to_tensor_idx; + // Inline computing `chunk_size` to avoid redundant computation + int64_t chunk_size = 0; + for (const auto i : c10::irange(num_tensors)) { + at::Tensor tensor = tensors[i]; + srcs.push_back(reinterpret_cast(tensor.data_ptr())); + auto sizes = tensor.sizes(); + auto [pad_size_along_dim, trailing_numel] = + get_pad_size(sizes, dim, num_chunks); + const int64_t pad_tensor_chunk_size = + pad_size_along_dim * trailing_numel * dst_elem_size / num_chunks; + pad_tensor_chunk_sizes.push_back(pad_tensor_chunk_size); + chunk_size += pad_tensor_chunk_size; + // Number of blocks required to process this tensor chunk. + const int64_t num_blocks = + detail::div_up(pad_tensor_chunk_size, detail::BYTES_PER_BLOCK); + num_blocks_per_tensor_chunk.push_back(num_blocks); + start_block_idx_per_tensor_chunk.push_back( + start_block_idx_per_tensor_chunk.back() + num_blocks); + block_idx_to_tensor_idx.insert( + block_idx_to_tensor_idx.end(), num_blocks, i); + tensor_idx_to_start_tensor_bytes.push_back( + tensor_idx_to_start_tensor_bytes.back() + pad_tensor_chunk_size); + actual_tensor_sizes.push_back(sizes[dim] * trailing_numel * src_elem_size); + } + const int64_t num_blocks_per_chunk = start_block_idx_per_tensor_chunk.back(); + const int64_t slice_size = num_chunks * chunk_size; + return std::make_tuple( + chunk_size, + leading_dim, + num_blocks_per_chunk, + slice_size, + srcs, + block_idx_to_tensor_idx, + tensor_idx_to_start_tensor_bytes, + start_block_idx_per_tensor_chunk, + actual_tensor_sizes, + pad_tensor_chunk_sizes, + num_blocks_per_tensor_chunk); +} + +// See [CUDA kernel for chunk_cat_cuda] +template +void _chunk_cat_out_zoom_contiguous( + TensorList tensors, + int64_t dim, + int64_t num_chunks, + Tensor& out, + int64_t dst_elem_size, + int64_t src_elem_size) { + const auto device = tensors[0].device(); + // `get_chunk_cat_metadata` must return vectors and `pack_vecs` cannot be + // moved into `get_chunk_cat_metadata`. Otherwise `packed` would point to + // vectors allocated inside `get_chunk_cat_metadata` which become out of local + // scope. + auto + [chunk_size, + leading_dim, + num_blocks_per_chunk, + slice_size, + srcs, + block_idx_to_tensor_idx, + tensor_idx_to_start_tensor_bytes, + start_block_idx_per_tensor_chunk, + actual_tensor_sizes, + pad_tensor_chunk_sizes, + num_blocks_per_tensor_chunk] = + get_chunk_cat_metadata( + tensors, dim, num_chunks, dst_elem_size, src_elem_size); + auto packed = pack_vecs( + {&srcs, + &block_idx_to_tensor_idx, + &tensor_idx_to_start_tensor_bytes, + &start_block_idx_per_tensor_chunk, + &actual_tensor_sizes, + &pad_tensor_chunk_sizes, + &num_blocks_per_tensor_chunk}, + device); + std::vector view_sizes = get_chunk_cat_out_sizes( + tensors[0].sizes(), dim, num_chunks, chunk_size, dst_elem_size); + at::native::resize_output(out, view_sizes); + dim3 blocks(num_blocks_per_chunk, num_chunks, leading_dim); + dim3 threads(detail::BLOCK_SIZE, 1, 1); + hipLaunchKernelGGL(( detail::chunk_cat_zoom_kernel), + dim3(blocks), + dim3(threads), + 0, + c10::zoom::getCurrentZoomStream(), + /*srcs=*/reinterpret_cast(packed.second[0]), + reinterpret_cast(out.data_ptr()), + /*block_idx_to_tensor_idx=*/packed.second[1], + /*tensor_idx_to_start_tensor_bytes=*/packed.second[2], + /*start_block_idx_per_tensor_chunk=*/packed.second[3], + /*actual_tensor_sizes=*/packed.second[4], + /*pad_tensor_chunk_sizes=*/packed.second[5], + /*num_blocks_per_tensor_chunk=*/packed.second[6], + slice_size, + chunk_size, + dst_elem_size / src_elem_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +} // namespace detail + +// See [CUDA fast path for split_with_sizes_copy.out] +void split_with_sizes_copy_out_zoom_contiguous_no_cast( + const at::Tensor& self, + at::IntArrayRef split_sizes, + int64_t dim, + at::TensorList out) { + const auto device = self.device(); + const auto src_base_addrs = + detail::get_split_base_addrs(self, split_sizes, dim); + const auto dst_base_addrs = detail::get_dst_addrs(out); + const auto src_stride = detail::get_chunk_stride(self, dim); + const auto split_chunk_sizes = + detail::get_split_chunk_sizes(self, split_sizes, dim); + const auto num_chunks = detail::get_num_chunks(self, dim); + + // Calculate the number of blocks required for the first chunk across all + // splits, assuming each thread only processes BYTES_PER_THREAD bytes. + int64_t num_blocks = 0; + for (const auto& split_chunk_size : split_chunk_sizes) { + num_blocks += detail::div_up( + split_chunk_size, detail::BLOCK_SIZE * detail::BYTES_PER_THREAD); + } + + // Calculate the maximum number of blocks to launch. Only consider + // maxThreadsPerMultiProcessor as a limiting factor as the kernel uses no + // shared memory and little registers. Over-subscribe the SMs to hide I/O + // latency. + const auto num_sms = + at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + const auto max_threads_per_sm = + at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; + const int64_t max_blocks = + num_sms * max_threads_per_sm / detail::BLOCK_SIZE * 2.0; + + // Make each thread process BYTES_PER_THREAD * iter_factor bytes to regulate + // block size. Spread iter_factor evenly between chunks_per_block and + // iters_per_chunk. + int64_t iter_factor = detail::div_up(num_blocks * num_chunks, max_blocks); + int64_t chunks_per_block = ::ceil(std::sqrt(iter_factor)); + chunks_per_block = ::min(chunks_per_block, num_chunks); + const int64_t iters_per_chunk = detail::div_up(iter_factor, chunks_per_block); + + // Launch a logically jagged grid of shape + // (chunk_size*, num_splits, num_chunks / chunks_per_block) + // backed by a physical grid of shape + // (sum(chunk_size), num_chunks / chunks_per_block). + // A block can find its split_idx via block_idx_to_split_idx. + std::vector block_idx_to_split_idx; + std::vector blocks_cumsums{0}; + block_idx_to_split_idx.reserve(num_blocks); + for (size_t split_idx = 0; split_idx < split_sizes.size(); ++split_idx) { + const auto blocks = detail::div_up( + split_chunk_sizes[split_idx], + detail::BLOCK_SIZE * detail::BYTES_PER_THREAD * iters_per_chunk); + block_idx_to_split_idx.insert( + block_idx_to_split_idx.end(), blocks, split_idx); + blocks_cumsums.push_back(blocks_cumsums.back() + blocks); + } + + dim3 blocks(blocks_cumsums.back(), num_chunks / chunks_per_block, 1); + dim3 threads(detail::BLOCK_SIZE, 1, 1); + + auto [_, ptrs] = detail::pack_vecs( + {&dst_base_addrs, + &src_base_addrs, + &split_chunk_sizes, + &block_idx_to_split_idx, + &blocks_cumsums}, + device); + + hipLaunchKernelGGL(( detail::split_with_sizes_copy_out_contiguous_no_cast_kernel), + dim3(blocks), + dim3(threads), + 0, + c10::zoom::getCurrentZoomStream(), + /*dst_base_addrs=*/reinterpret_cast(ptrs[0]), + /*src_base_addrs=*/reinterpret_cast(ptrs[1]), + /*split_chunk_sizes=*/ptrs[2], + /*block_idx_to_split_idx=*/ptrs[3], + /*blocks_cumsums=*/ptrs[4], + src_stride, + num_chunks); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +void split_with_sizes_copy_out_zoom( + const Tensor& self, + IntArrayRef split_sizes, + int64_t dim, + TensorList out) { + const bool is_capturing = c10::zoom::currentStreamCaptureStatusMayInitCtx() != + c10::zoom::CaptureStatus::None; + bool contiguous_no_cast = self.is_non_overlapping_and_dense(); + for (const auto& t : out) { + contiguous_no_cast &= t.is_non_overlapping_and_dense(); + contiguous_no_cast &= (t.dtype() == self.dtype()); + } + // TODO(yifu): make the fast path work for CUDA graph + if (!is_capturing && contiguous_no_cast) { + // Perform equivalent checks performed by the composite impl + if (dim < 0) { + dim = at::maybe_wrap_dim(dim, self.dim()); + } + TORCH_CHECK( + self.dim() != 0, "split expects at least a 1-dimensional tensor") + + const int64_t dim_size = self.size(dim); + int64_t split_sizes_sum = 0; + for (const auto i : c10::irange(split_sizes.size())) { + TORCH_CHECK( + split_sizes[i] >= 0, + "split_with_sizes expects split_sizes have only non-negative ", + "entries, but got split_sizes=", + split_sizes[i]); + split_sizes_sum += split_sizes[i]; + } + TORCH_CHECK( + split_sizes_sum == dim_size, + "split_with_sizes expects split_sizes to sum exactly to ", + dim_size, + " (input tensor's size at dimension ", + dim, + "), ", + "but got split_sizes=", + split_sizes); + + TORCH_CHECK( + out.size() == split_sizes.size(), + "split_with_sizes_copy_out() expected an out= argument of size ", + split_sizes.size(), + ", got size ", + out.size()); + + auto out_shape = self.sizes().vec(); + for (const auto i : c10::irange(split_sizes.size())) { + out_shape[dim] = split_sizes[i]; + if (resize_output_check(out[i], out_shape)) { + out[i].resize_(out_shape); + } + TORCH_CHECK( + out[i].dtype() == self.dtype(), + "Expected out tensor to have dtype ", + self.dtype(), + ", but got ", + out[i].dtype(), + " instead"); + TORCH_CHECK( + out[i].device() == self.device(), + "Expected out tensor to have device ", + self.device(), + ", but got ", + out[i].device(), + " instead"); + } + split_with_sizes_copy_out_zoom_contiguous_no_cast( + self, split_sizes, dim, out); + } else { + at::native::split_with_sizes_copy_out(self, split_sizes, dim, out); + } +} + +Tensor _chunk_cat_zoom(TensorList tensors, int64_t dim, int64_t num_chunks) { + dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks); + if (detail::all_contiguous(tensors)) { + // Return a tensor with the same dtype as input tensors + int64_t elem_size = tensors[0].element_size(); + int64_t chunk_size = + detail::get_chunk_size(tensors, dim, num_chunks, elem_size); + int64_t leading_dim = detail::get_leading_dim(tensors[0].sizes(), dim); + auto view_sizes = detail::get_chunk_cat_out_sizes( + tensors[0].sizes(), dim, num_chunks, chunk_size, elem_size); + Tensor out = + tensors[0] + .new_empty(chunk_size * num_chunks * leading_dim / elem_size) + .view(view_sizes); + // Type-agnostic copy since out and input tensors have the same type. + detail::_chunk_cat_out_zoom_contiguous( + tensors, dim, num_chunks, out, elem_size, elem_size); + return out; + } else { + return at::native::_chunk_cat(tensors, dim, num_chunks); + } +} + +Tensor& _chunk_cat_out_zoom( + TensorList tensors, + int64_t dim, + int64_t num_chunks, + Tensor& out) { + dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks); + TORCH_CHECK( + tensors[0].device() == out.device(), + "_chunk_cat_out_zoom: mismatch between input and out tensor devices"); + bool both_input_output_contiguous = + detail::all_contiguous(tensors) && out.is_non_overlapping_and_dense(); + if (both_input_output_contiguous && + (tensors[0].dtype() == at::ScalarType::BFloat16) && + (out.dtype() == at::ScalarType::Float)) { + // _chunk_cat_out_zoom_contiguous should also support other types, thanks to + // static_cast_with_inter_type. Here, we dispatch to BFloat16 in and float32 + // out since it is the only known use case. + detail::_chunk_cat_out_zoom_contiguous( + tensors, + dim, + num_chunks, + out, + out.element_size(), + tensors[0].element_size()); + } else if ( + both_input_output_contiguous && tensors[0].dtype() == out.dtype()) { + // Type-agnostic copy since out and input tensors have the same type. + detail::_chunk_cat_out_zoom_contiguous( + tensors, + dim, + num_chunks, + out, + out.element_size(), + tensors[0].element_size()); + } else { + at::native::_chunk_cat_out(tensors, dim, num_chunks, out); + } + return out; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/TensorShapeZoom.cpp b/aten/src/ATen/native/zoom/TensorShapeZoom.cpp new file mode 100644 index 00000000000000..b74ac6a36d482a --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorShapeZoom.cpp @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +Tensor& set_zoom_(Tensor& result) { + caffe2::TypeMeta dtype = result.dtype(); + Storage storage( + Storage::use_byte_size_t(), + 0, + at::zoom::getZoomDeviceAllocator(), + true); + result.set_(storage, 0, {0}, {}); + TORCH_INTERNAL_ASSERT(dtype == result.dtype()); + return result; +} + +Tensor& set_storage_zoom_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) { + checkSetStorage(result, storage, storage_offset, size, stride); + + result.unsafeGetTensorImpl()->set_storage_offset(storage_offset); + at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? + at::OptionalIntArrayRef(stride) : c10::nullopt; + at::native::resize_impl_zoom_(result.unsafeGetTensorImpl(), size, stride_opt); + return result; +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorTransformations.cu b/aten/src/ATen/native/zoom/TensorTransformations.cu new file mode 100644 index 00000000000000..fd84d2cb79a1bc --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorTransformations.cu @@ -0,0 +1,154 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +#include +#include + +namespace at::native { + +template +C10_LAUNCH_BOUNDS_2(zoom::getApplyBlockSize(), zoom::getApplyBlocksPerSM()) +__global__ void kernel_pointwise_flip_apply2( + const zoom::detail::TensorInfo in_tensor_info, + zoom::detail::TensorInfo out_tensor_info, + IndexType N, + int flip_dim, + IndexType total_dims) { + for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) { + IndexType dst_offset = 0; + if (flip_dim == 0) { + // flip 1st dim + dst_offset = (in_tensor_info.sizes[0] - 1 - linear_index / in_tensor_info.strides[0]) * in_tensor_info.strides[0] + linear_index % in_tensor_info.strides[0]; + } + else { + // flip last dim + IndexType i = total_dims - 1; + dst_offset = linear_index / in_tensor_info.strides[0] * in_tensor_info.strides[0] + (in_tensor_info.sizes[i] - 1 - linear_index % in_tensor_info.strides[0]); + } + out_tensor_info.data[dst_offset] = in_tensor_info.data[linear_index]; + } +} + +template +C10_LAUNCH_BOUNDS_1(zoom::getApplyBlockSize()) +__global__ void flip_zoom_kernel( + scalar_t* in_tensor, + scalar_t* out_tensor, + int64_t N, + int64_t* flip_dims, + int64_t flip_dims_size, + int64_t* strides, + int64_t* strides_contiguous, + int64_t* shape, + int64_t total_dims) { + int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= N) { + return; + } + + int64_t cur_indices = linear_index, rem = 0, dst_offset = 0; + for (int64_t i = 0; i < total_dims; i++) { + int64_t temp = cur_indices; + cur_indices = cur_indices / strides_contiguous[i]; + rem = temp - cur_indices * strides_contiguous[i]; + // flip the indices if it is in flip_dims + for (int64_t j = 0; j < flip_dims_size; j++) { + if (i == flip_dims[j]) { + cur_indices = shape[i] - 1 - cur_indices; + } + } + dst_offset += cur_indices * strides[i]; + cur_indices = rem; + } + out_tensor[linear_index] = in_tensor[dst_offset]; +} + +template +C10_LAUNCH_BOUNDS_1(zoom::getApplyBlockSize()) +__global__ void roll_zoom_kernel( + const scalar_t* in_tensor, + scalar_t* out_tensor, + int64_t N, + int64_t roll_dim, + int64_t start, + int64_t size, + int64_t stride, + int64_t total_dims) { + int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= N) { + return; + } + // roll dim idx is the index of linear_index along the rolling dimension. + int64_t roll_dim_idx = linear_index % (stride * size) / stride; + // index into the source data to find appropriate value. + int64_t source_idx = 0; + if( roll_dim_idx >= (size - start) ) { + source_idx = linear_index - ((size - start) * stride); + } else { + source_idx = linear_index + (start * stride); + } + out_tensor[linear_index] = in_tensor[source_idx]; +} + +// Roll a tensor along a dimension +Tensor roll_zoom(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { + if (dims.size() != 1 || shifts.size() != 1) { + return roll_common(self, shifts, dims); + } + + auto in_tensor = self; + if(!self.is_contiguous()) { + in_tensor = self.contiguous(); + } + auto out_tensor = at::empty_like(in_tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + if (out_tensor.numel() == 0) { + return out_tensor; + } + const int64_t N = in_tensor.numel(); + const int64_t dim = dims[0]; + const int64_t size = in_tensor.size(dim); + int64_t start = (size - shifts[0]) % size; + // Behavior of % is different in C++ vs Python for negative numbers. This + // corrects the difference. + if( start < 0 ) start = start + size; + + dim3 dim_block = zoom::getApplyBlock(); + dim3 dim_grid; + TORCH_CHECK(zoom::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid"); + + auto total_dims = in_tensor.dim(); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + at::ScalarType::ComplexHalf, + in_tensor.scalar_type(), "roll_zoom", + [&] { + hipLaunchKernelGGL(( roll_zoom_kernel), dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), + in_tensor.const_data_ptr(), out_tensor.mutable_data_ptr(), N, + dim, start, + size, + in_tensor.stride(dim), + total_dims); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + + return out_tensor; +} + +} // namespace at::native diff --git a/aten/src/ATen/zoom/ATenZoomGeneral.h b/aten/src/ATen/zoom/ATenZoomGeneral.h new file mode 100644 index 00000000000000..018bfd860bbaa5 --- /dev/null +++ b/aten/src/ATen/zoom/ATenZoomGeneral.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +#include + +// Use TORCH_ZOOM_API or TORCH_CUDA_CU_API for exports from this folder \ No newline at end of file diff --git a/aten/src/ATen/zoom/ApplyGridUtils.cuh b/aten/src/ATen/zoom/ApplyGridUtils.cuh new file mode 100644 index 00000000000000..0ba58874c8285b --- /dev/null +++ b/aten/src/ATen/zoom/ApplyGridUtils.cuh @@ -0,0 +1,47 @@ +#include + +#include + +namespace at::zoom { + +/** + Computes ceil(a / b) +*/ +template +__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) { + return (a + b - 1) / b; +} + +namespace { + +// Threads per block for our apply kernel +// FIXME: use occupancy calculator instead +constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512; +constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4; + +template +inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + if (curDevice == -1) return false; + uint64_t numel_per_thread = static_cast(max_threads_per_block) * static_cast(step); + uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread); + uint64_t maxGridX = at::zoom::getDeviceProperties(curDevice)->maxGridSize[0]; + if (numBlocks > maxGridX) + numBlocks = maxGridX; + grid = dim3(numBlocks); + return true; +} + +constexpr int getApplyBlocksPerSM() { + return AT_APPLY_BLOCKS_PER_SM; +} + +constexpr int getApplyBlockSize() { + return AT_APPLY_THREADS_PER_BLOCK; +} + +inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + return dim3(max_threads_per_block); +} + +} // anonymous namespace +} // namespace at::zoom diff --git a/aten/src/ATen/zoom/AsmUtils.cuh b/aten/src/ATen/zoom/AsmUtils.cuh new file mode 100644 index 00000000000000..a7d6987be574b6 --- /dev/null +++ b/aten/src/ATen/zoom/AsmUtils.cuh @@ -0,0 +1,85 @@ +#pragma once +#include + +// Collection of direct PTX functions + +namespace at::zoom { + +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + unsigned int getBitfield(unsigned int val, int pos, int len) { + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + return (val >> pos) & m; + } + + static __device__ __host__ __forceinline__ + unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; + } +}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + uint64_t getBitfield(uint64_t val, int pos, int len) { + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + return (val >> pos) & m; + } + + static __device__ __host__ __forceinline__ + uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; + } +}; + +__device__ __forceinline__ int getLaneId() { + return __lane_id(); +} + +__device__ __forceinline__ unsigned long long int getLaneMaskLt() { + const std::uint64_t m = (1ull << getLaneId()) - 1ull; + return m; +} + +__device__ __forceinline__ unsigned long long int getLaneMaskLe() { + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); + return m; +} + +__device__ __forceinline__ unsigned long long int getLaneMaskGt() { + const std::uint64_t m = getLaneMaskLe(); + return m ? ~m : m; +} + +__device__ __forceinline__ unsigned long long int getLaneMaskGe() { + const std::uint64_t m = getLaneMaskLt(); + return ~m; +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/Atomic.cuh b/aten/src/ATen/zoom/Atomic.cuh new file mode 100644 index 00000000000000..c4e4429cbd0eb9 --- /dev/null +++ b/aten/src/ATen/zoom/Atomic.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include +#include + +#include + +template +struct AtomicFPOp; + +template <> +struct AtomicFPOp { + template + inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::Half hsum; + do { + assumed = old; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + hsum = func(hsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return hsum; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::BFloat16 bsum; + do { + assumed = old; + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + bsum = func(bsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return bsum.x; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ double operator() (double * address, double val, const func_t& func) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull; + unsigned long long int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, func(val, assumed)); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); + } +}; + +#define ATOMIC_INTEGER_IMPL(NAME) \ +template \ +struct Atomic##NAME##IntegerImpl; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(func(val, static_cast(old_byte))); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 2; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + bool is_32_align = offset; \ + uint32_t old = *address_as_ui; \ + uint32_t old_bytes; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_bytes = is_32_align ? old >> 16 : old & 0xffff; \ + newval = static_cast(func(val, static_cast(old_bytes))); \ + newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + uint32_t * address_as_ui = (uint32_t *) (address); \ + uint32_t old = *address_as_ui; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + unsigned long long * address_as_ui = (unsigned long long *) (address); \ + unsigned long long old = *address_as_ui; \ + unsigned long long newval; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; + + +# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \ +static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \ +Atomic##NAME##IntegerImpl()(address, \ + val, \ + [](DTYPE a, DTYPE b) { \ + return OP; \ + }); \ +} \ + +ATOMIC_INTEGER_IMPL(Add) +GPU_ATOMIC_INTEGER(Add, a || b, bool) + +// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64) +static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](uint8_t a, uint8_t b) { + return a + b; + }); +} + +static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int8_t a, int8_t b) { + return a + b; + }); +} + +static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int16_t a, int16_t b) { + return a + b; + }); +} + +static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) { + return atomicAdd(address, val); +} + +static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) { + __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +} + +static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half hsum, at::Half val) { + return hsum + val; + }); +} + +static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +} + +/* Note [hip-clang differences to hcc] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * The upcoming hip-clang compiler for ROCm differs from hcc in a few details. + * It exports the __HIP__ macro, we can hence differentiate between hcc and + * hip-clang. In the below, hcc only received support for atomicAdd with double + * typing after work week 18312. hip-clang had support from the first version. + * In general, the code-visible differences between hip-clang and hcc will be + * minimal. + */ + + // // This needs to be defined for the host side pass + // static inline __device__ double atomicAdd(double *address, double val) { } + + +static inline __device__ double gpuAtomicAdd(double *address, double val) { + return atomicAdd(address, val); +} + +static inline __device__ float gpuAtomicAdd(float *address, float val) { + return atomicAdd(address, val); +} + +template +static inline __device__ void gpuAtomicAdd(c10::complex *address, c10::complex val) { + gpuAtomicAdd(&address->real_, val.real_); + gpuAtomicAdd(&address->imag_, val.imag_); +} + +/* Note [gpuAtomicAdd vs atomicAdd] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Some extensions such as torchvision call atomicAdd() + * directly and require non-library provided data type support. Only for these, we + * continue to provide atomicAdd overloads. + */ +static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { + return gpuAtomicAdd(address, val); +} + +static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) { + return gpuAtomicAdd(address, val); +} + +static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { + gpuAtomicAdd(address, val); +} + +static inline __device__ void atomicAdd(int8_t *address, int8_t val) { + gpuAtomicAdd(address, val); +} + +static inline __device__ void atomicAdd(int16_t *address, int16_t val) { + gpuAtomicAdd(address, val); +} + +static inline __device__ void atomicAdd(int64_t *address, int64_t val) { + gpuAtomicAdd(address, val); +} + +static inline __device__ void atomicAdd(bool *address, bool val) { + gpuAtomicAdd(address, val); +} + +/* Note [explicitly non-returning atomics] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet(). + * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction. + * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd, + * therefore we need a new API 'gpuAtomicAddNoReturn'. + */ +template +static inline __device__ void gpuAtomicAddNoReturn(c10::complex *address, c10::complex val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); } + +/* Special case fp32 atomic. */ +static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAdd(address, val); } + + +// Atomic multiplication implementation. + +ATOMIC_INTEGER_IMPL(Mul) +GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int16_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int32_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int64_t) + +inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return bsum * val; + }); +} + +inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum * val; + }); +} + +inline __device__ double gpuAtomicMul(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(val * __longlong_as_double(assumed)); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMul (float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(val * + __int_as_float(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic maximum implementation. + +template +__host__ __device__ T safe_max(T a, T b) { + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm-Developer-Tools/HIP/issues/2209 + T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b)); + return max; +} + +ATOMIC_INTEGER_IMPL(Max) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_max(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_max(bsum, val); + }); +} + +inline __device__ double gpuAtomicMax(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_max(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMax(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_max(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic minimum implementation. + +template +__host__ __device__ T safe_min(T a, T b) { + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm-Developer-Tools/HIP/issues/2209 + T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b)); + return min; +} + +ATOMIC_INTEGER_IMPL(Min) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_min(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_min(bsum, val); + }); +} + +inline __device__ double gpuAtomicMin(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_min(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMin(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_min(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} diff --git a/aten/src/ATen/zoom/CachingHostAllocator.cpp b/aten/src/ATen/zoom/CachingHostAllocator.cpp new file mode 100644 index 00000000000000..77d84838a83d7a --- /dev/null +++ b/aten/src/ATen/zoom/CachingHostAllocator.cpp @@ -0,0 +1,266 @@ +#include "CachingHostAllocator.h" + +#include +#include +#include +#include + +#include +#include + +namespace at::zoom { +namespace { + +// Note: cudaEventCreate when concurrently invoked from multiple threads can be +// very expensive (at least on certain device/driver combinations). Thus, we a) +// serialize event creation at a per-device level, and b) pool the events to +// avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in +// significant improvements in multithreaded workloads with high allocation +// rates. +class EventPool { + public: + using Event = std::unique_ptr< + at::zoom::ZoomEvent, + std::function>; + EventPool() : pools_(c10::zoom::device_count()) {} + + Event get(DeviceIndex device) { + TORCH_INTERNAL_ASSERT(0 <= device); + TORCH_INTERNAL_ASSERT(device < static_cast(pools_.size())); + auto& pool = pools_[device]; + auto destructor = [&pool](at::zoom::ZoomEvent* event) { + std::lock_guard g(pool.mutex_); + pool.event_pool_.push_back(std::unique_ptr(event)); + }; + + // Try to acquire an event from the per-device pool. + { + std::lock_guard g(pool.mutex_); + if (!pool.event_pool_.empty()) { + auto* event = pool.event_pool_.back().release(); + pool.event_pool_.pop_back(); + return Event(event, destructor); + } + } + // otherwise, allocate a new event that will be returned to the pool on + // destruction. + return Event( + std::make_unique(hipEventDisableTiming).release(), + destructor); + } + + void empty_cache() { + for (auto& pool : pools_) { + std::lock_guard g(pool.mutex_); + pool.event_pool_.clear(); + } + } + + private: + struct PerDevicePool { + alignas(64) std::mutex mutex_; + std::vector> event_pool_; + }; + std::vector pools_; +}; + +using Block = HostBlock; + +struct ZoomCachingHostAllocatorImpl + : public CachingHostAllocatorImpl { + private: + void allocate_host_memory(size_t size, void** ptr) override { + // Pinned memory pointers allocated by any device can be directly used by + // any other device, regardless of the current device at the time of + // allocation, since we assume unified addressing. So we grab any existing + // primary context, if available. See pytorch/pytorch#21081. + at::OptionalDeviceGuard device_guard; + auto primary_ctx_device_index = + c10::zoom::getDeviceIndexWithPrimaryContext(); + if (primary_ctx_device_index.has_value()) { + device_guard.reset_device( + at::Device(at::DeviceType::PrivateUse1, *primary_ctx_device_index)); + } + + if (c10::zoom::ZoomCachingAllocator::ZoomAllocatorConfig:: + pinned_use_zoom_host_register()) { + allocWithZoomHostRegister(ptr, size); + } else { + // Use hipHostMalloc for allocating pinned memory (global lock in driver) + C10_ZOOM_CHECK(hipHostMalloc(ptr, size, hipHostMallocDefault)); + } + } + + void free_block(Block* block) override { + if (c10::zoom::ZoomCachingAllocator::ZoomAllocatorConfig:: + pinned_use_zoom_host_register()) { + void* ptr = block->ptr_; + C10_ZOOM_CHECK(hipHostUnregister(ptr)); + free(ptr); + } else { + C10_ZOOM_CHECK(hipHostFree(block->ptr_)); + } + } + + void record_stream( + std::optional>& events, + c10::zoom::ZoomStream stream) override { + auto event = create_event_internal(stream.device_index()); + event->record(stream); + events->push_back(std::move(event)); + } + + bool query_event(EventPool::Event& event) override { + hipError_t err = hipEventQuery(*event); + if (err == hipErrorNotReady) { + (void)hipGetLastError(); // clear CUDA error + return false; + } else if (err != hipSuccess) { + C10_ZOOM_CHECK(err); + } + return true; + } + + EventPool::Event create_event_internal(DeviceIndex idx) { + // Leak the event pool to avoid shutdown issue. + static auto* event_pool = new EventPool(); + return event_pool->get(idx); + } + + TaskThreadPool* getThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool( + c10::zoom::ZoomCachingAllocator::ZoomAllocatorConfig:: + pinned_max_register_threads()); + return pool; + } + + void mapPagesForRegister( + const void* ptr, + size_t size, + size_t i, + size_t numThreads, + size_t pageSize) { + uintptr_t start = (uintptr_t)ptr + (size * i / numThreads); + uintptr_t end = (uintptr_t)start + (size / numThreads); + if (i == (numThreads - 1)) { + end = (uintptr_t)ptr + size; + } + + // pre-fault/map the pages by setting the first byte of the page + uintptr_t alignedStart = + (((uintptr_t)start + pageSize - 1) & ~(pageSize - 1)); + for (uintptr_t p = alignedStart; p < ((uintptr_t)end); p += pageSize) { + memset((void*)p, 0, 1); + } + } + + void registerPages(const void* ptr, size_t size) { + C10_ZOOM_CHECK( + hipHostRegister((void*)ptr, (size_t)size, hipHostRegisterDefault)); + + // If host and device pointer don't match, give a warning and exit + void* devptr; + C10_ZOOM_CHECK(hipHostGetDevicePointer(&devptr, (void*)ptr, 0)); + TORCH_CHECK( + (void*)devptr == (void*)ptr, + "Host and device pointer dont match with hipHostRegister. " + "Please dont use this feature by setting " + "PYTORCH_ZOOM_ALLOC_CONF=use_zoom_host_register:False (default)", + ""); + } + + void allocWithZoomHostRegister(void** ptr, size_t roundSize) { + // Here we do regular allocation, pre-fault/map the pages, and then do + // cudaHostRegister with GPU mapping flags to lock the pages, so we + // can minimize the cost for the cuda global lock. + *ptr = malloc(roundSize); + + // Parallelize the mapping/registering of pages to reduce wall time + size_t pageSize = (1 << 12); // 4kB pages + size_t numMapThreads = c10::zoom::ZoomCachingAllocator:: + ZoomAllocatorConfig::pinned_num_register_threads(); + if ((numMapThreads > 1) && (roundSize >= (pageSize * numMapThreads))) { + // parallelize the mapping of pages with a threadpool + auto* pool = getThreadPool(); + std::vector> promises; + std::vector> futures; + promises.reserve(numMapThreads); + futures.reserve(numMapThreads); + + for (size_t i = 0; i < numMapThreads; i++) { + promises.emplace_back(); + futures.push_back(promises[i].get_future()); + auto task = [this, + i, + ptr, + roundSize, + numMapThreads, + pageSize, + &promises]() mutable { + mapPagesForRegister( + *ptr, + roundSize, + i, // thread task-id + numMapThreads, + pageSize); + // set the promise when mapping pages are done + promises[i].set_value(); + }; + pool->run(task); + } + for (auto& future : futures) { + future.wait(); + } + } else { + // Map pages in the same thread + mapPagesForRegister(*ptr, roundSize, 0, 1, pageSize); + } + + // Register the mapped pages using cudaHostRegister + registerPages(*ptr, roundSize); + } +}; + +void raw_local_deleter(void* ptr); + +struct ZoomCachingHostAllocator final + : public CachingHostAllocatorInterface { + at::DataPtr allocate(size_t size) override { + auto ptr_and_ctx = impl_->allocate(size); + return { + ptr_and_ctx.first, + ptr_and_ctx.second, + &raw_local_deleter, + at::DeviceType::CPU}; + } +}; + +ZoomCachingHostAllocator caching_host_allocator; + +static inline ZoomCachingHostAllocator& getZoomCachingHostAllocator() { + return caching_host_allocator; +} + +void raw_local_deleter(void* ptr) { + getZoomCachingHostAllocator().free(ptr); +} + +} // anonymous namespace + +bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::zoom::ZoomStream stream) { + return getZoomCachingHostAllocator().record_event(ptr, ctx, stream); +} + +// Releases cached pinned memory allocations via cudaHostFree +void CachingHostAllocator_emptyCache() { + getZoomCachingHostAllocator().empty_cache(); +} + +at::Allocator* getCachingHostAllocator() { + return &getZoomCachingHostAllocator(); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/CachingHostAllocator.h b/aten/src/ATen/zoom/CachingHostAllocator.h new file mode 100644 index 00000000000000..f9dfab67591052 --- /dev/null +++ b/aten/src/ATen/zoom/CachingHostAllocator.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::zoom { + +// +// A caching allocator for CUDA host allocations (pinned memory). +// +// This provides a drop-in replacement for THCudaHostAllocator, which re-uses +// freed pinned (page-locked) memory allocations. This avoids device +// synchronizations due to cudaFreeHost calls. +// +// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be +// called anytime a pointer from this allocator is used in a cudaMemcpyAsync +// call between host and device, and passed the corresponding context from the +// allocation. This is currently invoked by at::native::copy_kernel_cuda. +// +TORCH_ZOOM_API c10::Allocator* getCachingHostAllocator(); + +// Records an event in the specified stream. The allocation corresponding to the +// input `ptr`/`ctx` will not be re-used until the event has occurred. +TORCH_ZOOM_API bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::zoom::ZoomStream stream); + +// Releases cached pinned memory allocations via cudaHostFree +TORCH_ZOOM_API void CachingHostAllocator_emptyCache(); + +inline TORCH_ZOOM_API at::DataPtr HostAlloc(size_t size) { + return getCachingHostAllocator()->allocate(size); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/DeviceUtils.cuh b/aten/src/ATen/zoom/DeviceUtils.cuh new file mode 100644 index 00000000000000..951d761d0b8533 --- /dev/null +++ b/aten/src/ATen/zoom/DeviceUtils.cuh @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include + +__device__ __forceinline__ unsigned int ACTIVE_MASK() +{ +// will be ignored anyway + return 0xffffffff; +} + +__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) { + +} + + +__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate) +{ +return __ballot(predicate); +} + + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ + return __shfl_xor(value, laneMask, width); +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) +{ + return __shfl(value, srcLane, width); +} + +template +__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ + return __shfl_up(value, delta, width); +} + +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ + return __shfl_down(value, delta, width); +} + +template<> +__device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width , unsigned int mask) +{ + //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ + int2 a = *reinterpret_cast(&value); + a.x = __shfl_down(a.x, delta); + a.y = __shfl_down(a.y, delta); + return *reinterpret_cast(&a); +} + +template<> +__device__ __forceinline__ c10::Half WARP_SHFL_DOWN(c10::Half value, unsigned int delta, int width, unsigned int mask) +{ + return c10::Half(WARP_SHFL_DOWN(value.x, delta, width, mask), c10::Half::from_bits_t{}); +} + +template +__device__ __forceinline__ c10::complex WARP_SHFL_DOWN(c10::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ + return c10::complex( + __shfl_down(value.real_, delta, width), + __shfl_down(value.imag_, delta, width)); +} + +template +__device__ __forceinline__ T doLdg(const T* p) { + return *p; +} \ No newline at end of file diff --git a/aten/src/ATen/zoom/EmptyTensor.cpp b/aten/src/ATen/zoom/EmptyTensor.cpp new file mode 100644 index 00000000000000..087962f699033e --- /dev/null +++ b/aten/src/ATen/zoom/EmptyTensor.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include + +namespace at::detail { + + TensorBase zoom_empty_generic(IntArrayRef size, ScalarType dtype, std::optional device_opt, std::optional memory_format_opt) { + at::globalContext().lazyInitPrivateUse1(); + const auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT(device.is_privateuseone()); + const DeviceGuard device_guard(device); + auto* allocator = at::zoom::getZoomDeviceAllocator(); + constexpr c10::DispatchKeySet zoom_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic( + size, allocator, zoom_dks, dtype, memory_format_opt); + } + + TensorBase zoom_empty_memory_format(IntArrayRef size, ::std::optional dtype_opt, ::std::optional layout_opt, ::std::optional device_opt, ::std::optional pin_memory_opt, ::std::optional memory_format_opt) { + TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout_or_default(layout_opt) == Layout::Strided); + + const auto dtype = dtype_or_default(dtype_opt); + return zoom_empty_generic(size, dtype, device_opt, memory_format_opt); + } + + TensorBase empty_zoom(IntArrayRef size, const TensorOptions &options) { + return zoom_empty_memory_format(size, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt(), + options.memory_format_opt()); + } + + + TensorBase zoom_empty_strided_generic(IntArrayRef size, IntArrayRef stride, ScalarType dtype, ::std::optional device_opt) { + at::globalContext().lazyInitPrivateUse1(); + const auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT(device.is_privateuseone()); + const DeviceGuard device_guard(device); + auto* allocator = at::zoom::getZoomDeviceAllocator(); + constexpr c10::DispatchKeySet zoom_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_strided_generic( + size, stride, allocator, zoom_dks, dtype); + } + + TensorBase zoom_empty_strided(IntArrayRef size, IntArrayRef stride, ::std::optional dtype_opt, ::std::optional layout_opt, ::std::optional device_opt, ::std::optional pin_memory_opt){ + TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout_or_default(layout_opt) == Layout::Strided); + + const auto dtype = dtype_or_default(dtype_opt); + return zoom_empty_strided_generic(size, stride, dtype, device_opt); + } + + TensorBase empty_strided_zoom( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions &options) { + return zoom_empty_strided( + size, + stride, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt()); +} + +} \ No newline at end of file diff --git a/aten/src/ATen/zoom/EmptyTensor.h b/aten/src/ATen/zoom/EmptyTensor.h new file mode 100644 index 00000000000000..59ac131c5b13c6 --- /dev/null +++ b/aten/src/ATen/zoom/EmptyTensor.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at::detail { + + TensorBase zoom_empty_generic(IntArrayRef size, ScalarType dtype, std::optional device, std::optional memory_format); + TensorBase zoom_empty_memory_format(IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty.memory_format(SymInt[] size, *, ScalarTy + TORCH_ZOOM_API TensorBase empty_zoom(IntArrayRef size, const TensorOptions &options); + + TensorBase zoom_empty_strided_generic(IntArrayRef size, IntArrayRef stride, ScalarType dtype, ::std::optional device_opt); + TensorBase zoom_empty_strided(IntArrayRef size, IntArrayRef stride, ::std::optional dtype_opt, ::std::optional layout_opt, ::std::optional device_opt, ::std::optional pin_memory_opt); // {"schema": "aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> TensorBase", "dispatch": "True", "default": "False"} + TORCH_ZOOM_API TensorBase empty_strided_zoom(IntArrayRef size, IntArrayRef stride, const TensorOptions &options); + +} \ No newline at end of file diff --git a/aten/src/ATen/zoom/HIPConfig.h b/aten/src/ATen/zoom/HIPConfig.h new file mode 100644 index 00000000000000..017177b4ed597b --- /dev/null +++ b/aten/src/ATen/zoom/HIPConfig.h @@ -0,0 +1,9 @@ +#define AT_ROCM_ENABLED() true +#define AT_MAGMA_ENABLED() false + +// disabled for now because we're testing on an old hipsparselt +#ifdef HIPSPARSELT_ENABLED +#define AT_HIPSPARSELT_ENABLED() true +#else +#define AT_HIPSPARSELT_ENABLED() false +#endif \ No newline at end of file diff --git a/aten/src/ATen/zoom/HIPGraph.cpp b/aten/src/ATen/zoom/HIPGraph.cpp new file mode 100644 index 00000000000000..49079ed083042f --- /dev/null +++ b/aten/src/ATen/zoom/HIPGraph.cpp @@ -0,0 +1,317 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace at::zoom { + +static bool _hip_graphs_debug = false; +constexpr int kSynchronizeBusyWaitMillis = 10; + +MempoolId_t graph_pool_handle() { + // uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle". + static std::atomic uid{1}; + // Sets just the second value, to distinguish it from MempoolId_ts created from + // cudaStreamGetCaptureInfo id_s in capture_begin. + return {0, uid++}; +} + + +// Get the expected id of a capture sequence so that we can call beginAllocateStreamToPool +// before starting a graph capture +CaptureId_t capture_sequence_id() { + // id starts at 1: + // Ensures uuid count starts at 1. 0 is reserved to mean "not set by cudaStreamGetCaptureInfo". + // (But how do we know GetCaptureInfo never sets id_ to 0? Because that's the current behavior, + // and I asked cuda devs to keep it that way, and they agreed.) + static std::atomic uuid{1}; + return uuid++; +} + +/** + * Note [CUDA Graph Wrapper Class] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Q: Why do we need graph capture and launch bindings in Pytorch? + * Why can't they live in a user extension, for example? + * + * A1: Convenience. + * A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with + * CPU statefulness) need cooperation from the capture and replay bindings + * (see Note [CUDA Graph-safe RNG states] in ZoomGeneratorImpl.h). + * + * We can't expect users to know about this cooperation. If users write capture + * bindings naively in an extension, they likely won't interact with the native + * ops properly. Their graphs would yield invalid numerics on replay. + */ + +/** + * Note [Interaction with CUDA graph capture] in ZoomCachingAllocator.cpp + * describes memory management for captures. + */ + +std::atomic HIPGraph::pending_event_queries = 0; + +// Track any outstanding event queries that could happen e.g., in a NCCL watchdog so that they +// can be resolved before the capture begins. Note that event queries are not allowed during a +// graph capture in the default capture mode. +void HIPGraph::inc_pending_event_queries() { + pending_event_queries++; +} + +void HIPGraph::dec_pending_event_queries() { + TORCH_INTERNAL_ASSERT(pending_event_queries > 0, + "Attempted to decrement the number of outstanding events to be queried, but it was <= 0."); + pending_event_queries--; +} + +int HIPGraph::num_pending_event_queries() { + return pending_event_queries; +} + +HIPGraph::HIPGraph() + // CUDAStreams may not be default-constructed. + : capture_stream_(c10::zoom::getCurrentZoomStream()) { +} + +void HIPGraph::register_generator_state( + c10::intrusive_ptr state) { + captured_generator_states_[std::move(state)] = 0; +} + +void HIPGraph::register_generator_state(const at::Generator& generator) { + c10::intrusive_ptr zoom_gen = + dynamic_intrusive_pointer_cast( + generator.getIntrusivePtr()); + zoom_gen->register_graph(this); +} + +void HIPGraph::capture_begin(MempoolId_t pool/*=0*/, hipStreamCaptureMode capture_mode) { + TORCH_CHECK(!has_graph_exec_, + "This HIPGraph instance already owns a captured graph. " + "To capture a new graph, create a new instance."); + + // default generator is always registered + auto* gen = get_generator_or_default( + c10::nullopt, zoom::detail::getDefaultZoomGenerator()); + gen->register_graph(this); + + for (auto& [generator_state, wholegraph_increments] : + captured_generator_states_) { + generator_state->capture_prologue(); + } + + auto stream = c10::zoom::getCurrentZoomStream(); + + TORCH_CHECK(stream != c10::zoom::getDefaultZoomStream(), + "HIP graphs must be captured on a non-default stream. " + "(However, after capture, it's ok to replay them on the " + "default stream.)"); + + capture_stream_ = stream; + capture_dev_ = c10::zoom::current_device(); + + id_ = capture_sequence_id(); + + if (pool.first != 0 || pool.second != 0) { + // Either value being nonzero means the user supplied a pool to share. + // But only one should be nonzero. + // If pool was created by another graph's capture_begin, first should be nonzero. + // If pool was created by graph_pool_handle, second should be nonzero. + TORCH_INTERNAL_ASSERT(!(pool.first && pool.second)); + mempool_id_ = pool; + } else { + // User did not ask us to share a mempool. Use our own id_ as our mempool_id_. + // Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle(). + mempool_id_ = {id_, 0}; + } + + // Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an + // autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator + // due to the capture status being updated _after_ a capture had already started. + c10::zoom::ZoomCachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](hipStream_t stream) { + hipStreamCaptureStatus status; + CaptureId_t stream_capture_id; + C10_ZOOM_CHECK(hipStreamGetCaptureInfo(stream, &status, &stream_capture_id)); + return status == hipStreamCaptureStatus::hipStreamCaptureStatusActive && stream_capture_id == capture_id_; + }); + + // At this point, any NCCL watchdogs should be aware that we are in capture mode + // and therefore should not enqueue any additional work that could be event-queried. + // We still must wait on any existing work that has not been cleaned up. + while (num_pending_event_queries()) { + TORCH_WARN_ONCE("Waiting for pending NCCL work to finish before starting graph capture."); + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + + // cudaStreamCaptureModeGlobal is the most conservative option to + // prevent potentially unsafe CUDA API calls during capture. See + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + C10_ZOOM_CHECK(hipStreamBeginCapture(capture_stream_, capture_mode)); + + hipStreamCaptureStatus status; + C10_ZOOM_CHECK(hipStreamGetCaptureInfo(stream, &status, &capture_id_)); + TORCH_INTERNAL_ASSERT(status == hipStreamCaptureStatus::hipStreamCaptureStatusActive); + + TORCH_INTERNAL_ASSERT(id_ > 0); +} + +void HIPGraph::capture_end() { + auto stream = c10::zoom::getCurrentZoomStream(); + + TORCH_CHECK(stream == capture_stream_, + "Capture must end on the same stream it began on."); + + C10_ZOOM_CHECK(hipStreamEndCapture(capture_stream_, &graph_)); + + c10::zoom::ZoomCachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); + + TORCH_CHECK(graph_ != NULL, "Invalid capture."); + has_graph_ = true; + + // In typical graph usage some tensors (e.g. the tensors used for graph IO) are not freed + // between replays. + // If Pytorch compiles and runs with a CUDA 11.4+ toolkit, there's a chance the allocator backend + // is cudaMallocAsync. + // cudaMallocAsync is generally graph-safe, but if some tensors are not freed between replays, + // the graph's internal bookkeeping requires that we instantiate with + // cudaGraphInstantiateFlagAutoFreeOnLaunch. See + // cudaGraphLaunch + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 + // cudaGraphInstantiateWithFlags + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 + + // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, + // who prefer not to report error message through these arguments moving forward + // (they prefer return value, or errors on api calls internal to the capture) + + C10_ZOOM_CHECK(hipGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); + + + has_graph_exec_ = true; + + for (auto& [generator_state, wholegraph_increments] : + captured_generator_states_) { + wholegraph_increments = generator_state->capture_epilogue(); + } + + size_t numHIPGraphNodes = 0; + C10_ZOOM_CHECK(hipGraphGetNodes(graph_, NULL, &numHIPGraphNodes)); + if (numHIPGraphNodes == 0) { + TORCH_WARN("The HIP Graph is empty. This usually means that the graph was ", + "attempted to be captured on wrong device or stream."); + } + + // check if debug path is set + if (!_hip_graphs_debug) { + // Now that we've instantiated graph_ into graph_exec_, + // we don't need graph_ anymore. + C10_ZOOM_CHECK(hipGraphDestroy(graph_)); + has_graph_ = false; + } else { + TORCH_WARN("DEBUG: TORCH_HIPGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called."); + } +} + +void HIPGraph::replay() { + TORCH_CHECK(has_graph_exec_, + "Called HIPGraph::replay without a preceding successful capture."); + + c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; + + for (auto& [generator_state, wholegraph_increments] : + captured_generator_states_) { + generator_state->replay_prologue(wholegraph_increments); + } + // graph_exec_ may be replayed in any stream. + C10_ZOOM_CHECK(hipGraphLaunch(graph_exec_, c10::zoom::getCurrentZoomStream())); + +// cuda does this sync for certain versions, we're ignoring it here +// int version; +// C10_ZOOM_CHECK(cudaDriverGetVersion(&version)); +// if (version < 11040) { +// // Workaround for bug in libcuda.so that causes replayed graphs with +// // certain topologies to be corrupted (kernels elided, internal syncs +// // ignored) when replayed back to back without a sync in between. +// // The bug is fixed in CUDA 11.4+. +// C10_ZOOM_CHECK(cudaDeviceSynchronize()); +// } +} + +void HIPGraph::enable_debug_mode() { + _hip_graphs_debug = true; +} + +void HIPGraph::debug_dump(const std::string& debug_path) { + if (_hip_graphs_debug) { + TORCH_WARN("DEBUG: calling debug_dump()"); + if (has_graph_) { + TORCH_WARN("DEBUG: calling hipGraphDebugDotPrint() with ", debug_path); + C10_ZOOM_CHECK_WARN(hipGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output + C10_ZOOM_CHECK(hipGraphDestroy(graph_)); + } + } else { + // TODO (Arham): technically false right now, need to add this functionality to the Zoom PyBind module + TORCH_WARN("HIP Graphs debug not enabled, set with torch._C._zoom_enable_graphs_debug_mode"); + } + +} + +void HIPGraph::reset() { + // I'd prefer these checks throw exceptions, not print warnings, + // but the destructor calls reset(), and at least one CI build + // refuses to compile with a throwing destructor. + // + // Instead of calling reset() in the destructor to clean up, I could + // call reset() in the __del__ method of a thin Python wrapper, + // in which case reset would be allowed to throw exceptions. + // But Stackoverflow does not like user-defined __del__. + // __del__ prevents Graph instances from EVER being garbage collected + // if they participate in a reference cycle. + // And exceptions thrown in __del__ only print a warning anyway. + // + // Calling reset() in the C++ destructor, with warnings instead of exceptions + // if calls fail, is the compromise we chose. + // + // If capture_begin, the capture, or capture_end failed at some point, this HIPGraph, the generator, + // and the allocator could end up in all kinds of weird states depending where failure occurred. + // If the user catches the failure exception in a script, or is running in REPL or (god forbid) + // a Jupyter notebook, I don't see an easy way for reset() to gracefully fix all such possible error states. + if (has_graph_ || has_graph_exec_) { + // notifyCaptureDestroy may throw. How should we handle this? + c10::zoom::ZoomCachingAllocator::releasePool(capture_dev_, mempool_id_); + } + if (has_graph_) { + C10_ZOOM_CHECK_WARN(hipGraphDestroy(graph_)); + has_graph_ = false; + } + if (has_graph_exec_) { + C10_ZOOM_CHECK_WARN(hipGraphExecDestroy(graph_exec_)); + has_graph_exec_ = false; + } +} + +// Returns an id another graph's capture_begin can use to share the same memory pool as this graph. +MempoolId_t HIPGraph::pool() { +TORCH_CHECK(has_graph_exec_, + "Called HIPGraph::pool() without a preceding successful capture."); + return mempool_id_; +} + +HIPGraph::~HIPGraph() { + for (auto& [generator_state, wholegraph_increments] : + captured_generator_states_) { + generator_state->unregister_graph(this); + } + reset(); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/HIPGraph.h b/aten/src/ATen/zoom/HIPGraph.h new file mode 100644 index 00000000000000..7bea7814fe344c --- /dev/null +++ b/aten/src/ATen/zoom/HIPGraph.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +struct Generator; +struct ZoomGeneratorImpl; +struct ZoomGeneratorState; + +using MempoolId_t = c10::zoom::MempoolId_t; +using CaptureId_t = c10::zoom::CaptureId_t; + +namespace zoom { + +// Standalone way to get a unique mempool id usable as a pool=... argument +// to HIPGraph::capture_begin +TORCH_ZOOM_API MempoolId_t graph_pool_handle(); + +struct TORCH_ZOOM_API HIPGraph { + HIPGraph(); + ~HIPGraph(); + + static void inc_pending_event_queries(); + static void dec_pending_event_queries(); + static int num_pending_event_queries(); + // See Note [Explicit Registration of Generators to the CUDA Graph] + void register_generator_state(c10::intrusive_ptr state); + void register_generator_state(const at::Generator& generator); + void capture_begin( + MempoolId_t pool = {0, 0}, + hipStreamCaptureMode capture_mode = hipStreamCaptureModeGlobal); + void capture_end(); + void replay(); + void reset(); + MempoolId_t pool(); + void enable_debug_mode(); + void debug_dump(const std::string& debug_path); + + protected: + hipGraph_t graph_ = NULL; + hipGraphExec_t graph_exec_ = NULL; + + static std::atomic pending_event_queries; + + // internal states so reset() can do its best cleaning up + // Set to true in capture_end if hipStreamEndCapture succeeded + // Set back to false soon after, when graph_ is consumed by hipGraphInstantiate + // to create graph_exec_, then graph_ is deleted + bool has_graph_ = false; + // Set to true in capture_end if hipGraphInstantiate succeeded + bool has_graph_exec_ = false; + + // uuid of this instance's current capture, used to + // specify the pool. + CaptureId_t id_; + + // the ID assigned by hip during graph capture, + // used to identify when a stream is participating in capture + CaptureId_t capture_id_ = -1; + + // uuid used to request a particular private mempool from CUDACachingAllocator. + // By default, this will be set to {id_, 0}. + // + // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_ + // will be set to the other graph's mempool_id_, and therefore share a mempool with the + // other graph. + // + // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(), + // it will share a mempool with any other captures that used "pool=handle". + // + // Sharing a mempool across graphs saves memory, and it's safe if you + // know you'll replay those graphs in the same order you captured them. + MempoolId_t mempool_id_; + + // Stream on which capture began + c10::zoom::ZoomStream capture_stream_; + + // multiple generator states and their wholegraph_increments in this graph + // that are managed by the CUDA Graph + ska::flat_hash_map, uint64_t> + captured_generator_states_; + + // Device where capture occurred. Right now, for simplicity, we require all ops + // in a capture to run on the same device, but this is a limitation of HIPGraph, + // not CUDA itself. We can straightforwardly modify HIPGraph to support multi-device + // captures if needed. + int capture_dev_; +}; + +} // namespace cuda +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/zoom/HIPGraphsUtils.hpp b/aten/src/ATen/zoom/HIPGraphsUtils.hpp new file mode 100644 index 00000000000000..1f9a227f5e5492 --- /dev/null +++ b/aten/src/ATen/zoom/HIPGraphsUtils.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +// #include " +#include +#include +#include +#include +#include + +// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten. +// This file adds utils used by aten only. + +namespace at::zoom { + +using CaptureId_t = c10::zoom::CaptureId_t; +using CaptureStatus = c10::zoom::CaptureStatus; + +// Use this version where you don't want to create a CUDA context if none exists. +inline CaptureStatus currentStreamCaptureStatus() { + // don't create a context if we don't have to + if (c10::zoom::hasPrimaryContext(c10::zoom::current_device())) { + return c10::zoom::currentStreamCaptureStatusMayInitCtx(); + } else { + return CaptureStatus::None; + } +} + +inline void assertNotCapturing(std::string attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + attempt, + " during HIP graph capture. If you need this call to be captured, " + "please file an issue. " + "Current hipStreamCaptureStatus: ", + status); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/HIPUtils.h b/aten/src/ATen/zoom/HIPUtils.h new file mode 100644 index 00000000000000..4461619e00cd96 --- /dev/null +++ b/aten/src/ATen/zoom/HIPUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::zoom { + +// Check if every tensor in a list of tensors matches the current +// device. +inline bool check_device(ArrayRef ts) { + if (ts.empty()) { + return true; + } + Device curDevice = Device(kPrivateUse1, c10::zoom::current_device()); + for (const Tensor& t : ts) { + if (t.device() != curDevice) return false; + } + return true; +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/NumericLimits.cuh b/aten/src/ATen/zoom/NumericLimits.cuh new file mode 100644 index 00000000000000..8b5d6b5932ee01 --- /dev/null +++ b/aten/src/ATen/zoom/NumericLimits.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +// NumericLimits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. +// This header is derived from the legacy THCNumerics.cuh. + +// The lower_bound and upper_bound constants are same as lowest and max for +// integral types, but are -inf and +inf for floating point types. They are +// useful in implementing min, max, etc. + +namespace at { + +template +struct numeric_limits { +}; + +// WARNING: the following at::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." + +namespace { + // ROCm doesn't like INFINITY too. + constexpr double inf = INFINITY; +} + +template <> +struct numeric_limits { + static inline __host__ __device__ bool lowest() { return false; } + static inline __host__ __device__ bool max() { return true; } + static inline __host__ __device__ bool lower_bound() { return false; } + static inline __host__ __device__ bool upper_bound() { return true; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ uint8_t lowest() { return 0; } + static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } + static inline __host__ __device__ uint8_t lower_bound() { return 0; } + static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } + static inline __host__ __device__ int8_t max() { return INT8_MAX; } + static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; } + static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } + static inline __host__ __device__ int16_t max() { return INT16_MAX; } + static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; } + static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } + static inline __host__ __device__ int32_t max() { return INT32_MAX; } + static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; } + static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } +}; + +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } + static inline __host__ __device__ int64_t max() { return _I64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; } +#else + static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } + static inline __host__ __device__ int64_t max() { return INT64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; } +#endif +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); } + static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ float lowest() { return -FLT_MAX; } + static inline __host__ __device__ float max() { return FLT_MAX; } + static inline __host__ __device__ float lower_bound() { return -static_cast(inf); } + static inline __host__ __device__ float upper_bound() { return static_cast(inf); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ double lowest() { return -DBL_MAX; } + static inline __host__ __device__ double max() { return DBL_MAX; } + static inline __host__ __device__ double lower_bound() { return -inf; } + static inline __host__ __device__ double upper_bound() { return inf; } +}; + +} // namespace at diff --git a/aten/src/ATen/zoom/PeerToPeerAccess.cpp b/aten/src/ATen/zoom/PeerToPeerAccess.cpp new file mode 100644 index 00000000000000..b5c3b8eda00565 --- /dev/null +++ b/aten/src/ATen/zoom/PeerToPeerAccess.cpp @@ -0,0 +1,59 @@ +#include + +#include +#include +#include +#include + +#include + +namespace at::zoom { + +static std::vector p2pAccessEnabled_; +static int64_t num_devices_ = -1; + +namespace detail { + +void init_p2p_access_cache(int64_t num_devices) { + // p2pAccessEnabled records if p2p copies are allowed between pairs of + // devices. Values include "1" (copy allowed), "0" (copy not allowed), and + // "-1" (unknown). + // Currently the max number of gpus in P2P group is 8, so if there are more + // we enable P2P in groups of 8 + p2pAccessEnabled_.clear(); + p2pAccessEnabled_.resize(num_devices * num_devices, -1); + num_devices_ = num_devices; + + for (const auto i : c10::irange(num_devices)) { + p2pAccessEnabled_[i * num_devices + i] = 1; + } +} + +} // namespace detail + +bool get_p2p_access(int dev, int dev_to_access) { + at::globalContext().lazyInitPrivateUse1(); + + TORCH_CHECK(dev >= 0 || dev < num_devices_, + dev, " is not a device"); + TORCH_CHECK(dev_to_access >= 0 || dev_to_access < num_devices_, + dev_to_access, " is not a device"); + TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized"); + + auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access]; + + if (cache != -1) { + return cache; + } + + int result; + C10_ZOOM_CHECK(hipDeviceCanAccessPeer(&result, dev, dev_to_access)); + cache = result ? 1 : 0; + if (cache) { + c10::zoom::ZoomCachingAllocator::enablePeerAccess(dev, dev_to_access); + } + + return cache; +} + +} // namespace at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/PeerToPeerAccess.h b/aten/src/ATen/zoom/PeerToPeerAccess.h new file mode 100644 index 00000000000000..b299e48862024c --- /dev/null +++ b/aten/src/ATen/zoom/PeerToPeerAccess.h @@ -0,0 +1,12 @@ +#include +#include +#include + +namespace at::zoom { +namespace detail { +void init_p2p_access_cache(int64_t num_devices); +} + +TORCH_ZOOM_API bool get_p2p_access(int source_dev, int dest_dev); + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/PhiloxHIPState.h b/aten/src/ATen/zoom/PhiloxHIPState.h new file mode 100644 index 00000000000000..58a6bb5199fe58 --- /dev/null +++ b/aten/src/ATen/zoom/PhiloxHIPState.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +#include \ No newline at end of file diff --git a/aten/src/ATen/zoom/PhiloxUtils.hpp b/aten/src/ATen/zoom/PhiloxUtils.hpp new file mode 100644 index 00000000000000..ba2afd230f2c90 --- /dev/null +++ b/aten/src/ATen/zoom/PhiloxUtils.hpp @@ -0,0 +1,4 @@ +#pragma once + +#include +#include \ No newline at end of file diff --git a/aten/src/ATen/zoom/PinnedMemoryAllocator.cpp b/aten/src/ATen/zoom/PinnedMemoryAllocator.cpp new file mode 100644 index 00000000000000..5b9ed21a971a23 --- /dev/null +++ b/aten/src/ATen/zoom/PinnedMemoryAllocator.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +bool is_pinned_zoom(const Tensor& self, std::optional device) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_privateuseone()); + // TODO: unhook this + return detail::getZoomHooks().isPinnedPtr(self.storage().data()); +} + +Tensor _pin_memory_zoom(const Tensor& self, std::optional device) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_privateuseone()); + auto* allocator = at::zoom::getPinnedMemoryAllocator(); + auto storage = Storage( + Storage::use_byte_size_t(), + detail::computeStorageNbytes( + self.sizes(), self.strides(), self.dtype().itemsize()), + allocator, + /*resizable=*/false); + auto tensor = at::cpu::empty({0}, self.options()).set_(storage, 0, self.sizes(), self.strides()); + tensor.copy_(self); + return tensor; +} + + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/zoom/PinnedMemoryAllocator.h b/aten/src/ATen/zoom/PinnedMemoryAllocator.h new file mode 100644 index 00000000000000..2c52bead795996 --- /dev/null +++ b/aten/src/ATen/zoom/PinnedMemoryAllocator.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace at::zoom { + +inline TORCH_ZOOM_API at::Allocator* getPinnedMemoryAllocator() { + return getCachingHostAllocator(); +} +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ScanUtils.cuh b/aten/src/ATen/zoom/ScanUtils.cuh new file mode 100644 index 00000000000000..d1a3558a42a6c1 --- /dev/null +++ b/aten/src/ATen/zoom/ScanUtils.cuh @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include + +// Collection of in-kernel scan / prefix sum utilities + +namespace at::zoom { + +// Inclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { + // Within-warp, we use warp voting. + unsigned long long int vote = WARP_BALLOT(in); + T index = __popcll(getLaneMaskLe() & vote); + T carry = __popcll(vote); + + int warp = threadIdx.x / C10_WARP_SIZE; + + // Per each warp, write out a value + if (getLaneId() == 0) { + smem[warp] = carry; + } + + __syncthreads(); + + // Sum across warps in one thread. This appears to be faster than a + // warp shuffle scan for CC 3.0+ + if (threadIdx.x == 0) { + int current = 0; + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { + T v = smem[i]; + smem[i] = binop(smem[i], current); + current = binop(current, v); + } + } + + __syncthreads(); + + // load the carry from the preceding warp + if (warp >= 1) { + index = binop(index, smem[warp - 1]); + } + + *out = index; + + if (KillWARDependency) { + __syncthreads(); + } +} + +// Exclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { + inclusiveBinaryPrefixScan(smem, in, out, binop); + + // Inclusive to exclusive + *out -= (T) in; + + // The outgoing carry for all threads is the last warp's sum + *carry = smem[at::ceil_div(blockDim.x, C10_WARP_SIZE) - 1]; + + if (KillWARDependency) { + __syncthreads(); + } +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ThrustAllocator.h b/aten/src/ATen/zoom/ThrustAllocator.h new file mode 100644 index 00000000000000..17ba84d64f2222 --- /dev/null +++ b/aten/src/ATen/zoom/ThrustAllocator.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace at::zoom { + +/// Allocator for Thrust to re-route its internal device allocations +/// to the THC allocator +class ThrustAllocator { +public: + typedef char value_type; + + char* allocate(std::ptrdiff_t size) { + return static_cast(c10::zoom::ZoomCachingAllocator::raw_alloc(size)); + } + + void deallocate(char* p, size_t size) { + c10::zoom::ZoomCachingAllocator::raw_delete(p); + } +}; + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomApplyUtils.cuh b/aten/src/ATen/zoom/ZoomApplyUtils.cuh new file mode 100644 index 00000000000000..dcb91d124a11d0 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomApplyUtils.cuh @@ -0,0 +1,537 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// +// This file contains pointwise operation functions and kernels that +// work on both contiguous and non-contiguous tensor arguments of +// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without +// copying or temporary storage. +// + +/* + NOTE [ CUDA_tensor_applyN helpers ] + + The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4) + functions apply a pointwise operator to N tensor(s). + + The calling convention is + + 1. The template arguments should be, sequentially, + - First N typename args specify the scalar types of each of the N tensors. + - (Optional) `int step` arg specifies the number of elements processed + together at the same time. + Default is 1. + - A usually omitted (i.e., inferred) typename arg specifies the type of the + function/functor applied on `N * step` values in each iteration of each + CUDA thread. + 2. The arguments should be, sequentially, + - N tensors + - op: a function/functor that processes `N * step` values at the same time. + - If `step == 1`, it must have signature + `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where + `scalar*_t`s are the first N typename template args, and the inputs + are the `N` values from the `N` tensors retrieved at a common index. + - Otherwise, it must must have signature + void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times + scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times + ..., + scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times + Different from `step == 1` case, it processes `N * step` values taken + from `step` common indices. Moreover, the first input `n` represents the + number of valid indices (it will always have `0 < n <= step`). It will + almost always be `step`, but at the boundary we may not have full `step` + elements and `n` can be a lesser value. + + E.g., if `step == 4` and `N == 2`, `op` could be + + [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4, + scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) { + // Only process u1, ..., un and v1, ..., vn. + // So if `n == 3`, `u4` and `v4` need not to be considered. + } + + In both cases, the references can actually be const, but at least one of + them should be non-const in order to write the output. + - (Optional, but recommended) N TensorArgType args that specify for each + tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite), + or only reads (i.e., TensorArgType::ReadOnly). + Default is TensorArgType::ReadWrite for first Tensor, and + TensorArgType::ReadOnly for the rest. + + E.g., + + to compute a = b^2 for a and b of same dtype, we can call + + Zoom_tensor_apply2( + a, b, + [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; } + ); + + to work on 2 values at the same time, we can call + + Zoom_tensor_apply2( + a, b, + [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2, + const scalar2 &b_val1, const scalar2 &b_val2) { + // call special vectorized op here, or just do elementwise and enjoy unrolling... + // if n == 1, only process a_val1 and b_val1 + } + ); +*/ + +namespace at::zoom { + +// TODO: combine with TensorArg? So far that's been for debugging, and this is functional... +enum class TensorArgType { ReadWrite, ReadOnly }; + +namespace { + +// Rearrange dimensions for pointwise operations so that strides are in +// decreasing order as much as possible, so that kernels have better memory +// access patterns. +// +// For example, consider a binary operation on two "transposed" 2-dim tensors: +// sizes: 256 512 +// aInfo->strides: 1 256 +// bInfo->strides: 1 256 +// +// Given this, each concurrent memory access inside kernelPointwiseApply2() is +// exactly 256 elements apart, resulting in poor performance. +// +// This function exchanges dimensions so that memory access is contiguous: +// sizes: 512 256 +// aInfo->strides: 256 1 +// bInfo->strides: 256 1 +// +// (Actually, it becomes even better because now collapseDims() can turn each +// input into one contiguous array.) +// +// In general, given M (<=4) TensorInfo's with N dimensions, we can view each +// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange +// strides[i] and [j] if +// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M) +// (exchanging them will benefit input #k), and +// (2) strides[i][k] <= strieds[j][k] for all k +// (exchanging them will not make any input worse). +template +inline void rearrangeDims(detail::TensorInfo* aInfo, + detail::TensorInfo* bInfo = nullptr, + detail::TensorInfo* cInfo = nullptr, + detail::TensorInfo* dInfo = nullptr) { + int numInfos = 1; + int dims = aInfo->dims; + IndexType *sizes[4] = { aInfo->sizes, }; + IndexType *strides[4] = { aInfo->strides, }; + + if (bInfo != nullptr) { + ++numInfos; + if (bInfo->dims != dims) return; + sizes[1] = bInfo->sizes; + strides[1] = bInfo->strides; + } + + if (cInfo != nullptr) { + ++numInfos; + if (cInfo->dims != dims) return; + sizes[2] = cInfo->sizes; + strides[2] = cInfo->strides; + } + + if (dInfo != nullptr) { + ++numInfos; + if (dInfo->dims != dims) return; + sizes[3] = dInfo->sizes; + strides[3] = dInfo->strides; + } + + // Bail out if sizes do not match: we are using "deprecated pointwise + // behavior" among tensors of different shapes but same number of elements. + for (int i = 1; i < numInfos; ++i) { + for (int j = 0; j < dims; ++j) { + if (sizes[i][j] != sizes[0][j]) return; + } + } + + for (int i = 0; i < dims - 1; ++i) { + // No need to consider dimensions of size 1. + if (sizes[0][i] == 1) continue; + + for (int j = i + 1; j < dims; ++j) { + if (sizes[0][j] == 1) continue; + + // Compare the relative sizes of strides between dim #i and dim #j. + bool hasIncreasingStrides = false; + bool hasDecreasingStrides = false; + + for (int k = 0; k < numInfos; k++) { + IndexType stride_i = strides[k][i]; + IndexType stride_j = strides[k][j]; + if (stride_i < stride_j) { + hasIncreasingStrides = true; + } else if (stride_i > stride_j) { + hasDecreasingStrides = true; + } + } + + if (hasIncreasingStrides && !hasDecreasingStrides) { + for (int k = 0; k < numInfos; k++) { + IndexType size = sizes[k][i]; + sizes[k][i] = sizes[k][j]; + sizes[k][j] = size; + + IndexType stride = strides[k][i]; + strides[k][i] = strides[k][j]; + strides[k][j] = stride; + } + } + } + } +} + +// The `remaining_steps` argument is used to support Op that operates on +// multiple elements at the same time. Generally, the strategy of ApplyOpN is to +// 1. Initialize `remaining_steps = step`, where `step` is the template arg of +// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the +// number of elements in bound for this call. It will almost always equal to +// `step` except at boundaries. +// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in +// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`. +// 3. At `remaining_steps = 0`, +// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`; +// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep, +// tensor2_val1, tensor2_val2, ..., tesor2_valstep, +// ... +// tensorN_val1, tensorN_val2, ..., tesorN_valstep);` +// +// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like. + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... aOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = sizeof...(Offsets) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + ApplyOp1::apply( + a, op, n, linearIndex + 1, aOffsets..., aOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, + int n, IndexType linearIndex, Offset offset) { + op(a.data[offset]); +} +}; + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... offsets) { + op(n, a.data[offsets]...); +} +}; + +template + +C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) + +__global__ void kernelPointwiseApply1(detail::TensorInfo a, + IndexType totalElements, const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp1::apply( + a, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); + } +} + + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int64_t n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, b) : 0; + + ApplyOp2::apply( + a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int /*n*/, IndexType /*linearIndex*/, + Offset aOffset, Offset bOffset) { + op(a.data[aOffset], b.data[bOffset]); +} +}; + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + op(n, a.data[aOffsets]..., b.data[bOffsets]...); +} +}; + +template + +C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) + +__global__ void +kernelPointwiseApply2(detail::TensorInfo a, + detail::TensorInfo b, + IndexType totalElements, + const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp2::apply( + a, b, op, ::min(step, static_cast(totalElements - linearIndex)), + linearIndex); + } +} + +} // anonymous namespace + +template +inline bool Zoom_tensor_apply2(at::TensorBase a, + at::TensorBase b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + TORCH_CHECK(a.device().is_privateuseone() && b.device().is_privateuseone(), + "Zoom_tensor_apply2: Expected tensors to have Zoom DeviceType, but got " + "tensors with type ", a.device().type(), " and ", b.device().type()); + int64_t totalElements = a.numel(); + + if (totalElements != b.numel()) { + return false; + } + + if (a.dim() > MAX_TENSORINFO_DIMS || + b.dim() > MAX_TENSORINFO_DIMS) { + return false; + } + + if (a.numel() == 0) { + // Empty tensor; do nothing + return true; + } + const dim3 block = getApplyBlock(max_threads_per_block); + + dim3 grid; + auto curDevice = c10::zoom::current_device(); + if (curDevice == -1) return false; + if (!getApplyGrid(totalElements, grid, curDevice, max_threads_per_block)) { + return false; + } + + /* + Expands readable/writable tensors whose indices may be "overlapped." + This ensures that each element of the tensor is operated on once and only + once. + */ + TensorBase oldA; + TensorBase oldB; + + if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { + // Must perform in contiguous space + oldA = std::exchange(a, a.contiguous()); + } + if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { + // Must perform in contiguous space + oldB = std::exchange(b, b.contiguous()); + } + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. + +#define HANDLE_CASE(TYPE, A, B) \ + kernelPointwiseApply2 \ + <<>>( \ + aInfo, bInfo, static_cast(totalElements), op); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + +#define HANDLE_B_CASE(TYPE, A, B) { \ + switch (B) { \ + case 1: \ + HANDLE_CASE(TYPE, A, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, -1); \ + break; \ + } \ +} + +#define HANDLE_A_CASE(TYPE, A, B) { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B); \ + break; \ + } \ +} + + if (detail::canUse32BitIndexMath(a) && + detail::canUse32BitIndexMath(b)) { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); + } else { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + /* + Only instantiates the all 1D special case and the fallback all nD case for + large (64-bit indexed) tensors to reduce compilation time. + */ + if (aInfo.dims == 1 && bInfo.dims == 1) { + HANDLE_CASE(uint64_t, 1, 1); + } else { + HANDLE_CASE(uint64_t, -1, -1); + } + } +#undef HANDLE_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + if (oldA.defined()) { + at::native::copy_ignoring_overlaps(oldA, a); + } + + if (oldB.defined()) { + at::native::copy_ignoring_overlaps(oldB, b); + } + + return true; +} + +/* Provides default step = 1 to Zoom_tensor_apply2. */ +template +inline bool Zoom_tensor_apply2(const at::TensorBase &a, + const at::TensorBase &b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + return Zoom_tensor_apply2(a, b, op, aType, bType); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomContext.cpp b/aten/src/ATen/zoom/ZoomContext.cpp new file mode 100644 index 00000000000000..3182fafed7493f --- /dev/null +++ b/aten/src/ATen/zoom/ZoomContext.cpp @@ -0,0 +1,69 @@ +#include +#include +#include + +// #include +#include +#include +#include + +namespace at::zoom { + +namespace { + +DeviceIndex num_gpus = -1; +c10::once_flag init_flag; +std::deque device_flags; +std::vector device_properties; + +void initZoomContextVectors() { + num_gpus = c10::zoom::device_count(); + device_flags.resize(num_gpus); + device_properties.resize(num_gpus); +} + +void initDeviceProperty(DeviceIndex device_index) { + hipDeviceProp_t device_prop; + C10_ZOOM_CHECK(hipGetDeviceProperties(&device_prop, device_index)); + device_properties[device_index] = device_prop; +} + +} // anonymous namespace + +// We need this function to force the linking against torch_cuda(_cpp) on Windows. +// If you need to modify this function, please specify a new function and apply +// the changes according to https://github.com/pytorch/pytorch/pull/34288. +// Related issue: https://github.com/pytorch/pytorch/issues/31611. +/* Device info */ +int warp_size() { + return getCurrentDeviceProperties()->warpSize; +} + +hipDeviceProp_t* getCurrentDeviceProperties() { + auto device = c10::zoom::current_device(); + return getDeviceProperties(device); +} + +hipDeviceProp_t* getDeviceProperties(c10::DeviceIndex device) { + c10::call_once(init_flag, initZoomContextVectors); + if (device == -1) device = c10::zoom::current_device(); + AT_ASSERT(device >= 0 && device < num_gpus, "device=", device, ", num_gpus=", num_gpus); + c10::call_once(device_flags[device], initDeviceProperty, device); + return &device_properties[device]; +} + +bool canDeviceAccessPeer(c10::DeviceIndex device, c10::DeviceIndex peer_device) { + c10::call_once(init_flag, initZoomContextVectors); + if (device == -1) device = c10::zoom::current_device(); + AT_ASSERT(device >= 0 && device < num_gpus, "device=", device, ", num_gpus=", num_gpus); + AT_ASSERT(peer_device >= 0 && peer_device < num_gpus, "peer_device=", peer_device, ", num_gpus=", num_gpus); + int can_access = 0; + C10_ZOOM_CHECK(hipDeviceCanAccessPeer(&can_access, device, peer_device)); + return can_access != 0; +} + +Allocator* getZoomDeviceAllocator() { + return c10::zoom::ZoomCachingAllocator::get(); +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomContext.h b/aten/src/ATen/zoom/ZoomContext.h new file mode 100644 index 00000000000000..98a36bee8b4fd7 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomContext.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +// Preserved for BC, as many files depend on these includes +#include +#include +#include +#include \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomContextLight.h b/aten/src/ATen/zoom/ZoomContextLight.h new file mode 100644 index 00000000000000..44a82879f05267 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomContextLight.h @@ -0,0 +1,85 @@ +#pragma once +// Light-weight version of ZoomContext.h with fewer transitive includes +#define DISABLE_HIPBLASLT + +#include + +#include +#include +#include + +#include +#include +#include +#ifndef DISABLE_HIPBLASLT +#include +#include +#endif + +namespace c10 { +struct Allocator; +} + +namespace at::zoom { + +/* +A common CUDA interface for ATen. + +This interface is distinct from CUDAHooks, which defines an interface that links +to both CPU-only and CUDA builds. That interface is intended for runtime +dispatch and should be used from files that are included in both CPU-only and +CUDA builds. + +CUDAContext, on the other hand, should be preferred by files only included in +CUDA builds. It is intended to expose CUDA functionality in a consistent +manner. + +This means there is some overlap between the CUDAContext and CUDAHooks, but +the choice of which to use is simple: use CUDAContext when in a CUDA-only file, +use CUDAHooks otherwise. + +Note that CUDAContext simply defines an interface with no associated class. +It is expected that the modules whose functions compose this interface will +manage their own state. There is only a single CUDA context/state. +*/ + +/** + * DEPRECATED: use device_count() instead + */ +inline int64_t getNumGPUs() { + return c10::zoom::device_count(); +} + +/** + * CUDA is available if we compiled with CUDA, and there are one or more + * devices. If we compiled with CUDA but there is a driver problem, etc., + * this function will report CUDA is not available (rather than raise an error.) + */ +inline bool is_available() { + return c10::zoom::device_count() > 0; +} + +TORCH_ZOOM_API hipDeviceProp_t* getCurrentDeviceProperties(); + +TORCH_ZOOM_API int warp_size(); + +TORCH_ZOOM_API hipDeviceProp_t* getDeviceProperties(c10::DeviceIndex device); + +TORCH_ZOOM_API bool canDeviceAccessPeer( + c10::DeviceIndex device, + c10::DeviceIndex peer_device); + +TORCH_ZOOM_API c10::Allocator* getZoomDeviceAllocator(); + +TORCH_ZOOM_API hipsparseHandle_t getCurrentHIPSparseHandle(); +TORCH_ZOOM_API hipblasHandle_t getCurrentHIPBlasHandle(); +#ifndef DISABLE_HIPBLASLT +TORCH_ZOOM_API hipblasLtHandle_t getCurrentHIPBlasLtHandle(); +#endif + + +#if defined(hipsolverVersionMajor) +TORCH_ZOOM_API hipsolverDnHandle_t getCurrentHIPSolverDnHandle(); +#endif + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomDataType.h b/aten/src/ATen/zoom/ZoomDataType.h new file mode 100644 index 00000000000000..41186e419bea1e --- /dev/null +++ b/aten/src/ATen/zoom/ZoomDataType.h @@ -0,0 +1,97 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at::zoom { + +template +hipDataType getHIPDataType() { + TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to hipDataType.") +} + +template<> inline hipDataType getHIPDataType() { + return HIP_R_16F; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_32F; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_64F; +} +template<> inline hipDataType getHIPDataType>() { + return HIP_C_16F; +} +template<> inline hipDataType getHIPDataType>() { + return HIP_C_32F; +} +template<> inline hipDataType getHIPDataType>() { + return HIP_C_64F; +} + +template<> inline hipDataType getHIPDataType() { + return HIP_R_8U; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_8I; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_32I; +} + +template<> inline hipDataType getHIPDataType() { + return HIP_R_16I; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_64I; +} +template<> inline hipDataType getHIPDataType() { + return HIP_R_16BF; +} + +inline hipDataType ScalarTypeToHIPDataType(const c10::ScalarType& scalar_type) { + switch (scalar_type) { + case c10::ScalarType::Byte: + return HIP_R_8U; + case c10::ScalarType::Char: + return HIP_R_8I; + case c10::ScalarType::Int: + return HIP_R_32I; + case c10::ScalarType::Half: + return HIP_R_16F; + case c10::ScalarType::Float: + return HIP_R_32F; + case c10::ScalarType::Double: + return HIP_R_64F; + case c10::ScalarType::ComplexHalf: + return HIP_C_16F; + case c10::ScalarType::ComplexFloat: + return HIP_C_32F; + case c10::ScalarType::ComplexDouble: + return HIP_C_64F; + case c10::ScalarType::Short: + return HIP_R_16I; + case c10::ScalarType::Long: + return HIP_R_64I; + case c10::ScalarType::BFloat16: + return HIP_R_16BF; +#if defined(HIP_NEW_TYPE_ENUMS) + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#else + case c10::ScalarType::Float8_e4m3fnuz: + return static_cast(1000); + case c10::ScalarType::Float8_e5m2fnuz: + return static_cast(1001); +#endif + default: + TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to hipDataType.") + } +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomDevice.h b/aten/src/ATen/zoom/ZoomDevice.h new file mode 100644 index 00000000000000..e7ac4e781cba91 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomDevice.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include + +namespace at::zoom { + +inline Device getDeviceFromPtr(void* ptr) { + hipPointerAttribute_t attr{}; + + C10_ZOOM_CHECK(hipPointerGetAttributes(&attr, ptr)); + + return {c10::DeviceType::PrivateUse1, static_cast(attr.device)}; +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomEvent.h b/aten/src/ATen/zoom/ZoomEvent.h new file mode 100644 index 00000000000000..dfb0557e6fba0e --- /dev/null +++ b/aten/src/ATen/zoom/ZoomEvent.h @@ -0,0 +1,213 @@ +#pragma once + +// #include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::zoom { + +/* +* CUDAEvents are movable not copyable wrappers around CUDA's events. +* +* CUDAEvents are constructed lazily when first recorded unless it is +* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this +* device is acquired from the first recording stream. However, if reconstructed +* from a handle, the device should be explicitly specified; or if ipc_handle() is +* called before the event is ever recorded, it will use the current device. +* Later streams that record the event must match this device. +*/ +struct TORCH_ZOOM_API ZoomEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + ZoomEvent() noexcept = default; + ZoomEvent(unsigned int flags) noexcept : flags_{flags} {} + + ZoomEvent( + DeviceIndex device_index, const hipIpcEventHandle_t* handle) { + device_index_ = device_index; + c10::zoom::ZoomGuard guard(device_index_); + + C10_ZOOM_CHECK(hipIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~ZoomEvent() { + try { + if (is_created_) { + c10::zoom::ZoomGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion(DeviceType::PrivateUse1, reinterpret_cast(event_)); + } + C10_ZOOM_CHECK(hipEventDestroy(event_)); + } + } catch (...) { /* No throw */ } + } + + ZoomEvent(const ZoomEvent&) = delete; + ZoomEvent& operator=(const ZoomEvent&) = delete; + + ZoomEvent(ZoomEvent&& other) noexcept { moveHelper(std::move(other)); } + ZoomEvent& operator=(ZoomEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator hipEvent_t() const { return event(); } + + // Less than operator (to allow use in sets) + friend bool operator<(const ZoomEvent& left, const ZoomEvent& right) { + return left.event_ < right.event_; + } + + optional device() const { + if (is_created_) { + return at::Device(DeviceType::PrivateUse1, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { return is_created_; } + DeviceIndex device_index() const {return device_index_;} + hipEvent_t event() const { return event_; } + + // Note: hipEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + hipError_t err = hipEventQuery(event_); + if (err == hipSuccess) { + return true; + } else if (err != hipErrorNotReady) { + C10_ZOOM_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)hipGetLastError(); + } + + return false; + } + + void record() { record(c10::zoom::getCurrentZoomStream()); } + + void recordOnce(const c10::zoom::ZoomStream& stream) { + if (!was_recorded_) record(stream); + } + + // Note: hipEventRecord must be called on the same device as the event. + void record(const c10::zoom::ZoomStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, + " does not match recording stream's device ", stream.device_index(), "."); + c10::zoom::ZoomGuard guard(device_index_); + C10_ZOOM_CHECK(hipEventRecord(event_, stream)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record(DeviceType::PrivateUse1, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + was_recorded_ = true; + } + + // Note: hipStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const c10::zoom::ZoomStream& stream) { + if (is_created_) { + c10::zoom::ZoomGuard guard(stream.device_index()); + C10_ZOOM_CHECK(hipStreamWaitEvent(stream, event_, 0)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait(DeviceType::PrivateUse1, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + } + } + + // Note: hipEventElapsedTime can be safely called from any device + float elapsed_time(const ZoomEvent& other) const { + TORCH_CHECK(is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new hip context, which will consume a lot of memory. + c10::zoom::ZoomGuard guard(device_index_); + // raise hipErrorNotReady if either event is recorded but not yet completed + C10_ZOOM_CHECK(hipEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: hipEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization(DeviceType::PrivateUse1, reinterpret_cast(event_)); + } + C10_ZOOM_CHECK(hipEventSynchronize(event_)); + } + } + + // Note: hipIpcGetEventHandle must be called on the same device as the event + void ipc_handle(hipIpcEventHandle_t * handle) { + if (!is_created_) { + // this ZoomEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(c10::zoom::getCurrentZoomStream().device_index()); + } + c10::zoom::ZoomGuard guard(device_index_); + C10_ZOOM_CHECK(hipIpcGetEventHandle(handle, event_)); + } + +private: + unsigned int flags_ = hipEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + DeviceIndex device_index_ = -1; + hipEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + device_index_ = device_index; + c10::zoom::ZoomGuard guard(device_index_); + C10_ZOOM_CHECK(hipEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation(DeviceType::PrivateUse1, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(ZoomEvent&& other) { + std::swap(flags_, other.flags_); + std::swap(is_created_, other.is_created_); + std::swap(was_recorded_, other.was_recorded_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } +}; + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomGeneratorImpl.cpp b/aten/src/ATen/zoom/ZoomGeneratorImpl.cpp new file mode 100644 index 00000000000000..d0b9a5a963db95 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomGeneratorImpl.cpp @@ -0,0 +1,512 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace zoom::detail { + +namespace { + +// Ensures we only call cudaGetDeviceCount only once. +static c10::once_flag num_gpu_init_flag; + +// Total number of gpus in the system. +static int64_t num_gpus; + +// Ensures default_gens_zoom is initialized once. +static std::deque zoom_gens_init_flag; + +// Default, global CUDA generators, one per GPU. +static std::vector default_gens_zoom; + +/* + * Populates the global variables related to CUDA generators + * Warning: this function must only be called once! + */ +static void initZoomGenVector() { + num_gpus = c10::zoom::device_count(); + zoom_gens_init_flag.resize(num_gpus); + default_gens_zoom.resize(num_gpus); +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * getDefaultZoomGenerator gets the default generator for a particular + * cuda device. + */ +const Generator& getDefaultZoomGenerator(DeviceIndex device_index) { + c10::call_once(num_gpu_init_flag, initZoomGenVector); + DeviceIndex idx = device_index; + if (idx == -1) { + idx = c10::zoom::current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < num_gpus); + } + c10::call_once(zoom_gens_init_flag[idx], [&] { + default_gens_zoom[idx] = make_generator(idx); + default_gens_zoom[idx].seed(); + }); + return default_gens_zoom[idx]; +} + +// register to PrivateUse1 +REGISTER_GENERATOR_PRIVATEUSE1(getDefaultZoomGenerator); + +/** + * Utility to create a ZoomGeneratorImpl. Returns a shared_ptr + */ +Generator createZoomGenerator(DeviceIndex device_index) { + c10::call_once(num_gpu_init_flag, initZoomGenVector); + DeviceIndex idx = device_index; + if (idx == -1) { + idx = c10::zoom::current_device(); + } + TORCH_CHECK(idx >= 0 && idx < num_gpus, "The device_index is invalid."); + auto gen = make_generator(idx); + auto zoom_gen = check_generator(gen); + zoom_gen->set_current_seed(default_rng_seed_val); + zoom_gen->set_philox_offset_per_thread(0); + return gen; +} + +} // namespace zoom::detail + +/** + * Creates a clone of this CUDA Generator State. + */ +c10::intrusive_ptr ZoomGeneratorState::clone() { + return make_intrusive( + seed_, philox_offset_per_thread_, offset_intragraph_); +} + +/** + * Function to increase the internal offset based on the specified increment. + */ +void ZoomGeneratorState::increase(uint64_t increment) { + // Rounds increment up to the nearest multiple of 4 to meet alignment + // requirements. + // see Note [Why enforce RNG offset % 4 == 0?] + increment = ((increment + 3) / 4) * 4; + // Handling different behaviors based on whether capturing is active. + if (at::zoom::currentStreamCaptureStatus() != at::zoom::CaptureStatus::None) { + // Ensures that the state is actually capturing. + TORCH_CHECK( + capturing_, + "Attempt to increase offset for a Zoom generator not in capture mode."); + // Ensures the offset is a multiple of 4 + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT( + offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4."); + // Ensures the increment does not cause overflow. + TORCH_INTERNAL_ASSERT( + offset_intragraph_ <= std::numeric_limits::max() - increment, + "Increment causes overflow in the offset value."); + offset_intragraph_ += increment; + } else { + // Checks that the increment is expected outside graph capturing. + TORCH_CHECK( + !capturing_, + "Offset increment outside graph capture encountered unexpectedly."); + // Ensures the offset is a multiple of 4 + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT( + philox_offset_per_thread_ % 4 == 0, + "RNG offset must be a multiple of 4."); + philox_offset_per_thread_ += increment; + } +} + +/** + * Registers this state to a CUDA graph to manage within the graph. + */ +void ZoomGeneratorState::register_graph(zoom::HIPGraph* graph) { + // Ensures that the RNG state is not currently being captured. + at::zoom::assertNotCapturing( + "Cannot register the state during capturing stage."); + + // If this is the first graph to be registered, allocate memory for the seed + // and offset on the GPU. + if (registered_graphs_.empty()) { + auto options = at::TensorOptions().device(DeviceType::PrivateUse1).dtype(at::kLong); + seed_extragraph_ = at::empty({1}, options); + offset_extragraph_ = at::empty({1}, options); + } + + // Insert the graph into the set of registered graphs if it's not already + // registered. + if (registered_graphs_.find(graph) == registered_graphs_.end()) { + registered_graphs_.insert(graph); + } +} + +/** + * Unregisters a CUDA graph from the RNG state. + */ +void ZoomGeneratorState::unregister_graph(zoom::HIPGraph* graph) { + // Ensures that the RNG state is not currently being captured. + at::zoom::assertNotCapturing( + "Cannot unregister the state during capturing stage."); + // Verify the graph was previously registered. + TORCH_CHECK( + registered_graphs_.find(graph) != registered_graphs_.end(), + "The graph should be registered to the state"); + + // Remove the graph from the set of registered graphs. + registered_graphs_.erase(graph); + + // If no more graphs are registered, deallocate the GPU memory for the seed + // and offset. + if (registered_graphs_.empty()) { + seed_extragraph_.reset(); + offset_extragraph_.reset(); + } +} + +/** + * Note [Explicit Registration of Generators to the CUDA Graph] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Ideally, it would be more user-friendly if the state could be exchanged and generators + * could be registered with the CUDA graph implicitly. However, resetting GPU tensors during + * the capture stage causes these reset operations to be recorded within the CUDA graph. + * This behavior is undesirable because we do not want these tensors to be reset during + * the replay stage of the graph. + * + * As of now, there is no available method to perform a CUDA operation during the graph's + * recording phase without having that operation be included in the CUDA graph. + * This limitation necessitates explicit user action to register generators with the graph. + * By requiring users to manually register their generators, we can ensure that state resets + * (capture_prologue) only occur before the graph capture begins, thus avoiding unintended + * resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068. + */ + +/** + * Performs the prologue steps for capturing a CUDA graph state. + * This method is intended to reset graph-related state variables before capturing begins. + */ +void ZoomGeneratorState::capture_prologue() { + capturing_ = true; + offset_intragraph_ = 0; + seed_extragraph_.fill_(int64_t(seed_)); + offset_extragraph_.fill_(int64_t(0)); +} + +/** + * Ends the capturing phase and resets related variables, returning the whole + * graph increment. + */ +uint64_t ZoomGeneratorState::capture_epilogue() { + capturing_ = false; + return offset_intragraph_; +} + +/** + * Prepares the state for replay by setting initial state tensors and applying + * total increment. + */ +void ZoomGeneratorState::replay_prologue(uint64_t wholegraph_increment) { + // Ensures the generator is not in capturing mode. + at::zoom::assertNotCapturing( + "Cannot prepare for replay during capturing stage."); + seed_extragraph_.fill_(int64_t(seed_)); + offset_extragraph_.fill_(int64_t(philox_offset_per_thread_)); + // Applies the total increment achieved during previous captures to update the + // offset. + increase(wholegraph_increment); +} + +/** + * Note [Why enforce RNG offset % 4 == 0?] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Curand philox does allow offsets that aren't a multiple of 4. + * But jit kernels don't use curand, they use a custom "Philox" class (see + * torch/csrc/jit/tensorexpr/cuda_random.h or + * torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu). + * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its + * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called + * in a thread, returns one 32-bit chunk at a time from that start in the bitstream. + * In other words, if the incoming offset is not a multiple of 4, each thread + * might repeat some previously-generated 32-bit values in the bitstream. See + * https://github.com/pytorch/pytorch/pull/50169. + */ + +/** + * ZoomGeneratorImpl class implementation + */ +ZoomGeneratorImpl::ZoomGeneratorImpl(DeviceIndex device_index) + : c10::GeneratorImpl{Device(DeviceType::PrivateUse1, device_index), + DispatchKeySet(c10::DispatchKey::PrivateUse1)} { + at::zoom::assertNotCapturing("Cannot construct a new ZoomGeneratorImpl"); + state_ = make_intrusive(); + no_reset_rnn_state_.clear(); +} + +ZoomGeneratorImpl::ZoomGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state) + : c10:: + GeneratorImpl{Device(DeviceType::PrivateUse1, device_index), DispatchKeySet(c10::DispatchKey::PrivateUse1)}, + state_(std::move(state)) { + no_reset_rnn_state_.clear(); +} + +/** + * Sets the seed to be used by curandStatePhilox4_32_10 + * Resets the philox_offset_per_thread_ to 0 + * + * See Note [Acquire lock when using random generators] + */ +void ZoomGeneratorImpl::set_current_seed(uint64_t seed) { + at::zoom::assertNotCapturing( + "Cannot call ZoomGeneratorImpl::set_current_seed"); + state_->seed_ = seed; + state_->philox_offset_per_thread_ = 0; + no_reset_rnn_state_.clear(); +} + +/** + * Sets the offset to be used by curandStatePhilox4_32_10 + * + * See Note [Acquire lock when using random generators] + */ +void ZoomGeneratorImpl::set_offset(uint64_t offset) { + at::zoom::assertNotCapturing("Cannot call ZoomGeneratorImpl::set_offset"); + // the set function checks if the offset is a multiple of 4. + set_philox_offset_per_thread(offset); + no_reset_rnn_state_.clear(); +} + +/** + * Gets the current offset of ZoomGeneratorImpl. + */ +uint64_t ZoomGeneratorImpl::get_offset() const { + // Debatable if get_offset() should be allowed in captured regions. + // Conservatively disallow it for now. + at::zoom::assertNotCapturing("Cannot call ZoomGeneratorImpl::get_offset"); + return state_->philox_offset_per_thread_; +} + +/** + * Gets the current seed of ZoomGeneratorImpl. + */ +uint64_t ZoomGeneratorImpl::current_seed() const { + // Debatable if current_seed() should be allowed in captured regions. + // Conservatively disallow it for now. + at::zoom::assertNotCapturing("Cannot call ZoomGeneratorImpl::current_seed"); + return state_->seed_; +} + +/** + * Gets a nondeterministic random number from /dev/urandom or time, + * seeds the CPUGeneratorImpl with it and then returns that number. + * + * FIXME: You can move this function to Generator.cpp if the algorithm + * in getNonDeterministicRandom is unified for both CPU and CUDA + */ +uint64_t ZoomGeneratorImpl::seed() { + at::zoom::assertNotCapturing("Cannot call ZoomGeneratorImpl::seed"); + auto random = c10::detail::getNonDeterministicRandom(true); + this->set_current_seed(random); + return random; +} + +/** + * Gets the current internal state of ZoomGeneratorImpl. The internal + * state is returned as a CPU byte tensor. + */ +c10::intrusive_ptr ZoomGeneratorImpl::get_state() const { + // The RNG state comprises the seed, and an offset used for Philox. + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = seed_size + offset_size; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + auto current_seed = this->current_seed(); + auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic + memcpy(rng_state, ¤t_seed, seed_size); + memcpy(rng_state + seed_size, &offset, offset_size); + + return state_tensor.getIntrusivePtr(); +} + +/** + * Sets the internal state of ZoomGeneratorImpl. The new internal state + * must be a strided CPU byte tensor and have appropriate size. See + * comments of ZoomGeneratorImpl::state for information about the layout + * and size of the internal state. + */ +void ZoomGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + at::zoom::assertNotCapturing( + "Please ensure to utilize the ZoomGeneratorImpl::set_state_index method during capturing."); + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = seed_size + offset_size; + + detail::check_rng_state(new_state); + + bool no_philox_seed = false; + auto new_state_size = new_state.numel(); + if (new_state_size == total_size - offset_size) { + no_philox_seed = true; + } else { + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + } + + uint64_t input_seed = 0; + auto new_rng_state = new_state.data_dtype_initialized(); + memcpy(&input_seed, new_rng_state, seed_size); + this->set_current_seed(input_seed); + int64_t philox_offset = 0; + if (!no_philox_seed) { + memcpy(&philox_offset, new_rng_state + seed_size, offset_size); + } + this->set_philox_offset_per_thread(static_cast(philox_offset)); +} + +/** + * Sets the generator's current state to + * This function allows switching between different registered states of + * the generator. + */ +void ZoomGeneratorImpl::graphsafe_set_state( + const c10::intrusive_ptr& gen) { + c10::intrusive_ptr zoom_gen = + dynamic_intrusive_pointer_cast(gen); + TORCH_CHECK(zoom_gen, "Expected a Zoom Generator"); + state_ = zoom_gen->state_; +} + +/** + * Get the GeneratorImpl that point to current state_ + */ +c10::intrusive_ptr ZoomGeneratorImpl::graphsafe_get_state() + const { + auto gen = make_intrusive(device().index(), state_); + return gen; +} + +/** + * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10 + * + * See Note [Acquire lock when using random generators] + */ +void ZoomGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); + state_->philox_offset_per_thread_ = offset; +} + +/** + * Gets the current philox_offset_per_thread_ of ZoomGeneratorImpl. + */ +uint64_t ZoomGeneratorImpl::philox_offset_per_thread() const { + return state_->philox_offset_per_thread_; +} + +/** + * Registers this state to a CUDA graph to manage within the graph. + */ +void ZoomGeneratorImpl::register_graph(zoom::HIPGraph* graph) { + graph->register_generator_state(state_); + state_->register_graph(graph); +} + +/** + * Unregisters a CUDA graph from the RNG state. + */ +void ZoomGeneratorImpl::unregister_graph(zoom::HIPGraph* graph) { + state_->unregister_graph(graph); +} + +/** + * Gets the seed and philox offset value to be used in + * curandStatePhilox4_32_10, in an opaque PhiloxHIPState that's safe + * and can be used non-divergently in callers whether CUDA graph + * capture is underway or not. See + * Note [CUDA Graph-safe RNG states] + * + * Each kernel using philox has to sensibly increment offset + * for future users of philox. So it gets the "old" value for + * itself (before add), and tells subsequent users which offset + * they should use, since only the kernel knows how many randoms + * it intends to generate. + * + * Increment should be at least the number of curand() random numbers used in + * each thread. It is the user's responsibility to make sure the increment + * for philox is never smaller than the number of curand() calls. Increment + * value > the number of curand() calls won't harm but anything less would mean + * that you would be reusing random values from previous calls. + * + * See Note [Acquire lock when using random generators] + */ +PhiloxHIPState ZoomGeneratorImpl::philox_hip_state(uint64_t increment) { + if (at::zoom::currentStreamCaptureStatus() != at::zoom::CaptureStatus::None) { + uint32_t offset = state_->offset_intragraph_; + state_->increase(increment); + return PhiloxHIPState( + state_->seed_extragraph_.data_ptr(), + state_->offset_extragraph_.data_ptr(), + offset); + } else { + uint64_t offset = state_->philox_offset_per_thread_; + state_->increase(increment); + return PhiloxHIPState(state_->seed_, offset); + } +} + +/** + * Temporarily accommodates call sites that use philox_engine_inputs. + * Allows incremental refactor of call sites to use philox_hip_state. + */ +std::pair ZoomGeneratorImpl::philox_engine_inputs( + uint64_t increment) { + at::zoom::assertNotCapturing( + "Refactor this op to use ZoomGeneratorImpl::philox_hip_state. Cannot call ZoomGeneratorImpl::philox_engine_inputs"); + uint64_t offset = state_->philox_offset_per_thread_; + state_->increase(increment); + return std::make_pair(state_->seed_, offset); +} + +/* + * Gets the DeviceType of ZoomGeneratorImpl. + * Used for type checking during run time. + */ +DeviceType ZoomGeneratorImpl::device_type() { + return DeviceType::PrivateUse1; +} + +/** + * Public clone method implementation + * + * See Note [Acquire lock when using random generators] + */ +std::shared_ptr ZoomGeneratorImpl::clone() const { + return std::shared_ptr(this->clone_impl()); +} + +/** + * Private clone method implementation + * + * See Note [Acquire lock when using random generators] + */ +ZoomGeneratorImpl* ZoomGeneratorImpl::clone_impl() const { + at::zoom::assertNotCapturing("Cannot call ZoomGeneratorImpl::clone_impl"); + auto gen = new ZoomGeneratorImpl(this->device().index(), state_->clone()); + return gen; +} + +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomGeneratorImpl.h b/aten/src/ATen/zoom/ZoomGeneratorImpl.h new file mode 100644 index 00000000000000..106432ec428fa2 --- /dev/null +++ b/aten/src/ATen/zoom/ZoomGeneratorImpl.h @@ -0,0 +1,181 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +namespace at { + +namespace zoom { +struct HIPGraph; +} + +/** + * Note [CUDA Graph-safe RNG states] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Strategy: + * ~~~~~~~~~ + * (It helps to look at + * cuda/detail/PhiloxCudaStateRaw.cuh and + * cuda/detail/UnpackRaw.cuh + * while you read this.) + * + * A CUDA graph containing multiple RNG ops behaves like a + * single giant kernel from the perspective of ops external + * to the graph. During graph capture, logic in ZoomGeneratorImpl + * records the total of all offset increments that occur in the + * graphed region, and records the final total as the offset for + * the entire graph. + * + * When the graph reruns, the logic that reruns it + * increments this device's CUDA generator's offset + * by that total. + * + * Meanwhile, within the graph, at capture time, instead of + * populating PhiloxCudaStates with the uint64_t offset pulled + * directly from the global state, PhiloxHIPState uses a pointer + * to a one-element stream-local int64_t device tensor + * holding an initial offset value, and a uint64_t holding an + * intra-graph offset. (The intra-graph offset starts from zero + * when capture begins.) In each consumer kernel, + * at::zoom::philox::unpack computes the offset to use for this kernel + * as intra-graph offset + *initial offset. + * + * When the graph reruns, the logic that reruns it first + * fill_s the initial offset tensor with this device's + * CUDA generator's current offset. + * + * The control flow above ensures graphed execution is bitwise + * identical to eager execution as long as RNG ops are enqueued + * from a single thread, even if RNG ops and graphs containing + * RNG ops are enqueued and run simultaneously on multiple streams. + * + * Usage: + * ~~~~~~ + * PhiloxHIPState in this file, and unpack() in + * cuda/CUDAGraphsUtils.cuh allow non-divergent use of + * ZoomGeneratorImpl whether graph capture is underway or not. + * + * Each PhiloxHIPState instance should be used for one and only one + * consumer kernel. + * + * Example (see e.g. native/cuda/Dropout.cu): + * + * #include + * #include + * + * __global__ void kernel(..., PhiloxHIPState philox_args) { + * auto seeds = at::zoom::philox::unpack(philox_args); + * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + * curandStatePhilox4_32_10_t state; + * curand_init(std::get<0>(seeds), // seed + * idx, // per-thread subsequence + * std::get<1>(seeds), // offset in subsequence + * &state); + * ... + * } + * + * host_caller(...) { + * PhiloxHIPState rng_engine_inputs; + * { + * // See Note [Acquire lock when using random generators] + * std::lock_guard lock(gen->mutex_); + * + * // gen could be HostState or DevState here! No divergent code needed! + * rng_engine_inputs = gen->philox_hip_state(offset_increment); + * } + * kernel<<<...>>>(..., rng_engine_inputs); + * } + * + */ + +struct ZoomGeneratorState : public c10::intrusive_ptr_target { + uint64_t seed_; + uint64_t philox_offset_per_thread_; + uint32_t offset_intragraph_; + bool capturing_{}; + std::unordered_set registered_graphs_; + at::TensorBase seed_extragraph_{}; + at::TensorBase offset_extragraph_{}; + + ZoomGeneratorState( + uint64_t seed = default_rng_seed_val, + uint64_t philox_offset_per_thread = 0, + uint32_t offset_intragraph = 0) + : seed_(seed), + philox_offset_per_thread_(philox_offset_per_thread), + offset_intragraph_(offset_intragraph) {} + + void increase(uint64_t increment); + + void register_graph(zoom::HIPGraph* graph); + void unregister_graph(zoom::HIPGraph* graph); + + void capture_prologue(); + // capture_epilogue returns the wholegraph_increment + uint64_t capture_epilogue(); + void replay_prologue(uint64_t wholegraph_increment); + c10::intrusive_ptr clone(); +}; + +struct TORCH_ZOOM_API ZoomGeneratorImpl : public c10::GeneratorImpl { + // Constructors + ZoomGeneratorImpl(DeviceIndex device_index = -1); + ZoomGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state_); + ~ZoomGeneratorImpl() override = default; + + // ZoomGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void graphsafe_set_state( + const c10::intrusive_ptr& state) override; + c10::intrusive_ptr graphsafe_get_state() const override; + + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + + void register_graph(zoom::HIPGraph* graph); + void unregister_graph(zoom::HIPGraph* graph); + + // Generates a PhiloxHIPState with a specified increment, and increment + // current state + PhiloxHIPState philox_hip_state(uint64_t increment); + + bool reset_rnn_state() { + return !no_reset_rnn_state_.test_and_set(); + } + + // Temporarily accommodates call sites that use philox_engine_inputs. + // Allows incremental refactor of call sites to use philox_hip_state. + std::pair philox_engine_inputs(uint64_t increment); + + static c10::DeviceType device_type(); + + private: + ZoomGeneratorImpl* clone_impl() const override; + + c10::intrusive_ptr state_; + std::atomic_flag no_reset_rnn_state_; +}; + +namespace zoom::detail { + +TORCH_ZOOM_API const Generator& getDefaultZoomGenerator( + DeviceIndex device_index = -1); +TORCH_ZOOM_API Generator createZoomGenerator(DeviceIndex device_index = -1); + +} // namespace zoom::detail +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/zoom/cub-RadixSortKeys.cu b/aten/src/ATen/zoom/cub-RadixSortKeys.cu new file mode 100644 index 00000000000000..a18326a1daacee --- /dev/null +++ b/aten/src/ATen/zoom/cub-RadixSortKeys.cu @@ -0,0 +1,59 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at::zoom::hipcub { + +template +void radix_sort_keys( + const key_t* keys_in, + key_t* keys_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::hip_type::type; + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceRadixSort::SortKeysDescending, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::zoom::getCurrentZoomStream()); + } else { + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceRadixSort::SortKeys, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::zoom::getCurrentZoomStream()); + } +} + +#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ + template void radix_sort_keys( \ + const scalar_t* keys_in, \ + scalar_t* keys_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) +AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16) +AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32) +AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64) + +} // namespace at::zoom::hipcub diff --git a/aten/src/ATen/zoom/cub-RadixSortPairs.cu b/aten/src/ATen/zoom/cub-RadixSortPairs.cu new file mode 100644 index 00000000000000..ef81eb365f1c9b --- /dev/null +++ b/aten/src/ATen/zoom/cub-RadixSortPairs.cu @@ -0,0 +1,86 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at::zoom::hipcub::detail { + +template +void radix_sort_pairs_impl( + const key_t* keys_in, + key_t* keys_out, + const OpaqueType* values_in, + OpaqueType* values_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::hip_type::type; + + auto allocator = c10::zoom::ZoomCachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(n * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceRadixSort::SortPairsDescending, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::zoom::getCurrentZoomStream()); + } else { + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceRadixSort::SortPairs, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::zoom::getCurrentZoomStream()); + } +} + +#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ + template void radix_sort_pairs_impl( \ + const key_t* keys_in, \ + key_t* keys_out, \ + const OpaqueType* values_in, \ + OpaqueType* values_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) + +#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ + AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) +AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8) +AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8) +AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8) +AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) + +} // namespace at::zoom::hipcub::detail diff --git a/aten/src/ATen/zoom/cub.cu b/aten/src/ATen/zoom/cub.cu new file mode 100644 index 00000000000000..f00caf3675f20a --- /dev/null +++ b/aten/src/ATen/zoom/cub.cu @@ -0,0 +1,51 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at::zoom::hipcub { + +namespace { +template +struct SumOp { + __device__ scalar_t operator () (scalar_t a, scalar_t b) const { + return a + b; + } +}; +} + +template +void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) { + using NO_ROCM(at_zoom_detail)::hipcub::Sum; + inclusive_scan(input, output, Sum{}, num_items); +} + +template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items); +template void inclusive_sum_truncating(const int64_t *input, int64_t *output, int64_t num_items); +template void inclusive_sum_truncating(const int32_t *input, int64_t *output, int64_t num_items); + +template +void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t num_items) { + using scalar_t = std::common_type_t; + exclusive_scan(input, output, SumOp{}, scalar_t(0), num_items); +} + +template void exclusive_sum_in_common_type(const int32_t *input, int32_t *output, int64_t num_items); +template void exclusive_sum_in_common_type(const int64_t *input, int64_t *output, int64_t num_items); + +namespace { +struct CountMaskOp { + __device__ int64_t operator() (const uint8_t &x) const { + return x != 0; + } +}; +} + +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) { + CountMaskOp op{}; + auto iter = NO_ROCM(at_zoom_detail)::hipcub::TransformInputIterator< + bool, decltype(op), decltype(mask)>(mask, op); + exclusive_scan(iter, output_idx, SumOp{}, int64_t{0}, n); +} + +} // namespace at::zoom::hipcub diff --git a/aten/src/ATen/zoom/cub.cuh b/aten/src/ATen/zoom/cub.cuh new file mode 100644 index 00000000000000..331f98301eca4e --- /dev/null +++ b/aten/src/ATen/zoom/cub.cuh @@ -0,0 +1,284 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#pragma once +#include + +#include +#include +#include +#include + +#include + +#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() + +#include + +#else + +// include cub in a safe manner, see: +// https://github.com/pytorch/pytorch/pull/55292 +#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER +#define CUB_NS_PREFIX namespace at_zoom_detail { +#define CUB_NS_POSTFIX } +#define CUB_NS_QUALIFIER ::at_zoom_detail::hipcub +#include +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER + +#endif + +#include +#include +#include + +// handle the temporary storage and 'twice' calls for cub API +#define HIPCUB_WRAPPER(func, ...) do { \ + size_t temp_storage_bytes = 0; \ + func(nullptr, temp_storage_bytes, __VA_ARGS__); \ + auto& caching_allocator = *::c10::zoom::ZoomCachingAllocator::get(); \ + auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ + func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ + C10_ZOOM_CHECK(hipGetLastError()); \ +} while (false) + +#define NO_ROCM(x) +#define ROCM_HIPCUB(x) ::hipcub + + +// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 + +template <> +struct ROCM_HIPCUB(cub)::FpLimits +{ + static __host__ __device__ __forceinline__ c10::BFloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast(max_word); + } + + static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast(lowest_word); + } +}; + +template <> +struct ROCM_HIPCUB(cub)::NumericTraits: + ROCM_HIPCUB(cub)::BaseTraits {}; + + + +namespace at::zoom::hipcub { + +namespace detail { + +template +struct hip_type { + using type = T; +}; +template<> +struct hip_type { + using type = __half; +}; + +template<> +struct hip_type { + using type = hip_bfloat16; +}; + + +} // namespace detail + +template +inline void segmented_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t num_elements, int64_t num_segments, + OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, + bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 +) { + TORCH_CHECK(num_elements <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + TORCH_CHECK(num_segments <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::hip_type::type; + + auto allocator = c10::zoom::ZoomCachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(num_elements * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_ *keys_in_ = reinterpret_cast(keys_in); + key_t_ *keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + HIPCUB_WRAPPER(NO_ROCM(at_zoom_detail)::hipcub::DeviceSegmentedRadixSort::SortPairsDescending, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::zoom::getCurrentZoomStream()); + } else { + HIPCUB_WRAPPER(NO_ROCM(at_zoom_detail)::hipcub::DeviceSegmentedRadixSort::SortPairs, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::zoom::getCurrentZoomStream()); + } +} + +#if CUB_SUPPORTS_UNIQUE_BY_KEY() +template +inline void unique_by_key( + KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, + KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out, + NumSelectedIteratorT num_selected, int64_t num_input_items) +{ + // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed. + constexpr bool null_keys_out = std::is_same::value; + using KeyT = typename std::iterator_traits::value_type; + using RealKeysOutputIteratorT = typename std::conditional::type; + RealKeysOutputIteratorT keys_out_; + auto allocator = c10::zoom::ZoomCachingAllocator::get(); + c10::DataPtr keys_out_owner; + if constexpr (null_keys_out) { + keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT)); + keys_out_ = static_cast(keys_out_owner.get()); + } else { + keys_out_ = keys_out; + } + HIPCUB_WRAPPER(NO_ROCM(at_zoom_detail)::hipcub::DeviceSelect::UniqueByKey, + keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::zoom::getCurrentZoomStream()); +} +#endif + +namespace impl { + +template +C10_LAUNCH_BOUNDS_1(1) +__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){ + // NOTE: out here not the final scan output, but an intermediate of the accumulation type. + using acc_t = typename std::iterator_traits::value_type; + *out = scan_op(static_cast(*a), static_cast(*b)); +} + +#if !CUB_SUPPORTS_FUTURE_VALUE() +template +struct chained_iterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = ValueT; + using pointer = ValueT*; + using reference = ValueT&; + + InputIteratorT iter; + ValueT *first; + difference_type offset = 0; + + __device__ ValueT operator[](difference_type i) { + i += offset; + if (i == 0) { + return *first; + } else { + return ValueT(iter[i - 1]); + } + } + __device__ chained_iterator operator+(difference_type i) { + return chained_iterator{iter, first, i}; + } + __device__ ValueT operator*() { + return (*this)[0]; + } +}; +#endif + +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +constexpr int max_cub_size = std::numeric_limits::max() / 2 + 1; // 2**30 +} + +// non synchronizing cub call +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +template +inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) { + //For ROCm, use hipCUB chained iterators + HIPCUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan, + input, + output, + scan_op, + num_items, + c10::zoom::getCurrentZoomStream()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) { + //For ROCm, use hipCUB chained iterators + HIPCUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan, + input, + output, + scan_op, + init_value, + num_items, + c10::zoom::getCurrentZoomStream()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + +} + +#if CUB_SUPPORTS_SCAN_BY_KEY() + +template +inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); + HIPCUB_WRAPPER(at_zoom_detail::hipcub::DeviceScan::InclusiveSumByKey, + keys, input, output, num_items, at_zoom_detail::hipcub::Equality(), c10::zoom::getCurrentZoomStream()); +} + +template +inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); + HIPCUB_WRAPPER(at_zoom_detail::hipcub::DeviceScan::InclusiveScanByKey, + keys, input, output, scan_op, num_items, at_zoom_detail::hipcub::Equality(), c10::zoom::getCurrentZoomStream()); +} + +#endif + +template +void unique(InputIteratorT input, OutputIteratorT output, + NumSelectedIteratorT num_selected_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + HIPCUB_WRAPPER(NO_ROCM(at_zoom_detail)::hipcub::DeviceSelect::Unique, + input, output, num_selected_out, num_items, c10::zoom::getCurrentZoomStream()); +} + +template +void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out, + LengthOutputIteratorT length_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub run_length_encode does not support more than INT_MAX elements"); + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceRunLengthEncode::Encode, + input, output, counts_out, length_out, num_items, + c10::zoom::getCurrentZoomStream()); +} + +template +void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub reduce does not support more than INT_MAX elements"); + HIPCUB_WRAPPER( + NO_ROCM(at_zoom_detail)::hipcub::DeviceReduce::Reduce, + input, output, num_items, op, init, + c10::zoom::getCurrentZoomStream()); + +} + +} // namespace at::zoom::hipcub diff --git a/aten/src/ATen/zoom/cub.h b/aten/src/ATen/zoom/cub.h new file mode 100644 index 00000000000000..c38b12526cfc6a --- /dev/null +++ b/aten/src/ATen/zoom/cub.h @@ -0,0 +1,88 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include +#include +#include + +// NOTE: These templates are intentionally not defined in this header, +// which aviods re-compiling them for each translation unit. If you get +// a link error, you need to add an explicit instantiation for your +// types in cub.cu + +namespace at::zoom::hipcub { + +inline int get_num_bits(uint64_t max_key) { + int num_bits = 1; + while (max_key > 1) { + max_key >>= 1; + num_bits++; + } + return num_bits; +} + +namespace detail { + +// radix_sort_pairs doesn't interact with value_t other than to copy +// the data, so we can save template instantiations by reinterpreting +// it as an opaque type. +template struct alignas(N) OpaqueType { char data[N]; }; + +template +void radix_sort_pairs_impl( + const key_t *keys_in, key_t *keys_out, + const OpaqueType *values_in, OpaqueType *values_out, + int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); + +} // namespace detail + +template +void radix_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) { + static_assert(std::is_trivially_copyable::value || + AT_ROCM_ENABLED(), // ROCm incorrectly fails this check for vector types + "radix_sort_pairs value type must be trivially copyable"); + // Make value type opaque, so all inputs of a certain size use the same template instantiation + using opaque_t = detail::OpaqueType; + static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0, + "This size of value_t is not instantiated. Please instantiate it in cub.cu" + " and modify this check."); + static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned"); + detail::radix_sort_pairs_impl( + keys_in, keys_out, + reinterpret_cast(values_in), + reinterpret_cast(values_out), + n, descending, begin_bit, end_bit); +} + +template +void radix_sort_keys( + const key_t *keys_in, key_t *keys_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); + +// NOTE: Intermediate sums will be truncated to input_t precision +template +void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); + +template +void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return inclusive_sum_truncating(input, output, n); +} + +// NOTE: Sums are done is common_type +template +void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n); + +template +void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return exclusive_sum_in_common_type(input, output, n); +} + +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); +inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { + return mask_exclusive_sum( + reinterpret_cast(mask), output_idx, n); +} + +} // namespace at::zoom::hipcub diff --git a/aten/src/ATen/zoom/cub_definitions.cuh b/aten/src/ATen/zoom/cub_definitions.cuh new file mode 100644 index 00000000000000..c199557279519d --- /dev/null +++ b/aten/src/ATen/zoom/cub_definitions.cuh @@ -0,0 +1,27 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#define CUB_VERSION 0 + +#define CUB_SUPPORTS_NV_BFLOAT16() false + +// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: +// https://github.com/NVIDIA/cub/pull/326 +// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake +// starting from CUDA 11.5 +#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true +#else +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false +#endif + + +#define CUB_SUPPORTS_UNIQUE_BY_KEY() false + + +#define CUB_SUPPORTS_SCAN_BY_KEY() 0 + + + +#define CUB_SUPPORTS_FUTURE_VALUE() false + diff --git a/aten/src/ATen/zoom/detail/DeviceThreadHandles.h b/aten/src/ATen/zoom/detail/DeviceThreadHandles.h new file mode 100644 index 00000000000000..1b7ba32607499c --- /dev/null +++ b/aten/src/ATen/zoom/detail/DeviceThreadHandles.h @@ -0,0 +1,151 @@ +// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states. +// These handles are tied to device, and these libraries requires/recommends not to +// share handles across host threads. +// +// These libraries recommend using one handle per host thread. We may not want to do +// this because threads are relatively light-weight, but creating and destroying +// handles is expensive (destroying the handle causes synchronizations). DataParallel, +// for example, creates new threads for each forward pass. +// +// This file implements a handle pool mechanism. The handle pool returns handles on +// demand as threads request them. If all existing handles in the pool are in use, +// it creates a new one. As threads terminate, they release handles back into the pool. +// In this way, the handle pool never creates more handles than the high-water mark of +// active threads, so it's efficient with DataParallel. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::zoom { namespace { + +template +struct DeviceThreadHandlePool : public std::enable_shared_from_this> { + + struct Handle { + Handle_t handle; + Handle(bool create = false) : handle(nullptr) + { + if(create) Create(&handle); + } + // std::vector.emplace() and push_back() may route through temporaries and call + // copy/move constructors along the way. If this is the case, we don't want + // the destructors of temporaries to call cudnnDestroy on the handle. + // We can achieve safety (for the narrow case of stashing within std::vectors) + // by making Handle moveable but not copyable, and transferring handle ownership + // to the latest constructed object. This is not a substitute for full-blown + // reference counting, but reference counting may be overkill here. + // Another alternative is to wrap the saved Handles in unique_ptrs, i.e., + // unordered_map>> created_handles; + Handle(const Handle& rhs) = delete; + // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom + Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); } + // operator= takes argument by value + Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; } + ~Handle() { + if(handle) Destroy(handle); + } + }; + + std::mutex mutex; + + // Handles are lazily created as different threads request them, + // but are never destroyed until the end of the process. + // The maximum number of handles this process will create for each device is equal + // to the high-water mark of the number of concurrently active threads that request + // handles for that device. + // When threads terminate, they release their handles back into the pool for reuse. + // Otherwise, new handles would be created every time new threads were spawned, + // resulting in poor performance for Python modules that repeatedly or frequently + // spawned new sets of threads (like DataParallel, which creates a new set of threads + // for each forward pass). + // + // To prevent potential deadlocks, we explicitly choose not to cap the number + // of handles that are created per device. + // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device, + // only 4 can make forward progress at any time. The other 4 will not release their + // handles until they exit, so the fifth cannot make progress until then. This is + // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an + // intermediate point (ie, before any of them have exited). We have no way to anticipate + // or enforce that user threads will not attempt such intermediate synchronization. + // The only way to ensure safety is to avoid imposing a cap on the number of handles. + std::unordered_map> created_handles; + std::unordered_map> available_handles; + + // PoolWindow lazily creates and caches the handles that a particular thread is using, + // so in the common case handle access doesn't incur either handle creation or a mutex lock. + class PoolWindow + { + public: + PoolWindow(std::shared_ptr parent): weak_parent(std::move(parent)) {} + ~PoolWindow(){ release(); } + + Handle_t reserve(int device) + { + // If this thread already has a handle for this device, return it + if(my_handles.find(device) != my_handles.end()) + return my_handles[device]; + + // otherwise, either grab a handle from the pool if one is available, + // or if not, create a new one. + auto parent = weak_parent.lock(); + TORCH_CHECK(parent, "Cannot create handle during program termination"); + std::lock_guard guard(parent->mutex); + + if(parent->available_handles[device].size() > 0) + { + my_handles[device] = parent->available_handles[device].back(); + parent->available_handles[device].pop_back(); + } + else + { + // In local testing, I do observe that emplace_back sometimes routes through temporaries + // that incur move-constructor and destructor calls. See comments in Handle above. + parent->created_handles[device].emplace_back(true /*create*/); + my_handles[device] = parent->created_handles[device].back().handle; + } + + return my_handles[device]; + } + + private: + // Stores the per-device handles currently owned by this thread + std::unordered_map my_handles; + + std::weak_ptr weak_parent; + + // Called by the destructor. Releases this thread's handles back into the pool. + void release() { + if(my_handles.size() > 0) { + auto parent = weak_parent.lock(); + if (!parent) { + // If this thread exits after atexit handlers have completed, the + // cuda context itself may be invalid, so we must leak the handles. + return; + } + + std::lock_guard guard(parent->mutex); + for(auto d_h : my_handles) + parent->available_handles[d_h.first].push_back(d_h.second); + } + } + }; + + // Warning: + // If you want to change this function, be aware that this function will be called + // by multiple threads and there is no mutex guarding the call of this function, so + // make sure your implementation is thread-safe. + PoolWindow *newPoolWindow() { + // The returned pointer will be owned by a thread local variable + // so that different threads does not share the same PoolWindow. + return new PoolWindow(this->shared_from_this()); + } +}; + +}} // namespace at::zoom::detail:: diff --git a/aten/src/ATen/zoom/detail/IndexUtils.cu b/aten/src/ATen/zoom/detail/IndexUtils.cu new file mode 100644 index 00000000000000..7e643871c6031d --- /dev/null +++ b/aten/src/ATen/zoom/detail/IndexUtils.cu @@ -0,0 +1,75 @@ +#include +#include + +namespace at { +namespace zoom { +namespace detail { + +struct SizeAndStride { + int64_t size; + int64_t stride; +}; + +/* + A comparator that will sort SizeAndStride structs by stride, + in ascending order. + */ + int compareSizeAndStride(const void* a, const void* b) { + const SizeAndStride* aS = (const SizeAndStride*) a; + const SizeAndStride* bS = (const SizeAndStride*) b; + + if (aS->stride < bS->stride) return -1; + if (aS->stride == bS->stride) return 0; + return 1; +} + +/* +Returns false if there is no possibility that the tensor +has "overlapping" indices and true otherwise. +"Overlapping" indices are two+ valid indices that specify +the same offset within the tensor. +The function does this by checking for a sufficient but not +necessary condition of no overlap. In particular, that +that there exists an ordering of the tensor's dimensions +that is nicely "nested," with each dimension contained +within the next one. +*/ +bool maybeOverlappingIndices(const TensorBase& t) { + /* Extract size/stride arrays; only consider size >1 dims. */ + std::vector info(t.dim()); + int dims = t.dim(); + int nonSize1Dims = 0; + for (int i = 0; i < dims; ++i) { + int64_t size = t.size(i); + if (size > 1) { + info[nonSize1Dims].size = size; + info[nonSize1Dims].stride = t.stride(i); + + if (info[nonSize1Dims].stride < 1) { + return true; + } + + ++nonSize1Dims; + } + } + + // Short-circuits if tensor is a single element. + if (nonSize1Dims == 0) { + return false; + } + + /* Ascending order (innermost dimension in sorted view is at [0]) */ + qsort(info.data(), nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride); + + for (int i = 0; i < (nonSize1Dims - 1); ++i) { + if (((info[i].size - 1) * info[i].stride) >= info[i + 1].stride) { + return true; + } + } + + return false; +} + +} // detail +} // zoom +} // at diff --git a/aten/src/ATen/zoom/detail/IndexUtils.cuh b/aten/src/ATen/zoom/detail/IndexUtils.cuh new file mode 100644 index 00000000000000..a3739645b6b427 --- /dev/null +++ b/aten/src/ATen/zoom/detail/IndexUtils.cuh @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +namespace at::zoom::detail { + +bool maybeOverlappingIndices(const at::TensorBase &t); +using at::native::canUse32BitIndexMath; + +template +TensorInfo +getTensorInfo(const at::TensorBase &t) { + IndexType sz[MAX_TENSORINFO_DIMS]; + IndexType st[MAX_TENSORINFO_DIMS]; + + int dims = t.dim(); + for (int i = 0; i < dims; ++i) { + sz[i] = t.size(i); + st[i] = t.stride(i); + } + + scalar* data_ptr = nullptr; + + if constexpr (std::is_const::value) { + data_ptr = t.const_data_ptr(); + } else { + data_ptr = t.mutable_data_ptr(); + } + + return TensorInfo( + data_ptr, dims, sz, st); +} + +} // namespace at::zoom::detail diff --git a/aten/src/ATen/zoom/detail/KernelUtils.h b/aten/src/ATen/zoom/detail/KernelUtils.h new file mode 100644 index 00000000000000..ad0e5cbe9cc2f4 --- /dev/null +++ b/aten/src/ATen/zoom/detail/KernelUtils.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace at::zoom::detail { + +// CUDA: grid stride looping +// +// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. +// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final +// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be +// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no +// further iterations and the overflowed value in i=_i_n_d_e_x is not used. +#define HIP_KERNEL_LOOP_TYPE(i, n, index_type) \ + int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \ + for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x) + +#define HIP_KERNEL_LOOP(i, n) HIP_KERNEL_LOOP_TYPE(i, n, int) + + +// Use 1024 threads per block, which requires cuda sm_2x or above +constexpr int HIP_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=HIP_NUM_THREADS) { + TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); + constexpr int64_t max_int = std::numeric_limits::max(); + + // Round up division for positive number that cannot cause integer overflow + auto block_num = (N - 1) / max_threads_per_block + 1; + TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on HIP device"); + + return static_cast(block_num); +} + +} // namespace at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/PhiloxHIPStateRaw.hpp b/aten/src/ATen/zoom/detail/PhiloxHIPStateRaw.hpp new file mode 100644 index 00000000000000..252cc3c9013537 --- /dev/null +++ b/aten/src/ATen/zoom/detail/PhiloxHIPStateRaw.hpp @@ -0,0 +1,43 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +// Stores RNG state values. Passed as a kernel argument. +// See Note [CUDA Graph-safe RNG states]. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +namespace at { + +struct PhiloxHIPState { + PhiloxHIPState() = default; + // Called if graph capture is not underway + PhiloxHIPState(uint64_t seed, + uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxHIPState(int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::zoom::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_; + Payload offset_; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/TensorInfo.cuh b/aten/src/ATen/zoom/detail/TensorInfo.cuh new file mode 100644 index 00000000000000..54debad5979827 --- /dev/null +++ b/aten/src/ATen/zoom/detail/TensorInfo.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include + +namespace at::zoom::detail { + +#define MAX_TENSORINFO_DIMS 25 + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + TensorInfo(); + TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]); + + // Set the size of the given dimension to 1, as if it were a + // reduction dim (allows you to calculate offsets of the reduction + // slice) + void reduceDim(int dim); + + // See note on [collapse dims]. + int collapseDims(const int excludeDim = -1); + + // Contiguous tensors of more than one dimension are collapsed down + // to one tensor + __host__ __device__ inline bool isContiguous() const { + return (dims == 1 && strides[0] == 1); + } + + T* data; + IndexType sizes[MAX_TENSORINFO_DIMS]; + IndexType strides[MAX_TENSORINFO_DIMS]; + int dims; +}; + +template +TensorInfo::TensorInfo() { + data = nullptr; + dims = 0; +} + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "Zoom tensors cannot have more than 25 dimensions"); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } +} + +template +void +TensorInfo::reduceDim(int dim) { + TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1"); + sizes[dim] = 1; +} + +template +int +TensorInfo::collapseDims(const int excludeDim) { + auto result = at::collapse_dims(sizes, strides, dims, excludeDim); + dims = std::get<1>(result); + return std::get<0>(result); +} + +// Translate a linear index for the apply to a T* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + // Uses static dims + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +// Uses dynamic (runtime) instead of static (compiletime) dims +template +struct IndexToOffset { + static inline __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +} // namespace at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/UnpackRaw.hpp b/aten/src/ATen/zoom/detail/UnpackRaw.hpp new file mode 100644 index 00000000000000..5a5172e73f3cb2 --- /dev/null +++ b/aten/src/ATen/zoom/detail/UnpackRaw.hpp @@ -0,0 +1,28 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +namespace at::zoom::philox { + +// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether +// that instance was created with graph capture underway or not. +// See Note [CUDA Graph-safe RNG states]. +// +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +__host__ __device__ __forceinline__ std::tuple +unpack(at::PhiloxHIPState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace at::zoom::philox \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/ZoomHooks.cpp b/aten/src/ATen/zoom/detail/ZoomHooks.cpp new file mode 100644 index 00000000000000..828ef6993c45b7 --- /dev/null +++ b/aten/src/ATen/zoom/detail/ZoomHooks.cpp @@ -0,0 +1,273 @@ +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include + +// #if AT_CUDNN_ENABLED() +// #include +// #endif + +// #if AT_MAGMA_ENABLED() +// #include +// #endif + +// #if defined(USE_ROCM) +// #include +// #endif + +#include +#include +#include +#include +#include +#include + +namespace c10::zoom::_internal { +void setHasPrimaryContext(bool (*func)(DeviceIndex)); +} + +// defined in Aten/zoom/HIPblasHandlePool.cpp +namespace at::zoom { + bool getHIPBlasAtomicsEnabled(); +} + +namespace at::zoom::detail { + +const at::zoom::HIPRTC& hiprtc(); +DeviceIndex current_device(); + +// static void (*magma_init_fn)() = nullptr; + +// void set_magma_init_fn(void (*fn)()) { +// magma_init_fn = fn; +// } + +namespace { +bool _hasPrimaryContext(DeviceIndex device_index) { + TORCH_CHECK(device_index >= 0 && device_index < c10::zoom::device_count(), + "hasPrimaryContext expects a valid device index, but got device_index=", device_index); + unsigned int ctx_flags; + // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird + // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. + int ctx_is_active = 0; +// AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active)); + hipDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active); + return ctx_is_active == 1; +} + +// Register hasPrimaryContext back to c10::zoom +struct _Initializer { + _Initializer() { + c10::zoom::_internal::setHasPrimaryContext(_hasPrimaryContext); + } + ~_Initializer() { + c10::zoom::_internal::setHasPrimaryContext(nullptr); + } +} initializer; +} // anonymous namespace + +// Sets the CUDA_MODULE_LOADING environment variable +// if it's not set by the user. +void maybe_set_zoom_module_loading(const std::string &def_value) { + auto value = std::getenv("ZOOM_MODULE_LOADING"); + if (!value) { +#ifdef _WIN32 + auto env_var = "ZOOM_MODULE_LOADING=" + def_value; + _putenv(env_var.c_str()); +#else + setenv("ZOOM_MODULE_LOADING", def_value.c_str(), 1); +#endif + } +} + +// NB: deleter is dynamic, because we need it to live in a separate +// compilation unit (alt is to have another method in hooks, but +// let's not if we don't need to!) +void ZoomHooks::initZoom() const { + C10_LOG_API_USAGE_ONCE("aten.init.zoom"); + // Force the update to enable unit testing. This code get executed before unit tests + // have a chance to enable vitals. + at::vitals::VitalsAPI.setVital("ZOOM", "used", "true", /* force = */ true); + + maybe_set_zoom_module_loading("LAZY"); + const auto num_devices = c10::zoom::device_count_ensure_non_zero(); + c10::zoom::ZoomCachingAllocator::init(num_devices); + at::zoom::detail::init_p2p_access_cache(num_devices); +} + +void ZoomHooks::initPrivateUse1() const { + initZoom(); +} + +const Generator& ZoomHooks::getDefaultZoomGenerator(DeviceIndex device_index) const { + return at::zoom::detail::getDefaultZoomGenerator(device_index); +} + +Device ZoomHooks::getDeviceFromPtr(void* data) const { + return at::zoom::getDeviceFromPtr(data); +} + +bool ZoomHooks::isPinnedPtr(const void* data) const { + // First check if driver is broken/missing, in which case PyTorch CPU + // functionalities should still work, we should report `false` here. + if (!at::zoom::is_available()) { + return false; + } + // cudaPointerGetAttributes grabs context on the current device, so we set + // device to one that already has context, if exists. + at::OptionalDeviceGuard device_guard; + auto primary_ctx_device_index = c10::zoom::getDeviceIndexWithPrimaryContext(); + if (primary_ctx_device_index.has_value()) { + device_guard.reset_device(at::Device(at::DeviceType::PrivateUse1, *primary_ctx_device_index)); + } + hipPointerAttribute_t attr; + // We do not believe that CUDA needs mutable access to the data + // here. + hipError_t err = hipPointerGetAttributes(&attr, data); + // HIP throws hipErrorUnknown here + if (err != hipSuccess) { + (void)hipGetLastError(); // clear HIP error + return false; + } + return attr.type == hipMemoryTypeHost; +} + +bool ZoomHooks::hasROCM() const { + return at::zoom::is_available(); +} + +// rocBLAS is deterministic if atomic operations are disabled +// for details on when rocBLAS is guaranteed to be bitwise deterministic see below: +// https://github.com/ROCm/rocBLAS/issues/1459#issuecomment-2272082035 +bool ZoomHooks::checkHIPBlasDeterministic() const { + return !at::zoom::getHIPBlasAtomicsEnabled(); +} + +// #if defined(USE_DIRECT_NVRTC) || defined(USE_DIRECT_HIPRTC) + static std::pair, at::zoom::HIPRTC*> load_hiprtc() { + return std::make_pair(nullptr, at::zoom::load_hiprtc()); + } +// #else +// static std::pair, at::zoom::HIPRTC*> load_hiprtc() { +// #if defined(_WIN32) +// std::string libcaffe2_hiprtc = "caffe2_hiprtc.dll"; +// #elif defined(__APPLE__) +// std::string libcaffe2_hiprtc = "libcaffe2_hiprtc.dylib"; +// #else +// std::string libcaffe2_hiprtc = "libcaffe2_hiprtc.so"; +// #endif +// std::unique_ptr libhiprtc_stub( +// new at::DynamicLibrary(libcaffe2_hiprtc.c_str())); +// auto fn = (at::zoom::HIPRTC * (*)()) libhiprtc_stub->sym("load_hiprtc"); +// return std::make_pair(std::move(libhiprtc_stub), fn()); +// } +// #endif + +const at::zoom::HIPRTC& hiprtc() { + // must hold onto DynamicLibrary otherwise it will unload + static auto handle = load_hiprtc(); + return *handle.second; +} + +const at::zoom::HIPRTC& ZoomHooks::hiprtc() const { + return at::zoom::detail::hiprtc(); +} + +DeviceIndex current_device() { + c10::DeviceIndex device = 0; + hipError_t err = c10::zoom::GetDevice(&device); + if (err == hipSuccess) { + return device; + } + return -1; +} + +DeviceIndex ZoomHooks::current_device() const { + return at::zoom::detail::current_device(); +} + +bool ZoomHooks::hasPrimaryContext(DeviceIndex device_index) const { + return _hasPrimaryContext(device_index); +} + +Allocator* ZoomHooks::getPinnedMemoryAllocator() const { + return at::zoom::getPinnedMemoryAllocator(); +} + +Allocator* ZoomHooks::getZoomDeviceAllocator() const { + return at::zoom::getZoomDeviceAllocator(); +} + +std::string ZoomHooks::showConfig() const { + std::ostringstream oss; + + int runtimeVersion; + hipRuntimeGetVersion(&runtimeVersion); + + auto printHIPStyleVersion = [&](int v) { + + // HIP_VERSION value format was changed after ROCm v4.2 to include the patch number + if(v < 500) { + // If major=xx, minor=yy then format -> xxyy + oss << (v / 100) << "." << (v % 10); + } + else { + // If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz + oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000); + } + + }; + + + oss << " - HIP Runtime "; + + printHIPStyleVersion(runtimeVersion); + oss << "\n"; + + return oss.str(); +} + +int ZoomHooks::getNumGPUs() const { + auto cnt = c10::zoom::device_count(); + std::cout << "numgpu: " << cnt << std::endl; + return cnt; +} + +void ZoomHooks::deviceSynchronize(DeviceIndex device_index) const { + at::DeviceGuard device_guard(at::Device(at::DeviceType::PrivateUse1, device_index)); + c10::zoom::device_synchronize(); +} + +// // Sigh, the registry doesn't support namespaces :( +// using at::zoomHooksRegistry; +// using at::RegistererCUDAHooksRegistry; + +// REGISTER_CUDA_HOOKS(ZoomHooks); + +using at::PrivateUse1HooksRegistry; +using at::RegistererPrivateUse1HooksRegistry; +REGISTER_PRIVATEUSE1_HOOKS(ZoomHooks); + +static ZoomHooks* zoom_hooks_impl = nullptr; +void register_zoom_hooks() { + if(zoom_hooks_impl == nullptr){ + zoom_hooks_impl = new ZoomHooks({}); + RegisterPrivateUse1HooksInterface(zoom_hooks_impl); + } +} + + +} // namespace at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/ZoomHooks.h b/aten/src/ATen/zoom/detail/ZoomHooks.h new file mode 100644 index 00000000000000..51cabb8bde377f --- /dev/null +++ b/aten/src/ATen/zoom/detail/ZoomHooks.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include +#include + +// TODO: No need to have this whole header, we can just put it all in +// the cpp file + +namespace at::zoom::detail { + + +// The real implementation of ZoomHooksInterface +struct ZoomHooks : public ZoomHooksInterface { + ZoomHooks(ZoomHooksArgs) {} + void initZoom() const override; + void initPrivateUse1() const override; + Device getDeviceFromPtr(void* data) const override; + bool isPinnedPtr(const void* data) const override; + const Generator& getDefaultZoomGenerator(DeviceIndex device_index = -1) const override; + bool hasROCM() const override; + bool checkHIPBlasDeterministic() const override; + const at::zoom::HIPRTC& hiprtc() const override; + DeviceIndex current_device() const override; + bool hasPrimaryContext(DeviceIndex device_index) const override; + Allocator* getZoomDeviceAllocator() const override; + Allocator* getPinnedMemoryAllocator() const override; + std::string showConfig() const override; + int getNumGPUs() const override; + void deviceSynchronize(DeviceIndex device_index) const override; +}; + +void register_zoom_hooks(); + +} // at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.cpp b/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.cpp new file mode 100644 index 00000000000000..d8f0c36d000a7c --- /dev/null +++ b/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.cpp @@ -0,0 +1,13 @@ +#include +#include + +namespace at { namespace zoom { + +HIPRTC* load_hiprtc() { + auto self = new HIPRTC(); +#define CREATE_ASSIGN(name) self->name = name; + AT_FORALL_HIPRTC(CREATE_ASSIGN) + return self; +} + +}} // at::zoom diff --git a/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.h b/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.h new file mode 100644 index 00000000000000..bc3de47142f1e7 --- /dev/null +++ b/aten/src/ATen/zoom/hiprtc_stub/ATenHIPRTC.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include + +namespace at { namespace zoom { + + +// NOTE [ USE OF NVRTC AND DRIVER API ] +// +// ATen does not directly link to either libnvrtc or libcuda because they +// require libcuda to be installed, yet we want our GPU build to work on CPU +// machines as long as CUDA is not initialized. +// +// Normal CUDA code in torch uses the cuda runtime libraries which can be +// installed even if the driver is not installed, but sometimes we specifically +// need to use the driver API (e.g., to load JIT compiled code). +// To accomplish this, we lazily link libcaffe2_nvrtc which provides a struct +// at::zoom::HIPRTC that contains function pointers to all of the apis we need. +// +// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY. +// INSTEAD USE, e.g. +// detail::getZoomHooks().nvrtc().cuLoadModule(...) +// or +// globalContext().getNVRTC().cuLoadModule(...) +// +// If a function is missing add it to the list in ATen/cuda/nvrtc_stub/ATenNVRTC.h +// and edit ATen/cuda/detail/LazyNVRTC.cpp accordingly (e.g., via one of the stub +// macros). + + +// NOTE [ ATen NVRTC Stub and HIP ] +// +// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both +// NVRTC and driver APIs. While the former is not yet supported for HIP, the +// later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext() +// used by tensor.pin_memory()). +// +// The macro below strips out certain unsupported operations on HIP from the full +// list above. +// +// HIP doesn't have +// cuGetErrorString (maps to non-functional hipGetErrorString___) +// +// HIP from ROCm 3.5 on renamed hipOccupancyMaxActiveBlocksPerMultiprocessor +// to hipModuleOccupancyMaxActiveBlocksPerMultiprocessor. +// #if TORCH_HIP_VERSION < 305 +// #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR hipOccupancyMaxActiveBlocksPerMultiprocessor +// #else +// #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR hipModuleOccupancyMaxActiveBlocksPerMultiprocessor +// #endif + +#define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + +#define AT_FORALL_HIPRTC(_) \ + _(hiprtcVersion) \ + _(hiprtcCreateProgram) \ + _(hiprtcAddNameExpression) \ + _(hiprtcDestroyProgram) \ + _(hiprtcGetCodeSize) \ + _(hiprtcGetCode) \ + _(hipModuleLoadData) \ + _(hipModuleGetFunction) \ + _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ + _(hiprtcGetErrorString) \ + _(hiprtcGetProgramLogSize) \ + _(hiprtcGetProgramLog) \ + _(hipModuleLaunchKernel) \ + _(hiprtcCompileProgram) \ + _(hipCtxGetCurrent) \ + _(hiprtcGetLoweredName) \ + _(hipModuleUnload) \ + _(hipDevicePrimaryCtxGetState) + + + +extern "C" typedef struct HIPRTC { +#define CREATE_MEMBER(name) decltype(&name) name; + AT_FORALL_HIPRTC(CREATE_MEMBER) +#undef CREATE_MEMBER +} HIPRTC; + +extern "C" TORCH_ZOOM_API HIPRTC* load_hiprtc(); +}} // at::zoom diff --git a/aten/src/ATen/zoom/jit/HIPJitLoops.cuh b/aten/src/ATen/zoom/jit/HIPJitLoops.cuh new file mode 100644 index 00000000000000..01154c0b568173 --- /dev/null +++ b/aten/src/ATen/zoom/jit/HIPJitLoops.cuh @@ -0,0 +1,292 @@ +#pragma once +#include + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Loops.cuh" + +#include +#include +#include + +#include +#include +#include +#include + +namespace at { +namespace native { + +template +constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence seq) { + constexpr auto size = seq.size(); + (void)t; // warning : unused parameter when tuple is empty. + return std::array{static_cast(&std::get(t))...}; +} + +// Helper function convert tuple to std::array +// for passing the arguments to CUDA Kernel +// NOTE: We capture tuple by reference, +// so the pointers in returned array are only valid +// till tuple is alive. +template +constexpr auto tuple_to_array(std::tuple& extra_args) { + constexpr auto tuple_size = sizeof...(Args); + return tuple_to_array_helper(extra_args, std::make_index_sequence{}); +} + +struct JittedVecKernelCache { + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + at::zoom::jit::hiprtcFunction vec1; + at::zoom::jit::hiprtcFunction vec2; + at::zoom::jit::hiprtcFunction vec4; +}; + +struct JittedKernelVariantCache { + JittedVecKernelCache vec; + at::zoom::jit::hiprtcFunction noncontiguous; + at::zoom::jit::hiprtcFunction dynamic_contiguous; + at::zoom::jit::hiprtcFunction dynamic_noncontiguous; +}; + +inline c10::SmallBuffer pack_kernel_args( + std::initializer_list args, + c10::ArrayRef extra_args) { + c10::SmallBuffer ret(args.size() + extra_args.size()); + std::copy(args.begin(), args.end(), ret.data()); + std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); + return ret; +} + +template +void launch_jitted_unrolled_kernel( + std::mutex &jiterator_mutex, + at::zoom::jit::hiprtcFunction &fn_cache, + const at::zoom::jit::KernelDescriptor &desc, + int64_t N, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s, + bool contiguous, + at::zoom::jit::BinaryFuncVariant scalar_pos, + void* scalar_val, + c10::ArrayRef extra_args) { + + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + //casting result to int is always safe, intermediate is int64 and won't overflow + const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); + + if (!fn_cache.function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_cache.function) { + constexpr bool dynamic_casting = !std::is_same() || + !std::is_same(); + auto code = at::zoom::jit::generate_code( + desc, contiguous, dynamic_casting, scalar_pos); + fn_cache = at::zoom::jit::jit_pwise_function(code, desc.name); + } + } + + auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::zoom::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, + {num_threads(), 1u, 1u}); +} + +template +void launch_jitted_vectorized_kernel( + std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, + const at::zoom::jit::KernelDescriptor &desc, int64_t N, array_t data, + at::zoom::jit::BinaryFuncVariant scalar_pos, + void *scalar_val, c10::ArrayRef extra_args) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + // N is still int64_t for the computation, but it's always safe to cast result to int + const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); + const int vec_size = at::zoom::jit::can_vectorize_up_to( + desc, c10::ArrayRef(data.data, data.size())); + + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + // fn_ptr is set to the appropriate function based on the vec size and GPU used + at::zoom::jit::hiprtcFunction* fn_ptr; + if (vec_size == 4) { + fn_ptr = &fn_cache.vec4; + } else if (vec_size == 2) { + fn_ptr = &fn_cache.vec2; + } else if (vec_size ==1) { + fn_ptr = &fn_cache.vec1; + } else { + TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); + } + + bool vectorized = vec_size > 1; + + if (!fn_ptr->function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_ptr->function) { // cache miss! + + // Generates program + auto code = at::zoom::jit::generate_code( + desc, /*contiguous=*/true, /*dynamic_casting=*/false, + scalar_pos, vectorized, vec_size); + std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; + + // Acquires the program + *fn_ptr = at::zoom::jit::jit_pwise_function(code, kernel_name); + } + } + + if (vectorized) { + auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); + at::zoom::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + } else { +// NVCC complains about unused variables l and s. +// It should be false positive in most cases, so we suppress the warnings. +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 + auto ic = TrivialOffsetCalculator(); + auto oc = TrivialOffsetCalculator<1>(); + auto l = memory::LoadWithoutCast(); + auto s = memory::StoreWithoutCast(); + + auto args = pack_kernel_args( + {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::zoom::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); +#pragma nv_diagnostic pop + } +} + +template +void jitted_gpu_kernel_generic( + std::mutex &jiterator_mutex, + JittedKernelVariantCache &cache, + const at::zoom::jit::KernelDescriptor &desc, + at::zoom::jit::BinaryFuncVariant scalar_pos, + c10::ArrayRef extra_args, + TensorIteratorBase& iter, + const bool dynamic_casting, + void *scalar_val) { + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + constexpr int ntensors = arity + 1; + at::detail::Array data; + for (auto i : c10::irange(ntensors)) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + bool contiguous = iter.is_contiguous(); + + // Decides which of 4 kernel types to launch + // Variations are: + // - Case 1: no dynamic casting and contiguous + // - Case 2: no dynamic casting and noncontiguous + // - Case 3: dynamic casting and contiguous + // - Case 4: dynamic casting and noncontiguous + // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl + + if (!dynamic_casting) { + if (contiguous) { + // Case 1: no dynamic casting and contiguous + launch_jitted_vectorized_kernel( + jiterator_mutex, cache.vec, desc, + numel, data, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 2: no dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.noncontiguous, desc, numel, data, + input_offset_calculator, output_offset_calculator, loader, + storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Cases 3 and 4 are handled below + // Both require construction of a storer (this asserts 1 output) and one or more loaders + + // Creates store cast to output (the zeroth tensor in TensorIterator) + auto storer = memory::StoreWithCast<1>(iter); + + // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors) + auto loader = memory::LoadWithCast(iter); + + if (contiguous) { + // Case 3: dynamic casting and contiguous + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 4: dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); +} + +// NOTE: static to reduce chances of name collision. +template < + char const* name, + typename result_type, + typename f_inputs_type, + int arity, + at::zoom::jit::BinaryFuncVariant scalar_pos = + at::zoom::jit::BinaryFuncVariant::NoScalar, + typename... ExtraArgs> +static void jitted_gpu_kernel_impl( + TensorIteratorBase& iter, + const std::string &f, + const bool dynamic_casting, + at::opmath_type scalar_val, + std::tuple extra_args) { + + // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // the same compute capability + static std::mutex jiterator_mutex; + static std::vector device_caches(c10::zoom::device_count()); + + constexpr int nInputs = arity; + constexpr int nOutputs = 1; // TODO: Support more than 1 output + static const auto desc = at::zoom::jit::make_kernel_descriptor< + result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); + + auto &cache = device_caches[iter.device().index()]; + auto extra_args_array = tuple_to_array(extra_args); + return jitted_gpu_kernel_generic( + jiterator_mutex, + cache, + desc, + scalar_pos, + extra_args_array, + iter, + dynamic_casting, + &scalar_val + ); +} + +}} // at::native diff --git a/aten/src/ATen/zoom/jit/HIPLoops.cuh b/aten/src/ATen/zoom/jit/HIPLoops.cuh new file mode 100644 index 00000000000000..85cdd5211e7006 --- /dev/null +++ b/aten/src/ATen/zoom/jit/HIPLoops.cuh @@ -0,0 +1,333 @@ +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// #ifdef __NVCC__ +// #define ASSERT_HOST_DEVICE_LAMBDA(type) \ +// static_assert( \ +// __nv_is_extended_host_device_lambda_closure_type(type), \ +// #type " must be a __host__ __device__ lambda") +// #else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +// #endif + +namespace at { +namespace native { + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + int remaining = N - block_work_size() * blockIdx.x; + + if (remaining < block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - block_work_size() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = c10::zoom::getCurrentZoomStream(); + int vec_size = memory::can_vectorize_up_to(data); + + switch (vec_size) { + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = c10::zoom::getCurrentZoomStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = c10::zoom::getCurrentZoomStream(); + elementwise_kernel<<>>(N, f); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data.data[1], &offsets.data[1], 1); + }); +} + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + at::detail::Array dtypes; + auto inner_strides = iter.get_inner_strides(); + at::detail::Array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx); + c10::cast_and_store(dtypes[0], out, result); + }); + } else { + at::detail::Array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); + } +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/IntegerDivider.cuh b/aten/src/ATen/zoom/jit/IntegerDivider.cuh new file mode 100644 index 00000000000000..2e0d34df31e02e --- /dev/null +++ b/aten/src/ATen/zoom/jit/IntegerDivider.cuh @@ -0,0 +1,126 @@ +#pragma once + +#include +#include + +// insurance for now, torch only defines this macro if you're compiling with cuda or +// following traditional ROCm build +#define C10_HOST_DEVICE __host__ __device__ + +namespace at::zoom::detail { + +// A utility class to implement integer division by multiplication, given a fixed +// divisor. +// +// WARNING: The fast divider algorithm is only implemented for unsigned int; +// otherwise we default to plain integer division. For unsigned int, +// we further assume that the dividend is at most INT32_MAX. Thus, +// IntDivider must NOT be used for general integer division. +// +// This reduced range is enough for our purpose, and it allows us to +// slightly simplify the computation. +// +// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1< 0), we can find a "magic number" m (2^N +// <= m < 2^(N+1)) and shift s such that: +// +// \floor(n / d) = \floor((m * n) / 2^(N+s)). +// +// Given such m and s, the integer division can be then implemented as: +// +// let m' = m - 2^N // 0 <= m' < 2^N +// +// fast_integer_division(n): +// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned +// // integer. Then take the higher N bits. +// t = (m' * n) >> N +// +// // Here we use the fact that n is less than 2^(N-1): otherwise the value +// // of (t + n) may not fit in an N-bit integer. +// return (t + n) >> s +// +// Finding such a magic number is surprisingly easy: +// +// s = \ceil(\log_2 d) +// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic. +// +// See also: +// - Division by Invariant Integers Using Multiplication, +// Torbjörn Granlund and Peter L. Montgomery, 1994. +// +// - http://www.hackersdelight.org/magic.htm +// +// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + +// Result of div/mod operation stored together. +template +struct DivMod { + Value div, mod; + + C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { } +}; + +// Base case: we only have an implementation for uint32_t for now. For +// everything else, we use plain division. +template +struct IntDivider { + IntDivider() = default; + IntDivider(Value d) : divisor(d) { } + + C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; } + C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; } + C10_HOST_DEVICE inline DivMod divmod(Value n) const { + return DivMod(n / divisor, n % divisor); + } + + Value divisor; +}; + +// Implement fast integer division. +template <> +struct IntDivider { + static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int."); + + IntDivider() = default; + + IntDivider(unsigned int d) : divisor(d) { + assert(divisor >= 1 && divisor <= INT32_MAX); + + // TODO: gcc/clang has __builtin_clz() but it's not portable. + for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + m1 = magic; + assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits. + } + + C10_HOST_DEVICE inline unsigned int div(unsigned int n) const { +#if defined(__HIP_DEVICE_COMPILE__) + // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and + // 'm1'. + unsigned int t = __umulhi(n, m1); + return (t + n) >> shift; +#else + // Using uint64_t so that the addition does not overflow. + uint64_t t = ((uint64_t) n * m1) >> 32; + return (t + n) >> shift; +#endif + } + + C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const { + return n - div(n) * divisor; + } + + C10_HOST_DEVICE inline DivMod divmod(unsigned int n) const { + unsigned int q = div(n); + return DivMod(q, n - q * divisor); + } + + unsigned int divisor; // d above. + unsigned int m1; // Magic number: m' above. + unsigned int shift; // Shift amounts. +}; + +} // namespace at::zoom::detail \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/JitLoops.cuh b/aten/src/ATen/zoom/jit/JitLoops.cuh new file mode 100644 index 00000000000000..8cd5ac713856cb --- /dev/null +++ b/aten/src/ATen/zoom/jit/JitLoops.cuh @@ -0,0 +1,182 @@ +#pragma once + +#include + + +#include +#include +#include + +#include + +#include + +namespace at { +namespace native { + +/* Note [Jiterator] +The "jiterator" simply just-in-time compiles the same kernels that +Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time, +build size, and initial CUDA context size. + +By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels. +This behavior is controlled with two environment variables: + - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use + - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels + +The jiterator currently has some limitations, however. It cannot: + - handle math on complex datatypes + - handle kernels with scalar parameters + +These improvements will likely come soon. + +For examples of how to use the jiterator see the i1 and gcd kernel +implementations, which pass jittable strings implementing their +operations instead of the typical CUDA functors. + +To pass a runtime argument (similar to lambda captures in non-JIT kernels), +we need to pass to additional arguments to `jitted_gpu_kernel` by value. +Currently only primitive C++ types used for computation are valid. +The order of these extra arguments should be same as the order they appear +in kernel's function signature. (look at polygamma for example) + +NOTE: One big restriction being that these arguments should be after the +arguments provided by TensorIterator. Eg. While capturing `n`, where +`scalar_t x` and `scalar_t y` are provided by TensorIterator, +* foo(scalar_t x, scalar_t y, int n) works! +* foo(int n, scalar_t x, scalar_y) doesn't work +* foo(scalar_t x, int n, scalar_y) doesn't work + +*/ + +// Entrypoint for jitted GPU kernels. +// Only handles elementwise unary and binary kernels with a +// common dtype and a single output. +// NOTE: this assumes the op's iterator has a common_dtype. +// NOTE: We use std::tuple instead of parameter pack +// for `extra_args` due to following +// bug on older versions of clang +// https://bugs.llvm.org/show_bug.cgi?id=23029 +template < + char const* name, + typename return_type, + typename f_inputs_type, + int arity, + typename... Args> +void jitted_gpu_kernel( + TensorIteratorBase& iter, + const std::string& f, + at::zoom::jit::BinaryFuncVariant scalar_pos = + at::zoom::jit::BinaryFuncVariant::NoScalar, + at::opmath_type scalar_val = 0, + std::tuple extra_args = std::make_tuple()) { + // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel + // Maybe it could be refactored? + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_privateuseone(), + "argument ", arg, ": expected a Zoom device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + jitted_gpu_kernel( + sub_iter, f, scalar_pos, scalar_val, extra_args); + } + + return; + } + + // Computes if dynamic casting is needed + // Dynamic casting is needed if an input's dtype differs from the common dtype + // or if the result dtype differs from the output's dtype + // Note: this is intentionally divergent from calling needs_dynamic_casting, + // which is more general and inspects a lambda to determine if dynamic + // casting is needed. + bool needs_dynamic_casting = false; + + // Checks output + const ScalarType return_scalar_type = c10::CppTypeToScalarType::value; + const auto dtype0 = iter.dtype(0); + if (dtype0 != return_scalar_type) { + needs_dynamic_casting = true; + } + + // Checks input(s) + const ScalarType inputs_scalar_type = c10::CppTypeToScalarType::value; + for (auto i = decltype(arity){1}; i < (arity + 1); ++i) { + const auto dtypei = iter.dtype(i); + if (dtypei != inputs_scalar_type) { + needs_dynamic_casting = true; + break; + } + } + if (scalar_pos == at::zoom::jit::BinaryFuncVariant::NoScalar) { + // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used + // for computation in the generated code and hence we pass a dummy + // value of `0`. + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::zoom::jit::BinaryFuncVariant::NoScalar>( + iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args); + } else if (scalar_pos == at::zoom::jit::BinaryFuncVariant::RhsScalar) { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::zoom::jit::BinaryFuncVariant::RhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + + } else { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::zoom::jit::BinaryFuncVariant::LhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + } +} + +// TODO: support runtime state capture similar to `jitted_gpu_kernel`. +template +void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type) + using opmath_t = at::opmath_type; + if (iter.is_cpu_scalar(1)) { + auto scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + jitted_gpu_kernel(iter, f, at::zoom::jit::BinaryFuncVariant::LhsScalar, scalar_val); + } else if (iter.is_cpu_scalar(2)) { + auto scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + jitted_gpu_kernel(iter, f, at::zoom::jit::BinaryFuncVariant::RhsScalar, scalar_val); + } else { + jitted_gpu_kernel(iter, f); + } +} + +}} // at::native diff --git a/aten/src/ATen/zoom/jit/Loops.cuh b/aten/src/ATen/zoom/jit/Loops.cuh new file mode 100644 index 00000000000000..cc6d2845506939 --- /dev/null +++ b/aten/src/ATen/zoom/jit/Loops.cuh @@ -0,0 +1,325 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + + +namespace at { namespace native { + +template +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size()]; + args_t args[thread_work_size()]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} + +}} // namespace at::native + +#include "HIPLoops.cuh" + +namespace at:: native { + +template +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_privateuseone(), + "argument ", arg, ": expected a Zoom device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_nocast(sub_iter, f); + } + return; + } + + gpu_kernel_impl_nocast(iter, f); +} + +template +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_privateuseone(), + "argument ", arg, ": expected a Zoom device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel(sub_iter, f); + } + return; + } + + gpu_kernel_impl(iter, f); +} + +template +struct AUnaryFunctor { + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + __device__ return_t operator()(arg2_t b) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} + private: + func_t f; + opmath_arg1_t a; +}; + +template +struct BUnaryFunctor { + using traits = function_traits; + using opmath_arg2_t = typename traits::template arg<1>::type; + __device__ return_t operator()(arg1_t a) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} + private: + func_t f; + opmath_arg2_t b; +}; + +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + + if (iter.is_cpu_scalar(1)) { + AUnaryFunctor af(f, iter.scalar_value(1)); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + gpu_kernel(iter, af); + } else if (iter.is_cpu_scalar(2)) { + BUnaryFunctor bf(f, iter.scalar_value(2)); + iter.remove_operand(2); + gpu_kernel(iter, bf); + } else { + gpu_kernel(iter, BinaryFunctor(f)); + } +} + +template +void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + // Use symmetric property of the functor to reduce number of kernels, + // requires f(a, b) == f(b, a) + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg_t = typename traits::template arg<0>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + static_assert(std::is_same::type>::value, + "f is not symmetric"); + + OptionalDeviceGuard device_guard; + opmath_arg_t scalar_val{}; + + if (iter.is_cpu_scalar(1)) { + scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + device_guard.reset_device(iter.device(1)); + } else if (iter.is_cpu_scalar(2)) { + scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + } + + if (iter.ninputs() == 2) { + gpu_kernel(iter, BinaryFunctor(f)); + } else { + AUnaryFunctor unary_f(f, scalar_val); + gpu_kernel(iter, unary_f); + } +} + +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size() * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = c10::zoom::getCurrentZoomStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_privateuseone()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + +} //namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/MemoryAccess.cuh b/aten/src/ATen/zoom/jit/MemoryAccess.cuh new file mode 100644 index 00000000000000..4b182724166cbb --- /dev/null +++ b/aten/src/ATen/zoom/jit/MemoryAccess.cuh @@ -0,0 +1,395 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +// make sure this is defined +#define C10_HOST_DEVICE __host__ __device__ + +// References: +// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ + +namespace at { namespace native { namespace memory { + +namespace detail { + +// What does the `static_unroll` do? +// +// We want to do something like: +// +// using args_t = typename traits::ArgsTuple; +// args_t args; +// #pragma unroll +// for (int i = 0; i < traits::arity; i++) { +// std::get(args) = .... +// } +// +// but unfortunately the above code does not work because +// the template argument has to be a compile time constant +// so `static_unroll` is created to simulate `#pragma unroll` +// using template metaprogramming. + +template typename func, int end, int current=0> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args&&... args) { + func::apply(std::forward(args)...); + static_unroll::with_args(args...); + } +}; + +template typename func, int end> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args... args) {} +}; + +// helper structs to be used with static_unroll to load arguments +// one by one + +template +struct vectorized_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size() * idx; + auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; + self.load_single_arg(args_accessor, ptr); + } +}; + +template +struct unroll_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); + } +}; + +template +struct multi_outputs_store_helper { + template + C10_HOST_DEVICE static void apply( + at::detail::Array data, + at::detail::Array offsets, + thrust::tuple ret) { + using T = typename thrust::tuple_element>::type; + T *to = reinterpret_cast(data[current]) + offsets[current]; + *to = thrust::get(ret); + } +}; + +} // namespace detail + +struct LoadWithoutCast { + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } +}; + +template +struct LoadWithCast { + using array_t = at::detail::Array(N, 1)>; + using size_array_t = at::detail::Array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + LoadWithCast(const TensorIteratorBase& iter) { + ZOOM_KERNEL_ASSERT(iter.ninputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i + iter.noutputs()); + element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs())); + } + } + + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + void *ptr = base_ptr + element_sizes[arg] * offset; + return c10::fetch_and_cast(dtypes[arg], ptr); + } +}; + +struct StoreWithoutCast { + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + *(reinterpret_cast(base_ptr) + offset) = value; + } +}; + +template +struct StoreWithCast { + using array_t = at::detail::Array(N, 1)>; + using size_array_t = at::detail::Array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + StoreWithCast(const TensorIteratorBase& iter) { + ZOOM_KERNEL_ASSERT(iter.noutputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i); + element_sizes[i] = c10::elementSize(iter.dtype(i)); + } + } + + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + void *ptr = base_ptr + element_sizes[arg] * offset; + c10::cast_and_store(dtypes[arg], ptr, value); + } +}; + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; +}; + +template +__device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { + using vec_t = aligned_vector; + auto *from = reinterpret_cast(base_ptr); + return from[offset]; +} + +template +__device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { + // See NOTE [Loading boolean values] + auto tmp = load_vector(reinterpret_cast(base_ptr), offset); + aligned_vector ret; + for (int i = 0; i < vec_size; ++i) { + ret.val[i] = bool(tmp.val[i]); + } + return ret; +} + +namespace policies { + +// Assumption: +// all tensors are contiguous, that is: stride == sizeof(type) for all tensors +template +struct unroll { + + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + loader_t loader; + storer_t storer; + + __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + template + __device__ inline void store(scalar_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads(); + } + } +}; + +// Assumption: +// all tensors are contiguous, that is: stride == sizeof(type) for all tensors +// Note: +// Functions in vectorized policy does not do boundary check. It assumes the whole block +// has its job to do. So the reminders should be handled by the caller manually. +template // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized { + + static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = thread_work_size() / vec_size; + + data_t data; + + __device__ vectorized(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, scalar_t *from) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + auto v = load_vector(from, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = v.val[j]; + } + } + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll::with_args(*this, args, idx); + } + + template + __device__ inline void store(scalar_t *from, int idx) { + using vec_t = aligned_vector; + scalar_t *to = reinterpret_cast(data[0]) + block_work_size() * idx; + vec_t *to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; + +template +struct multi_outputs_unroll { + //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct + //we don't use inheritance because of compiler bug in cuda 10.2+ + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + LoadWithoutCast loader; + StoreWithoutCast storer; + + __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + + template + __device__ inline void store(return_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= this->remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offsets = this->output_offset_calculator.get(linear_idx); + memory::detail::static_unroll::with_args(this->data, offsets, from[i]); + thread_idx += num_threads(); + } + } +}; + +} // namespace policies + +// This is only used in host, but we will wrap this into some templates +// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE +// in order to compile +template +inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec2_alignment = std::alignment_of>::value; + constexpr int vec4_alignment = std::alignment_of>::value; + if (address % vec4_alignment == 0) { + return 4; + } else if (address % vec2_alignment == 0) { + return 2; + } + return 1; +} + +template +inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) { + return can_vectorize_up_to(static_cast(pointer)); +} + +template +struct can_vectorize_up_to_helper { + template + static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) { + using arg_t = typename traits::template arg::type; + // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + result = std::min(result, can_vectorize_up_to(pointers[i + 1])); + } +}; + +template +inline int can_vectorize_up_to(array_t pointers) { + using traits = function_traits; + using return_t = typename traits::result_type; + constexpr int arity = traits::arity; + int result = can_vectorize_up_to(pointers[0]); + // We need to get the type for each argument of `func_t`, this can only + // be done at compile time. + detail::static_unroll::with_args(result, pointers, traits()); + return result; +} + +}}} // namespace at::native::memory \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/OffsetCalculator.cuh b/aten/src/ATen/zoom/jit/OffsetCalculator.cuh new file mode 100644 index 00000000000000..618d30a23f5dd0 --- /dev/null +++ b/aten/src/ATen/zoom/jit/OffsetCalculator.cuh @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// If element_sizes is nullptr, then the strides will be in bytes, otherwise +// the strides will be in # of elements. +// Operands that share the same shape, but may have different strides. +// OffsetCalculator iterates the tensor in a column-major order + +constexpr int MAX_DIMS = 16; + +template +struct OffsetCalculator { + // We allow having negative strides to implement some operations like torch.flip + using stride_t = std::conditional_t, + index_t>; + // The offset for each argument. Wrapper around fixed-size array. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = at::detail::Array(NARGS, 1)>; + + // if element_sizes is nullptr, then the strides will be in bytes, otherwise + // the strides will be in # of elements. + OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) { + TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims"); + for (int i=0; i < dims; i++){ + sizes_[i] = at::zoom::detail::IntDivider(sizes[i]); + for (int arg = 0; arg < NARGS; arg++) { + int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]); + strides_[i][arg] = strides[arg][i] / element_size; + } + } + } + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = 0; + } + + #pragma unroll + for (int dim = 0; dim < MAX_DIMS; ++dim) { + if (dim == dims) { + break; + } + auto divmod = sizes_[dim].divmod(linear_idx); + linear_idx = divmod.div; + + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] += divmod.mod * strides_[dim][arg]; + } + + } + return offsets; + } + + int dims; + at::zoom::detail::IntDivider sizes_[MAX_DIMS]; + stride_t strides_[MAX_DIMS][std::max(NARGS, 1)]; +}; + +template +struct TrivialOffsetCalculator { + // The offset for each argument. Wrapper around fixed-size array. + // The offsets are in # of elements, not in bytes. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = at::detail::Array(NARGS, 1)>; + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = linear_idx; + } + return offsets; + } +}; + +// Make an OffsetCalculator with byte offsets +template +static OffsetCalculator make_offset_calculator(const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data()); +} + +// Make an OffsetCalculator with element offsets +template +static OffsetCalculator make_element_offset_calculator( + const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + std::array element_sizes; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator( + iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data()); +} \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/jit_utils.cpp b/aten/src/ATen/zoom/jit/jit_utils.cpp new file mode 100644 index 00000000000000..16f6b3807260d9 --- /dev/null +++ b/aten/src/ATen/zoom/jit/jit_utils.cpp @@ -0,0 +1,1752 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include // istreambuf_iterator +#include +#include + +// TODO: C++17 has the filesystem header, which may replace these +#ifdef _WIN32 + // On Windows, the POSIX implementations are considered deprecated. We simply map to the newer variant. + #include + #include + #include + #define access _access + #define getpid _getpid + #define R_OK 4 + #define W_OK 2 + #define F_OK 0 +#else + #include + #include // mkdir + #include +#endif + + +namespace at::zoom::jit { + +const std::string jit_preamble = R"ESCAPE( +#pragma clang force_cuda_host_device begin +)ESCAPE"; +const std::string jit_epilogue = R"ESCAPE( +#pragma clang force_cuda_host_device end +)ESCAPE"; + + +const std::string jit_common_types = R"ESCAPE( + #ifdef __HIPCC__ + #define ERROR_UNSUPPORTED_CAST ; + // corresponds to aten/src/ATen/native/cuda/thread_constants.h + #define CUDA_OR_ROCM_NUM_THREADS 256 + // corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh + #define MAX_DIMS 16 + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + #else + //TODO use _assert_fail, because assert is disabled in non-debug builds + #define ERROR_UNSUPPORTED_CAST assert(false); + #define CUDA_OR_ROCM_NUM_THREADS 128 + #define MAX_DIMS 25 + #endif + #define POS_INFINITY __int_as_float(0x7f800000) + #define INFINITY POS_INFINITY + #define NEG_INFINITY __int_as_float(0xff800000) + #define NAN __int_as_float(0x7fffffff) + + typedef long long int int64_t; + typedef unsigned int uint32_t; + typedef signed char int8_t; + typedef unsigned char uint8_t; // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char" + typedef short int16_t; + static_assert(sizeof(int64_t) == 8, "expected size does not match"); + static_assert(sizeof(uint32_t) == 4, "expected size does not match"); + static_assert(sizeof(int8_t) == 1, "expected size does not match"); + constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS; + constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live + constexpr int block_work_size = thread_work_size * num_threads; + + ${traits_string} + ${cmath_string} + + // NB: Order matters for this macro; it is relied upon in + // _promoteTypesLookup and the serialization format. + // Note, some types have ctype as void because we don't support them in codegen + #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(std::complex, ComplexHalf) /* 8 */ \ + _(std::complex, ComplexFloat) /* 9 */ \ + _(std::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(void, QInt8) /* 12 */ \ + _(void, QUInt8) /* 13 */ \ + _(void, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + + #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(std::complex, ComplexHalf) \ + _(std::complex, ComplexFloat) \ + _(std::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) + + + enum class ScalarType : int8_t { + #define DEFINE_ENUM(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM) + #undef DEFINE_ENUM + Undefined, + NumOptions + }; + + template + struct Array { + T data[size]; + + __device__ T operator[](int i) const { + return data[i]; + } + __device__ T& operator[](int i) { + return data[i]; + } + Array() = default; + Array(const Array&) = default; + Array& operator=(const Array&) = default; + __device__ Array(T x) { + for (int i = 0; i < size; i++) { + data[i] = x; + } + } + }; + + ${half_string} + ${bfloat16_string} + ${complex_body_string} + ${complex_half_body_string} + ${complex_math_string} + + +)ESCAPE"; + +//we need to include half, bfloat16 and complex strings to all kernels with half arguments and to all kernels with type casting +//regardless of whether they have half arguments (because fetch_and_cast and cast_and_store loop over all types) +const std::string jiterator_half_support_literal = R"ESCAPE( +namespace at { +struct alignas(2) Half { + unsigned short x; + + Half() = default; + inline __host__ __device__ Half(float value){ +#ifdef __HIPCC__ + x = __half_as_short(__float2half(value)); +#else + asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(x) : "f"(value)); +#endif + } + inline __host__ __device__ Half(const __half& value) { + x = *reinterpret_cast(&value); + } + inline __host__ __device__ operator __half() const { + return *reinterpret_cast(&x); + } + inline __host__ __device__ operator float() const{ +#ifdef __HIPCC__ + return __half2float(*reinterpret_cast(&x)); +#else + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x)); // do we need const cast here? + //asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(x))); + return val; +#endif + } +}; + + /// Arithmetic + + inline __host__ __device__ Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); + } + + inline __host__ __device__ Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); + } + + inline __host__ __device__ Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); + } + + inline __host__ __device__ Half operator/(const Half& a, const Half& b) + { + return static_cast(a) / static_cast(b); + } + + inline __host__ __device__ Half operator-(const Half& a) { + #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); + #elif defined(__SYCL_DEVICE_ONLY__) + return -c10::bit_cast(a); + #else + return -static_cast(a); + #endif + } + + inline __host__ __device__ Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; + } + + inline __host__ __device__ Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; + } + + inline __host__ __device__ Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; + } + + inline __host__ __device__ Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; + } + +} + + +)ESCAPE"; + +const std::string jiterator_bfloat16_support_literal = R"ESCAPE( +namespace at { +struct alignas(2) BFloat16 { + unsigned short x; + + __device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); + } + + + BFloat16() = default; + inline __host__ __device__ BFloat16(float value){ + #if __CUDA_ARCH__ >= 800 + asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(x) : "f"(value)); + )ESCAPE" + R"ESCAPE( + #else + unsigned int sign; + unsigned int remainder; + x = __internal_float2bfloat16(value, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((x & 0x1U) != 0U))) { + x++; + } + #endif + } + + inline __host__ __device__ operator float() const{ +#ifdef __HIPCC__ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + return u.fp32; +#else + float val; + asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(x)); //do we need const cast here? + return val; +#endif + } + +}; + + /// Arithmetic + + inline __host__ __device__ BFloat16 + operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); + } + + inline __host__ __device__ BFloat16 + operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); + } + + inline __host__ __device__ BFloat16 + operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); + } + + inline __host__ __device__ BFloat16 operator/(const BFloat16& a, const BFloat16& b) { + return static_cast(a) / static_cast(b); + } + + inline __host__ __device__ BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); + } + + inline __host__ __device__ BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; + } + + inline __host__ __device__ BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; + } + + inline __host__ __device__ BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; + } + + inline __host__ __device__ BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; + } + + inline __host__ __device__ BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; + } + + inline __host__ __device__ BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; + } + + inline __host__ __device__ BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; + } + +} +)ESCAPE"; + +// From c10/util/Load.h +const std::string load_support_literal = R"ESCAPE( + + namespace c10 { + template + struct LoadImpl { + __device__ static T apply(const void *src) { + return *reinterpret_cast(src); + } + }; + + template <> + struct LoadImpl { + __device__ static bool apply(const void *src) { + static_assert(sizeof(bool) == sizeof(char), ""); + return LoadImpl::apply(src); + } + }; + + template + __device__ T load(const void *src) { + return LoadImpl::apply(src); + } + + template + __device__ scalar_t load(const scalar_t *src) { + return LoadImpl::apply(src); + } + } // namespace c10 + +)ESCAPE"; + +// copy-pasted from c10/util/TypeCast.h and c10/core/DynamicCast.h +const std::string dynamic_cast_support_literal = R"ESCAPE( + + template + struct is_complex : public std::false_type {}; + + template + struct is_complex> : public std::true_type {}; + + template + struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); + }; + + template + struct maybe_real { + static inline src_t apply(src_t src) { + return src; + } + }; + + template + struct maybe_real { + static inline decltype(auto) apply(src_t src) { + return src.real(); + } + }; + + template + struct static_cast_with_inter_type { + static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast(maybe_real::apply(src)); + } + }; + + template + struct static_cast_with_inter_type { + static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } + }; + + template <> + struct static_cast_with_inter_type, at::BFloat16> { + static inline std::complex apply(at::BFloat16 src) { + return static_cast>(float{src}); + } + }; + + template <> + struct static_cast_with_inter_type, at::Half> { + static inline std::complex apply(at::Half src) { + return static_cast>(float{src}); + } + }; + + template <> + struct static_cast_with_inter_type< + std::complex, + std::complex> { + static inline std::complex apply(std::complex src) { + return static_cast>(static_cast>(src)); + } + }; + + // Fetch a value with dynamic type src_type from ptr, and cast it to static type dest_t. + #define FETCH_AND_CAST_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + return static_cast_with_inter_type::apply(c10::load(ptr)); + template + __device__ inline dest_t fetch_and_cast(const ScalarType src_type, const void *ptr) { + switch (src_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(FETCH_AND_CAST_CASE) + default: + ERROR_UNSUPPORTED_CAST + } + return dest_t(0); // just to avoid compiler warning + } + + // Cast a value with static type src_t into dynamic dest_type, and store it to ptr. + #define CAST_AND_STORE_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + *(type*)ptr = static_cast_with_inter_type::apply(value); \ + return; + template + __device__ inline void cast_and_store(const ScalarType dest_type, void *ptr, src_t value) { + switch (dest_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(CAST_AND_STORE_CASE) + default:; + } + ERROR_UNSUPPORTED_CAST + } + + template + struct LoadWithCast { + using array_t = Array; + using size_array_t = Array; + + array_t dtypes; + size_array_t element_sizes; + template + __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg) { + void* ptr = base_ptr + element_sizes[arg] * offset; + return fetch_and_cast(dtypes[arg], ptr); + } + }; + + template + struct StoreWithCast { + using array_t = Array; + using size_array_t = Array; + + array_t dtypes; + size_array_t element_sizes; + + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + void *ptr = base_ptr + element_sizes[arg] * offset; + cast_and_store(dtypes[arg], ptr, value); + } + }; + +)ESCAPE"; + +const std::string no_dynamic_cast_support_literal = R"ESCAPE( + + struct LoadWithoutCast { + template + __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg=0) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } + }; + + struct StoreWithoutCast { + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg=0) { + *(reinterpret_cast(base_ptr) + offset) = value; + } + }; + +)ESCAPE"; + +const std::string offset_calc_template = R"ESCAPE( + template + struct DivMod { + T div; + T mod; + + __device__ DivMod(T _div, T _mod) { + div = _div; + mod = _mod; + } + }; + + // + struct IntDivider { + IntDivider() = default; + + __device__ inline unsigned int div(unsigned int n) const { + unsigned int t = __umulhi(n, m1); + return (t + n) >> shift; + } + + __device__ inline unsigned int mod(unsigned int n) const { + return n - div(n) * divisor; + } + + __device__ inline DivMod divmod(unsigned int n) const { + unsigned int q = div(n); + return DivMod(q, n - q * divisor); + } + + unsigned int divisor; // d above. + unsigned int m1; // Magic number: m' above. + unsigned int shift; // Shift amounts. + }; + + template + struct TrivialOffsetCalculator { + // The offset for each argument. Wrapper around fixed-size array. + // The offsets are in # of elements, not in bytes. + Array<${index_type}, NARGS> get(${index_type} linear_idx) const { + Array<${index_type}, NARGS> offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = linear_idx; + } + return offsets; + } + }; + + template + struct OffsetCalculator { + OffsetCalculator() = default; + __device__ __forceinline__ Array<${index_type}, NARGS> get(${index_type} linear_idx) const { + Array<${index_type}, NARGS> offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; ++arg) { + offsets[arg] = 0; + } + + #pragma unroll + for (int dim = 0; dim < MAX_DIMS; ++dim) { + if (dim == dims) { + break; + } + + auto divmod = sizes_[dim].divmod(linear_idx); + linear_idx = divmod.div; + + #pragma unroll + for (int arg = 0; arg < NARGS; ++arg) { + offsets[arg] += divmod.mod * strides_[dim][arg]; + } + //printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n", + //threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod); + } + return offsets; + } + + int dims; + IntDivider sizes_[MAX_DIMS]; + // NOTE: this approach will not support nInputs == 0 + ${index_type} strides_[MAX_DIMS][NARGS]; + }; + + +)ESCAPE"; + +const std::string jit_code_template = R"ESCAPE( + + ${load_support} + ${dynamic_casting_string} + + + ${functor} + + // TODO: setup grid-stride loop + extern "C" __global__ + void ${name}_kernel( + const int numel, + Array data, //[${nInputs}+${nOutputs}], + ${offset_calculator}<${nInputs}> input_calculator, + ${offset_calculator}<${nOutputs}> output_calculator, + ${loader} l, + ${storer} s, + ${compute_type} scalar_val${extra_params}) { + ${declare_load_arrays} + ${declare_store_arrays} + + int idx = blockIdx.x; + + int remaining = numel - block_work_size * idx; + int thread_idx = threadIdx.x; + + #pragma unroll + for (int j = 0; j < thread_work_size; j++){ + if (thread_idx >= remaining) { + break; + } + + int linear_idx = thread_idx + block_work_size * idx; + auto input_offsets = input_calculator.get(linear_idx); + ${load_inputs} + // printf( + // "thread %d a %f offsets %d\n", threadIdx.x, arg0[j], input_offsets[0]); + thread_idx += num_threads; + } + + #pragma unroll + for (int j = 0; j < thread_work_size; j++) { + if ((threadIdx.x + j*num_threads) < remaining) { + ${call_functor} + } + } + + thread_idx = threadIdx.x; + #pragma unroll + for (int j = 0; j < thread_work_size; j++){ + if (thread_idx >= remaining) { + break; + } + //TODO maybe think about unifying offset calculators and reuse + //offsets computed in the load loop + int linear_idx = thread_idx + block_work_size * idx; + auto output_offsets = output_calculator.get(linear_idx); + //printf("output thread %d offset %d\n", threadIdx.x, output_offsets[0]); + ${store_outputs} + thread_idx += num_threads; + } + } +)ESCAPE"; + +const std::string jit_vectorized_code_template = R"ESCAPE( + + ${load_support} + + template + __device__ __inline__ scalar_t load(char* base_ptr, uint32_t offset) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } + + template + __device__ __inline__ void store(scalar_t value, char *base_ptr, uint32_t offset) { + *(reinterpret_cast(base_ptr) + offset) = value; + } + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + template + __device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { + using vec_t = aligned_vector; + auto *from = reinterpret_cast(base_ptr); + return from[offset]; + } + + template + __device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { + // See NOTE [Loading boolean values] + auto tmp = load_vector(reinterpret_cast(base_ptr), offset); + aligned_vector ret; + for (int i = 0; i < vec_size; ++i) { + ret.val[i] = bool(tmp.val[i]); + } + return ret; + } + + ${functor} + + // TODO: setup grid-stride loop + + extern "C" __global__ + void ${name}_vectorized${vec_size}_kernel( + const int N, + Array data, + ${compute_type} scalar_val${extra_params}) //[${nInputs}+${nOutputs}], + { + constexpr int vec_size = ${vec_size}; + using scalar_t = ${scalar_type}; + int remaining = N - block_work_size * blockIdx.x; + int thread_idx = threadIdx.x; + int idx = blockIdx.x; + ${declare_load_arrays} + ${declare_store_arrays} + + if (remaining < block_work_size) { + #pragma unroll + for (int j = 0; j < thread_work_size; j++){ + if (thread_idx >= remaining) { + break; + } + int linear_idx = thread_idx + block_work_size * idx; + ${load_unrolled_inputs} + thread_idx += num_threads; + } + #pragma unroll + for (int j = 0; j < thread_work_size; j++) { + if ((threadIdx.x + j*num_threads) < remaining) { + ${call_functor} + } + } + thread_idx = threadIdx.x; + #pragma unroll + for (int j = 0; j < thread_work_size; j++) { + if (thread_idx >= remaining) { + break; + } + int linear_idx = thread_idx + block_work_size * idx; + ${store_unrolled_outputs} + thread_idx += num_threads; + } + } else { + static constexpr int loop_size = thread_work_size / vec_size; + //actual loading + ${vector_inputs} + #pragma unroll + for (int i = 0; i; + ${vector_outputs} + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i + __device__ T zero_init() { + return T(0); + } + + template <> + __device__ hipFloatComplex zero_init() { + return make_hipFloatComplex(0.0f, 0.0f); + } + + template <> + __device__ hipDoubleComplex zero_init() { + return make_hipDoubleComplex(0.0, 0.0); + } + + // kernels can use scalar_t as a template type in their implementation + using scalar_t = ${scalar_t}; + ${kernel} + +)ESCAPE"; + +static void replace_all(std::string& s, const std::string& to_replace, const std::string& replace_with) { + std::ostringstream oss; + std::size_t pos = 0; + std::size_t prev_pos = pos; + + while (true) { + prev_pos = pos; + pos = s.find(to_replace, pos); + if (pos == std::string::npos) + break; + oss << s.substr(prev_pos, pos - prev_pos); + oss << replace_with; + pos += to_replace.size(); + } + + oss << s.substr(prev_pos); + s = oss.str(); +} + +// hipify replaces certain device math functions, e.g., std::max -> ::max +// See torch/utils/hipify/cuda_to_hip_mappings.py. +// Replace them back. Search for " ::" to avoid duplicate replacements. +static std::string unhipify_math_functions(const std::string &original) { + static std::vector> mappings = { + {" std::max", " ::max"}, + {" std::min", " ::min"}, + {" std::ceil", " ::ceil"}, + {" std::floor", " ::floor"}, + {" std::exp", " ::exp"}, + {" std::log", " ::log"}, + {" std::pow", " ::pow"}, + {" std::fabs", " ::fabs"}, + {" std::fmod", " ::fmod"}, + {" std::remainder", " ::remainder"}, + {" std::frexp", " ::frexp"} + }; + std::string ret = original; + for (const auto& mapping : mappings) { + replace_all(ret, mapping.second, mapping.first); + } + return ret; +} + +// The following is copied from fused_kernel.cpp +// TODO: refactor codegenOutputQuery into its own file +// that can be included by both files +// See NOTE [ USE OF NVRTC AND DRIVER API ] +const at::zoom::HIPRTC& hiprtc() { + return at::globalContext().getHIPRTC(); +} + +// query codegen output arch and target +// TODO refactor so this function is usable both from jit and from aten +void codegenOutputQuery( + const hipDeviceProp_t* const prop, + int& hip_major, + int& hip_minor, + int& hiprtc_major, + int& hiprtc_minor, + bool& compile_to_sass) { + ZOOM_HIPRTC_CHECK(hiprtc().hiprtcVersion(&hiprtc_major, &hiprtc_minor)); + hip_major = prop->major; + hip_minor = prop->minor; + compile_to_sass = false; +} + +// TODO: another copy paste from jit, refactor so it's usable from both +// TODO: try making the CUcontext thread local to see if that improves performance - why is this slow? +void initializeZoomContext() { + // lazily construct context if non-existing yet; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + hipCtx_t pctx = nullptr; + HIP_DRIVER_CHECK(at::globalContext().getHIPRTC().hipCtxGetCurrent(&pctx)); + if (!pctx) { + std::unique_lock hipFreeMutexLock( + *(c10::zoom::getFreeMutex())); + hipFree(nullptr); + } +} + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized, + int vec_size, + bool return_by_ref) { + c10::SmallVector extra_args_typenames(desc.extra_args_types.size()); + for (auto i : c10::irange(extra_args_typenames.size())) { + extra_args_typenames[i] = typeName(desc.extra_args_types[i]); + } + + return generate_code( + desc.nInputs, + desc.nOutputs, + desc.f, + desc.name, + typeName(desc.f_inputs_type), + typeName(toOpMathType(desc.f_inputs_type)), + typeName(desc.result_type), + contiguous, + dynamic_casting, + scalar_pos, + extra_args_typenames, + vectorized, + vec_size, + return_by_ref); +} + +//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh +#define THREAD_WORK_SIZE 4 +constexpr int thread_work_size = THREAD_WORK_SIZE; + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func_, + const std::string& name, + const std::string& f_inputs_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + bool vectorized, + int vec_size, + bool return_by_ref) { + std::string func = func_; + at::jit::TemplateEnv env; + + env.s("index_type", "unsigned int"); + env.s("nInputs", std::to_string(nInputs)); + env.s("nOutputs", std::to_string(nOutputs)); + env.s("scalar_type", f_inputs_type); + env.s("compute_type", compute_type); + env.s("functor", func); + env.s("name", name); + env.s("cmath_string", get_cmath_string()); + + // Generate `extra_params` for function signature + // and `extra_args` for computation call if + // extra arguments to capture runtime state are passed. + // (look at polygamma for example). + std::string extra_params = ""; + std::string extra_args = ""; + for (size_t i = 0; i < extra_args_typenames.size(); i++) { + auto type = std::string(extra_args_typenames[i]); + auto name = "extra_arg_" + std::string(to_string(i)); + extra_params += "," + type + " " + name; + extra_args += ", " + name; + } + env.s("extra_params", extra_params); + env.s("extra_args", extra_args); + + std::stringstream declare_load_arrays; + for (int i = 0; i < nInputs; i++) { + // TODO these arrays are potentially of the different types, use function + // traits to determine the types + declare_load_arrays << f_inputs_type << " arg" << std::to_string(i) + << "[" << std::to_string(thread_work_size) << "];\n"; + } + env.s("declare_load_arrays", declare_load_arrays.str()); + + std::stringstream declare_store_arrays; + for (int i = 0; i < nOutputs; i++) { + declare_store_arrays << result_type << " out" << std::to_string(i) + << "[" << std::to_string(thread_work_size) << "];\n"; + } + env.s("declare_store_arrays", declare_store_arrays.str()); + + std::stringstream functor_args; + if (scalar_pos == BinaryFuncVariant::NoScalar) { + for (int i = 0; i < nInputs - 1; i++) { + functor_args << "arg" << std::to_string(i) << "[j], "; + } + functor_args << "arg" << std::to_string(nInputs - 1) << "[j]"; + } else if (scalar_pos == BinaryFuncVariant::LhsScalar) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1); + functor_args << "scalar_val, arg0[j]"; + } else { //RhsScalar + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1); + functor_args << "arg0[j], scalar_val"; + } + env.s("args", functor_args.str()); + + std::string call_functor_template; + if (return_by_ref) { // return one or more outputs by reference + bool need_temp_out = (compute_type != result_type); + std::stringstream functor_outs; + if (need_temp_out) { + for (int i = 0; i < nOutputs - 1; i++) { + functor_outs << "temp_out" << std::to_string(i) << ", "; + } + functor_outs << "temp_out" << std::to_string(nOutputs - 1); + } else { + for (int i = 0; i < nOutputs - 1; i++) { + functor_outs << "out" << std::to_string(i) << "[j], "; + } + functor_outs << "out" << std::to_string(nOutputs - 1) << "[j]"; + } + env.s("functor_outs", functor_outs.str()); + + if (need_temp_out) { + call_functor_template += "${compute_type} ${functor_outs};\n"; + } + + call_functor_template += "${name}<${compute_type}>(${args} ${extra_args}, ${functor_outs});\n"; + + if (need_temp_out) { + for (int i = 0; i < nOutputs; i++) { + auto i_string = std::to_string(i); + call_functor_template += "out" +i_string + "[j] = temp_out" + i_string + ";\n"; + } + } + + } else { // return by value for single output functor + call_functor_template = "out0[j] = ${name}<${compute_type}>(${args} ${extra_args});"; + } + env.s("call_functor", at::jit::CodeTemplate(call_functor_template).format(env)); + + if (f_inputs_type == "at::Half" || result_type == "at::Half" || + f_inputs_type == "std::complex" || + result_type == "std::complex" || dynamic_casting) { + // complex depends on complex and Half dtypes. + env.s("half_string", jiterator_half_support_literal); + } else { + env.s("half_string", ""); + } + if (f_inputs_type == "at::BFloat16" || result_type == "at::BFloat16" || dynamic_casting) { + env.s("bfloat16_string", jiterator_bfloat16_support_literal); + } else { + env.s("bfloat16_string", ""); + } + // the definition of complex math functions is only needed when the compute type is complex + // but the definition of std::complex is needed for dynamic casting even if the compute type is not complex + if (f_inputs_type == "std::complex" || result_type == "std::complex" || + f_inputs_type == "std::complex" || result_type == "std::complex" || + f_inputs_type == "std::complex" || result_type == "std::complex") { + // complex depends on complex and Half dtypes. + env.s("traits_string", get_traits_string()); + env.s("complex_body_string", get_complex_body_string()); + env.s("complex_math_string", get_complex_math_string()); + + // unhipify math functions, but only if std::complex is used. + func = unhipify_math_functions(func); + env.s("functor", func); + + } else if (dynamic_casting) { + env.s("traits_string", get_traits_string()); + env.s("complex_body_string", get_complex_body_string()); + env.s("complex_math_string", ""); + } else { + env.s("traits_string", ""); + env.s("complex_body_string", ""); + env.s("complex_math_string", ""); + } + if (f_inputs_type == "std::complex" || + result_type == "std::complex" || dynamic_casting) { + // dynamic_casting requires the definition of all types + // include complex + // Look at the definition of `StoreWithCast` and `LoadWithCast`. + env.s("complex_half_body_string", get_complex_half_body_string()); + } else { + env.s("complex_half_body_string", ""); + } + + env.s("load_support", load_support_literal); + + if (!vectorized) { + if (!dynamic_casting) { + env.s("loader", "LoadWithoutCast"); + env.s("storer", "StoreWithoutCast"); + env.s("dynamic_casting_string", no_dynamic_cast_support_literal); + } else { + env.s("loader", std::string("LoadWithCast<" + std::to_string(nInputs) + ">")); + env.s("storer", std::string("StoreWithCast<" + std::to_string(nOutputs) + ">")); + env.s("dynamic_casting_string", dynamic_cast_support_literal); + } + + if (contiguous) { + env.s("offset_calculator", "TrivialOffsetCalculator"); + } else { + env.s("offset_calculator", "OffsetCalculator"); + } + + std::stringstream load_inputs; + for (int i = 0; i < nInputs; i++) { + auto i_string = std::to_string(i); + load_inputs << "arg" << i_string << "[j] = l.load<" << f_inputs_type + << ">(data[" << std::to_string(i + nOutputs) + << "], input_offsets[" << i_string << "], " << i_string + << ");\n"; + } + env.s("load_inputs", load_inputs.str()); + + std::stringstream store_outputs; + for (int i = 0; i < nOutputs; i++) { + auto i_string = std::to_string(i); + store_outputs << "s.store<" << result_type + << ">(out" << i_string << "[j], data[" << i_string + << "], output_offsets[" << i_string << "], " << i_string + << ");\n"; + } + env.s("store_outputs", store_outputs.str()); + + static auto hip_template = at::jit::CodeTemplate( + jit_preamble + jit_common_types + offset_calc_template + jit_code_template + jit_epilogue); + const auto code = hip_template.format(env); + return code; + } + + // vectorized case + env.s("vec_size", std::to_string(vec_size)); + env.s("result_type", result_type); + + std::stringstream vector_inputs; + for (const auto i : c10::irange(nInputs)){ + auto i_string = std::to_string(i); + vector_inputs << "auto * input" << i_string << + " = reinterpret_cast(data[" << i_string << "+" << nOutputs << "])" << + " + block_work_size * idx;\n"; + } + env.s("vector_inputs", vector_inputs.str()); + + std::stringstream vector_outputs; + for (const auto i : c10::irange(nOutputs)){ + auto i_string = std::to_string(i); + vector_outputs << "vec_t_output* to_" << i_string << + " = reinterpret_cast(data[" << i_string << "])" << + " + block_work_size / vec_size * idx;\n"; + } + env.s("vector_outputs", vector_outputs.str()); + + std::stringstream load_vectorized_inputs; + for (const auto i : c10::irange(nInputs)) { + auto i_string = std::to_string(i); + load_vectorized_inputs << "const auto vec" << i_string << " = load_vector(" + << "input" << i_string << ", thread_idx);\n"; + load_vectorized_inputs << "#pragma unroll\n"; + load_vectorized_inputs << "for (int j=0; j < vec_size; j++){\n"; + load_vectorized_inputs << " arg" << i_string << "[vec_size * i + j] = vec" << i_string << ".val[j];\n"; + load_vectorized_inputs << "}\n"; + } + env.s("load_vectorized_inputs", load_vectorized_inputs.str()); + + std::stringstream store_vectorized_outputs; + for (const auto i : c10::irange(nOutputs)) { + auto i_string = std::to_string(i); + store_vectorized_outputs << "#pragma unroll\n"; + store_vectorized_outputs << "for (int j=0; j(data[" << std::to_string(i + nOutputs) << "], linear_idx);\n"; + } + env.s("load_unrolled_inputs", load_unrolled_inputs.str()); + + std::stringstream store_unrolled_outputs; + for (const auto i : c10::irange(nOutputs)) { + auto i_string = std::to_string(i); + store_unrolled_outputs << "store<" << result_type << ">(out" << i_string + << "[j], data[" << i_string << "], linear_idx);\n"; + } + env.s("store_unrolled_outputs", store_unrolled_outputs.str()); + + static auto hip_template = at::jit::CodeTemplate( + jit_preamble + jit_common_types + jit_vectorized_code_template + jit_epilogue); + const auto code = hip_template.format(env); + return code; +} + +std::string zoom_generate_code( + const KernelDescriptor &desc, + bool dynamic_casting + ) { + c10::SmallVector extra_args_typenames(desc.extra_args_types.size()); + for (auto i : c10::irange(extra_args_typenames.size())) { + extra_args_typenames[i] = typeName(desc.extra_args_types[i]); + } + + return zoom_generate_code( + desc.nInputs, + desc.nOutputs, + desc.f, + desc.name, + typeName(desc.f_inputs_type), + typeName(toOpMathType(desc.f_inputs_type)), + typeName(desc.result_type), + dynamic_casting, + extra_args_typenames + ); +} + +std::string zoom_generate_code( + int nInputs, + int nOutputs, + const std::string& func_, + const std::string& name, + const std::string& f_inputs_type, + const std::string& compute_type, + const std::string& result_type, + bool dynamic_casting, + c10::SmallVector& extra_args_typenames +) { + std::string func = func_; + at::jit::TemplateEnv env; + + env.s("index_type", "unsigned int"); + env.s("nInputs", std::to_string(nInputs)); + env.s("nOutputs", std::to_string(nOutputs)); + env.s("scalar_t", f_inputs_type); + // std::complex and hipComplex have the same memory layout so we can readily + // replace these with one another, and this makes writing kernels much easier. + if(f_inputs_type == "std::complex") { + env.s("scalar_t", "hipFloatComplex"); + } + else if(f_inputs_type == "std::complex") { + env.s("scalar_t", "hipDoubleComplex"); + } + env.s("compute_type", compute_type); + env.s("kernel", func); + env.s("name", name); + env.s("cmath_string", get_cmath_string()); + + // Generate `extra_params` for function signature + // and `extra_args` for computation call if + // extra arguments to capture runtime state are passed. + // (look at polygamma for example). + std::string extra_params = ""; + std::string extra_args = ""; + for (size_t i = 0; i < extra_args_typenames.size(); i++) { + auto type = std::string(extra_args_typenames[i]); + auto name = "extra_arg_" + std::string(to_string(i)); + extra_params += "," + type + " " + name; + extra_args += ", " + name; + } + env.s("extra_params", extra_params); + env.s("extra_args", extra_args); + + if (f_inputs_type == "at::Half" || result_type == "at::Half" || + f_inputs_type == "std::complex" || + result_type == "std::complex" || dynamic_casting) { + // complex depends on complex and Half dtypes. + env.s("half_string", jiterator_half_support_literal); + } else { + env.s("half_string", ""); + } + if (f_inputs_type == "at::BFloat16" || result_type == "at::BFloat16" || dynamic_casting) { + env.s("bfloat16_string", jiterator_bfloat16_support_literal); + } else { + env.s("bfloat16_string", ""); + } + // the definition of complex math functions is only needed when the compute type is complex + // but the definition of std::complex is needed for dynamic casting even if the compute type is not complex + if (f_inputs_type == "std::complex" || result_type == "std::complex" || + f_inputs_type == "std::complex" || result_type == "std::complex" || + f_inputs_type == "std::complex" || result_type == "std::complex") { + // complex depends on complex and Half dtypes. + env.s("traits_string", get_traits_string()); + env.s("complex_body_string", get_complex_body_string()); + env.s("complex_math_string", get_complex_math_string()); + + // unhipify math functions, but only if std::complex is used. + func = unhipify_math_functions(func); + env.s("functor", func); + + } else if (dynamic_casting) { + env.s("traits_string", get_traits_string()); + env.s("complex_body_string", get_complex_body_string()); + env.s("complex_math_string", ""); + } else { + env.s("traits_string", ""); + env.s("complex_body_string", ""); + env.s("complex_math_string", ""); + } + if (f_inputs_type == "std::complex" || + result_type == "std::complex" || dynamic_casting) { + // dynamic_casting requires the definition of all types + // include complex + // Look at the definition of `StoreWithCast` and `LoadWithCast`. + env.s("complex_half_body_string", get_complex_half_body_string()); + } else { + env.s("complex_half_body_string", ""); + } + + env.s("load_support", load_support_literal); + if (!dynamic_casting) { + env.s("loader", "LoadWithoutCast"); + env.s("storer", "StoreWithoutCast"); + env.s("dynamic_casting_string", no_dynamic_cast_support_literal); + } else { + env.s("loader", std::string("LoadWithCast<" + std::to_string(nInputs) + ">")); + env.s("storer", std::string("StoreWithCast<" + std::to_string(nOutputs) + ">")); + env.s("dynamic_casting_string", dynamic_cast_support_literal); + } + + static auto hip_template = at::jit::CodeTemplate( + jit_preamble + jit_common_types + offset_calc_template + zoom_jit_code_template + jit_epilogue); + const auto code = hip_template.format(env); + return code; + +} + +// Creates directories recursively +bool _r_mkdir(const std::string& dir) { + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of("/\\"); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(_r_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +// Creates directories recursively assuming that base exists +bool r_mkdir_with_base(std::string& base, std::string& dir){ + const char* p_base = base.c_str(); + const bool base_exists = (access(p_base, F_OK) == 0); + if (!base_exists) { + return false; + } + + // remove trailing '/' or '\\' + if ((base[base.size()-1]=='/') || base[base.size()-1]=='\\') { + base.pop_back(); + } + if ((dir[dir.size()-1]=='/') || dir[dir.size()-1]=='\\') { + dir.pop_back(); + } + + return _r_mkdir(base+dir); + +} + +std::string load_code_template(const std::string& path) { + std::ifstream ifs{path}; + std::string s{ + std::istreambuf_iterator(ifs), + std::istreambuf_iterator()}; + return s; +} + +std::string generate_reduction_code( + const KernelDescriptor &desc, + int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen) { + TORCH_INTERNAL_ASSERT(desc.nInputs == 1); + TORCH_INTERNAL_ASSERT(desc.extra_args_types.size() == 0); + + return generate_reduction_code( + desc.nOutputs, + desc.f, + desc.name, + vt0, + typeName(desc.f_inputs_type), + typeName(toOpMathType(desc.f_inputs_type)), + typeName(desc.result_type), + contiguous, + vectorized, + vec_size, + max_threads_codegen + ); +} + +std::string generate_reduction_code( + int nOutputs, + const std::string& func_, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen) { + std::string func = func_; + at::jit::TemplateEnv env; + env.s("index_type", "unsigned int"); + env.s("scalar_type", f_inputs_type); + env.s("result_type", result_type); + env.s("reduction_accum_type", reduction_accum_type); + env.s("vt0", std::to_string(vt0)); + env.s("name", name); + env.s("max_threads_lb", std::to_string(max_threads_codegen)); + // reductions don't support dynamic casting, so the only way to get nonstandard types + // is through input + if (f_inputs_type == "at::Half" || f_inputs_type == "std::complex") { + // complex depends on complex and Half dtypes. + env.s("half_string", jiterator_half_support_literal); + } else { + env.s("half_string", ""); + } + if (f_inputs_type == "at::BFloat16") { + env.s("bfloat16_string", jiterator_bfloat16_support_literal); + } else { + env.s("bfloat16_string", ""); + } + if (f_inputs_type == "std::complex" || + f_inputs_type == "std::complex" || + f_inputs_type == "std::complex" ) { + // complex depends on complex and Half dtypes. + env.s("traits_string", get_traits_string()); + env.s("complex_body_string", get_complex_body_string()); + env.s("complex_math_string", get_complex_math_string()); + env.s("complex", std::to_string(1)); + // unhipify math functions, but only if std::complex is used. + func = unhipify_math_functions(func); + } else { + env.s("traits_string", ""); + env.s("complex_body_string", ""); + env.s("complex_math_string", ""); + env.s("complex", std::to_string(0)); + } + if (f_inputs_type == "std::complex") { + env.s("complex_half_body_string", get_complex_half_body_string()); + } else { + env.s("complex_half_body_string", ""); + } + env.s("cmath_string", get_cmath_string()); + env.s("functor", func); + env.s("output_vec_size", std::to_string(vec_size)); + static auto hip_template = at::jit::CodeTemplate( + jit_preamble + jit_common_types + offset_calc_template + get_reduction_template() + jit_epilogue); + const auto code = hip_template.format(env); + return code; +} + +// Acquires (possibly creating) the kernel cache directory +std::optional get_cache_dir() { + // If the environment variable USE_TORCH_KERNEL_CACHE is set to "0" then no persistent cache is used + const char* uptkc = std::getenv("USE_PYTORCH_KERNEL_CACHE"); + const bool use_kernel_cache = (uptkc == nullptr) ? true : std::strcmp(uptkc, "0"); + + if (!use_kernel_cache) { + return {}; + } + + // Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then TEMP (Windows) or XDG_CACHE_HOME (Linux), then HOME environment variables + std::string cache_dir; + char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH"); + // Create kernel_cache_dir if needed as we do not want to create the base directory passed by the user + std::string kernels_cache_dir = ""; + if (ptkcp != nullptr) { + cache_dir = std::string(ptkcp); + } else { +#ifdef _WIN32 + ptkcp = std::getenv("TEMP"); +#else + // USES XDG_CACHE_HOME if it's set + ptkcp = std::getenv("XDG_CACHE_HOME"); +#endif + if (ptkcp != nullptr) { + kernels_cache_dir = "/torch/kernels"; + cache_dir = std::string(ptkcp) + kernels_cache_dir; + } else { + // Falls back to HOME/.cache + ptkcp = std::getenv("HOME"); + if (ptkcp == nullptr) { + TORCH_WARN_ONCE("No PYTORCH_KERNEL_CACHE_PATH or HOME environment variable set!", + " This disables kernel caching."); + return {}; + } else { + kernels_cache_dir = "/.cache/torch/kernels"; + cache_dir = std::string(ptkcp) + kernels_cache_dir; + } + } + } + + // Creates the cache directory if it does not exist + const char* p_cache_dir = cache_dir.c_str(); + const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0); + if (!cache_dir_exists) { + std::string s_ptkcp = std::string(ptkcp); + if (!r_mkdir_with_base(s_ptkcp, kernels_cache_dir)) { + TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.", + " Specified directory is ", cache_dir, ".", + " This warning will appear only once per process."); + return {}; + } + } + + // Checks that the cache directory is readable and writable + const bool cache_dir_readable = (access(p_cache_dir, R_OK) == 0); + if (!cache_dir_readable) { + TORCH_WARN_ONCE("Specified kernel cache directory is not readable! This disables kernel caching.", + " Specified directory is ", cache_dir, ".", + " This warning will appear only once per process."); + return {}; + } + + const bool cache_dir_writable = (access(p_cache_dir, W_OK) == 0); + if (!cache_dir_writable) { + TORCH_WARN_ONCE("Specified kernel cache directory is not writable! This disables kernel caching.", + " Specified directory is ", cache_dir, ".", + " This warning will appear only once per process."); + return {}; + } + + return cache_dir; +} + +// Compiles the kernel, or acquires if from the cache if caching +hiprtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name) { + initializeZoomContext(); + // Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS + const hipDeviceProp_t* prop = at::zoom::getCurrentDeviceProperties(); + int hip_major = 0, hip_minor = 0, hiprtc_major = 0, hiprtc_minor = 0; + bool compile_to_sass = false; + at::zoom::jit::codegenOutputQuery( + prop, hip_major, hip_minor, hiprtc_major, hiprtc_minor, compile_to_sass); + + // Objects used whether loading from the cache or jit compiling + const auto& hiprtc = at::globalContext().getHIPRTC(); + hiprtcFunction compiled_kernel_; + std::string name = kernel_name + "_kernel"; + + static const std::optional cache_dir = get_cache_dir(); + + std::string file_path; + if (cache_dir.has_value()) { + printf("Attempting to read from kernel cache...\n"); + // Attemps to read from the cache. + // Cubin name is _arch._nvrtc.___ + // Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents, + // because we hash on the CUDA code, but we save the compiled ptx or sass + + // Acquires SHA1 hash + c10::sha1 sha1_hash{code}; + const auto hash_code = sha1_hash.str(); + + // Constructs file path by appending constructed cubin name to cache path + std::stringstream ss; + ss << *cache_dir << "/"; + ss << kernel_name; + ss << "_arch" << prop->gcnArchName; + ss << "_hiprtc" << hiprtc_major << "." << hiprtc_minor; + ss << (compile_to_sass ? "_sass" : "_ptx"); + ss << "_" << code.length(); + ss << "_" << hash_code; + file_path = ss.str(); + + std::ifstream readin{file_path, std::ios::in | std::ifstream::binary}; + if (readin.fail()) { + // NOTE: this does not warn because the file might not exist + // TODO: consider if this should explicitly check for the file's existence or not to throw + // an informative warning + readin.close(); + } else { + printf("loading module from cache\n"); + // TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer + std::vector buffer(std::istreambuf_iterator(readin), {}); + HIP_DRIVER_CHECK(hiprtc.hipModuleLoadData(&(compiled_kernel_.module), buffer.data())); + printf("funcload\n"); + HIP_DRIVER_CHECK( + hiprtc.hipModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str())); + readin.close(); + printf("finmodload\n"); + return compiled_kernel_; + } + } + + // Just-in-time compiles the program + + // Creates the NVRTC program + hiprtcProgram program; + ZOOM_HIPRTC_CHECK(hiprtc.hiprtcCreateProgram( + &program, code.c_str(), nullptr, 0, nullptr, nullptr)); + + std::vector args = {"--std=c++17", "-ggdb", "-O0"}; + + #undef NDEBUG + #ifndef NDEBUG + // Add line info to generated kernels + args.push_back("-lineinfo"); + #else + // Avoid excessive register usage from assertion + args.push_back("-DNDEBUG"); + #endif + + const auto compilation_result = + hiprtc.hiprtcCompileProgram(program, args.size(), args.data()); + + // Throws an error on compilation failure + if (compilation_result != HIPRTC_SUCCESS) { + size_t logsize; + ZOOM_HIPRTC_CHECK(hiprtc.hiprtcGetProgramLogSize(program, &logsize)); + std::string log(logsize, '\0'); + ZOOM_HIPRTC_CHECK(hiprtc.hiprtcGetProgramLog(program, &log[0])); + throw std::runtime_error(code + log); + } + + size_t ptx_size = 0; + std::vector ptx; + + const auto getSize = hiprtc.hiprtcGetCodeSize; + const auto getFunc = hiprtc.hiprtcGetCode; + + + ZOOM_HIPRTC_CHECK(getSize(program, &ptx_size)); + ptx.resize(ptx_size); + ZOOM_HIPRTC_CHECK(getFunc(program, ptx.data())); + + printf("modload2\n"); + HIP_DRIVER_CHECK(hiprtc.hipModuleLoadData(&(compiled_kernel_.module), ptx.data())); + printf("funcload2\n"); + HIP_DRIVER_CHECK( + hiprtc.hipModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str())); + // TODO: use guards to avoid leaking + printf("flend\n"); + ZOOM_HIPRTC_CHECK(hiprtc.hiprtcDestroyProgram(&program)); + + if (cache_dir.has_value()) { + // Writes the program to the cache if caching + // NOTE: Actually writes to a per-process temporary file to avoid multi-process contention. + // The temporary file is then renamed to the actual file. + // If the actual file already exists then the rename may fail or replace the actual file, + // the behavior is implementation-specific. + // Files replaced through this process should remain extant if they are being read because + // of UNIX filesystem properties, but this behavior is unverified and may require + // additional review in the future. + // TODO: In C++17 we should be able to use the filesystem header. + const auto pid = getpid(); + std::stringstream tmp_file_path_ss; + tmp_file_path_ss << file_path << "_tmp_" << pid; + const std::string tmp_file_path = tmp_file_path_ss.str(); + std::ofstream hipbin(tmp_file_path, std::ios::out | std::ofstream::binary); + if (hipbin.fail()) { + TORCH_WARN_ONCE("Failed to write temporarily kernel cache file!", + " File path was ", tmp_file_path, ".", + " This warning will only appear once per process."); + } else { + std::copy(ptx.begin(), ptx.end(), std::ostreambuf_iterator(hipbin)); + if (std::rename(tmp_file_path.c_str(), file_path.c_str()) != 0) { + // Removes tmp file if the rename failed + std::remove(tmp_file_path.c_str()); + } + } + hipbin.close(); + } + + return compiled_kernel_; +} + +// TODO: may need/want to initialize CUDA context here (refactor into nvrtc call) +void launch_jitted_pwise_function( + hiprtcFunction function, + void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem) { + initializeZoomContext(); + const auto& hiprtc = at::globalContext().getHIPRTC(); + // Launches kernel on current stream + auto stream = c10::zoom::getCurrentZoomStream(); + stream.synchronize(); + HIP_DRIVER_CHECK(hiprtc.hipModuleLaunchKernel( + function.function, + nBlocks.x, + nBlocks.y, + nBlocks.z, + kBlockSize.x, + kBlockSize.y, + kBlockSize.z, + smem, + stream, + args, + nullptr)); +} + +} // at::zoom::jit \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/jit_utils.h b/aten/src/ATen/zoom/jit/jit_utils.h new file mode 100644 index 00000000000000..1115906144c724 --- /dev/null +++ b/aten/src/ATen/zoom/jit/jit_utils.h @@ -0,0 +1,230 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace at { namespace zoom { namespace jit { + +enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; + +struct hiprtcFunction { + hipModule_t module = hipModule_t(); + hipFunction_t function = nullptr; +}; + +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + int result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func, + const std::string& name, + const std::string& f_input_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string zoom_generate_code( + const KernelDescriptor &desc, + bool dynamic_casting = false); + +std::string zoom_generate_code( + int nInputs, + int nOutputs, + const std::string& func_, + const std::string& name, + const std::string& f_inputs_type, + const std::string& compute_type, + const std::string& result_type, + bool dynamic_casting, + c10::SmallVector& extra_args_typenames); + +std::string generate_reduction_code( + int nOutputs, + const std::string& func, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +hiprtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name); + +void launch_jitted_pwise_function( + hiprtcFunction function, + void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem=0); + +template +struct delayed_false : std::false_type { +}; + +// Defines type names +// NOTE: General case is instantiated only for invalid types. +// All the valid types have specialization using the TYPE_NAME_FN +// macro below. +template +inline std::string typeName() { + // we can't use static_assert(false) directly as the + // program will be not compiled even if the template is not + // instantiated, so we use `delayed_false` + // to make sure compiler doesn't eagerly raise + // fail this assertion. + static_assert(delayed_false::value, "invalid type for jiterator"); + return "void"; +} + +#define TYPE_NAME_FN(ctype, name) \ +template <> inline std::string typeName(){ \ + return std::string(#ctype); \ +} + +AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) +#undef TYPE_NAME_FN +// JIT uses std::complex directly, because nvRTC compile programs +// with -default-device, so there is no such issue like: +// "std::sin(complex) is __host__ only" +template <> inline std::string typeName(){ + return "bool"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName(){ + return "at::Half"; +} +template <> inline std::string typeName(){ + return "at::BFloat16"; +} +template <> inline std::string typeName(){ + return "at::Float8_e5m2"; +} +template <> inline std::string typeName(){ + return "at::Float8_e4m3fn"; +} +template <> inline std::string typeName() { + return "at::Float8_e5m2fnuz"; +} +template <> inline std::string typeName() { + return "at::Float8_e4m3fnuz"; +} + +#define TYPE_NAME_CASE(ctype, scalartype) \ + case ScalarType::scalartype: return typeName(); +inline std::string typeName(ScalarType t) { + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); + } +} +#undef TYPE_NAME_CASE + +TORCH_ZOOM_API void initializeZoomContext(); + +}}} // namespace at::zoom::jit \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/llvm_jit_strings.cpp b/aten/src/ATen/zoom/jit/llvm_jit_strings.cpp new file mode 100644 index 00000000000000..4e8a4ddacce065 --- /dev/null +++ b/aten/src/ATen/zoom/jit/llvm_jit_strings.cpp @@ -0,0 +1,1444 @@ +// This is copy-pasted (with modification) from the following llvm file: +// - https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex +// +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + + +namespace at::zoom { + +// copy-pasted from some llvm files: +// - https://github.com/llvm/llvm-project/blob/main/libcxx/include/type_traits +// - https://github.com/llvm/llvm-project/blob/main/clang/test/Headers/Inputs/include/type_traits + +// hiprtc already includes some traits, so this removes duplicate definitions of +// integral_constant, is_same, is_integral, enable_if, is_floating_point, is_arithmetic. +// Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above. +const std::string traits = R"ESCAPE( +namespace std { + +template +_Tp&& __declval(int); +template +_Tp __declval(long); +template +decltype(__declval<_Tp>(0)) declval() noexcept; + +template struct remove_const {typedef _Tp type;}; +template struct remove_const {typedef _Tp type;}; +template using remove_const_t = typename remove_const<_Tp>::type; + +template struct remove_volatile {typedef _Tp type;}; +template struct remove_volatile {typedef _Tp type;}; +template using remove_volatile_t = typename remove_volatile<_Tp>::type; + +template struct remove_cv +{typedef typename remove_volatile::type>::type type;}; +template using remove_cv_t = typename remove_cv<_Tp>::type; + +template struct __libcpp_is_floating_point : public false_type {}; +template <> struct __libcpp_is_floating_point : public true_type {}; +template <> struct __libcpp_is_floating_point : public true_type {}; +template <> struct __libcpp_is_floating_point : public true_type {}; + +template +inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value; + +template +struct __numeric_type +{ + static void __test(...); + static float __test(float); + static double __test(char); + static double __test(int); + static double __test(unsigned); + static double __test(long); + static double __test(unsigned long); + static double __test(long long); + static double __test(unsigned long long); + static double __test(double); + static long double __test(long double); + + typedef decltype(__test(declval<_Tp>())) type; + static const bool value = !is_same::value; +}; + +template <> +struct __numeric_type +{ + static const bool value = true; +}; + +// __promote + +template ::value && + __numeric_type<_A2>::value && + __numeric_type<_A3>::value> +class __promote_imp +{ +public: + static const bool value = false; +}; + +template +class __promote_imp<_A1, _A2, _A3, true> +{ +private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; + typedef typename __promote_imp<_A3>::type __type3; +public: + typedef decltype(__type1() + __type2() + __type3()) type; + static const bool value = true; +}; + +template +class __promote_imp<_A1, _A2, void, true> +{ +private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; +public: + typedef decltype(__type1() + __type2()) type; + static const bool value = true; +}; + +template +class __promote_imp<_A1, void, void, true> +{ +public: + typedef typename __numeric_type<_A1>::type type; + static const bool value = true; +}; + +template +class __promote : public __promote_imp<_A1, _A2, _A3> {}; + +} // namespace std +)ESCAPE"; + +const std::string &get_traits_string() { + return traits; +} + +// This is copy-pasted from the following llvm file: +// - https://github.com/llvm/llvm-project/blob/main/libcxx/include/cmath +const std::string cmath = R"ESCAPE( + +namespace std { + +using ::signbit; +using ::isfinite; +using ::isinf; +using ::isnan; + +using ::abs; + +using ::acos; +using ::acosf; +using ::asin; +using ::asinf; +using ::atan; +using ::atanf; +using ::atan2; +using ::atan2f; +using ::ceil; +using ::ceilf; +using ::cos; +using ::cosf; +using ::cosh; +using ::coshf; + +using ::exp; +using ::expf; + +using ::fabs; +using ::fabsf; +using ::floor; +using ::floorf; + +using ::fmod; +using ::fmodf; + +using ::frexp; +using ::frexpf; +using ::ldexp; +using ::ldexpf; + +using ::log; +using ::logf; + +using ::log10; +using ::log10f; +using ::modf; +using ::modff; + +using ::pow; +using ::powf; + +using ::sin; +using ::sinf; +using ::sinh; +using ::sinhf; + +using ::sqrt; +using ::sqrtf; +using ::tan; +using ::tanf; + +using ::tanh; +using ::tanhf; + +using ::acosh; +using ::acoshf; +using ::asinh; +using ::asinhf; +using ::atanh; +using ::atanhf; +using ::cbrt; +using ::cbrtf; + +using ::copysign; +using ::copysignf; + +using ::erf; +using ::erff; +using ::erfc; +using ::erfcf; +using ::exp2; +using ::exp2f; +using ::expm1; +using ::expm1f; +using ::fdim; +using ::fdimf; +using ::fmaf; +using ::fma; +using ::fmax; +using ::fmaxf; +using ::fmin; +using ::fminf; +using ::hypot; +using ::hypotf; +using ::ilogb; +using ::ilogbf; +using ::lgamma; +using ::lgammaf; +using ::llrint; +using ::llrintf; +using ::llround; +using ::llroundf; +using ::log1p; +using ::log1pf; +using ::log2; +using ::log2f; +using ::logb; +using ::logbf; +using ::lrint; +using ::lrintf; +using ::lround; +using ::lroundf; + +using ::nan; +using ::nanf; + +using ::nearbyint; +using ::nearbyintf; +using ::nextafter; +using ::nextafterf; +using ::remainder; +using ::remainderf; +using ::remquo; +using ::remquof; +using ::rint; +using ::rintf; +using ::round; +using ::roundf; +using ::scalbln; +using ::scalblnf; +using ::scalbn; +using ::scalbnf; +using ::tgamma; +using ::tgammaf; +using ::trunc; +using ::truncf; + +} // namespace std + +)ESCAPE"; + +const std::string &get_cmath_string() { + return cmath; +} + + +const std::string complex_body = R"ESCAPE( + +namespace std { + +template class complex; + +template complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w); +template complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y); + +template +class complex +{ +public: + typedef _Tp value_type; +private: + value_type __re_; + value_type __im_; +public: + constexpr + complex(const value_type& __re = value_type(), const value_type& __im = value_type()) + : __re_(__re), __im_(__im) {} + template constexpr + complex(const complex<_Xp>& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + + constexpr value_type real() const {return __re_;} + constexpr value_type imag() const {return __im_;} + + void real(value_type __re) {__re_ = __re;} + void imag(value_type __im) {__im_ = __im;} + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator= (const value_type& __re) + {__re_ = __re; __im_ = value_type(); return *this;} + complex& operator+=(const value_type& __re) {__re_ += __re; return *this;} + complex& operator-=(const value_type& __re) {__re_ -= __re; return *this;} + complex& operator*=(const value_type& __re) {__re_ *= __re; __im_ *= __re; return *this;} + complex& operator/=(const value_type& __re) {__re_ /= __re; __im_ /= __re; return *this;} + + template complex& operator= (const complex<_Xp>& __c) + { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + template complex& operator+=(const complex<_Xp>& __c) + { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template complex& operator-=(const complex<_Xp>& __c) + { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template complex& operator*=(const complex<_Xp>& __c) + { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template complex& operator/=(const complex<_Xp>& __c) + { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +template<> class complex; + +template<> +class complex +{ + float __re_; + float __im_; +public: + typedef float value_type; + + constexpr complex(float __re = 0.0f, float __im = 0.0f) + : __re_(__re), __im_(__im) {} + + explicit constexpr complex(const complex& __c); + + constexpr float real() const {return __re_;} + constexpr float imag() const {return __im_;} + + void real(value_type __re) {__re_ = __re;} + void imag(value_type __im) {__im_ = __im;} + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator= (float __re) + {__re_ = __re; __im_ = value_type(); return *this;} + complex& operator+=(float __re) {__re_ += __re; return *this;} + complex& operator-=(float __re) {__re_ -= __re; return *this;} + complex& operator*=(float __re) {__re_ *= __re; __im_ *= __re; return *this;} + complex& operator/=(float __re) {__re_ /= __re; __im_ /= __re; return *this;} + + template complex& operator= (const complex<_Xp>& __c) + { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + template complex& operator+=(const complex<_Xp>& __c) + { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template complex& operator-=(const complex<_Xp>& __c) + { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template complex& operator*=(const complex<_Xp>& __c) + { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template complex& operator/=(const complex<_Xp>& __c) + { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +template<> +class complex +{ + double __re_; + double __im_; +public: + typedef double value_type; + + constexpr complex(double __re = 0.0, double __im = 0.0) + : __re_(__re), __im_(__im) {} + + constexpr complex(const complex& __c); + + constexpr double real() const {return __re_;} + constexpr double imag() const {return __im_;} + + void real(value_type __re) {__re_ = __re;} + void imag(value_type __im) {__im_ = __im;} + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator= (double __re) + {__re_ = __re; __im_ = value_type(); return *this;} + complex& operator+=(double __re) {__re_ += __re; return *this;} + complex& operator-=(double __re) {__re_ -= __re; return *this;} + complex& operator*=(double __re) {__re_ *= __re; __im_ *= __re; return *this;} + complex& operator/=(double __re) {__re_ /= __re; __im_ /= __re; return *this;} + + template complex& operator= (const complex<_Xp>& __c) + { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + template complex& operator+=(const complex<_Xp>& __c) + { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template complex& operator-=(const complex<_Xp>& __c) + { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template complex& operator*=(const complex<_Xp>& __c) + { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template complex& operator/=(const complex<_Xp>& __c) + { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +inline +constexpr +complex::complex(const complex& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + +inline +constexpr +complex::complex(const complex& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + + +// 26.3.6 operators: + +template +inline +complex<_Tp> +operator+(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(__x); + __t += __y; + return __t; +} + +template +inline +complex<_Tp> +operator+(const complex<_Tp>& __x, const _Tp& __y) +{ + complex<_Tp> __t(__x); + __t += __y; + return __t; +} + +template +inline +complex<_Tp> +operator+(const _Tp& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(__y); + __t += __x; + return __t; +} + +template +inline +complex<_Tp> +operator-(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(__x); + __t -= __y; + return __t; +} + +template +inline +complex<_Tp> +operator-(const complex<_Tp>& __x, const _Tp& __y) +{ + complex<_Tp> __t(__x); + __t -= __y; + return __t; +} + +template +inline +complex<_Tp> +operator-(const _Tp& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(-__y); + __t += __x; + return __t; +} + +template +complex<_Tp> +operator*(const complex<_Tp>& __z, const complex<_Tp>& __w) +{ + _Tp __a = __z.real(); + _Tp __b = __z.imag(); + _Tp __c = __w.real(); + _Tp __d = __w.imag(); + _Tp __ac = __a * __c; + _Tp __bd = __b * __d; + _Tp __ad = __a * __d; + _Tp __bc = __b * __c; + _Tp __x = __ac - __bd; + _Tp __y = __ad + __bc; + if (isnan(__x) && isnan(__y)) + { + bool __recalc = false; + if (isinf(__a) || isinf(__b)) + { + __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); + __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); + if (isnan(__c)) + __c = copysign(_Tp(0), __c); + if (isnan(__d)) + __d = copysign(_Tp(0), __d); + __recalc = true; + } + if (isinf(__c) || isinf(__d)) + { + __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); + __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); + if (isnan(__a)) + __a = copysign(_Tp(0), __a); + if (isnan(__b)) + __b = copysign(_Tp(0), __b); + __recalc = true; + } + if (!__recalc && (isinf(__ac) || isinf(__bd) || + isinf(__ad) || isinf(__bc))) + { + if (isnan(__a)) + __a = copysign(_Tp(0), __a); + if (isnan(__b)) + __b = copysign(_Tp(0), __b); + if (isnan(__c)) + __c = copysign(_Tp(0), __c); + if (isnan(__d)) + __d = copysign(_Tp(0), __d); + __recalc = true; + } + if (__recalc) + { + __x = _Tp(INFINITY) * (__a * __c - __b * __d); + __y = _Tp(INFINITY) * (__a * __d + __b * __c); + } + } + return complex<_Tp>(__x, __y); +} + +template +inline +complex<_Tp> +operator*(const complex<_Tp>& __x, const _Tp& __y) +{ + complex<_Tp> __t(__x); + __t *= __y; + return __t; +} + +template +inline +complex<_Tp> +operator*(const _Tp& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(__y); + __t *= __x; + return __t; +} + +template +complex<_Tp> +operator/(const complex<_Tp>& __z, const complex<_Tp>& __w) +{ + int __ilogbw = 0; + _Tp __a = __z.real(); + _Tp __b = __z.imag(); + _Tp __c = __w.real(); + _Tp __d = __w.imag(); + _Tp __logbw = logb(fmax(fabs(__c), fabs(__d))); + if (isfinite(__logbw)) + { + __ilogbw = static_cast(__logbw); + __c = scalbn(__c, -__ilogbw); + __d = scalbn(__d, -__ilogbw); + } + _Tp __denom = __c * __c + __d * __d; + _Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw); + _Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw); + if (isnan(__x) && isnan(__y)) + { + if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b))) + { + __x = copysign(_Tp(INFINITY), __c) * __a; + __y = copysign(_Tp(INFINITY), __c) * __b; + } + else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d)) + { + __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); + __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); + __x = _Tp(INFINITY) * (__a * __c + __b * __d); + __y = _Tp(INFINITY) * (__b * __c - __a * __d); + } + else if (isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) && isfinite(__b)) + { + __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); + __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); + __x = _Tp(0) * (__a * __c + __b * __d); + __y = _Tp(0) * (__b * __c - __a * __d); + } + } + return complex<_Tp>(__x, __y); +} + +template +inline +complex<_Tp> +operator/(const complex<_Tp>& __x, const _Tp& __y) +{ + return complex<_Tp>(__x.real() / __y, __x.imag() / __y); +} + +template +inline +complex<_Tp> +operator/(const _Tp& __x, const complex<_Tp>& __y) +{ + complex<_Tp> __t(__x); + __t /= __y; + return __t; +} + +template +inline +complex<_Tp> +operator+(const complex<_Tp>& __x) +{ + return __x; +} + +template +inline +complex<_Tp> +operator-(const complex<_Tp>& __x) +{ + return complex<_Tp>(-__x.real(), -__x.imag()); +} + +template +inline constexpr +bool +operator==(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + return __x.real() == __y.real() && __x.imag() == __y.imag(); +} + +template +inline constexpr +bool +operator==(const complex<_Tp>& __x, const _Tp& __y) +{ + return __x.real() == __y && __x.imag() == 0; +} + +template +inline constexpr +bool +operator==(const _Tp& __x, const complex<_Tp>& __y) +{ + return __x == __y.real() && 0 == __y.imag(); +} + +template +inline constexpr +bool +operator!=(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + return !(__x == __y); +} + +template +inline constexpr +bool +operator!=(const complex<_Tp>& __x, const _Tp& __y) +{ + return !(__x == __y); +} + +template +inline constexpr +bool +operator!=(const _Tp& __x, const complex<_Tp>& __y) +{ + return !(__x == __y); +} + +template +inline constexpr +bool +operator&&(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + return bool(__x) && bool(__y); +} + +template +inline constexpr +bool +isnan(const complex<_Tp>& __x) +{ + return isnan(__x.real()) || isnan(__x.imag()); +} + +template +inline constexpr +bool +operator||(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + return bool(__x) || bool(__y); +} + +// 26.3.7 values: + +template ::value, + bool = is_floating_point<_Tp>::value + > +struct __libcpp_complex_overload_traits {}; + +// Integral Types +template +struct __libcpp_complex_overload_traits<_Tp, true, false> +{ + typedef double _ValueType; + typedef complex _ComplexType; +}; + +// Floating point types +template +struct __libcpp_complex_overload_traits<_Tp, false, true> +{ + typedef _Tp _ValueType; + typedef complex<_Tp> _ComplexType; +}; + +// real + +template +inline constexpr +_Tp +real(const complex<_Tp>& __c) +{ + return __c.real(); +} + +template +inline constexpr +typename __libcpp_complex_overload_traits<_Tp>::_ValueType +real(_Tp __re) +{ + return __re; +} + +// imag + +template +inline constexpr +_Tp +imag(const complex<_Tp>& __c) +{ + return __c.imag(); +} + +template +inline constexpr +typename __libcpp_complex_overload_traits<_Tp>::_ValueType +imag(_Tp) +{ + return 0; +} + +// abs + +template +inline +_Tp +abs(const complex<_Tp>& __c) +{ + return hypot(__c.real(), __c.imag()); +} + +// arg + +template +inline +_Tp +arg(const complex<_Tp>& __c) +{ + return atan2(__c.imag(), __c.real()); +} + +template +inline +typename enable_if +< + is_integral<_Tp>::value || is_same<_Tp, double>::value, + double +>::type +arg(_Tp __re) +{ + return atan2(0., __re); +} + +template +inline +typename enable_if< + is_same<_Tp, float>::value, + float +>::type +arg(_Tp __re) +{ + return atan2f(0.F, __re); +} + +} + +)ESCAPE"; + +const std::string complex_half_body = R"ESCAPE( +namespace std { +template <> +struct alignas(2) complex { + at::Half real_; + at::Half imag_; + + // Constructors + complex() = default; + + // implicit casting to and from `complex`. + // NOTE: computation of `complex` will occur in `complex` + __host__ __device__ inline complex(const std::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + inline __host__ __device__ operator std::complex() const { + return {real_, imag_}; + } + + at::Half real() const {return real_;} + at::Half imag() const {return imag_;} + +}; +} +)ESCAPE"; + + +const std::string &get_complex_body_string() { + return complex_body; +} + +const std::string &get_complex_half_body_string() { + return complex_half_body; +} + +const std::string complex_math = R"ESCAPE( + +namespace std { + +// norm + +template +inline +_Tp +norm(const complex<_Tp>& __c) +{ + if (isinf(__c.real())) + return abs(__c.real()); + if (isinf(__c.imag())) + return abs(__c.imag()); + return __c.real() * __c.real() + __c.imag() * __c.imag(); +} + +template +inline +typename __libcpp_complex_overload_traits<_Tp>::_ValueType +norm(_Tp __re) +{ + typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; + return static_cast<_ValueType>(__re) * __re; +} + +// conj + +template +inline +complex<_Tp> +conj(const complex<_Tp>& __c) +{ + return complex<_Tp>(__c.real(), -__c.imag()); +} + +template +inline +typename __libcpp_complex_overload_traits<_Tp>::_ComplexType +conj(_Tp __re) +{ + typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + return _ComplexType(__re); +} + + + +// proj + +template +inline +complex<_Tp> +proj(const complex<_Tp>& __c) +{ + complex<_Tp> __r = __c; + if (isinf(__c.real()) || isinf(__c.imag())) + __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag())); + return __r; +} + +template +inline +typename enable_if +< + is_floating_point<_Tp>::value, + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType +>::type +proj(_Tp __re) +{ + if (isinf(__re)) + __re = abs(__re); + return complex<_Tp>(__re); +} + +template +inline +typename enable_if +< + is_integral<_Tp>::value, + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType +>::type +proj(_Tp __re) +{ + typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + return _ComplexType(__re); +} + +// polar + +template +complex<_Tp> +polar(const _Tp& __rho, const _Tp& __theta = _Tp()) +{ + if (isnan(__rho) || signbit(__rho)) + return complex<_Tp>(_Tp(NAN), _Tp(NAN)); + if (isnan(__theta)) + { + if (isinf(__rho)) + return complex<_Tp>(__rho, __theta); + return complex<_Tp>(__theta, __theta); + } + if (isinf(__theta)) + { + if (isinf(__rho)) + return complex<_Tp>(__rho, _Tp(NAN)); + return complex<_Tp>(_Tp(NAN), _Tp(NAN)); + } + _Tp __x = __rho * cos(__theta); + if (isnan(__x)) + __x = 0; + _Tp __y = __rho * sin(__theta); + if (isnan(__y)) + __y = 0; + return complex<_Tp>(__x, __y); +} + +// log + +template +inline +complex<_Tp> +log(const complex<_Tp>& __x) +{ + return complex<_Tp>(log(abs(__x)), arg(__x)); +} + +// log10 + +template +inline +complex<_Tp> +log10(const complex<_Tp>& __x) +{ + return log(__x) / log(_Tp(10)); +} + +// log2 + +template +inline +complex<_Tp> +log2(const complex<_Tp>& __x) +{ + return log(__x) / log(_Tp(2)); +} + +// sqrt + +template +complex<_Tp> +sqrt(const complex<_Tp>& __x) +{ + if (isinf(__x.imag())) + return complex<_Tp>(_Tp(INFINITY), __x.imag()); + if (isinf(__x.real())) + { + if (__x.real() > _Tp(0)) + return complex<_Tp>(__x.real(), isnan(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag())); + return complex<_Tp>(isnan(__x.imag()) ? __x.imag() : _Tp(0), copysign(__x.real(), __x.imag())); + } + return polar(sqrt(abs(__x)), arg(__x) / _Tp(2)); +} + +// exp + +template +complex<_Tp> +exp(const complex<_Tp>& __x) +{ + _Tp __i = __x.imag(); + if (__i == 0) { + return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag())); + } + if (isinf(__x.real())) + { + if (__x.real() < _Tp(0)) + { + if (!isfinite(__i)) + __i = _Tp(1); + } + else if (__i == 0 || !isfinite(__i)) + { + if (isinf(__i)) + __i = _Tp(NAN); + return complex<_Tp>(__x.real(), __i); + } + } + _Tp __e = exp(__x.real()); + return complex<_Tp>(__e * cos(__i), __e * sin(__i)); +} + +// pow + +template +inline +complex<_Tp> +pow(const complex<_Tp>& __x, const complex<_Tp>& __y) +{ + return exp(__y * log(__x)); +} + +template +inline +complex::type> +pow(const complex<_Tp>& __x, const complex<_Up>& __y) +{ + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +template +inline +typename enable_if +< + is_arithmetic<_Up>::value, + complex::type> +>::type +pow(const complex<_Tp>& __x, const _Up& __y) +{ + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +template +inline +typename enable_if +< + is_arithmetic<_Tp>::value, + complex::type> +>::type +pow(const _Tp& __x, const complex<_Up>& __y) +{ + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +// __sqr, computes pow(x, 2) + +template +inline +complex<_Tp> +__sqr(const complex<_Tp>& __x) +{ + return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()), + _Tp(2) * __x.real() * __x.imag()); +} + +// asinh + +template +complex<_Tp> +asinh(const complex<_Tp>& __x) +{ + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) + { + if (isnan(__x.imag())) + return __x; + if (isinf(__x.imag())) + return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); + return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); + } + if (isnan(__x.real())) + { + if (isinf(__x.imag())) + return complex<_Tp>(__x.imag(), __x.real()); + if (__x.imag() == 0) + return __x; + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>(copysign(__x.imag(), __x.real()), copysign(__pi/_Tp(2), __x.imag())); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1))); + return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); +} + +// acosh + +template +complex<_Tp> +acosh(const complex<_Tp>& __x) +{ + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) + { + if (isnan(__x.imag())) + return complex<_Tp>(abs(__x.real()), __x.imag()); + if (isinf(__x.imag())) + { + if (__x.real() > 0) + return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); + else + return complex<_Tp>(-__x.real(), copysign(__pi * _Tp(0.75), __x.imag())); + } + if (__x.real() < 0) + return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag())); + return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); + } + if (isnan(__x.real())) + { + if (isinf(__x.imag())) + return complex<_Tp>(abs(__x.imag()), __x.real()); + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>(abs(__x.imag()), copysign(__pi/_Tp(2), __x.imag())); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + return complex<_Tp>(copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag())); +} + +// atanh + +template +complex<_Tp> +atanh(const complex<_Tp>& __x) +{ + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.imag())) + { + return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag())); + } + if (isnan(__x.imag())) + { + if (isinf(__x.real()) || __x.real() == 0) + return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag()); + return complex<_Tp>(__x.imag(), __x.imag()); + } + if (isnan(__x.real())) + { + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.real())) + { + return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag())); + } + if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) + { + return complex<_Tp>(copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag())); + } + complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2); + return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); +} + +// sinh + +template +complex<_Tp> +sinh(const complex<_Tp>& __x) +{ + if (isinf(__x.real()) && !isfinite(__x.imag())) + return complex<_Tp>(__x.real(), _Tp(NAN)); + if (__x.real() == 0 && !isfinite(__x.imag())) + return complex<_Tp>(__x.real(), _Tp(NAN)); + if (__x.imag() == 0 && !isfinite(__x.real())) + return __x; + return complex<_Tp>(sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag())); +} + +// cosh + +template +complex<_Tp> +cosh(const complex<_Tp>& __x) +{ + if (isinf(__x.real()) && !isfinite(__x.imag())) + return complex<_Tp>(abs(__x.real()), _Tp(NAN)); + if (__x.real() == 0 && !isfinite(__x.imag())) + return complex<_Tp>(_Tp(NAN), __x.real()); + if (__x.real() == 0 && __x.imag() == 0) + return complex<_Tp>(_Tp(1), __x.imag()); + if (__x.imag() == 0 && !isfinite(__x.real())) + return complex<_Tp>(abs(__x.real()), __x.imag()); + return complex<_Tp>(cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag())); +} + +// tanh + +template +complex<_Tp> +tanh(const complex<_Tp>& __x) +{ + if (isinf(__x.real())) + { + if (!isfinite(__x.imag())) + return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0)); + return complex<_Tp>(copysign(_Tp(1), __x.real()), copysign(_Tp(0), sin(_Tp(2) * __x.imag()))); + } + if (isnan(__x.real()) && __x.imag() == 0) + return __x; + _Tp __2r(_Tp(2) * __x.real()); + _Tp __2i(_Tp(2) * __x.imag()); + _Tp __d(cosh(__2r) + cos(__2i)); + _Tp __2rsh(sinh(__2r)); + if (isinf(__2rsh) && isinf(__d)) + return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1), + __2i > _Tp(0) ? _Tp(0) : _Tp(-0.)); + return complex<_Tp>(__2rsh/__d, sin(__2i)/__d); +} + +// asin + +template +complex<_Tp> +asin(const complex<_Tp>& __x) +{ + complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// acos + +template +complex<_Tp> +acos(const complex<_Tp>& __x) +{ + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) + { + if (isnan(__x.imag())) + return complex<_Tp>(__x.imag(), __x.real()); + if (isinf(__x.imag())) + { + if (__x.real() < _Tp(0)) + return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag()); + return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag()); + } + if (__x.real() < _Tp(0)) + return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real()); + return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real()); + } + if (isnan(__x.real())) + { + if (isinf(__x.imag())) + return complex<_Tp>(__x.real(), -__x.imag()); + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>(__pi/_Tp(2), -__x.imag()); + if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) + return complex<_Tp>(__pi/_Tp(2), -__x.imag()); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + if (signbit(__x.imag())) + return complex<_Tp>(abs(__z.imag()), abs(__z.real())); + return complex<_Tp>(abs(__z.imag()), -abs(__z.real())); +} + +// atan + +template +complex<_Tp> +atan(const complex<_Tp>& __x) +{ + complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// sin + +template +complex<_Tp> +sin(const complex<_Tp>& __x) +{ + complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// cos + +template +inline +complex<_Tp> +cos(const complex<_Tp>& __x) +{ + return cosh(complex<_Tp>(-__x.imag(), __x.real())); +} + +// tan + +template +complex<_Tp> +tan(const complex<_Tp>& __x) +{ + complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// Literal suffix for complex number literals [complex.literals] +inline namespace literals +{ + inline namespace complex_literals + { + constexpr complex operator""i(long double __im) + { + return { 0.0, static_cast(__im) }; + } + + constexpr complex operator""i(unsigned long long __im) + { + return { 0.0, static_cast(__im) }; + } + + + constexpr complex operator""if(long double __im) + { + return { 0.0f, static_cast(__im) }; + } + + constexpr complex operator""if(unsigned long long __im) + { + return { 0.0f, static_cast(__im) }; + } + } // namespace complex_literals +} // namespace literals + +} // namespace std + +)ESCAPE"; + +const std::string &get_complex_math_string() { + return complex_math; +} + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/llvm_jit_strings.h b/aten/src/ATen/zoom/jit/llvm_jit_strings.h new file mode 100644 index 00000000000000..3d71ff866c47fa --- /dev/null +++ b/aten/src/ATen/zoom/jit/llvm_jit_strings.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::zoom { + +TORCH_ZOOM_API const std::string &get_traits_string(); +TORCH_ZOOM_API const std::string &get_cmath_string(); +TORCH_ZOOM_API const std::string &get_complex_body_string(); +TORCH_ZOOM_API const std::string &get_complex_half_body_string(); +TORCH_ZOOM_API const std::string &get_complex_math_string(); + +} // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/macros.h b/aten/src/ATen/zoom/jit/macros.h new file mode 100644 index 00000000000000..2402b0a0d52ba5 --- /dev/null +++ b/aten/src/ATen/zoom/jit/macros.h @@ -0,0 +1,4 @@ +#include + +#define AT_USE_JITERATOR() true +#define jiterator_stringify(...) std::string(#__VA_ARGS__); \ No newline at end of file diff --git a/aten/src/ATen/zoom/jit/thread_constants.h b/aten/src/ATen/zoom/jit/thread_constants.h new file mode 100644 index 00000000000000..0df30f8d5a45e8 --- /dev/null +++ b/aten/src/ATen/zoom/jit/thread_constants.h @@ -0,0 +1,16 @@ +#pragma once +#include + +// Marks a lambda as executable on both the host and device. The __host__ +// attribute is important so that we can access static type information from +// the host, even if the function is typically only executed on the device. +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __host__ __device__ +#endif + +constexpr int num_threads() { + return 256; +} + +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } \ No newline at end of file diff --git a/build_variables.bzl b/build_variables.bzl index 3f16f9b847c1cc..6bd1898db6310b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -773,6 +773,19 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ "torch/csrc/cuda/Tensor.cpp", ] +libtorch_python_zoom_sources = [ + "torch/csrc/zoom/Module.cpp", + "torch/csrc/zoom/Event.cpp", + "torch/csrc/zoom/python_comm.cpp", + "torch/csrc/zoom/Stream.cpp", + "torch/csrc/zoom/Graph.cpp", + "torch/csrc/zoom/utils.cpp", + "torch/csrc/zoom/ZoomPluggableAllocator.cpp", + "torch/csrc/zoom/comm.cpp", + "torch/csrc/zoom/memory_snapshot.cpp", + "torch/csrc/zoom/shared/hiprt.cpp", +] + libtorch_python_xpu_sources = [ "torch/csrc/xpu/Event.cpp", "torch/csrc/xpu/Module.cpp", @@ -952,6 +965,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): aten_cpu_non_globed_sources = [ "aten/src/ATen/detail/CUDAHooksInterface.cpp", "aten/src/ATen/detail/HIPHooksInterface.cpp", + "aten/src/ATen/detail/ZoomHooksInterface.cpp", "aten/src/ATen/detail/MPSHooksInterface.cpp", "aten/src/ATen/detail/MAIAHooksInterface.cpp", "aten/src/ATen/detail/PrivateUse1HooksInterface.cpp", @@ -970,6 +984,7 @@ aten_cpu_non_globed_headers = [ "aten/src/ATen/detail/CUDAHooksInterface.h", "aten/src/ATen/detail/MPSHooksInterface.h", "aten/src/ATen/detail/HIPHooksInterface.h", + "aten/src/ATen/detail/ZoomHooksInterface.h", "aten/src/ATen/detail/MAIAHooksInterface.h", "aten/src/ATen/detail/PrivateUse1HooksInterface.h", "aten/src/ATen/detail/XPUHooksInterface.h", diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 1f742f4c17683d..c8f74102099a87 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -140,6 +140,10 @@ if(USE_ROCM) add_subdirectory(hip) endif() +if(USE_ZOOM) + add_subdirectory(zoom) +endif() + if(USE_XPU) add_subdirectory(xpu) endif() diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index 491c85b081e885..e855f870a759b8 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -41,6 +41,17 @@ C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES]; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0}; +/* + (Arham) This holds functor that enables getting the PU1 allocator from a function rather than statically registering + a pointer to a static global variable, which is useful when we want to create a global allocator that is thread safe + (e.g. using std::atomic). See the usage below in GetAllocator and REGISTER_PU1_ALLOCATOR in Allocator.h +*/ +C10_API at::Allocator* (*getPrivateUse1Allocator)() = nullptr; + +void SetPrivateUse1GetAllocator(at::Allocator* (*getAllocatorFunc)()) { + getPrivateUse1Allocator = getAllocatorFunc; +} + void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) { if (priority >= allocator_priority[static_cast(t)]) { allocator_array[static_cast(t)] = alloc; @@ -49,6 +60,10 @@ void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) { } at::Allocator* GetAllocator(const at::DeviceType& t) { + // if registered, use the functor registration for the PU1 allocator, else use the traditional static registration + if(t == DeviceType::PrivateUse1 && getPrivateUse1Allocator != nullptr) { + return getPrivateUse1Allocator(); + } auto* alloc = allocator_array[static_cast(t)]; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alloc, "Allocator for ", t, " is not set."); return alloc; diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index 412412557a0d11..936c19469af905 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -255,6 +255,15 @@ struct C10_API InefficientStdFunctionContext { C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0); C10_API Allocator* GetAllocator(const DeviceType& t); +// set a functor that can retrieve the PrivateUse1 Allocator at will +C10_API void SetPrivateUse1GetAllocator(at::Allocator* (*getAllocatorFunc)()); + +struct PrivateUse1AllocatorRegisterer { + explicit PrivateUse1AllocatorRegisterer(at::Allocator* (*getAllocatorFunc)()) { + SetPrivateUse1GetAllocator(getAllocatorFunc); + } +}; + template struct AllocatorRegisterer { explicit AllocatorRegisterer(Allocator* alloc) { @@ -267,6 +276,11 @@ struct AllocatorRegisterer { static c10::AllocatorRegisterer g_allocator_d(f); \ } +#define REGISTER_PU1_ALLOCATOR(f) \ + namespace { \ + static PrivateUse1AllocatorRegisterer g_allocator_d(f); \ + } + // An interface for reporting thread local memory usage // per device struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { diff --git a/c10/macros/Export.h b/c10/macros/Export.h index cb68060ed8129d..84438c6eead37d 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -140,8 +140,10 @@ #if defined(TORCH_HIP_BUILD_MAIN_LIB) #define TORCH_HIP_API C10_EXPORT +#define TORCH_ZOOM_API C10_EXPORT #else #define TORCH_HIP_API C10_IMPORT +#define TORCH_ZOOM_API C10_EXPORT #endif #if defined(TORCH_XPU_BUILD_MAIN_LIB) diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index f28e526a0431a9..a704e55142f52e 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -310,7 +310,7 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #define C10_HIP_HOST_DEVICE #endif -#if defined(USE_ROCM) +#if defined(USE_ROCM) || defined(USE_ZOOM) #define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h) #else #define C10_WARP_SIZE 32 diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index adfdbfd9955c00..579aee2e2d83e0 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -8,7 +8,11 @@ #include #define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign #elif defined(__HIPCC__) -#include + #ifdef USE_ZOOM + #include + #else + #include + #endif #define C10_COMPAT_COPYSIGN c10::hip::compat::copysign #else #include diff --git a/c10/zoom/CMakeLists.txt b/c10/zoom/CMakeLists.txt new file mode 100644 index 00000000000000..f055a8d824cddf --- /dev/null +++ b/c10/zoom/CMakeLists.txt @@ -0,0 +1,60 @@ +include(../../cmake/public/utils.cmake) + +# ---[ Configure macro file. +set(C10_ZOOM_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in +# configure_file( +# ${CMAKE_CURRENT_LIST_DIR}/impl/hip_cmake_macros.h.in +# ${CMAKE_BINARY_DIR}/c10/hip/impl/hip_cmake_macros.h) + +# NB: All previous cu files are renamed into cc files. This isn't tested at the +# moment. +file(GLOB C10_ZOOM_SRCS + *.cpp + *.cu + impl/*.cpp + impl/*.cu + ) + +# Mark the cc files as HIP files, so we call the compiler. (They have to be +# suffixed with cc, because the hcc compiler won't accept them otherwise.) +file(GLOB __c10_zoom_srcs_cpp *.cu impl/*.cu) +set_source_files_properties(${__c10_zoom_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) +set_source_files_properties(${C10_ZOOM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + +file(GLOB_RECURSE C10_ZOOM_HEADERS *.h) +hip_add_library(c10_zoom ${C10_ZOOM_SRCS} ${C10_ZOOM_HEADERS}) + +# Propagate HIP_CXX_FLAGS that were set from Dependencies.cmake +target_compile_options(c10_zoom PRIVATE ${HIP_CXX_FLAGS}) + +# caffe2_hip adds a bunch of dependencies like rocsparse, but c10/hip is supposed to be +# minimal. I'm not sure if we need hip_hcc or not; for now leave it out + +# If building shared library, set dllimport/dllexport proper. +target_compile_options(c10_zoom PRIVATE "-DC10_ZOOM_BUILD_MAIN_LIB") +# Enable hidden visibility if compiler supports it. +if(${COMPILER_SUPPORTS_HIDDEN_VISIBILITY}) + target_compile_options(c10_zoom PRIVATE "-fvisibility=hidden") +endif() + +# ---[ Dependency of c10_zoom +target_link_libraries(c10_zoom PUBLIC c10) + +target_link_libraries(c10_zoom PUBLIC ${PYTORCH_HIP_LIBRARIES}) + +target_include_directories( + c10_zoom PUBLIC + $ + $ + $ + $) + +# add_subdirectory(test) + +# ---[ Installation +install(TARGETS c10_zoom EXPORT Caffe2Targets DESTINATION lib) +install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + DESTINATION include + FILES_MATCHING PATTERN "*.h") +# install(FILES ${CMAKE_BINARY_DIR}/c10/hip/impl/hip_cmake_macros.h +# DESTINATION include/c10/hip/impl) diff --git a/c10/zoom/HIPGraphsC10Utils.h b/c10/zoom/HIPGraphsC10Utils.h new file mode 100644 index 00000000000000..9e423df8bd0250 --- /dev/null +++ b/c10/zoom/HIPGraphsC10Utils.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include + +// CUDA Graphs utils used by c10 and aten. +// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. + +namespace c10::zoom { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::zoom::graph_pool_handle. +using MempoolId_t = std::pair; + +// RAII guard for "hipStreamCaptureMode", a thread-local value +// that controls the error-checking strictness of a capture. +struct ZoomStreamCaptureModeGuard { + ZoomStreamCaptureModeGuard(hipStreamCaptureMode desired) + : strictness_(desired) { + C10_ZOOM_CHECK(hipThreadExchangeStreamCaptureMode(&strictness_)); + } + ~ZoomStreamCaptureModeGuard() { + C10_ZOOM_CHECK_WARN(hipThreadExchangeStreamCaptureMode(&strictness_)); + } + + private: + hipStreamCaptureMode strictness_; +}; + +// Protects against enum hipStreamCaptureStatus implementation changes. +// Some compilers seem not to like static_assert without the messages. +static_assert( + int(hipStreamCaptureStatus::hipStreamCaptureStatusNone) == 0, + "unexpected int(hipStreamCaptureStatusNone) value"); +static_assert( + int(hipStreamCaptureStatus::hipStreamCaptureStatusActive) == 1, + "unexpected int(hipStreamCaptureStatusActive) value"); +static_assert( + int(hipStreamCaptureStatus::hipStreamCaptureStatusInvalidated) == 2, + "unexpected int(hipStreamCaptureStatusInvalidated) value"); + +enum class CaptureStatus : int { + None = int(hipStreamCaptureStatus::hipStreamCaptureStatusNone), + Active = int(hipStreamCaptureStatus::hipStreamCaptureStatusActive), + Invalidated = int(hipStreamCaptureStatus::hipStreamCaptureStatusInvalidated) +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch (status) { + case CaptureStatus::None: + os << "hipStreamCaptureStatusNone"; + break; + case CaptureStatus::Active: + os << "hipStreamCaptureStatusActive"; + break; + case CaptureStatus::Invalidated: + os << "hipStreamCaptureStatusInvalidated"; + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown HIP graph CaptureStatus", int(status)); + } + return os; +} + +// Use this version where you're sure a HIP context exists already. +inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { + hipStreamCaptureStatus is_capturing{hipStreamCaptureStatusNone}; + C10_ZOOM_CHECK( + hipStreamIsCapturing(c10::zoom::getCurrentZoomStream(), &is_capturing)); + return CaptureStatus(is_capturing); +} + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/HIPMathCompat.h b/c10/zoom/HIPMathCompat.h new file mode 100644 index 00000000000000..12c08d2a8a13b4 --- /dev/null +++ b/c10/zoom/HIPMathCompat.h @@ -0,0 +1,152 @@ +#pragma once + +/* This file defines math functions compatible across different gpu + * platforms (currently CUDA and HIP). + */ +#if defined(__CUDACC__) || defined(__HIPCC__) + +#include +#include + +#ifdef __HIPCC__ +#define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE +#else /* __HIPCC__ */ +#ifdef __CUDACC_RTC__ +#define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE +#else /* __CUDACC_RTC__ */ +#define __MATH_FUNCTIONS_DECL__ static inline C10_HOST_DEVICE +#endif /* __CUDACC_RTC__ */ +#endif /* __HIPCC__ */ + +namespace c10::hip::compat { + +__MATH_FUNCTIONS_DECL__ float abs(float x) { + return ::fabsf(x); +} +__MATH_FUNCTIONS_DECL__ double abs(double x) { + return ::fabs(x); +} + +__MATH_FUNCTIONS_DECL__ float exp(float x) { + return ::expf(x); +} +__MATH_FUNCTIONS_DECL__ double exp(double x) { + return ::exp(x); +} + +__MATH_FUNCTIONS_DECL__ float ceil(float x) { + return ::ceilf(x); +} +__MATH_FUNCTIONS_DECL__ double ceil(double x) { + return ::ceil(x); +} + +__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysignf(x, y); +#else + // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64 + // (e.g. Jetson), see PyTorch PR #51834 + // This host function needs to be here for the compiler but is never used + TORCH_INTERNAL_ASSERT( + false, "HIPMathCompat copysign should not run on the CPU"); +#endif +} +__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysign(x, y); +#else + // see above + TORCH_INTERNAL_ASSERT( + false, "HIPMathCompat copysign should not run on the CPU"); +#endif +} + +__MATH_FUNCTIONS_DECL__ float floor(float x) { + return ::floorf(x); +} +__MATH_FUNCTIONS_DECL__ double floor(double x) { + return ::floor(x); +} + +__MATH_FUNCTIONS_DECL__ float log(float x) { + return ::logf(x); +} +__MATH_FUNCTIONS_DECL__ double log(double x) { + return ::log(x); +} + +__MATH_FUNCTIONS_DECL__ float log1p(float x) { + return ::log1pf(x); +} + +__MATH_FUNCTIONS_DECL__ double log1p(double x) { + return ::log1p(x); +} + +__MATH_FUNCTIONS_DECL__ float max(float x, float y) { + return ::fmaxf(x, y); +} +__MATH_FUNCTIONS_DECL__ double max(double x, double y) { + return ::fmax(x, y); +} + +__MATH_FUNCTIONS_DECL__ float min(float x, float y) { + return ::fminf(x, y); +} +__MATH_FUNCTIONS_DECL__ double min(double x, double y) { + return ::fmin(x, y); +} + +__MATH_FUNCTIONS_DECL__ float pow(float x, float y) { + return ::powf(x, y); +} +__MATH_FUNCTIONS_DECL__ double pow(double x, double y) { + return ::pow(x, y); +} + +__MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) { + return ::sincosf(x, sptr, cptr); +} +__MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) { + return ::sincos(x, sptr, cptr); +} + +__MATH_FUNCTIONS_DECL__ float sqrt(float x) { + return ::sqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double sqrt(double x) { + return ::sqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float rsqrt(float x) { + return ::rsqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double rsqrt(double x) { + return ::rsqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float tan(float x) { + return ::tanf(x); +} +__MATH_FUNCTIONS_DECL__ double tan(double x) { + return ::tan(x); +} + +__MATH_FUNCTIONS_DECL__ float tanh(float x) { + return ::tanhf(x); +} +__MATH_FUNCTIONS_DECL__ double tanh(double x) { + return ::tanh(x); +} + +__MATH_FUNCTIONS_DECL__ float normcdf(float x) { + return ::normcdff(x); +} +__MATH_FUNCTIONS_DECL__ double normcdf(double x) { + return ::normcdf(x); +} + +} // namespace c10::hip::compat + +#endif \ No newline at end of file diff --git a/c10/zoom/ZoomAllocatorConfig.cpp b/c10/zoom/ZoomAllocatorConfig.cpp new file mode 100644 index 00000000000000..7ff6e6955e98c3 --- /dev/null +++ b/c10/zoom/ZoomAllocatorConfig.cpp @@ -0,0 +1,350 @@ +#include +#include +#include + +namespace c10::zoom::ZoomCachingAllocator { + +constexpr size_t kRoundUpPowerOfTwoIntervals = 16; + +ZoomAllocatorConfig::ZoomAllocatorConfig() + : m_max_split_size(std::numeric_limits::max()), + m_garbage_collection_threshold(0), + m_pinned_num_register_threads(1), + m_expandable_segments(false), + m_release_lock_on_hipMalloc(false), + m_pinned_use_zoom_host_register(false), + m_last_allocator_settings("") { + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); +} + +size_t ZoomAllocatorConfig::roundup_power2_divisions(size_t size) { + size_t log_size = (63 - llvm::countLeadingZeros(size)); + + // Our intervals start at 1MB and end at 64GB + const size_t interval_start = + 63 - llvm::countLeadingZeros(static_cast(1048576)); + const size_t interval_end = + 63 - llvm::countLeadingZeros(static_cast(68719476736)); + TORCH_CHECK( + (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), + "kRoundUpPowerOfTwoIntervals mismatch"); + + int index = static_cast(log_size) - static_cast(interval_start); + + index = std::max(0, index); + index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); + return instance().m_roundup_power2_divisions[index]; +} + +void ZoomAllocatorConfig::lexArgs( + const char* env, + std::vector& config) { + std::vector buf; + + size_t env_length = strlen(env); + for (size_t i = 0; i < env_length; i++) { + if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') { + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + buf.clear(); + } + config.emplace_back(1, env[i]); + } else if (env[i] != ' ') { + buf.emplace_back(static_cast(env[i])); + } + } + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + } +} + +void ZoomAllocatorConfig::consumeToken( + const std::vector& config, + size_t i, + const char c) { + TORCH_CHECK( + i < config.size() && config[i] == std::string(1, c), + "Error parsing CachingAllocator settings, expected ", + c, + ""); +} + +size_t ZoomAllocatorConfig::parseMaxSplitSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_split_size_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_split_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); + } + return i; +} + +size_t ZoomAllocatorConfig::parseGarbageCollectionThreshold( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + double val1 = stod(config[i]); + TORCH_CHECK( + val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); + TORCH_CHECK( + val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); + m_garbage_collection_threshold = val1; + } else { + TORCH_CHECK( + false, "Error, expecting garbage_collection_threshold value", ""); + } + return i; +} + +size_t ZoomAllocatorConfig::parseRoundUpPower2Divisions( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + bool first_value = true; + + if (++i < config.size()) { + if (std::string_view(config[i]) == "[") { + size_t last_index = 0; + while (++i < config.size() && std::string_view(config[i]) != "]") { + const std::string& val1 = config[i]; + size_t val2 = 0; + + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + val2 = stoi(config[i]); + } else { + TORCH_CHECK( + false, "Error parsing roundup_power2_divisions value", ""); + } + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "For roundups, the divisons has to be power of 2 ", + ""); + + if (std::string_view(val1) == ">") { + std::fill( + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + last_index)), + m_roundup_power2_divisions.end(), + val2); + } else { + size_t val1_long = stoul(val1); + TORCH_CHECK( + llvm::isPowerOf2_64(val1_long), + "For roundups, the intervals have to be power of 2 ", + ""); + + size_t index = 63 - llvm::countLeadingZeros(val1_long); + index = std::max((size_t)0, index); + index = std::min(index, m_roundup_power2_divisions.size() - 1); + + if (first_value) { + std::fill( + m_roundup_power2_divisions.begin(), + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + index)), + val2); + first_value = false; + } + if (index < m_roundup_power2_divisions.size()) { + m_roundup_power2_divisions[index] = val2; + } + last_index = index; + } + + if (std::string_view(config[i + 1]) != "]") { + consumeToken(config, ++i, ','); + } + } + } else { // Keep this for backwards compatibility + size_t val1 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val1), + "For roundups, the divisons has to be power of 2 ", + ""); + std::fill( + m_roundup_power2_divisions.begin(), + m_roundup_power2_divisions.end(), + val1); + } + } else { + TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); + } + return i; +} + +size_t ZoomAllocatorConfig::parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_zoomMallocAsync) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + ((config[i] == "native") || (config[i] == "zoomMallocAsync")), + "Unknown allocator backend, " + "options are native and zoomMallocAsync"); + + // HIP supports hipMallocAsync and does not need to check versions unlike CUDA + used_zoomMallocAsync = (config[i] == "zoomMallocAsync"); + + TORCH_INTERNAL_ASSERT( + config[i] == get()->name(), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time"); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); + } + return i; +} + +void ZoomAllocatorConfig::parseArgs(const char* env) { + // If empty, set the default values + m_max_split_size = std::numeric_limits::max(); + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); + m_garbage_collection_threshold = 0; + bool used_zoomMallocAsync = false; + bool used_native_specific_option = false; + + if (env == nullptr) { + return; + } + { + std::lock_guard lock(m_last_allocator_settings_mutex); + m_last_allocator_settings = env; + } + + std::vector config; + lexArgs(env, config); + + for (size_t i = 0; i < config.size(); i++) { + std::string_view config_item_view(config[i]); + if (config_item_view == "max_split_size_mb") { + i = parseMaxSplitSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "garbage_collection_threshold") { + i = parseGarbageCollectionThreshold(config, i); + used_native_specific_option = true; + } else if (config_item_view == "roundup_power2_divisions") { + i = parseRoundUpPower2Divisions(config, i); + used_native_specific_option = true; + } else if (config_item_view == "backend") { + i = parseAllocatorConfig(config, i, used_zoomMallocAsync); + } else if (config_item_view == "expandable_segments") { + used_native_specific_option = true; + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for expandable_segments"); + config_item_view = config[i]; + m_expandable_segments = (config_item_view == "True"); + } else if ( + // ROCm build's hipify step will change "cuda" to "hip", but for ease of + // use, accept both. We must break up the string to prevent hipify here. + config_item_view == "release_lock_on_hipMalloc" || + config_item_view == + "release_lock_on_c" + "udamalloc") { + used_native_specific_option = true; + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for release_lock_on_hipMalloc"); + config_item_view = config[i]; + m_release_lock_on_hipMalloc = (config_item_view == "True"); + } else if ( + // ROCm build's hipify step will change "cuda" to "hip", but for ease of + // use, accept both. We must break up the string to prevent hipify here. + config_item_view == "pinned_use_hip_host_register" || + config_item_view == + "pinned_use_c" + "uda_host_register") { + i = parsePinnedUseZoomHostRegister(config, i); + used_native_specific_option = true; + } else if (config_item_view == "pinned_num_register_threads") { + i = parsePinnedNumRegisterThreads(config, i); + used_native_specific_option = true; + } else { + TORCH_CHECK( + false, "Unrecognized CachingAllocator option: ", config_item_view); + } + + if (i + 1 < config.size()) { + consumeToken(config, ++i, ','); + } + } + + if (used_zoomMallocAsync && used_native_specific_option) { + TORCH_WARN( + "backend:zoomMallocAsync ignores max_split_size_mb," + "roundup_power2_divisions, and garbage_collect_threshold."); + } +} + +size_t ZoomAllocatorConfig::parsePinnedUseZoomHostRegister( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_zoom_host_register"); + m_pinned_use_zoom_host_register = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_zoom_host_register value", ""); + } + return i; +} + +size_t ZoomAllocatorConfig::parsePinnedNumRegisterThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val2 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = ZoomAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; + } else { + TORCH_CHECK( + false, "Error, expecting pinned_num_register_threads value", ""); + } + return i; +} + +// General caching allocator utilities +void setAllocatorSettings(const std::string& env) { + ZoomCachingAllocator::ZoomAllocatorConfig::instance().parseArgs(env.c_str()); +} + +} // namespace c10::zoom::ZoomCachingAllocator \ No newline at end of file diff --git a/c10/zoom/ZoomAllocatorConfig.h b/c10/zoom/ZoomAllocatorConfig.h new file mode 100644 index 00000000000000..86a2d5a6e10c4c --- /dev/null +++ b/c10/zoom/ZoomAllocatorConfig.h @@ -0,0 +1,128 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace c10::zoom::ZoomCachingAllocator { + +// Environment config parser +class ZoomAllocatorConfig { + public: + static size_t max_split_size() { + return instance().m_max_split_size; + } + static double garbage_collection_threshold() { + return instance().m_garbage_collection_threshold; + } + + static bool expandable_segments() { + // for now, we don't support expanable segments + if (instance().m_expandable_segments) { + TORCH_WARN_ONCE("expandable_segments not supported on this platform") + } + return false; +// #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED +// if (instance().m_expandable_segments) { +// TORCH_WARN_ONCE("expandable_segments not supported on this platform") +// } +// return false; +// #else +// return instance().m_expandable_segments; +// #endif + } + + static bool release_lock_on_hipMalloc() { + return instance().m_release_lock_on_hipMalloc; + } + + /** Pinned memory allocator settings */ + static bool pinned_use_zoom_host_register() { + return instance().m_pinned_use_zoom_host_register; + } + + static size_t pinned_num_register_threads() { + return instance().m_pinned_num_register_threads; + } + + static size_t pinned_max_register_threads() { + // Based on the benchmark results, we see better allocation performance + // with 8 threads. However on future systems, we may need more threads + // and limiting this to 128 threads. + return 128; + } + + // This is used to round-up allocation size to nearest power of 2 divisions. + // More description below in function roundup_power2_next_division + // As ane example, if we want 4 divisions between 2's power, this can be done + // using env variable: PYTORCH_ZOOM_ALLOC_CONF=roundup_power2_divisions:4 + static size_t roundup_power2_divisions(size_t size); + + static std::vector roundup_power2_divisions() { + return instance().m_roundup_power2_divisions; + } + + static std::string last_allocator_settings() { + std::lock_guard lock( + instance().m_last_allocator_settings_mutex); + return instance().m_last_allocator_settings; + } + + static ZoomAllocatorConfig& instance() { + static ZoomAllocatorConfig* s_instance = ([]() { + auto inst = new ZoomAllocatorConfig(); + const char* env = getenv("PYTORCH_ZOOM_ALLOC_CONF"); + inst->parseArgs(env); + return inst; + })(); + return *s_instance; + } + + void parseArgs(const char* env); + + private: + ZoomAllocatorConfig(); + + static void lexArgs(const char* env, std::vector& config); + static void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, + size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_zoomMallocAsync); + size_t parsePinnedUseZoomHostRegister( + const std::vector& config, + size_t i); + size_t parsePinnedNumRegisterThreads( + const std::vector& config, + size_t i); + + std::atomic m_max_split_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; + std::atomic m_expandable_segments; + std::atomic m_release_lock_on_hipMalloc; + std::atomic m_pinned_use_zoom_host_register; + std::string m_last_allocator_settings; + std::mutex m_last_allocator_settings_mutex; +}; + +// General caching allocator utilities +void setAllocatorSettings(const std::string& env); + +} // namespace c10::zoom::ZoomCachingAllocator \ No newline at end of file diff --git a/c10/zoom/ZoomCachingAllocator.cpp b/c10/zoom/ZoomCachingAllocator.cpp new file mode 100644 index 00000000000000..c28541f862c3f7 --- /dev/null +++ b/c10/zoom/ZoomCachingAllocator.cpp @@ -0,0 +1,3104 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +// #include +// #include +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TORCH_SDT_DEFINE_SEMAPHORE(malloc) +TORCH_SDT_DEFINE_SEMAPHORE(free) + +namespace c10 { + +C10_DEFINE_REGISTRY(FreeZoomMemoryCallbacksRegistry, FreeMemoryCallback); + +namespace zoom::ZoomCachingAllocator { + +// Included here as this is externally used in ZoomAllocatorConfig +const size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks + +namespace Native { + +// +// Yet another caching allocator for HIP device allocations. +// +// - Allocations are associated with a stream. Once freed, blocks can be +// re-allocated on the same stream, but not on any other stream. +// - The allocator attempts to find the smallest cached block that will fit the +// requested size. If the block is larger than the requested size, it may be +// split. If no block is found, the allocator will delegate to hipMalloc. +// - If the hipMalloc fails, the allocator will attempt to free one cached +// block of sufficient size that is not split and retry the allocation. +// If this also fails, the allocator will attempt to free all cached blocks +// that are not split and retry the allocation. +// - Large (>1MB) and small allocations are stored in separate pools. +// Small requests are packed into 2MB buffers. Large requests will use the +// smallest available free block or allocate a new block using hipMalloc. +// - To reduce fragmentation, requests between 1MB and 10MB will allocate and +// split a 20MB block, if no free block of sufficient size is available. +// - To further reduce fragmentation, blocks >= max_split_size are not allowed +// to be split. These oversize cached blocks will still satisfy requests +// within 1MB of the oversize cached block size. +// +// With this allocator, allocations and frees should logically be considered +// "usages" of the memory segment associated with streams, just like kernel +// launches. The programmer must insert the proper synchronization if memory +// segments are used from multiple streams. +// +// The library provides a recordStream() function to help insert the correct +// synchronization when allocations are used on multiple streams. This will +// ensure that the block is not reused before each recorded stream completes +// work. +// + +/** + * Note [Interaction with HIP graph capture] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Graph capture performs a dry run of a region of execution, freezing all HIP + * work (and virtual addresses used during that work) into a "graph." The graph + * may be "replayed" like a single giant kernel, with greatly reduced CPU + * overhead as well as modestly improved GPU performance. + * + * Because capture bakes in memory addresses, the memory used during capture + * must be available for the graph to use during replay. DeviceCachingAllocator + * assigns and frees memory eagerly and dynamically, so if we're not careful + * about managing graphs' memory, at replay time those memory addresses could be + * used by other tensors. + * + * To guarantee a graph's baked in addresses are safe to reuse in replay, + * DeviceAllocator satisfies allocations from a graph-private memory pool during + * capture, and doesn't begin hipFreeing those addresses until the graph is + * destroyed. + * + * Within the private pool, allocations are freed and reassigned as usual during + * capture. Memory regions will be used in a consistent order during replay. So + * a private pool doesn't use memory more wastefully than the default pools + * during capture, but it does reserve its high-water mark of used memory away + * from the default pools as long as the capture(s) it served survive + * (regardless whether those captures are idle or replaying). + * + * CUDAGraph's requests for private pools are mediated by + * DeviceAllocator::notifyCaptureBegin, + * notifyCaptureAboutToEnd, + * notifyCaptureEnded, + * notifyCaptureDestroy. + */ + +constexpr size_t kMinBlockSize = + 512; // all sizes are rounded to at least 512 bytes +constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB +constexpr size_t kSmallBuffer = + 2097152; // "small" allocations are packed in 2 MiB blocks +constexpr size_t kMinLargeAlloc = + 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer +constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB + +namespace { + +using stream_set = ska::flat_hash_set; + +using StatTypes = std::array(StatType::NUM_TYPES)>; + +void increase_stat(Stat& stat, size_t amount) { + stat.current += static_cast(amount); + stat.peak = std::max(stat.current, stat.peak); + stat.allocated += static_cast(amount); +} + +void decrease_stat(Stat& stat, size_t amount) { + stat.current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stat.current >= 0, + "Negative tracked stat in HIP allocator (likely logic error)."); + stat.freed += static_cast(amount); +} + +void reset_accumulated_stat(Stat& stat) { + stat.allocated = 0; + stat.freed = 0; +} + +void reset_peak_stat(Stat& stat) { + stat.peak = stat.current; +} + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +void decrease_stat_array( + StatArray& stat_array, + size_t amount, + const StatTypes& stat_types) { + for_each_selected_stat_type( + stat_types, [&stat_array, amount](size_t stat_type) { + decrease_stat(stat_array[stat_type], amount); + }); +} + +struct Block; +struct PrivatePool; +typedef bool (*Comparison)(const Block*, const Block*); +static bool BlockComparatorSize(const Block* a, const Block* b); +static bool BlockComparatorAddress(const Block* a, const Block* b); + +struct BlockPool { + BlockPool(bool small, PrivatePool* private_pool = nullptr) + : blocks(BlockComparatorSize), + unmapped(BlockComparatorAddress), + is_small(small), + owner_PrivatePool(private_pool) {} + + // Do not insert a Block to blocks directly; use insert_into_blocks(), + // instead. + std::set blocks; + std::set unmapped; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const bool is_small; + PrivatePool* owner_PrivatePool; + int64_t get_free_blocks_call_count{0}; + + // Add a Block into blocks set with updating gc counter. + std::pair::iterator, bool> insert_into_blocks( + Block* block); +}; + +struct ExpandableSegment; + +struct Block { + c10::DeviceIndex device; // gpu + hipStream_t stream; // allocation stream + stream_set stream_uses; // streams on which the block was used + size_t size; // block size in bytes + size_t requested_size; // memory originally requested + BlockPool* pool{nullptr}; // owning memory pool + void* ptr{nullptr}; // memory address + bool allocated{false}; // in-use flag + bool mapped{true}; // is the virtual address range this Block references + // backed by physical pages. Always true when + // expandable_segment_ is null. When false + // This Block will be aligned to the segment size + // of its expandable_segment_. + Block* prev{nullptr}; // prev block if split from a larger allocation + Block* next{nullptr}; // next block if split from a larger allocation + int event_count{0}; // number of outstanding HIP events + int64_t gc_count_base{0}; // get_free_blocks_call_count when Block is inserted + std::shared_ptr context_when_allocated; + // only set for the first block in the segment (when prev == null) + // this records the frame information when hipMalloc was called + // whereas context_when_allocated records the last time we handed this + // memory out from our cache. + std::shared_ptr context_when_segment_allocated; + + ExpandableSegment* expandable_segment_{nullptr}; + + Block( + c10::DeviceIndex device, + hipStream_t stream, + size_t size, + BlockPool* pool, + void* ptr) + : device(device), + stream(stream), + stream_uses(), + size(size), + requested_size(0), + pool(pool), + ptr(ptr) {} + + // constructor for search key + Block(c10::DeviceIndex device, hipStream_t stream, size_t size) + : device(device), + stream(stream), + stream_uses(), + size(size), + requested_size(0) {} + + size_t gc_count() { + TORCH_INTERNAL_ASSERT(pool); + return static_cast(pool->get_free_blocks_call_count - gc_count_base); + } + + bool is_split() const { + return (prev != nullptr) || (next != nullptr); + } + void splice(Block* before, Block* after) { + if (before) { + TORCH_INTERNAL_ASSERT(before->next == after); + before->next = this; + } + prev = before; + if (after) { + TORCH_INTERNAL_ASSERT(after->prev == before); + after->prev = this; + } + next = after; + } +}; + +std::pair::iterator, bool> BlockPool:: + insert_into_blocks(Block* block) { + block->gc_count_base = get_free_blocks_call_count; + return blocks.insert(block); +} + +struct SegmentRange { + char* ptr; + size_t size; + SegmentRange(void* p, size_t s) : ptr(static_cast(p)), size(s) {} +}; + +// For now we don't support expandable segments +struct ExpandableSegment { + ExpandableSegment( + c10::DeviceIndex device, + hipStream_t stream, + size_t size, + const std::vector& peers) { + TORCH_INTERNAL_ASSERT(false, "expandable segment not supported"); + } + SegmentRange map(SegmentRange range) { + return SegmentRange(nullptr, 0); + } + SegmentRange unmap(SegmentRange range) { + return SegmentRange(nullptr, 0); + } + char* ptr() const { + return nullptr; + } + size_t size() const { + return 0; + } + void addPeer(c10::DeviceIndex device) {} +}; + +// BlockState, BlockPoolState, and PrivatePoolState contain the information +// needed to reconstruct a private pool to a previous state. See note +// [Checkpointing PrivatePoolState] +struct BlockState { + c10::DeviceIndex device = 0; + hipStream_t stream = nullptr; + stream_set stream_uses = {}; + size_t size = 0; + void* ptr = nullptr; + bool allocated = false; + int64_t gc_count_base = 0; + // maintain invariant that event_count == 0 ; + // history will be left alone in checkpoint + + BlockState(Block* block); +}; + +struct SegmentState { + std::vector blocks; + bool is_small = false; + + SegmentState(Block* head); +}; + +struct PrivatePoolState : AllocatorState { + // omitting use_count, and hipMalloc_count as they remain the same + MempoolId_t owner_id = {0, 0}; + + std::vector segments; + + PrivatePoolState( + MempoolId_t pool_id, + const std::vector& private_pool_head_blocks); +}; + +struct RestoreResult { + std::vector allocations_freed; + std::vector allocations_created; +}; + +static bool BlockComparatorSize(const Block* a, const Block* b) { + if (a->stream != b->stream) { + return (uintptr_t)a->stream < (uintptr_t)b->stream; + } + if (a->size != b->size) { + return a->size < b->size; + } + return (uintptr_t)a->ptr < (uintptr_t)b->ptr; +} +static bool BlockComparatorAddress(const Block* a, const Block* b) { + if (a->stream != b->stream) { + return (uintptr_t)a->stream < (uintptr_t)b->stream; + } + return (uintptr_t)a->ptr < (uintptr_t)b->ptr; +} + +struct AllocParams { + AllocParams( + c10::DeviceIndex device, + size_t size, + hipStream_t stream, + BlockPool* pool, + size_t alloc_size, + DeviceStats& stats) + : search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {} + + c10::DeviceIndex device() const { + return search_key.device; + } + hipStream_t stream() const { + return search_key.stream; + } + size_t size() const { + return search_key.size; + } + + Block search_key; + BlockPool* pool; + size_t alloc_size; + Block* block{nullptr}; + StatTypes stat_types = {false}; + hipError_t err{hipSuccess}; +}; + +// Note: cudaEventCreate when concurrently invoked from multiple threads can be +// very expensive (at least on certain device/driver combinations). Thus, we a) +// serialize event creation at a per-device level, and b) pool the events to +// avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in +// significant improvements in multithreaded workloads with high allocation +// rates. +class EventPool { + public: + using Event = std::unique_ptr>; + // TODO: Explicit device count + EventPool() : pools_(c10::zoom::device_count()) {} + + Event get(c10::DeviceIndex device) { + TORCH_INTERNAL_ASSERT(0 <= device); + TORCH_INTERNAL_ASSERT(device < static_cast(pools_.size())); + auto& pool = pools_[device]; + auto destructor = [&pool](hipEvent_t* event) { + std::lock_guard g(pool.mutex_); + pool.event_pool_.push_back(std::unique_ptr(event)); + }; + + // Try to acquire an event from the per-device pool. + { + std::lock_guard g(pool.mutex_); + if (!pool.event_pool_.empty()) { + auto* event = pool.event_pool_.back().release(); + pool.event_pool_.pop_back(); + return Event(event, destructor); + } + } + // otherwise, allocate a new event that will be returned to the pool on + // destruction. + auto new_ptr = std::make_unique(); + C10_ZOOM_CHECK( + hipEventCreateWithFlags(new_ptr.get(), hipEventDisableTiming)); + + return Event(new_ptr.release(), destructor); + } + + void empty_cache() { + for (auto& pool : pools_) { + std::lock_guard g(pool.mutex_); + pool.event_pool_.clear(); + } + } + + private: + struct PerDevicePool { + alignas(64) std::mutex mutex_; + std::vector> event_pool_; + }; + std::vector pools_; +}; + +// HIP graphs helper +struct PrivatePool { + PrivatePool() + : large_blocks(/*small=*/false, this), + small_blocks(/*small=*/true, this) {} + PrivatePool(const PrivatePool&) = delete; + PrivatePool(PrivatePool&&) = delete; + PrivatePool& operator=(const PrivatePool&) = delete; + // Number of live graphs using this pool + int use_count{1}; + // Number of unfreed hipMallocs made for this pool. When use_count and + // hipMalloc_count drop to zero, we can delete this PrivatePool from + // graph_pools. + int hipMalloc_count{0}; + // Instead of maintaining private BlockPools here, I could stuff all blocks + // (private or no) into the top-level large_blocks and small_blocks, and + // distinguish private blocks by adding a "pool id" check above the stream + // check in BlockComparator. BlockComparator is performance- critical though, + // I'd rather not add more logic to it. + BlockPool large_blocks; + BlockPool small_blocks; +}; + +BlockState::BlockState(Block* block) + : stream(block->stream), + stream_uses(block->stream_uses), + size(block->size), + ptr(block->ptr), + allocated(block->allocated), + gc_count_base(block->gc_count_base) { + TORCH_CHECK( + block->event_count == 0, + "Events should have synchronized when checkpointing block"); +}; + +SegmentState::SegmentState(Block* head) { + TORCH_INTERNAL_ASSERT(head->prev == nullptr && head->pool != nullptr); + is_small = head->pool->is_small; + + for (Block* curr = head; curr != nullptr; curr = curr->next) { + blocks.emplace_back(curr); + } +} + +PrivatePoolState::PrivatePoolState( + MempoolId_t pool_id, + const std::vector& private_pool_head_blocks) + : owner_id(std::move(pool_id)) { + for (Block* head : private_pool_head_blocks) { + segments.emplace_back(head); + } +} + +struct MempoolIdHash { + std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { + return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; + } +}; + +hipError_t hipMallocMaybeCapturing(void** p, size_t size) { + if (c10::zoom::currentStreamCaptureStatusMayInitCtx() == + c10::zoom::CaptureStatus::None) { + return C10_ZOOM_ERROR_HANDLED(hipMalloc(p, size)); + } else { + // It's ok to capture hipMallocs, as long as we never hipFree those + // addresses before replay. + // Capturing hipMalloc behaves nicely: it gives the graph new VA, + // but is ignored (won't leakily allocate new memory) in replays. + c10::zoom::ZoomStreamCaptureModeGuard g{hipStreamCaptureModeRelaxed}; + return C10_ZOOM_ERROR_HANDLED(hipMalloc(p, size)); + } +} + +} // anonymous namespace +} // namespace Native + +static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { + return ""; +} + +namespace Native { + +class DeviceCachingAllocator { + private: + // lock around all operations + mutable std::recursive_mutex mutex; + + // device statistics + DeviceStats stats; + + // unallocated cached blocks larger than 1 MB + BlockPool large_blocks; + + // unallocated cached blocks 1 MB or smaller + BlockPool small_blocks; + + // allocated or in use by a stream. Holds all active allocations, + // whether they came from graph_pools or one of the BlockPools above. + ska::flat_hash_set active_blocks; + + // captures_underway tracks if we are diverting some + // allocations to a specific pool. + // Most of the time it's empty, in which case malloc can avoid calling + // hipStreamGetCaptureInfo in the hot path. + std::vector>> + captures_underway; + + // See free() for this thing's purpose + std::vector needs_events_deferred_until_no_capture; + // outstanding hip events + ska::flat_hash_map< + zoom::ZoomStream, + std::deque>> + hip_events; + + // record used memory. + size_t total_allocated_memory = 0; + + size_t allowed_memory_maximum = 0; + + // all live expandable segments + std::vector expandable_segments_; + std::vector devices_with_peer_access_; + + bool set_fraction = false; + + bool record_history = false; + + std::atomic context_recorder_; + size_t alloc_trace_next = 0; + RecordContext record_context_ = RecordContext::NEVER; + size_t alloc_trace_max_entries_ = 1; + std::vector* + alloc_trace; // pointer because we need to intentionally leak this on + // deallocation it can hold references to Python state which + // will already be destroyed when we are in exit handlers + + // Members specific to HIP graphs + + // Private pools for HIP graphs + ska::flat_hash_map, MempoolIdHash> + graph_pools; + // Pools no longer referenced by any graph. Their BlockPools are eligible for + // free_blocks. Can't be a vector or deque because we might erase entries in + // any order. Could be an std::list, but we don't care much, access and + // insert/erase are rare. + ska::flat_hash_map + graph_pools_freeable; + + // XXX - maybe we should generalize and have multiple events + std::vector oom_observers_; + + std::vector trace_trackers_; + + public: + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + DeviceCachingAllocator() + : large_blocks(/*small=*/false), + small_blocks(/*small=*/true), + alloc_trace(new std::vector()) { + stats.max_split_size = + static_cast(ZoomAllocatorConfig::max_split_size()); + context_recorder_.store(nullptr); + } + + void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) { + std::unique_lock lock(mutex); + TORCH_CHECK(when == RecordContext::NEVER || context_recorder); + record_history = enabled; + context_recorder_.store(record_history ? context_recorder : nullptr); + alloc_trace_max_entries_ = std::max(size_t(1), alloc_trace_max_entries); + record_context_ = enabled ? when : RecordContext::NEVER; + if (!enabled) { + alloc_trace_next = 0; + alloc_trace->clear(); + } + } + + bool isHistoryEnabled() { + return record_history; + } + + bool checkPoolLiveAllocations( + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + std::unique_lock lock(mutex); + + PrivatePool* pool = nullptr; + auto pool_it = graph_pools.find(mempool_id); + TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id"); + pool = pool_it->second.get(); + + TORCH_INTERNAL_ASSERT(pool != nullptr); + + size_t allocated_pool_blocks = 0; + + for (Block* b : active_blocks) { + TORCH_INTERNAL_ASSERT(b != nullptr); + TORCH_INTERNAL_ASSERT(b->pool != nullptr); + if (b->allocated && b->pool->owner_PrivatePool == pool) { + if (!expected_live_allocations.count(b->ptr)) { + return false; + } + + allocated_pool_blocks += 1; + } + } + + return allocated_pool_blocks == expected_live_allocations.size(); + } + + void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { + oom_observers_.emplace_back(std::move(observer)); + } + + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { + std::unique_lock lock(mutex); + trace_trackers_.emplace_back(std::move(tracker)); + } + + // Must be called outside of `mutex` or deadlocks are possible with Python + std::shared_ptr maybeGatherContext(RecordContext level) { + if (record_context_ < level) { + return nullptr; + } + return context_recorder_.load()(); + } + + // All public methods (except the above) acquire the allocator mutex. + // Thus, do not call a public method from another public method. + + Block* malloc( + c10::DeviceIndex device, + size_t orig_size, + hipStream_t stream) { + // done outside the lock because we don't know what locks the recorder needs + // to have... + auto context = maybeGatherContext(RecordContext::STATE); + + std::unique_lock lock(mutex); + + if (C10_LIKELY(captures_underway.empty())) { + // Processes end-of-life events for outstanding allocations used on + // multiple streams (checks if their GPU-side uses are complete and + // recycles their memory if so) + // + // Q. Why skip process_events if a capture might be underway? + // A. process_events involves hipEventQueries, illegal during HIP graph + // capture. + // Dumb simple solution: defer reclaiming these allocations until after + // capture. Cross-stream memory use is uncommon, so the deferral's + // effect on memory use during capture should be small. + process_events(context); + } + size_t size = round_size(orig_size); + auto& pool = get_pool(size, stream); + const size_t alloc_size = get_allocation_size(size); + AllocParams params(device, size, stream, &pool, alloc_size, stats); + params.stat_types = get_stat_types_for_pool(pool); + + // First, try to get a block from the existing pool. + bool block_found = + // Search pool + get_free_block(params) + // Trigger callbacks and retry search + || (trigger_free_memory_callbacks(params) && get_free_block(params)); + + // Can't reuse an existing block; try to get a new one. + if (!block_found) { + // Do garbage collection if the flag is set. + if (C10_UNLIKELY( + set_fraction && + ZoomAllocatorConfig::garbage_collection_threshold() > 0.0)) { + garbage_collect_cached_blocks(context); + } + // Attempt allocate + // WARNING: alloc_block may release the allocator lock when calling + // hipMalloc. So far this function has not modified allocator state, but + // keep in mind that any observed allocator state may change across calls + // to alloc_block since it may release the lock. + block_found = alloc_block(params, false, context, lock) + // Free enough available cached blocks to satisfy alloc and retry + // alloc. + || (release_available_cached_blocks(params, context) && + alloc_block(params, false, context, lock)) + // Free all non-split cached blocks and retry alloc. + || (C10_LIKELY(captures_underway.empty()) && + release_cached_blocks(context) && + alloc_block(params, true, context, lock)); + } + + if (!block_found) { + // For any error code other than hipErrorMemoryAllocation, + // alloc_block should have thrown an exception already. + TORCH_INTERNAL_ASSERT(params.err == hipErrorMemoryAllocation); + + size_t device_free = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&device_free, &device_total)); + std::string allowed_info; + + if (set_fraction) { + allowed_info = format_size(allowed_memory_maximum) + " allowed; "; + } + + std::string proc_info = reportProcessMemoryInfo(device); + + record_trace( + TraceEntry::OOM, + device_free, + params.size(), + params.stream(), + params.device(), + std::move(context)); + stats.num_ooms += 1; + + c10::reportOutOfMemoryToProfiler( + static_cast(size), + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current, + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current, + c10::Device(c10::DeviceType::PrivateUse1, device)); + + auto allocated_bytes = + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current; + auto reserved_bytes = + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current; + auto observers_local = oom_observers_; + + size_t allocated_in_private_pools = 0; + auto get_size_block = [](const BlockPool& pool) { + size_t res = 0; + for (const auto& block : pool.blocks) { + res += block->size; + } + return res; + }; + for (const auto& p : graph_pools) { + allocated_in_private_pools += get_size_block(p.second->large_blocks); + allocated_in_private_pools += get_size_block(p.second->small_blocks); + } + + std::string private_pool_msg; + + if (allocated_in_private_pools > 0) { + private_pool_msg = "with " + format_size(allocated_in_private_pools) + + " allocated in private pools (e.g., HIP Graphs), "; + } + + // Make sure we do not have the device lock before calling our + // observers which might need hold the GIL + // It is safe to release at this point because will no longer + // be reading any allocator state. + + lock.unlock(); + + for (const auto& obs : observers_local) { + obs(device, + alloc_size, + set_fraction ? allowed_memory_maximum : device_total, + device_free); + } + + // "total capacity": total global memory on GPU + // "allowed": memory is allowed to use, which set by fraction. + // "already allocated": memory allocated by the program using the + // caching allocator + // "free": free memory as reported by the HIP API + // "cached": memory held by the allocator but not used by the program + // + // The "allocated" amount does not include memory allocated outside + // of the caching allocator, such as memory allocated by other programs + // or memory held by the driver. + // + // The sum of "allocated" + "free" + "cached" may be less than the + // total capacity due to memory held by the driver and usage by other + // programs. + // + // Note that at this point free_cached_blocks has already returned all + // possible "cached" memory to the driver. The only remaining "cached" + // memory is split from a larger block that is partially in-use. + TORCH_CHECK_WITH( + OutOfMemoryError, + false, + "HIP out of memory. Tried to allocate ", + format_size(alloc_size), + ". GPU ", + static_cast(device), + " has a total capacity of ", + format_size(device_total), + " of which ", + format_size(device_free), + " is free. ", + proc_info, + "Of the allocated memory ", + format_size(allocated_bytes + allocated_in_private_pools), + " is allocated by PyTorch, ", + private_pool_msg, + "and ", + format_size( + reserved_bytes - allocated_bytes - allocated_in_private_pools), + " is reserved by PyTorch but unallocated.", + " If reserved but unallocated memory is large try setting", + " PYTORCH_ZOOM_ALLOC_CONF=expandable_segments:True to avoid" + " fragmentation. See documentation for Memory Management " + " (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"); + } + + bool split_remainder = should_split(params.block, params.size()); + return alloc_found_block( + params, orig_size, std::move(context), split_remainder); + } + + Block* alloc_found_block( + const AllocParams& params, + size_t orig_size, + std::shared_ptr context, + bool split_remainder) { + auto size = params.size(); + auto device = params.device(); + auto pool = params.pool; + auto stream = params.stream(); + + TORCH_INTERNAL_ASSERT( + params.err == hipSuccess && params.block != nullptr && + params.block->ptr != nullptr); + Block* block = params.block; + Block* remaining = nullptr; + + const bool already_split = block->is_split(); + if (split_remainder) { + remaining = block; + + block = new Block(device, stream, size, pool, block->ptr); + block->expandable_segment_ = remaining->expandable_segment_; + block->prev = remaining->prev; + if (block->prev) { + block->prev->next = block; + } + block->next = remaining; + + remaining->prev = block; + remaining->ptr = static_cast(remaining->ptr) + size; + remaining->size -= size; + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + bool inserted = pool->insert_into_blocks(remaining).second; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); + + if (already_split && !block->expandable_segment_) { + // An already-split inactive block is being shrunk by size bytes. + decrease_stat_array( + stats.inactive_split_bytes, block->size, params.stat_types); + } else if (!block->expandable_segment_) { + // A new split inactive block is being created from a previously unsplit + // block, size remaining->size bytes. + for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { + increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); + increase_stat(stats.inactive_split[stat_type], 1); + }); + } + + } else if (already_split && !block->expandable_segment_) { + // An already-split block is becoming active + for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { + decrease_stat(stats.inactive_split_bytes[stat_type], block->size); + decrease_stat(stats.inactive_split[stat_type], 1); + }); + } + + block->allocated = true; + block->requested_size = orig_size; + + block->context_when_allocated = std::move(context); + record_trace( + TraceEntry::ALLOC, + int64_t(block->ptr), + orig_size, + block->stream, + block->device, + block->context_when_allocated); + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + bool inserted = active_blocks.insert(block).second; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); + + for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { + increase_stat(stats.allocation[stat_type], 1); + increase_stat(stats.allocated_bytes[stat_type], block->size); + increase_stat(stats.active[stat_type], 1); + increase_stat(stats.active_bytes[stat_type], block->size); + increase_stat(stats.requested_bytes[stat_type], block->requested_size); + }); + if (block->size >= ZoomAllocatorConfig::max_split_size()) + increase_stat(stats.oversize_allocations, 1); + + c10::reportMemoryUsageToProfiler( + block->ptr, + static_cast(block->size), + stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current, + stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current, + c10::Device(c10::DeviceType::PrivateUse1, device)); + + return block; + } + + void free(Block* block) { + std::shared_ptr context = + maybeGatherContext(RecordContext::ALL); + std::lock_guard lock(mutex); + + block->allocated = false; + + // following logic might modifying underlaying Block, causing the size + // changed. We store ahead for reporting + auto orig_block_ptr = block->ptr; + auto orig_block_size = block->size; + + StatTypes stat_types = get_stat_types_for_pool(*block->pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + decrease_stat(stats.allocation[stat_type], 1); + decrease_stat(stats.allocated_bytes[stat_type], block->size); + }); + + record_trace( + TraceEntry::FREE_REQUESTED, + int64_t(block->ptr), + block->requested_size, + block->stream, + block->device, + context ? context : block->context_when_allocated); + + if (block->size >= ZoomAllocatorConfig::max_split_size()) + decrease_stat(stats.oversize_allocations, 1); + + if (!block->stream_uses.empty()) { + if (C10_UNLIKELY(!captures_underway.empty())) { + // It's forbidden to hipEventQuery an event recorded during HIP graph + // capture. We conservatively defer recording end-of-life events until + // the next call to process_events() (which won't happen until no + // captures are underway) + needs_events_deferred_until_no_capture.push_back(block); + } else { + insert_events(block); + } + } else { + free_block(block, context); + } + + c10::reportMemoryUsageToProfiler( + orig_block_ptr, + -static_cast(orig_block_size), + stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current, + stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current, + c10::Device(c10::DeviceType::PrivateUse1, block->device)); + } + + void* getBaseAllocation(Block* block, size_t* outSize) { + std::lock_guard lock(mutex); + TORCH_CHECK( + !block->expandable_segment_, + "Tensors allocated with expandable_segments:True cannot be shared between processes. Consider using expandable_segments:False in data loading workers via torch.cuda.memory._set_allocator_settings('expandable_segments:False')"); + while (block->prev) { + block = block->prev; + } + void* basePtr = block->ptr; + if (outSize) { + size_t size = 0; + while (block) { + size += block->size; + block = block->next; + } + *outSize = size; + } + return basePtr; + } + + void recordStream(Block* block, zoom::ZoomStream stream) { + std::lock_guard lock(mutex); + if (stream.stream() == block->stream) { + // ignore uses on the allocation stream, since those don't require any + // special synchronization + return; + } + block->stream_uses.insert(stream); + } + + /** set memory fraction to limit maximum allocated memory **/ + void setMemoryFraction(double fraction) { + size_t device_free = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&device_free, &device_total)); + allowed_memory_maximum = + static_cast(fraction * static_cast(device_total)); + set_fraction = true; + } + + /** returns cached blocks to the system allocator **/ + void emptyCache() { + auto context = maybeGatherContext(RecordContext::ALL); + std::lock_guard lock(mutex); + release_cached_blocks(context); + } + + /** Retrieves size of largest unused block held by the memory cache **/ + void cacheInfo(size_t* largest) { + std::lock_guard lock(mutex); + if (*largest == + 0) { // make an initial guess if a zero *largest is passed in + size_t tmp_bytes = 0; + C10_ZOOM_CHECK(hipMemGetInfo( + largest, // Use free memory as an optimistic initial guess of *largest + &tmp_bytes)); + } + cache_info_aux(large_blocks, largest); + cache_info_aux(small_blocks, largest); + for (const auto& gp : graph_pools) { + cache_info_aux(gp.second->large_blocks, largest); + cache_info_aux(gp.second->small_blocks, largest); + } + } + + /** Returns a copy of the memory allocator stats **/ + DeviceStats getStats() { + std::lock_guard lock(mutex); + return stats; + } + + /** Resets the historical accumulation stats for the device **/ + void resetAccumulatedStats() { + std::lock_guard lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + reset_accumulated_stat(stats.allocation[statType]); + reset_accumulated_stat(stats.segment[statType]); + reset_accumulated_stat(stats.active[statType]); + reset_accumulated_stat(stats.inactive_split[statType]); + reset_accumulated_stat(stats.allocated_bytes[statType]); + reset_accumulated_stat(stats.reserved_bytes[statType]); + reset_accumulated_stat(stats.active_bytes[statType]); + reset_accumulated_stat(stats.inactive_split_bytes[statType]); + reset_accumulated_stat(stats.requested_bytes[statType]); + } + + stats.num_alloc_retries = 0; + stats.num_ooms = 0; + stats.num_sync_all_streams = 0; + stats.num_device_alloc = 0; + stats.num_device_free = 0; + reset_accumulated_stat(stats.oversize_allocations); + reset_accumulated_stat(stats.oversize_segments); + } + + /** Resets the historical peak stats for the device **/ + void resetPeakStats() { + std::lock_guard lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + reset_peak_stat(stats.allocation[statType]); + reset_peak_stat(stats.segment[statType]); + reset_peak_stat(stats.active[statType]); + reset_peak_stat(stats.inactive_split[statType]); + reset_peak_stat(stats.allocated_bytes[statType]); + reset_peak_stat(stats.reserved_bytes[statType]); + reset_peak_stat(stats.active_bytes[statType]); + reset_peak_stat(stats.inactive_split_bytes[statType]); + reset_peak_stat(stats.requested_bytes[statType]); + } + reset_peak_stat(stats.oversize_allocations); + reset_peak_stat(stats.oversize_segments); + } + + /* Checkpoint the state of a private pool necessary to return it to its + * current state */ + std::unique_ptr getCheckpointState(MempoolId_t id) { + std::lock_guard lock(mutex); + + auto pool = graph_pools.find(id); + if (pool != graph_pools.end()) { + auto private_pool_head_blocks = + get_private_pool_head_blocks(pool->second.get()); + return std::make_unique(id, private_pool_head_blocks); + } else if (graph_pools_freeable.count(id)) { + TORCH_CHECK(false, "Not expected to checkpoint freeable graph"); + } else { + TORCH_CHECK(false, "Could not find pool of id"); + } + } + + void freeBlocksAllocatedToPool(PrivatePool* private_pool, RestoreResult& rr) { + auto pool_blocks = get_private_pool_head_blocks(private_pool); + + std::vector head_blocks; + for (Block* block : pool_blocks) { + if (block->prev == nullptr) { + head_blocks.push_back(block); + } + } + + for (Block* block : head_blocks) { + Block* curr = block; + + while (curr) { + // When we free a block, its pointer should never change + // only its adjacent blocks, so free, then look at pointer + if (curr->allocated) { + TORCH_CHECK( + curr->event_count == 0, + "Events should have synchronized when setting checkpointed block"); + rr.allocations_freed.push_back(curr->ptr); + free(curr); + TORCH_CHECK(!curr->allocated) + } + curr = curr->next; + } + } + + for (Block* b : get_private_pool_head_blocks(private_pool)) { + Block* curr = b; + while (curr) { + TORCH_CHECK(!curr->allocated); + curr = curr->next; + } + } + } + + // checkpoint the state of an allocation that may have been + // split into multiple blocks + void setSegmentStateToCheckpoint( + Block* block, + SegmentState& segment, + const std::shared_ptr& context, + RestoreResult& rr) { + Block* curr_block = block; + Block* last_block = block; + + TORCH_INTERNAL_ASSERT(block->pool); + BlockPool& pool = *block->pool; + const auto segment_len = segment.blocks.size(); + + // allocate all blocks in the segment + for (size_t i = 0; i < segment_len; ++i) { + auto& block_state = segment.blocks.at(i); + AllocParams params( + block_state.device, + block_state.size, + block_state.stream, + &pool, + block_state.size, + stats); + pool.blocks.erase(curr_block); + params.block = curr_block; + params.stat_types = get_stat_types_for_pool(pool); + + // splitting a block depends on `max_split_size`, which may have changed + // between whe checkpoint was taken and now, so we make sure to recreate + // the behavior from the checkpoint. + bool split = (i + 1) < segment.blocks.size(); + + // curr_block will become next pointer if it is split, so reassign with + // the returned value + curr_block = alloc_found_block(params, block_state.size, context, split); + + TORCH_CHECK(curr_block->ptr == block_state.ptr); + TORCH_CHECK(curr_block->size == block_state.size); + + last_block = curr_block; + curr_block = curr_block->next; + + TORCH_CHECK((curr_block != nullptr) == ((i + 1) < (segment_len))); + } + + while (last_block->prev) { + last_block = last_block->prev; + } + + // free blocks that are not allocated in the checkpoint + curr_block = last_block; + + for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) { + auto& block_state = segment.blocks.at(i); + TORCH_INTERNAL_ASSERT(curr_block != nullptr); + + if (block_state.allocated) { + rr.allocations_created.push_back(curr_block); + continue; + } + + free(curr_block); + + TORCH_CHECK(curr_block->ptr == block_state.ptr); + TORCH_CHECK(curr_block->allocated == block_state.allocated); + TORCH_CHECK(curr_block->size == block_state.size); + } + } + + /** + * Note [Checkpointing PrivatePoolState] + * + * Refer above to Note [Interaction with HIP graph capture]. Allocations made + * during graph capture are made from a separate private pool. During graph + * capture allocations behave as usual. During graph replay the allocator + * state does not change even as new tensors are created. The private pool + * will not free its blocks to the main caching allocator until cuda graph use + * is finished to prevent an allocation from eager clobbering the memory from + * a live but unaccounted for tensor that was created during replay. + * + * `make_graphed_callables`, a series of separate callables chained in + * successive cuda graphs, can share a memory pool because after a cuda graph + * recording the allocations in the shared private pool exactly reflect the + * tensors that are allocated. + * + * We would like to extend callable chaining to support a graphed callable + * tree. In this scenario, we have a tree of callable chains which will be + * captured with cuda graphs. In the diagram below, we have a tree with four + * callables, A, B, C, and D. Suppose we have captured, and subsequently + * replayed, A, B, and C. Then on a new invocation, we replay A and B, but + * would now like to record D. At this point the private pool will not reflect + * any of the live tensors created during graph replay. Allocations made + * during a new recording with the pool could overwrite those live tensors. + * + * In order to record a new graph capture after replaying prior callables in + * the tree, we need the allocator to reflect the state of the live tensors. + * We checkpoint the state of the private pool after each recording, and then + * reapply it when we are starting a new recording chain. Additionally, we + * must free the allocations for any tensors that died between the end of our + * previous graph replaying and our new recording. All of the allocated + * segments that existed in the checkpointed state must still exist in the + * pool. There may also exist new allocated blocks. + * (TODO : link note [live tensors between iterations] when it exists). For + * every block that is currently allocated but no allocated in the snapshot, + * we will return a pointer to their block. + *. + * + * + * ---------------> A ---------------> B ---------------> C + * | + * | + * | + * | + * ╰ ---------------> D + */ + RestoreResult setCheckpointPoolState(PrivatePoolState& pps) { + // To reset the caching allocator state we will + // - Free all the blocks currently allocated to the pool (see [live tensors + // between iterations]) + // - Allocate all the blocks in a checkpointed segment, whether they are + // live or not + // - Free the blocks in a checkpointed segment which are not live + // This could be optimized, but it nicely reuses exiting apis, and this + // is not on the hot path. + + // following `done outside the lock because we don't know what locks the + // recorder needs to have...` + + std::shared_ptr context = + maybeGatherContext(RecordContext::STATE); + + std::lock_guard lock(mutex); + + RestoreResult rr; + + TORCH_CHECK( + !graph_pools_freeable.count(pps.owner_id), + "Not expected to checkpoint freeable graph"); + + auto pool = graph_pools.find(pps.owner_id); + TORCH_CHECK(pool != graph_pools.end(), "Could not find private pool id"); + + PrivatePool* private_pool = pool->second.get(); + + freeBlocksAllocatedToPool(private_pool, rr); + + std::unordered_map ptrs_to_blocks; + // at this point, all of the blocks should be free, so they will all be in + // the block set + for (Block* block : private_pool->small_blocks.blocks) { + ptrs_to_blocks[block->ptr] = block; + } + for (Block* block : private_pool->large_blocks.blocks) { + ptrs_to_blocks[block->ptr] = block; + } + + for (auto& segment : pps.segments) { + auto ptr = segment.blocks.at(0).ptr; + TORCH_CHECK(ptrs_to_blocks.count(ptr), " could not find ", ptr) + auto block = ptrs_to_blocks[ptr]; + + setSegmentStateToCheckpoint(block, segment, context, rr); + } + return rr; + } + + /** Dump a complete snapshot of the memory held by the allocator. Potentially + * VERY expensive. **/ + std::vector snapshot() { + std::lock_guard lock(mutex); + + std::unordered_map pool_to_id; + pool_to_id.reserve(graph_pools.size() + graph_pools_freeable.size()); + for (const auto& pair : graph_pools) { + pool_to_id[pair.second.get()] = pair.first; + } + for (const auto& pair : graph_pools_freeable) { + pool_to_id[pair.second] = pair.first; + } + + size_t total_active = 0; + std::vector result; + const auto all_blocks = get_all_blocks(); + + for (const Block* const head_block : all_blocks) { + // For expandable segments, we report one segment for each contiguous + // mapped range of memory + if (head_block->prev && head_block->prev->mapped) { + continue; + } + result.emplace_back(); + SegmentInfo& segment_info = result.back(); + segment_info.device = head_block->device; + segment_info.address = reinterpret_cast(head_block->ptr); + segment_info.stream = head_block->stream; + segment_info.is_large = (!head_block->pool->is_small); + segment_info.is_expandable = head_block->expandable_segment_; + segment_info.context_when_allocated = + head_block->context_when_segment_allocated; + auto mempool_id = pool_to_id.find(head_block->pool->owner_PrivatePool); + if (mempool_id != pool_to_id.end()) { + segment_info.owner_private_pool_id = mempool_id->second; + } + + const Block* block = head_block; + while (block != nullptr && block->mapped) { + segment_info.blocks.emplace_back(); + BlockInfo& block_info = segment_info.blocks.back(); + + block_info.size = block->size; + block_info.requested_size = block->requested_size; + block_info.allocated = block->allocated; + block_info.active = block->allocated || (block->event_count > 0) || + !block->stream_uses.empty(); + + segment_info.total_size += block_info.size; + if (block_info.allocated) { + segment_info.allocated_size += block_info.size; + } + if (block_info.active) { + segment_info.active_size += block_info.size; + segment_info.requested_size += block_info.requested_size; + } + block_info.context_when_allocated = block->context_when_allocated; + block = block->next; + } + total_active += segment_info.active_size; + } + + std::sort( + result.begin(), + result.end(), + [](const SegmentInfo& a, const SegmentInfo& b) { + return a.address < b.address; + }); + + record_trace(TraceEntry::SNAPSHOT, 0, total_active, nullptr, 0, nullptr); + return result; + } + + std::vector trace( + const std::function& tsc_to_us) { + std::lock_guard lock(mutex); + std::vector result; + result.reserve(alloc_trace->size()); + result.insert( + result.end(), + alloc_trace->begin() + + static_cast::difference_type>( + alloc_trace_next), + alloc_trace->end()); + result.insert( + result.end(), + alloc_trace->begin(), + alloc_trace->begin() + + static_cast::difference_type>( + alloc_trace_next)); + + // Convert all the timestamps from tsc to epoch time in microseconds. + for (auto& te : result) { + te.time_.t_ = tsc_to_us(te.time_.approx_t_); + } + return result; + } + + // This function takes the size and number of divisions argument and rounds + // up the size argument for the nearest power-of-2 division. + // For example, if we need to round-up 1200 and number of divisions is 4, + // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between + // them, the values are 1024, 1280, 1536, and 1792. So the function will + // return 1280 as the nearest ceiling of power-2 divison. + static size_t roundup_power2_next_division(size_t size, size_t divisions) { + if (C10_UNLIKELY(size <= 4 || divisions <= 1)) { + return size; + } + if (llvm::isPowerOf2_64(size)) { + return size; + } + + // divide the space between these 2's power into equal divisions + // If division is zero, return the power-of-2 ceiling. + size_t power2_floor = llvm::PowerOf2Floor(size); + size_t power2_divison = + power2_floor >> (63 - llvm::countLeadingZeros(divisions)); + if (C10_UNLIKELY(power2_divison == 0)) { + return (power2_floor << 1); + } + size_t round_size_floor = size & (~(power2_divison - 1)); + return (round_size_floor == size) ? size + : round_size_floor + power2_divison; + } + + static size_t round_size(size_t size) { + if (size < kMinBlockSize) { + return kMinBlockSize; + } else { + auto divisions = ZoomAllocatorConfig::roundup_power2_divisions(size); + if (divisions > 0 && size > (kMinBlockSize * divisions)) { + return roundup_power2_next_division(size, divisions); + } else { + return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); + } + } + } + + // See Note [Interaction with HIP graph capture] + + // Called by CUDAGraph::capture_begin + void beginAllocateToPool( + MempoolId_t mempool_id, + std::function filter) { + std::lock_guard lock(mutex); + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + // mempool_id does not reference an existing pool. Make a new pool for + // this capture. + graph_pools.emplace(mempool_id, std::make_unique()); + } else { + // mempool_id references an existing pool, which the current capture will + // share. Check this pool is live (at least one other capture already + // references it). + TORCH_INTERNAL_ASSERT(it->second->use_count > 0); + it->second->use_count++; + } + for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); + ++it2) { + TORCH_CHECK( + it2->first != mempool_id, + "beginAllocateToPool: already recording to mempool_id"); + } + captures_underway.emplace_back(mempool_id, std::move(filter)); + } + + // Called by CUDAGraph::capture_end + void endAllocateToPool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + for (auto it = captures_underway.begin(); it != captures_underway.end(); + ++it) { + if (it->first == mempool_id) { + captures_underway.erase(it); + return; + } + } + TORCH_CHECK( + false, "endAllocatePool: not currently recording to mempool_id"); + } + + // Called by CUDAGraph::reset + void releasePool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + // The instantiated cudaGraphExec_t has been destroyed. We can't blindly + // delete and hipFree the mempool its capture used, because + // 1. other graph(s) might share the same pool + // 2. the user might still hold references to output tensors allocated + // during capture. + // To handle 1 and 2, we track the number of graphs using this particular + // mempool. When the count reaches 0, we tell free_cached_blocks it may now + // hipFree blocks from this graph's pool when it discovers they're unused + // (unsplit). + auto it = graph_pools.find(mempool_id); + TORCH_INTERNAL_ASSERT(it != graph_pools.end()); + auto uc = --(it->second->use_count); + TORCH_INTERNAL_ASSERT(uc >= 0); + if (uc == 0) { + // Allows free_cached_blocks to begin hipFreeing this pool's memory, + // and makes sure this pool wasn't somehow made freeable already. + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + bool inserted = + graph_pools_freeable.insert({mempool_id, it->second.get()}).second; + TORCH_INTERNAL_ASSERT(inserted); + } + } + + void addPeerAccess(c10::DeviceIndex dev_to_access) { + if (std::find( + devices_with_peer_access_.begin(), + devices_with_peer_access_.end(), + dev_to_access) != devices_with_peer_access_.end()) { + return; + } + devices_with_peer_access_.push_back(dev_to_access); + for (auto& es : expandable_segments_) { + es->addPeer(dev_to_access); + } + } + + bool hasAllocatedExpandableSegments() const { + return !expandable_segments_.empty(); + } + + private: + // All private methods do not acquire the allocator mutex. + + std::vector get_all_blocks() const { + std::vector blocks; + blocks.insert( + blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); + blocks.insert( + blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); + for (const auto& gp : graph_pools) { + blocks.insert( + blocks.end(), + gp.second->small_blocks.blocks.begin(), + gp.second->small_blocks.blocks.end()); + blocks.insert( + blocks.end(), + gp.second->large_blocks.blocks.begin(), + gp.second->large_blocks.blocks.end()); + } + blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end()); + return blocks; + } + + std::vector get_private_pool_head_blocks(PrivatePool* pool) const { + std::vector blocks; + for (Block* b : active_blocks) { + if ((b->pool == &pool->small_blocks || b->pool == &pool->large_blocks) && + b->prev == nullptr) { + blocks.push_back(b); + } + } + + for (Block* b : pool->small_blocks.blocks) { + if (b->prev == nullptr) { + blocks.push_back(b); + } + } + for (Block* b : pool->large_blocks.blocks) { + if (b->prev == nullptr) { + blocks.push_back(b); + } + } + + return blocks; + } + + // returns the smallest possible address in any segment + // where there is enough free address space to fit size + // may be composed of free and unmapped segments + Block* find_expandable_block( + c10::DeviceIndex device, + hipStream_t stream, + BlockPool* pool, + size_t size) { + Block key(device, stream, 0); + + auto allocatable = [](Block* b) { + return b && !b->allocated && b->event_count == 0 && + b->stream_uses.empty(); + }; + auto has_available_address_space = [&](Block* b) { + size_t bytes = 0; + while (bytes < size && allocatable(b)) { + bytes += b->size; + b = b->next; + } + return bytes >= size; + }; + for (auto it = pool->unmapped.lower_bound(&key); + it != pool->unmapped.end() && (*it)->stream == stream; + ++it) { + Block* c = *it; + // we found the lowest address of an unmapped segment + // but there might be a free segment we can also use + // right before it + if (allocatable(c->prev)) { + c = c->prev; + } + if (has_available_address_space(c)) { + return c; + } + } + auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer; + expandable_segments_.emplace_back(new ExpandableSegment( + device, stream, segment_size, devices_with_peer_access_)); + + ExpandableSegment* es = expandable_segments_.back(); + Block* candidate = new Block(device, stream, es->size(), pool, es->ptr()); + candidate->mapped = false; + candidate->expandable_segment_ = es; + pool->unmapped.insert(candidate); + return candidate; + } + + bool map_block( + Block* to_map, + size_t size, + const std::shared_ptr& ctx) { + TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size); + TORCH_INTERNAL_ASSERT( + !to_map->context_when_allocated); // unmapped blocks should not keep + // history + auto mapped_range = + to_map->expandable_segment_->map(SegmentRange{to_map->ptr, size}); + // failed to map the memory + if (mapped_range.size == 0) { + return false; + } + TORCH_INTERNAL_ASSERT( + mapped_range.ptr == to_map->ptr && mapped_range.size >= size); + + BlockPool& pool = *to_map->pool; + pool.unmapped.erase(to_map); + to_map->mapped = true; + + if (mapped_range.size < to_map->size) { + // to_map -> remaining -> to_map->next(?) + Block* remaining = new Block( + to_map->device, + to_map->stream, + to_map->size - mapped_range.size, + &pool, + static_cast(to_map->ptr) + mapped_range.size); + remaining->mapped = false; + remaining->expandable_segment_ = to_map->expandable_segment_; + remaining->splice(to_map, to_map->next); + pool.unmapped.insert(remaining); + to_map->size = mapped_range.size; + } + + try_merge_blocks(to_map, to_map->prev, pool); + try_merge_blocks(to_map, to_map->next, pool); + + pool.insert_into_blocks(to_map); + + // update statistics + total_allocated_memory += mapped_range.size; + StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); + }); + + stats.num_device_alloc++; + record_trace( + TraceEntry::SEGMENT_MAP, + int64_t(mapped_range.ptr), + mapped_range.size, + to_map->stream, + to_map->device, + ctx); + if (!to_map->prev && !to_map->context_when_segment_allocated) { + to_map->context_when_segment_allocated = ctx; + } + + return true; + } + + Block* try_allocate_expandable_block( + c10::DeviceIndex device, + hipStream_t stream, + BlockPool* pool, + size_t size, + const std::shared_ptr& ctx) { + Block* candidate = find_expandable_block(device, stream, pool, size); + // Candidate is now a list free/unmapped blocks with at least size room: + // unmapped -> null + // unmapped -> free -> * + // free -> unmapped -> * + + if (!candidate->mapped && + !map_block(candidate, std::min(candidate->size, size), ctx)) { + return nullptr; + } + TORCH_INTERNAL_ASSERT(candidate->mapped); + + while (candidate->size < size) { + // invariant: free -> unmapped -> * + // map_block will map some of unmapped and merge with free + auto remaining = size - candidate->size; + auto new_candidate = candidate->next; + if (!map_block( + new_candidate, std::min(remaining, candidate->next->size), ctx)) { + return nullptr; + } + candidate = new_candidate; + } + pool->blocks.erase(candidate); + return candidate; + } + + /** moves a block into a pool of cached free blocks */ + void free_block( + Block* block, + const std::shared_ptr& context) { + TORCH_INTERNAL_ASSERT( + !block->allocated && block->event_count == 0 && + block->stream_uses.empty()); + + record_trace( + TraceEntry::FREE_COMPLETED, + int64_t(block->ptr), + block->requested_size, + block->stream, + block->device, + context ? context : block->context_when_allocated); + + block->context_when_allocated = nullptr; + size_t original_block_size = block->size; + size_t requested_size = block->requested_size; + + auto& pool = *block->pool; + int64_t net_change_inactive_split_blocks = 0; + int64_t net_change_inactive_split_size = 0; + + const std::array merge_candidates = {block->prev, block->next}; + for (Block* merge_candidate : merge_candidates) { + const auto subsumed_size = try_merge_blocks(block, merge_candidate, pool); + if (subsumed_size > 0) { + net_change_inactive_split_blocks -= 1; + net_change_inactive_split_size -= static_cast(subsumed_size); + } + } + + active_blocks.erase(block); + // Makes sure the Block* isn't already present in the pool we're freeing it + // back into. + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + bool inserted = pool.insert_into_blocks(block).second; + TORCH_INTERNAL_ASSERT(inserted); + + if (block->is_split()) { + net_change_inactive_split_blocks += 1; + net_change_inactive_split_size += static_cast(block->size); + } + + StatTypes stat_types = get_stat_types_for_pool(pool); + + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + // inactive_split tries to capture the idea that blocks + // cannot be freed when requested, but fully free pages + // of expandable blocks can always be freed. + // The logic to track this as statistic is pretty involved, + // so we simply just exclude expandable segments from + // inactive_split + if (!block->expandable_segment_) { + if (net_change_inactive_split_blocks > 0) { + increase_stat( + stats.inactive_split[stat_type], + static_cast(net_change_inactive_split_blocks)); + } else if (net_change_inactive_split_blocks < 0) { + decrease_stat( + stats.inactive_split[stat_type], + static_cast(-net_change_inactive_split_blocks)); + } + if (net_change_inactive_split_size > 0) { + increase_stat( + stats.inactive_split_bytes[stat_type], + static_cast(net_change_inactive_split_size)); + } else if (net_change_inactive_split_size < 0) { + decrease_stat( + stats.inactive_split_bytes[stat_type], + static_cast(-net_change_inactive_split_size)); + } + } + decrease_stat(stats.active[stat_type], 1); + decrease_stat(stats.active_bytes[stat_type], original_block_size); + decrease_stat(stats.requested_bytes[stat_type], requested_size); + }); + } + + /** combine previously split blocks. returns the size of the subsumed block, + * or 0 on failure. */ + size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { + if (!src || src->allocated || src->event_count > 0 || + !src->stream_uses.empty() || dst->mapped != src->mapped) { + return 0; + } + + AT_ASSERT(dst->is_split() && src->is_split()); + + if (dst->prev == src) { // [src dst] + dst->ptr = src->ptr; + dst->prev = src->prev; + if (dst->prev) { + dst->prev->next = dst; + } + dst->context_when_segment_allocated = + std::move(src->context_when_segment_allocated); + } else { // [dest src] + dst->next = src->next; + if (dst->next) { + dst->next->prev = dst; + } + } + const size_t subsumed_size = src->size; + dst->size += subsumed_size; + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto erased = + src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); + delete src; + + return subsumed_size; + } + + BlockPool& get_pool(size_t size, hipStream_t stream) { + // captures_underway is a conservative guess that the current stream may be + // capturing. It's only non-empty if some thread has begun and not yet ended + // a capture, so it's usually 0, and we can short-circuit + // hipStreamCaptureStatus (which does a TLS lookup). + if (C10_UNLIKELY(!captures_underway.empty())) { + for (auto& entry : captures_underway) { + if (entry.second(stream)) { + auto it1 = graph_pools.find(entry.first); + TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); + if (size <= kSmallSize) { + return it1->second->small_blocks; + } else { + return it1->second->large_blocks; + } + } + } + } + if (size <= kSmallSize) { + return small_blocks; + } else { + return large_blocks; + } + } + + StatTypes get_stat_types_for_pool(const BlockPool& pool) { + StatTypes stat_types = {false}; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast( + pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true; + return stat_types; + } + + bool should_split(const Block* block, size_t size) { + size_t remaining = block->size - size; + if (block->pool->is_small || ZoomAllocatorConfig::expandable_segments()) { + return remaining >= kMinBlockSize; + } else { + return (size < ZoomAllocatorConfig::max_split_size()) && + (remaining > kSmallSize); + } + } + + static size_t get_allocation_size(size_t size) { + if (size <= kSmallSize) { + return kSmallBuffer; + } else if (size < kMinLargeAlloc) { + return kLargeBuffer; + } else { + return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); + } + } + + bool get_free_block(AllocParams& p) { + BlockPool& pool = *p.pool; + + if (C10_UNLIKELY( + set_fraction && + ZoomAllocatorConfig::garbage_collection_threshold() > 0.0)) { + // Track block reuse interval only when garbage collection is enabled. + ++pool.get_free_blocks_call_count; + } + auto it = pool.blocks.lower_bound(&p.search_key); + if (it == pool.blocks.end() || (*it)->stream != p.stream()) + return false; + + if ((*it)->expandable_segment_) { + if (ZoomAllocatorConfig::expandable_segments()) { + // if we are allocated to the part of the block that is expandable + // for the purposes of "best fit" we consider its size to be the size it + // can expand to, not the size it currently is. This means that we + // sometimes have to search for blocks with bigger 'size' before + // choosing this segment. + auto expandable_size = [](Block* b) { + return b->size + (b->next && !b->next->mapped ? b->next->size : 0); + }; + auto next = it; + next++; + while ((*it)->expandable_segment_ && next != pool.blocks.end() && + (*next)->stream == p.stream() && + expandable_size(*next) < expandable_size(*it)) { + it = next++; + } + } else { + // Rarely expandable segments has been turned off after we have + // already allocated some blocks as expandable. For instance, + // since we cannot share expandable memory via IPC, someone might + // temporarily disable it. In this case we need to honor this request + // by only finding non-expandable blocks + do { + it++; + } while (it != pool.blocks.end() && (*it)->expandable_segment_ && + (*it)->stream == p.stream()); + if (it == pool.blocks.end() || (*it)->stream != p.stream()) { + return false; + } + } + } + + // Do not return an oversized block for a large request + if ((p.size() < ZoomAllocatorConfig::max_split_size()) && + ((*it)->size >= ZoomAllocatorConfig::max_split_size())) + return false; + // Allow oversized block size to be rounded up but within a limit + if ((p.size() >= ZoomAllocatorConfig::max_split_size()) && + ((*it)->size >= p.size() + kLargeBuffer)) + return false; + p.block = *it; + pool.blocks.erase(it); + return true; + } + + bool trigger_free_memory_callbacks(AllocParams& p) { + bool freed_memory = false; + for (const auto& name : FreeZoomMemoryCallbacksRegistry()->Keys()) { + freed_memory |= + FreeZoomMemoryCallbacksRegistry()->Create(name)->Execute(); + } + return freed_memory; + } + + void garbage_collect_cached_blocks( + const std::shared_ptr& context) { + // Free unused cached blocks to reclaim GPU memory. + // Unlike release_cached_blocks(), this does not enforce synchronization and + // therefore should be of less overheads. + + size_t gc_threshold = static_cast( + ZoomAllocatorConfig::garbage_collection_threshold() * + static_cast(allowed_memory_maximum)); + // No need to trigger GC yet + if (total_allocated_memory <= gc_threshold) { + return; + } + const auto target_size = total_allocated_memory - gc_threshold; + size_t gc_reclaimed = 0; + + // Calculate the total age of the free-able blocks. We'll use it later to + // get "avg age" threshold. + size_t total_age = 0.0; + int freeable_block_count = 0; + for (auto& b : large_blocks.blocks) { + if (!b->is_split()) { + total_age += b->gc_count(); + ++freeable_block_count; + } + } + // No free-able blocks? + if (freeable_block_count == 0) { + return; + } + + // Repeat GC until we reach reclaim > target size. + bool block_freed = true; + while (gc_reclaimed < target_size && block_freed == true && + freeable_block_count > 0) { + // Free blocks exceeding this age threshold first. + double age_threshold = + static_cast(total_age) / freeable_block_count; + // Stop iteration if we can no longer free a block. + block_freed = false; + + // Free blocks of > avg age. Don't stop upon reaching the target_size, + // we don't want this GC to be triggered frequently. + auto it = large_blocks.blocks.begin(); + while (it != large_blocks.blocks.end()) { + Block* block = *it; + ++it; + if (!block->is_split() && + static_cast(block->gc_count()) >= age_threshold) { + block_freed = true; + gc_reclaimed += block->size; + total_age -= block->gc_count(); // Decrement the age + freeable_block_count--; // One less block that can be freed + release_block(block, context); + } + } + } + } + + // This function assumes that global lock has been taken whle calling into + // this function. We do hipMalloc sync call in this function which + // can be expensive while holding the lock. Hence, we pass-in the lock to the + // function to temporarily release the lock before hipMalloc call and acquire + // it back again after the call so that other threads dont get blocked. + bool alloc_block( + AllocParams& p, + bool isRetry, + const std::shared_ptr& ctx, + std::unique_lock& lock) { + // Defensively checks for preexisting HIP error state. + C10_ZOOM_CHECK(hipGetLastError()); + + size_t size = p.alloc_size; + void* ptr = nullptr; + + if (isRetry) { + stats.num_alloc_retries += 1; + } + + if (set_fraction && + total_allocated_memory + size > allowed_memory_maximum) { + p.err = hipErrorMemoryAllocation; + return false; + } else if ( + ZoomAllocatorConfig::expandable_segments() && + // our checkpointing logic for private pools doesn't support + // the expandable_segments_ structure yet + !p.pool->owner_PrivatePool) { + p.block = try_allocate_expandable_block( + p.device(), p.stream(), p.pool, p.size(), ctx); + if (p.block) { + p.err = hipSuccess; + } else { + p.err = hipErrorMemoryAllocation; + } + return bool(p.block); + } else { + if (ZoomAllocatorConfig::release_lock_on_hipMalloc()) { + // At scope exit, acquire the lock again. This provides safety against + // any potential exceptions in the hipMallocMaybeCapturing function. + auto sg = c10::make_scope_exit([&]() { lock.lock(); }); + lock.unlock(); + p.err = hipMallocMaybeCapturing(&ptr, size); + } else { + p.err = hipMallocMaybeCapturing(&ptr, size); + } + if (ZoomAllocatorConfig::release_lock_on_hipMalloc()) { + TORCH_CHECK( + lock.owns_lock(), "Failed to acquire lock after hipMalloc"); + } + + if (p.err != hipSuccess) { + if (p.err == hipErrorMemoryAllocation) { + // If this is the first attempt (!isRetry), we can forgive and clear + // HIP's internal error state. + // + // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH + // will take over to throw a helpful exception. The user can choose + // to catch the exception, free some stuff in their script, and + // attempt the allocation again. In this case, we can also forgive and + // clear HIP's internal error state. + (void)hipGetLastError(); + } else { + // If the error's unrelated to memory allocation, we should throw + // immediately. + C10_ZOOM_CHECK(p.err); + } + return false; + } + } + + if (p.pool->owner_PrivatePool) { + // The block is for a HIP graph's PrivatePool. + p.pool->owner_PrivatePool->hipMalloc_count++; + } + + total_allocated_memory += size; + p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); + for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { + increase_stat(stats.segment[stat_type], 1); + increase_stat(stats.reserved_bytes[stat_type], size); + }); + if (size >= ZoomAllocatorConfig::max_split_size()) + increase_stat(stats.oversize_segments, 1); + + // p.block came from new, not hipMalloc. It should not be nullptr here. + TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); + stats.num_device_alloc++; + record_trace( + TraceEntry::SEGMENT_ALLOC, + int64_t(p.block->ptr), + p.block->size, + p.stream(), + p.device(), + ctx); + p.block->context_when_segment_allocated = ctx; + return true; + } + + /** Free one or more oversize blocks to the system allocator. But only enough + * **/ + /** to satisfy the target size **/ + bool release_available_cached_blocks( + const AllocParams& p, + const std::shared_ptr& context) { + if (ZoomAllocatorConfig::max_split_size() == + std::numeric_limits::max()) + return false; + BlockPool& pool = *p.pool; + + // because of std::unique_ptr, block cannot be trivially copied + // Use constructor for search key. + Block key(p.search_key.device, p.search_key.stream, p.search_key.size); + key.size = (key.size < ZoomAllocatorConfig::max_split_size()) + ? ZoomAllocatorConfig::max_split_size() + : key.size; + auto it = pool.blocks.lower_bound(&key); + if (it == pool.blocks.end() || (*it)->stream != p.stream()) { + // No single block is large enough; free multiple oversize blocks, + // starting with the largest + if (it == pool.blocks.begin()) + return false; + size_t totalReleased = 0; + --it; // Back up one item. Now on the largest block for the correct + // stream + while ((totalReleased < key.size) && + ((*it)->size >= ZoomAllocatorConfig::max_split_size()) && + ((*it)->stream == p.stream())) { + auto cur = it; + totalReleased += (*it)->size; + if (it != pool.blocks.begin()) { + --it; + release_block(*cur, context); + } else { + release_block(*cur, context); + break; + } + } + if (totalReleased < key.size) + return false; + } else { + release_block(*it, context); + } + return true; + } + + bool release_cached_blocks(const std::shared_ptr& context) { + // First ensure that all blocks that can't currently be allocated due to + // outstanding events are returned to the pool. + synchronize_and_free_events(context); + + // Free all non-split cached blocks to system allocator + release_blocks(large_blocks, context); + release_blocks(small_blocks, context); + + for (auto it = graph_pools_freeable.begin(); + it != graph_pools_freeable.end();) { + // See notifyCaptureDestroy for the strategy here. + TORCH_INTERNAL_ASSERT(it->second->use_count == 0); + release_blocks(it->second->small_blocks, context); + release_blocks(it->second->large_blocks, context); + if (it->second->hipMalloc_count == 0) { + auto erase_count = graph_pools.erase(it->first); + TORCH_INTERNAL_ASSERT(erase_count == 1); + it = graph_pools_freeable.erase(it); + } else { + ++it; + } + } + + return true; + } + + void release_expandable_segment(Block* block) { + TORCH_INTERNAL_ASSERT( + block->size == block->expandable_segment_->size(), + "block disagrees with segment"); + TORCH_INTERNAL_ASSERT(!block->mapped); + auto it = std::find( + expandable_segments_.begin(), + expandable_segments_.end(), + block->expandable_segment_); + TORCH_INTERNAL_ASSERT(it != expandable_segments_.end()); + expandable_segments_.erase(it); + block->pool->unmapped.erase(block); + delete block->expandable_segment_; + delete block; + } + + void release_block( + Block* block, + const std::shared_ptr& context) { + TORCH_INTERNAL_ASSERT(!block->expandable_segment_); + stats.num_device_free++; + record_trace( + TraceEntry::SEGMENT_FREE, + int64_t(block->ptr), + block->size, + block->stream, + block->device, + context ? context : block->context_when_segment_allocated); + + C10_ZOOM_CHECK(hipFree((void*)block->ptr)); + total_allocated_memory -= block->size; + + auto* pool = block->pool; + if (pool->owner_PrivatePool) { + // The hipFreed block belonged to a HIP graph's PrivatePool. + TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->hipMalloc_count > 0); + pool->owner_PrivatePool->hipMalloc_count--; + } + + StatTypes stat_types = get_stat_types_for_pool(*pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + decrease_stat(stats.segment[stat_type], 1); + decrease_stat(stats.reserved_bytes[stat_type], block->size); + }); + + if (block->size >= ZoomAllocatorConfig::max_split_size()) + decrease_stat(stats.oversize_segments, 1); + pool->blocks.erase(block); + delete block; + } + + void unmap_block( + Block* block, + const std::shared_ptr& context) { + auto unmapped = block->expandable_segment_->unmap( + SegmentRange{block->ptr, block->size}); + if (unmapped.size == 0) { + return; + } + block->pool->blocks.erase(block); + + ptrdiff_t before_size = + static_cast(unmapped.ptr) - static_cast(block->ptr); + if (before_size > 0) { + // prev? -> before_free -> block + Block* before_free = new Block( + block->device, block->stream, before_size, block->pool, block->ptr); + before_free->expandable_segment_ = block->expandable_segment_; + before_free->splice(block->prev, block); + block->pool->insert_into_blocks(before_free); + } + + auto after_size = block->size - (before_size + unmapped.size); + if (after_size > 0) { + // block -> after_free -> next? + Block* after_free = new Block( + block->device, + block->stream, + after_size, + block->pool, + static_cast(unmapped.ptr) + unmapped.size); + after_free->expandable_segment_ = block->expandable_segment_; + after_free->splice(block, block->next); + block->pool->insert_into_blocks(after_free); + } + + block->ptr = unmapped.ptr; + block->size = unmapped.size; + block->mapped = false; + + try_merge_blocks(block, block->prev, *block->pool); + try_merge_blocks(block, block->next, *block->pool); + block->pool->unmapped.insert(block); + + // update statistics + total_allocated_memory -= unmapped.size; + StatTypes stat_types = get_stat_types_for_pool(*block->pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); + }); + + stats.num_device_free++; + record_trace( + TraceEntry::SEGMENT_UNMAP, + int64_t(unmapped.ptr), + unmapped.size, + block->stream, + block->device, + context ? context : block->context_when_segment_allocated); + } + void release_blocks( + BlockPool& pool, + const std::shared_ptr& context) { + std::vector to_unmap; + // Frees all non-split blocks + auto it = pool.blocks.begin(); + while (it != pool.blocks.end()) { + Block* block = *it; + ++it; + if (block->expandable_segment_) { + // unmapping will mutate the free pool + // so just gather what needs to be freed + // to avoid invalidating the iterator + to_unmap.push_back(block); + } else if (!block->prev && !block->next) { + release_block(block, context); + } + } + for (Block* block : to_unmap) { + unmap_block(block, context); + if (!block->prev && !block->next) { + release_expandable_segment(block); + } + } + } + + EventPool::Event create_event_internal(c10::DeviceIndex idx) { + // Leak the event pool to avoid shutdown issues. + static auto* event_pool = new EventPool(); + return event_pool->get(idx); + } + + void synchronize_and_free_events( + const std::shared_ptr& context) { + // Synchronize on outstanding events and then free associated blocks. + stats.num_sync_all_streams++; + + // This function syncs, so capture should not be underway. Might as well + // make sure capture-deferred end of life events get processed too. + TORCH_INTERNAL_ASSERT(captures_underway.empty()); + insert_events_deferred_until_no_capture(); + + for (auto& st : hip_events) { + for (auto& e : st.second) { + EventPool::Event event = std::move(e.first); + Block* block = e.second; + + C10_ZOOM_CHECK(hipEventSynchronize(*event)); + + block->event_count--; + if (block->event_count == 0) { + free_block(block, context); + } + } + } + + hip_events.clear(); + } + + void insert_events(Block* block) { + c10::DeviceIndex prev_device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&prev_device)); + + stream_set streams(std::move(block->stream_uses)); + AT_ASSERT(block->stream_uses.empty()); + for (auto& stream : streams) { + C10_ZOOM_CHECK(c10::zoom::SetDevice(stream.device_index())); + + EventPool::Event event = create_event_internal(stream.device_index()); + C10_ZOOM_CHECK(hipEventRecord(*event, stream.stream())); + + block->event_count++; + hip_events[stream].emplace_back(std::move(event), block); + } + + C10_ZOOM_CHECK(c10::zoom::MaybeSetDevice(prev_device)); + } + + void insert_events_deferred_until_no_capture() { + if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) { + for (auto* block : needs_events_deferred_until_no_capture) { + TORCH_INTERNAL_ASSERT(!block->stream_uses.empty()); + insert_events(block); + } + needs_events_deferred_until_no_capture.clear(); + } + } + + void process_events(const std::shared_ptr& context) { + insert_events_deferred_until_no_capture(); + + // Process outstanding hipEvents. Events that are completed are + // removed from the queue, and the 'event_count' for the + // corresponding allocation is decremented. We maintain a separate + // list of events per stream to avoid head-of-line delays if one + // or more streams has long-running operations. + + // Iterate over different streams. + for (auto it = hip_events.begin(); it != hip_events.end();) { + // Iterate over this stream's (event, block) pairs. + while (!it->second.empty()) { + auto& e = it->second.front(); + EventPool::Event event = std::move(e.first); + Block* block = e.second; + + hipError_t err = C10_ZOOM_ERROR_HANDLED(hipEventQuery(*event)); + if (err == hipErrorNotReady) { + // ignore and clear the error if not ready + (void)hipGetLastError(); + // Return the ownership of the Event (unique ptr) + e.first = std::move(event); + break; + } else if (err != hipSuccess) { + C10_ZOOM_CHECK(err); + } + + block->event_count--; + if (block->event_count == 0) { + free_block(block, context); + } + it->second.pop_front(); + } + + if (it->second.empty()) { + it = hip_events.erase(it); + } else { + it++; + } + } + } + + // Iterates over sizes of all memory blocks for given device in given pool + void cache_info_aux(const BlockPool& pool, size_t* largest) { + for (const auto& block : pool.blocks) { + const auto blocksize = block->size; + if (blocksize > *largest) { + *largest = blocksize; + } + } + } + + void record_trace( + TraceEntry::Action action, + size_t addr, + size_t size, + hipStream_t stream, + c10::DeviceIndex device, + std::shared_ptr context) { + if (!record_history && trace_trackers_.empty()) + return; + + auto te = TraceEntry( + action, + device, + addr, + size, + stream, + getApproximateTime(), + record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr); + + // Callbacks should not include any Pytorch call + for (const auto& cb : trace_trackers_) { + cb(te); + } + + if (record_history) { + if (alloc_trace->size() < alloc_trace_max_entries_) { + alloc_trace->emplace_back(te); + } else { + (*alloc_trace)[alloc_trace_next++] = te; + if (alloc_trace_next == alloc_trace_max_entries_) { + alloc_trace_next = 0; + } + } + } + } +}; + +// Returns whether to force all allocations to bypass the caching allocator and +// go straight to hipMalloc. This setting is useful when debugging GPU memory +// errors, since the caching allocator foils cuda-memcheck. +bool forceUncachedAllocator() { + static bool force_uncached = + getenv("PYTORCH_NO_ZOOM_MEMORY_CACHING") != nullptr; + return force_uncached; +} + +static void uncached_delete(void* ptr) { + if (TORCH_SDT_IS_ENABLED(free)) { + TORCH_SDT_WITH_SEMAPHORE(free, ptr); + } + + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_memory_deallocation( + c10::DeviceType::PrivateUse1, reinterpret_cast(ptr)); + } + C10_ZOOM_CHECK(hipFree(ptr)); +} + +void local_raw_delete(void* ptr); + +class NativeCachingAllocator : public ZoomAllocator { + private: + // Shard allocation region to have independent mutexes to reduce contention. + static constexpr size_t kNumMutexShard = 67; + + // TODO: use std::hardware_destructive_interference_size once available + struct alignas(64) AlignedMutex { + std::mutex m; + }; + + std::array mutex; + + // allocated blocks by device pointer + std::array, kNumMutexShard> + allocated_blocks; + + static size_t get_mutex_shard_id(void* ptr) { + return twang_mix64((size_t)ptr) % kNumMutexShard; + } + + void add_allocated_block(Block* block) { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + const auto mutex_shard_id = get_mutex_shard_id(block->ptr); + std::lock_guard lock(mutex[mutex_shard_id].m); + allocated_blocks[mutex_shard_id][block->ptr] = block; + } + + c10::ApproximateClockToUnixTimeConverter clock_converter; + + public: + std::vector> device_allocator; + + Block* get_allocated_block(void* ptr, bool remove = false) { + const auto mutex_shard_id = get_mutex_shard_id(ptr); + std::lock_guard lock(mutex[mutex_shard_id].m); + auto it = allocated_blocks[mutex_shard_id].find(ptr); + if (it == allocated_blocks[mutex_shard_id].end()) { + return nullptr; + } + Block* block = it->second; + if (remove) { + allocated_blocks[mutex_shard_id].erase(it); + } + return block; + } + + void init(int device_count) override { + const auto size = static_cast(device_allocator.size()); + if (size < device_count) { + device_allocator.resize(device_count); + for (const auto i : c10::irange(size, device_count)) { + device_allocator[i] = std::make_unique(); + } + } + } + + bool initialized() override { + return !device_allocator.empty(); + } + + /** allocates a block which is safe to use from the provided stream */ + void malloc( + void** devPtr, + c10::DeviceIndex device, + size_t size, + hipStream_t stream) { + TORCH_INTERNAL_ASSERT( + 0 <= device && static_cast(device) < device_allocator.size(), + "Allocator not initialized for device ", + device, + ": did you call init?"); + Block* block = device_allocator[device]->malloc(device, size, stream); + add_allocated_block(block); + *devPtr = (void*)block->ptr; + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_memory_allocation( + c10::DeviceType::PrivateUse1, reinterpret_cast(*devPtr)); + } + } + + void free(void* ptr) { + if (!ptr) { + return; + } + Block* block = get_allocated_block(ptr, true /* remove */); + if (!block) { + TORCH_CHECK(false, "invalid device pointer: ", ptr); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_memory_deallocation( + c10::DeviceType::PrivateUse1, reinterpret_cast(block->ptr)); + } + device_allocator[block->device]->free(block); + } + + void setMemoryFraction(double fraction, c10::DeviceIndex device) override { + TORCH_INTERNAL_ASSERT( + 0 <= device && static_cast(device) < device_allocator.size(), + "Allocator not initialized for device ", + device, + ": did you call init?"); + TORCH_INTERNAL_ASSERT( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within (0, 1)."); + C10_ZOOM_CHECK(c10::zoom::SetDevice(device)); + device_allocator[device]->setMemoryFraction(fraction); + } + + void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) override { + for (auto& allocator : device_allocator) { + allocator->recordHistory( + enabled, context_recorder, alloc_trace_max_entries, when); + } + } + + bool isHistoryEnabled() override { + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + return device_allocator[device]->isHistoryEnabled(); + } + + bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) override { + return device_allocator[device]->checkPoolLiveAllocations( + mempool_id, expected_live_allocations); + } + + void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override { + for (auto& allocator : device_allocator) { + allocator->attachOutOfMemoryObserver(observer); + } + } + + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override { + for (auto& allocator : device_allocator) { + allocator->attachAllocatorTraceTracker(tracker); + } + } + + void emptyCache() override { + for (auto& da : device_allocator) + da->emptyCache(); + } + + void* getBaseAllocation(void* ptr, size_t* outSize) override { + Block* block = get_allocated_block(ptr); + if (!block) { + TORCH_CHECK(false, "invalid device pointer: ", ptr); + } + return device_allocator[block->device]->getBaseAllocation(block, outSize); + } + + void recordStream(const DataPtr& ptr, zoom::ZoomStream stream) override { + // Empty tensor's storage().data() might be a null ptr. As there is no + // blocks associated with those tensors, it is fine to do nothing here. + if (!ptr.get()) { + return; + } + + // If a tensor is not allocated by this instance, simply skip + // This usually happens when HIP tensors are shared across processes, + // we have implemented reference counting based sharing mechanism to + // guarantee tensors won't be accidentally freed by one process while + // they are still being used in another + if (ptr.get_deleter() != &local_raw_delete) + return; + + Block* block = get_allocated_block(ptr.get()); + // block must not be null reaching here + TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found"); + device_allocator[block->device]->recordStream(block, stream); + } + + SnapshotInfo snapshot() override { + // Set-up converter to convert timestamps from tsc to microseconds. + auto tsc_to_ns = clock_converter.makeConverter(); + auto tsc_to_us = [=](approx_time_t t_approx) { + return tsc_to_ns(t_approx) / 1000; + }; + + SnapshotInfo result; + for (auto& da : device_allocator) { + result.device_traces.emplace_back(da->trace(tsc_to_us)); + auto snap = da->snapshot(); + result.segments.insert(result.segments.end(), snap.begin(), snap.end()); + } + + auto& md = result.config_metadata; + md.garbage_collection_threshold = + ZoomAllocatorConfig::garbage_collection_threshold(); + md.max_split_size = ZoomAllocatorConfig::max_split_size(); + md.pinned_num_register_threads = + ZoomAllocatorConfig::pinned_num_register_threads(); + md.expandable_segments = ZoomAllocatorConfig::expandable_segments(); + md.release_lock_on_malloc = + ZoomAllocatorConfig::release_lock_on_hipMalloc(); + md.pinned_use_host_register = + ZoomAllocatorConfig::pinned_use_zoom_host_register(); + md.last_allocator_settings = ZoomAllocatorConfig::last_allocator_settings(); + md.roundup_power2_divisions = + ZoomAllocatorConfig::roundup_power2_divisions(); + + return result; + } + + std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) override { + return device_allocator[device]->getCheckpointState(id); + } + + /** + * @brief Checkpoint the private pool state identified in `as` to its prior + * state + * + * @param device - device of the pool to manipulate + * @param as - allocator state + * @param stale_live_storages - storages of tensors which are currently + * allocated but which will be not be allocated after the checkpoint is set. + * For these storages we will remove their deleter function. + * @return CheckpointDelta - Freed Pointers and DataPtrs that contain deleter + * functions for all allocated blocks in the new checkpoint state. + */ + CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr as) override { + std::shared_ptr pps = + std::dynamic_pointer_cast(as); + + TORCH_CHECK(pps, "Expected PrivatePoolState"); + + auto rr = device_allocator[device]->setCheckpointPoolState(*pps); + + CheckpointDelta cpd; + for (void* ptr : rr.allocations_freed) { + get_allocated_block(ptr, /*remove*/ true); + cpd.ptrs_freed.push_back(ptr); + } + for (Block* block : rr.allocations_created) { + add_allocated_block(block); + cpd.dataptrs_allocd.emplace_back( + block->ptr, + block->ptr, + &local_raw_delete, + Device(DeviceType::PrivateUse1, device)); + } + + return cpd; + } + + DataPtr allocate(size_t size) override { + constexpr size_t one_exa_bytes = 1152921504606846976ULL; + TORCH_CHECK_WITH( + OutOfMemoryError, + size < one_exa_bytes, + "HIP out of memory. Tried to allocate more than 1EB memory."); + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* devPtr = nullptr; + void (*deleteFunc)(void*) = &local_raw_delete; + ZoomStream stream = zoom::getCurrentZoomStream(device); + + if (forceUncachedAllocator()) { + deleteFunc = &uncached_delete; + + // Deliberately don't use hipMallocMaybeCapturing here, to force an error + // if someone tries to use forceUncachedAllocator while capturing. + C10_ZOOM_CHECK(hipMalloc(&devPtr, size)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_memory_allocation( + c10::DeviceType::PrivateUse1, reinterpret_cast(devPtr)); + } + } else { + if (size != 0) { + this->malloc(&devPtr, device, size, stream); + } + } + + if (size && TORCH_SDT_IS_ENABLED(malloc)) { + TORCH_SDT_WITH_SEMAPHORE(malloc, devPtr, device, size, stream.id()); + } + + return {devPtr, devPtr, deleteFunc, Device(DeviceType::PrivateUse1, device)}; + } + DeleterFnPtr raw_deleter() const override { + if (forceUncachedAllocator()) { + return &uncached_delete; + } else { + return &local_raw_delete; + } + } + void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override { + device_allocator[device]->cacheInfo(largestBlock); + } + void assertValidDevice(c10::DeviceIndex device) { + const auto device_num = device_allocator.size(); + TORCH_CHECK( + 0 <= device && device < static_cast(device_num), + "Invalid device argument ", + device, + ": did you call init?"); + } + + DeviceStats getDeviceStats(c10::DeviceIndex device) override { + assertValidDevice(device); + return device_allocator[device]->getStats(); + } + + void resetAccumulatedStats(c10::DeviceIndex device) override { + assertValidDevice(device); + device_allocator[device]->resetAccumulatedStats(); + } + + void resetPeakStats(c10::DeviceIndex device) override { + assertValidDevice(device); + device_allocator[device]->resetPeakStats(); + } + // HIPGraph interactions + void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) override { + assertValidDevice(device); + device_allocator[device]->beginAllocateToPool( + std::move(mempool_id), std::move(filter)); + } + + void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) + override { + assertValidDevice(device); + device_allocator[device]->endAllocateToPool(mempool_id); + } + + void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override { + assertValidDevice(device); + device_allocator[device]->releasePool(std::move(mempool_id)); + } + + void* raw_alloc(size_t nbytes) override { + if (nbytes == 0) { + return nullptr; + } + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* r = nullptr; + malloc(&r, device, nbytes, zoom::getCurrentZoomStream(device)); + return r; + } + + void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override { + if (nbytes == 0) { + return nullptr; + } + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* r = nullptr; + malloc(&r, device, nbytes, stream); + return r; + } + + void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) + override { + c10::zoom::ZoomGuard device_guard(dev); + hipError_t err = hipDeviceEnablePeerAccess(dev_to_access, 0); + if (err == hipErrorPeerAccessAlreadyEnabled) { + // ignore and clear the error if access was already enabled + (void)hipGetLastError(); + } else { + C10_ZOOM_CHECK(err); + } + device_allocator[dev_to_access]->addPeerAccess(dev); + } + + hipError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) override { + if (p2p_enabled || // memcpy ok because memory is mapped in both devices + srcDevice == dstDevice || // memcpy ok on a single device + // memcpy ok because both dst and src must have come from hipMalloc + (!device_allocator[dstDevice]->hasAllocatedExpandableSegments() && + !device_allocator[srcDevice]->hasAllocatedExpandableSegments())) { + return hipMemcpyAsync(dst, src, count, hipMemcpyDeviceToDevice, stream); + } + // when p2p is not enabled, only hipMemcpyPeerAsync correctly handles + // memory not allocated via hipMalloc + return hipMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream); + } + + void raw_delete(void* ptr) override { + this->free(ptr); + } + + // In HIP IPC, sender sends a tensor to receiver, getIpcDevPtr + // is called by the receiving process to map the HIP memory from the sending + // process into its own address space. + // + // HIP IPC only allows sharing a big memory block associated with a + // hipIpcMemHandle_t and it can be opened only **once** per context per + // process. There can be multiple types of storage in the same IPC mem block, + // so we must cache the device ptr to construct typed storage as it comes. + // + // ipcMemHandle_to_devptr maps a hipIpcMemHandle_t to a device pointer in the + // process that can be used to access the memory block in the sender process. + // It only saves a weak_ptr of the device pointer in the map, the shared_ptr + // will be used to reconstruct all storages in this hipMalloc allocation. And + // it will deleted in cudaIpcCloseMemHandle when its reference count is 0. + // + std::mutex IpcMutex; + ska::flat_hash_map> ipcMemHandle_to_devptr; + std::shared_ptr getIpcDevPtr(std::string handle) override { + std::lock_guard lock(IpcMutex); + + auto iter = ipcMemHandle_to_devptr.find(handle); + if (iter != ipcMemHandle_to_devptr.end()) { + auto devptr = iter->second.lock(); + if (devptr) + return devptr; + } + // This ipcMemHandle hasn't been opened, or already expired, open it to + // enable IPC access to that mem block. + void* dev = nullptr; + auto ipc_handle = + reinterpret_cast(handle.c_str()); + C10_ZOOM_CHECK(hipIpcOpenMemHandle( + &dev, *ipc_handle, hipIpcMemLazyEnablePeerAccess)); + // devPtr has to be deleted in same device when created. + c10::DeviceIndex curr_device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&curr_device)); + auto sp = + std::shared_ptr(dev, [handle, curr_device, this](void* ptr) { + zoom::ZoomGuard device_guard(curr_device); + std::lock_guard deleter_lock(IpcMutex); + C10_ZOOM_CHECK(hipIpcCloseMemHandle(ptr)); + ipcMemHandle_to_devptr.erase(handle); + }); + std::weak_ptr wp = sp; + // To eliminate an additional search, we can use insert(). + // It doesn't overwrite when key already exists(ptr expired). + // But in the deleter for sp we erased the entry, + // this should be safe to do now. + ipcMemHandle_to_devptr.insert(iter, {handle, wp}); + + return sp; + } + std::string name() override { + return "native"; + } + void copy_data(void* dest, const void* src, std::size_t count) const final { + C10_ZOOM_CHECK( + hipMemcpy(dest, src, count, hipMemcpyKind::hipMemcpyDeviceToDevice)); + } +}; + +NativeCachingAllocator allocator; + +void local_raw_delete(void* ptr) { + if (TORCH_SDT_IS_ENABLED(free)) { + TORCH_SDT_WITH_SEMAPHORE(free, ptr); + } + + allocator.free(ptr); +} + +} // namespace Native +// Size pretty-printer +std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +namespace ZoomMallocAsync { +// If this is put in its own header file, it gets incorrectly renamed in HIPify. +ZoomAllocator* allocator(); + +} // namespace ZoomMallocAsync + +struct BackendStaticInitializer { + // Parses env for backend at load time, duplicating some logic from + // ZoomAllocatorConfig. ZoomAllocatorConfig double-checks it later (at + // runtime). Defers verbose exceptions and error checks, including Cuda + // version checks, to ZoomAllocatorConfig's runtime doublecheck. If this + // works, maybe we should move all of ZoomAllocatorConfig here? + ZoomAllocator* parseEnvForBackend() { + const char* val = getenv("PYTORCH_ZOOM_ALLOC_CONF"); + if (val != nullptr) { + const std::string config(val); + + std::regex exp("[\\s,]+"); + std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); + std::sregex_token_iterator end; + std::vector options(it, end); + + for (auto option : options) { + std::regex exp2("[:]+"); + std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); + std::sregex_token_iterator end2; + std::vector kv(it2, end2); + if (kv.size() >= 2) { + if (kv[0] == "backend") { + if (kv[1] == "hipMallocAsync") + return ZoomMallocAsync::allocator(); + if (kv[1] == "native") + return &Native::allocator; + } + } + } + } + return &Native::allocator; + } + + BackendStaticInitializer() { + auto r = parseEnvForBackend(); + allocator.store(r); + } +}; + +std::atomic allocator; +BackendStaticInitializer backend_static_initializer; + +Allocator* getZoomAllocator() { + return c10::zoom::ZoomCachingAllocator::get(); +} + +REGISTER_PU1_ALLOCATOR(getZoomAllocator); +// namespace { +// static PrivateUse1AllocatorRegisterer g_allocator_d(getZoomAllocator); +// } + +} // namespace zoom::ZoomCachingAllocator + +} // namespace c10 \ No newline at end of file diff --git a/c10/zoom/ZoomCachingAllocator.h b/c10/zoom/ZoomCachingAllocator.h new file mode 100644 index 00000000000000..5311a04726d9fb --- /dev/null +++ b/c10/zoom/ZoomCachingAllocator.h @@ -0,0 +1,480 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Caching allocator will execute every registered callback if it unable to find +// block inside of already allocated area. +class FreeMemoryCallback { + public: + virtual ~FreeMemoryCallback() = default; + virtual bool Execute() = 0; +}; + +C10_DECLARE_REGISTRY(FreeZoomMemoryCallbacksRegistry, FreeMemoryCallback); +#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(FreeZoomMemoryCallbacksRegistry, name, __VA_ARGS__); +} // namespace c10 + // +// TODO: Turn this into an honest to goodness class. I briefly attempted to do +// this, but it was a bit irritating to figure out how to also correctly +// apply pimpl pattern so I didn't have to leak any internal implementation +// details in the header (ZoomCachingAllocator could be made a pimpl, but +// you also need to appropriately define a class which is a subclass +// of Allocator. Not impossible, but required a bit more surgery than +// I wanted to do at the time.) +// +// Why is this using a namespace rather than old-style THCCachingAllocator_ +// prefix? Mostly because it made the HIPify rules easier to write; _ is +// not counted as a word boundary, so you would otherwise have to list each +// of these functions. + +namespace c10::zoom::ZoomCachingAllocator { + +extern const size_t kLargeBuffer; + +struct Stat { + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +typedef std::array(StatType::NUM_TYPES)> StatArray; + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from cudaMalloc(). + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via cudaFree) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to CUDA malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of CUDA allocation calls. This includes both cuMemMap + // and cudaMalloc. + int64_t num_device_alloc = 0; + + // COUNT: total number of CUDA free calls. This includes both cuMemUnmap + // and cudaFree. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +typedef std::shared_ptr (*CreateContextFn)(); + +// Struct containing info of an allocation block (i.e. a fractional part of a +// cudaMalloc).. +struct BlockInfo { + size_t size = 0; + size_t requested_size = 0; + int32_t gc_counter = 0; + bool allocated = false; + bool active = false; + std::shared_ptr + context_when_allocated; // per-watcher context +}; + +// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc). +struct SegmentInfo { + c10::DeviceIndex device = 0; + size_t address = 0; + size_t total_size = 0; + size_t requested_size = 0; // unrounded, actually requested size + size_t allocated_size = 0; + size_t active_size = 0; + hipStream_t stream = nullptr; + bool is_large = false; + bool is_expandable = false; + MempoolId_t owner_private_pool_id = {0, 0}; + std::vector blocks; + std::shared_ptr context_when_allocated; +}; + +struct AllocatorState { + virtual ~AllocatorState() = default; +}; + +union trace_time_ { + time_t t_; + approx_time_t approx_t_; +}; + +struct TraceEntry { + enum Action { + ALLOC, // API made to the caching allocator for new memory + FREE_REQUESTED, // API call made to the caching allocator to free memory + FREE_COMPLETED, // The allocator might have to delay a free because + // it is still in use on another stream via record_stream + // This event is generated when a free actually completes. + SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS + SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to + // defragment or empty_caches) + SEGMENT_MAP, // a call to cuMemMap (used with expandable_segments) + SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments) + SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace + // events + OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free + // bytes reported by cuda) + }; + TraceEntry( + Action action, + c10::DeviceIndex device, + size_t addr, + size_t size, + hipStream_t stream, + approx_time_t time, + std::shared_ptr context = nullptr) + : action_(action), + device_(device), + addr_(addr), + context_(std::move(context)), + stream_(stream), + size_(size) { + time_.approx_t_ = time; + } + Action action_; + c10::DeviceIndex device_; + size_t addr_; // for OOM, this is the amount of free bytes reported by cuda + std::shared_ptr context_; + hipStream_t stream_{}; + size_t size_; + trace_time_ time_{}; +}; + +struct AllocatorConfigInfo { + double garbage_collection_threshold; + size_t max_split_size; + size_t pinned_num_register_threads; + bool expandable_segments; + bool release_lock_on_malloc; + bool pinned_use_host_register; + std::string last_allocator_settings; + std::vector roundup_power2_divisions; +}; + +struct SnapshotInfo { + std::vector segments; + std::vector> device_traces; + AllocatorConfigInfo config_metadata; +}; + +// returns the pointers freed in the pool +// and the pointers allocated. Note: a pointer +// may appear in both freed and allocated +struct CheckpointDelta { + std::vector ptrs_freed; + std::vector dataptrs_allocd; +}; + +enum struct RecordContext { + NEVER = 0, + STATE = 1, // only keep stacks for active allocations + ALLOC = 2, // additionally keep stacks for allocations in the trace history + ALL = 3, // additionally record stacks for when something is freed +}; + +// Size pretty-printer +std::string format_size(uint64_t size); + +using OutOfMemoryObserver = std::function; + +using AllocatorTraceTracker = std::function; + +class ZoomAllocator : public Allocator { + public: + virtual void* raw_alloc(size_t nbytes) = 0; + virtual void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) = 0; + virtual void raw_delete(void* ptr) = 0; + virtual void init(int device_count) = 0; + virtual bool initialized() = 0; + virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual void emptyCache() = 0; + virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; + virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; + virtual void recordStream(const DataPtr&, ZoomStream stream) = 0; + virtual DeviceStats getDeviceStats(c10::DeviceIndex device) = 0; + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + virtual void resetPeakStats(c10::DeviceIndex device) = 0; + virtual SnapshotInfo snapshot() = 0; + virtual void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) = 0; + virtual void endAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) = 0; + virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0; + // returns true if the allocated blocks are equal to expected live allocations + virtual bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + TORCH_CHECK( + false, + name(), + " does not yet support checkPoolLiveAllocations. " + "If you need it, please file an issue describing your use case."); + } + virtual std::shared_ptr getIpcDevPtr(std::string handle) = 0; + virtual bool isHistoryEnabled() { + TORCH_CHECK( + false, + name(), + " does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); + } + virtual void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) = 0; + virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; + + // Attached AllocatorTraceTracker callbacks will be called while the + // per-device allocator lock is held. Any additional locks taken from within + // the callback must be proven to always have the lock order that never + // triggers a deadlock. In particular, Python's GIL may be held when + // calling the allocator so it is unsafe to try to acquire the GIL in this + // callback. + virtual void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) = 0; + + virtual void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) = 0; + + // memory not allocated from cudaMalloc cannot be copied + // across devices using cudaMemcpyAsync if peer to peer access is disabled. + // instead it requires cudaMemcpyAsyncPeer + // with P2P Enabled, all combinations work + // with P2P Disabled: + // cudaMalloc cudaMallocAsync/cuMemMap + // cudaMemcpyAsyncPeer works works + // cudaMemcpyAsync works error + + // This function performs chooses to use the Peer version of + // memcpy if required based on where the allocated put dst/src. + virtual hipError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) = 0; + virtual std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) = 0; + virtual CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) = 0; + virtual std::string name() = 0; +}; + +// Allocator object, statically initialized +// See BackendInitializer in ZoomCachingAllocator.cpp. +// Atomic loads on x86 are just normal loads, +// (atomic stores are different), so reading this value +// is no different than loading a pointer. +extern std::atomic allocator; + +inline ZoomAllocator* get() { + return allocator.load(); +} + +// Called directly by clients. +inline void* raw_alloc(size_t nbytes) { + return get()->raw_alloc(nbytes); +} + +inline void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) { + return get()->raw_alloc_with_stream(nbytes, stream); +} + +inline void raw_delete(void* ptr) { + return get()->raw_delete(ptr); +} + +inline void init(int device_count) { + return get()->init(device_count); +} + +inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { + return get()->setMemoryFraction(fraction, device); +} + +inline void emptyCache() { + return get()->emptyCache(); +} + +inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) { + return get()->cacheInfo(device, largestBlock); +} + +inline void* getBaseAllocation(void* ptr, size_t* size) { + return get()->getBaseAllocation(ptr, size); +} + +inline void recordStream(const DataPtr& dataPtr, ZoomStream stream) { + return get()->recordStream(dataPtr, stream); +} + +inline DeviceStats getDeviceStats(c10::DeviceIndex device) { + return get()->getDeviceStats(device); +} + +inline void resetAccumulatedStats(c10::DeviceIndex device) { + return get()->resetAccumulatedStats(device); +} + +inline void resetPeakStats(c10::DeviceIndex device) { + return get()->resetPeakStats(device); +} + +inline SnapshotInfo snapshot() { + return get()->snapshot(); +} + +inline std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) { + return get()->getCheckpointState(device, id); +} + +inline CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) { + return get()->setCheckpointPoolState(device, std::move(pps)); +} + +// CUDAGraph interactions +inline void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + get()->beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->endAllocateToPool(device, mempool_id); +} + +inline void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) { + return get()->recordHistory( + enabled, context_recorder, alloc_trace_max_entries, when); +} + +inline bool isHistoryEnabled() { + return get()->isHistoryEnabled(); +} + +inline bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + return get()->checkPoolLiveAllocations( + device, mempool_id, expected_live_allocations); +} + +inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { + return get()->attachOutOfMemoryObserver(std::move(observer)); +} + +inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { + return get()->attachAllocatorTraceTracker(std::move(tracker)); +} + +inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return get()->releasePool(device, mempool_id); +} +// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE +inline std::shared_ptr getIpcDevPtr(std::string handle) { + return get()->getIpcDevPtr(std::move(handle)); +} + +inline std::string name() { + return get()->name(); +} + +inline hipError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) { + return get()->memcpyAsync( + dst, dstDevice, src, srcDevice, count, stream, p2p_enabled); +} + +inline void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) { + return get()->enablePeerAccess(dev, dev_to_access); +} + +} // namespace c10::zoom::ZoomCachingAllocator \ No newline at end of file diff --git a/c10/zoom/ZoomDeviceAssertionHost.cpp b/c10/zoom/ZoomDeviceAssertionHost.cpp new file mode 100644 index 00000000000000..a3b2207df73a4b --- /dev/null +++ b/c10/zoom/ZoomDeviceAssertionHost.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#ifdef TORCH_USE_ZOOM_DSA +#include +#include +#endif + +#define C10_ZOOM_CHECK_WO_DSA(EXPR) \ + do { \ + const hipError_t __err = EXPR; \ + c10::zoom::c10_zoom_check_implementation( \ + static_cast(__err), \ + __FILE__, \ + __func__, /* Line number data type not well-defined between \ + compilers, so we perform an explicit cast */ \ + static_cast(__LINE__), \ + false); \ + } while (0) + +namespace c10::zoom { + +namespace { + +#ifdef TORCH_USE_ZOOM_DSA +/// Get current device id +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for ZoomKernelLaunchRegistry +int dsa_get_device_id() { + c10::DeviceIndex device = -1; + C10_ZOOM_CHECK_WO_DSA(c10::zoom::GetDevice(&device)); + return device; +} + +/// Get a device's compute capability - note that this dangerously assumes +/// that if one CUDA GPU supports device-side assertions they all do. This is +/// probably fine since the latest CUDA GPU that doesn't support UVM is the +/// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be +/// rare enough that the defensive +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for ZoomKernelLaunchRegistry +int dsa_get_device_compute_capability(const int device_num) { + int compute_capability = -1; + C10_ZOOM_CHECK_WO_DSA(hipDeviceGetAttribute( + &compute_capability, hipDevAttrComputeCapabilityMajor, device_num)); + return compute_capability; +} +#endif + +/// Get the number of HIP devices +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for ZoomKernelLaunchRegistry +int dsa_get_device_count() { + int device_count = -1; + C10_ZOOM_CHECK_WO_DSA(c10::zoom::GetDeviceCount(&device_count)); + return device_count; +} + +bool dsa_check_if_all_devices_support_managed_memory() { +// It looks as though this'll work best on CUDA GPUs with Pascal +// architectures or newer, per +// https://developer.nvidia.com/blog/unified-memory-cuda-beginners/ +#ifdef TORCH_USE_ZOOM_DSA + for (const auto i : c10::irange(dsa_get_device_count())) { + if (dsa_get_device_compute_capability(i) < 6) { + return false; + } + } + return true; +#else + return false; +#endif +} + +bool env_flag_set(const char* env_var_name) { + const char* const env_string = std::getenv(env_var_name); + return (env_string == nullptr) ? false : std::strcmp(env_string, "0"); +} + +/// Deleter for UVM/managed memory pointers +void uvm_deleter(DeviceAssertionsData* uvm_assertions_ptr) { + // Ignore error in destructor + if (uvm_assertions_ptr) { + C10_ZOOM_IGNORE_ERROR(hipFree(uvm_assertions_ptr)); + } +} + +} // namespace + +/// Check that kernels ran correctly by checking the message buffer. BLOCKING. +std::string c10_retrieve_device_side_assertion_info() { +#ifdef TORCH_USE_ZOOM_DSA + const auto& launch_registry = ZoomKernelLaunchRegistry::get_singleton_ref(); + if (!launch_registry.enabled_at_runtime) { + return "Device-side assertion tracking was not enabled by user."; + } else if (!launch_registry.do_all_devices_support_managed_memory) { + return "Device-side assertions disabled because not all devices support managed memory."; + } + + // Hack that saves a lot of challenging sync logic. + // The GPU increments the number of errors it's observed and the CPU can see + // that happening immediately which means we can make it here before the GPU + // is done writing information about those errors to memory. + // A short pause gives it time to finish. Since something's gone wrong, this + // pause shouldn't affect perf. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // The snapshot causes a brief block. That's okay because this function only + // executes if something's gone wrong such that speed is no longer a priority. + const auto launch_data = launch_registry.snapshot(); + const auto& assertion_data = launch_data.first; + const auto& launch_infos = launch_data.second; + + std::stringstream oss; + + oss << "Looking for device-side assertion failure information...\n"; + + // Loop over each device that could be managed by the process + for (const auto device_num : c10::irange(assertion_data.size())) { + const auto& assertion_data_for_device = assertion_data.at(device_num); + + // Did anything fail? + const auto failures_found = std::min( + assertion_data_for_device.assertion_count, + C10_ZOOM_DSA_ASSERTION_COUNT); + if (failures_found == 0) { + continue; + } + + // Something failed, let's talk about that + oss << failures_found + << " HIP device-side assertion failures were found on GPU #" + << device_num << "!" << std::endl; + if (assertion_data_for_device.assertion_count > + C10_ZOOM_DSA_ASSERTION_COUNT) { + oss << "But at least " << assertion_data_for_device.assertion_count + << " assertion failures occurred on the device" << std::endl; + oss << "Adjust `C10_ZOOM_DSA_ASSERTION_COUNT` if you need more assertion failure info" + << std::endl; + } + + for (const auto i : c10::irange(failures_found)) { + const auto& self = assertion_data_for_device.assertions[i]; + const auto& launch_info = launch_infos[self.caller % launch_infos.size()]; + oss << "Assertion failure " << i << std::endl; + oss << " GPU assertion failure message = " << self.assertion_msg + << std::endl; + oss << " File containing assertion = " << self.filename << ":" + << self.line_number << std::endl; + oss << " Device function containing assertion = " << self.function_name + << std::endl; + oss << " Thread ID that failed assertion = [" << self.thread_id[0] << "," + << self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl; + oss << " Block ID that failed assertion = [" << self.block_id[0] << "," + << self.block_id[1] << "," << self.block_id[2] << "]" << std::endl; + if (launch_info.generation_number == self.caller) { + oss << " File containing kernel launch = " + << launch_info.launch_filename << ":" << launch_info.launch_linenum + << std::endl; + oss << " Function containing kernel launch = " + << launch_info.launch_function << std::endl; + oss << " Name of kernel launched that led to failure = " + << launch_info.kernel_name << std::endl; + oss << " Device that launched kernel = " << launch_info.device + << std::endl; + oss << " Stream kernel was launched on = " << launch_info.stream + << std::endl; + oss << " Backtrace of kernel launch site = "; + if (launch_registry.gather_launch_stacktrace) { + oss << "Launch stacktracing disabled." << std::endl; + } else { + oss << "\n" << launch_info.launch_stacktrace << std::endl; + } + } else { + oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `ZoomKernelLaunchRegistry::max_size`." + << std::endl; + } + } + } + return oss.str(); +#else + return "Compile with `TORCH_USE_ZOOM_DSA` to enable device-side assertions.\n"; +#endif +} + +ZoomKernelLaunchRegistry::ZoomKernelLaunchRegistry() + : do_all_devices_support_managed_memory( + dsa_check_if_all_devices_support_managed_memory()), + gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()), + enabled_at_runtime(check_env_for_dsa_enabled()) { + for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) { + uvm_assertions.emplace_back(nullptr, uvm_deleter); + } + + kernel_launches.resize(max_kernel_launches); +} + +bool ZoomKernelLaunchRegistry::check_env_for_enable_launch_stacktracing() + const { + return env_flag_set("PYTORCH_ZOOM_DSA_STACKTRACING"); +} + +bool ZoomKernelLaunchRegistry::check_env_for_dsa_enabled() const { + return env_flag_set("PYTORCH_USE_ZOOM_DSA"); +} + +uint32_t ZoomKernelLaunchRegistry::insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id) { +#ifdef TORCH_USE_ZOOM_DSA + if (!enabled_at_runtime) { + return 0; + } + + const auto backtrace = gather_launch_stacktrace ? c10::get_backtrace() : ""; + + const std::lock_guard lock(read_write_mutex); + + const auto my_gen_number = generation_number++; + // TODO: It would probably be good to get a stack trace here so that + // we can better indicate which launch caused the failure. + kernel_launches[my_gen_number % max_kernel_launches] = { + launch_filename, + launch_function, + launch_linenum, + backtrace, + kernel_name, + dsa_get_device_id(), + stream_id, + my_gen_number}; + return my_gen_number; +#else + return 0; +#endif +} + +std::pair, std::vector> +ZoomKernelLaunchRegistry::snapshot() const { + // This is likely to be the longest-lasting hold on the mutex, but + // we only expect it to be called in cases where we're already failing + // and speed is no longer important + const std::lock_guard lock(read_write_mutex); + + std::vector device_assertions_data; + for (const auto& x : uvm_assertions) { + if (x) { + device_assertions_data.push_back(*x); + } else { + device_assertions_data.emplace_back(); + } + } + + return std::make_pair(device_assertions_data, kernel_launches); +} + +DeviceAssertionsData* ZoomKernelLaunchRegistry:: + get_uvm_assertions_ptr_for_current_device() { +#ifdef TORCH_USE_ZOOM_DSA + if (!enabled_at_runtime) { + return nullptr; + } + + const auto device_num = dsa_get_device_id(); + + // If we've already set up this GPU with managed memory, return a pointer to + // the managed memory. This is a lock-free quick-return path. + if (uvm_assertions.at(device_num)) { + return uvm_assertions.at(device_num).get(); + } + + // Need a lock here so there's not race-condition on creating the new device + // assertions buffer + const std::lock_guard lock(gpu_alloc_mutex); + + // If we've already set up this GPU with managed memory, return a pointer to + // the managed memory. This locked path ensures that the device memory is + // allocated only once + if (uvm_assertions.at(device_num)) { + return uvm_assertions.at(device_num).get(); + } + + // Otherwise, set up the GPU to be able to use the device-side assertion + // system + DeviceAssertionsData* uvm_assertions_ptr = nullptr; + + C10_ZOOM_CHECK_WO_DSA( + hipMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData))); + + C10_ZOOM_CHECK_WO_DSA(hipMemAdvise( + uvm_assertions_ptr, + sizeof(DeviceAssertionsData), + hipMemAdviseSetPreferredLocation, + hipCpuDeviceId)); + + // GPU will establish direct mapping of data in CPU memory, no page faults + // will be generated + C10_ZOOM_CHECK_WO_DSA(hipMemAdvise( + uvm_assertions_ptr, + sizeof(DeviceAssertionsData), + hipMemAdviseSetAccessedBy, + hipCpuDeviceId)); + + // Initialize the memory from the CPU; otherwise, pages may have to be created + // on demand. We think that UVM documentation indicates that first access may + // not honor preferred location, which would be bad, if true, because we want + // this memory on the host so we can access it post-assertion. Initializing + // this on the CPU helps ensure that that's where the memory will live. + *uvm_assertions_ptr = DeviceAssertionsData(); + + // Ownership and lifetime management of `uvm_assertions_ptr` now passes to the + // uvm_assertions unique_ptr vector + uvm_assertions.at(device_num).reset(uvm_assertions_ptr); + + return uvm_assertions_ptr; +#else + return nullptr; +#endif +} + +ZoomKernelLaunchRegistry& ZoomKernelLaunchRegistry::get_singleton_ref() { + static ZoomKernelLaunchRegistry launch_registry; + return launch_registry; +} + +bool ZoomKernelLaunchRegistry::has_failed() const { + for (const auto& x : uvm_assertions) { + if (x && x->assertion_count > 0) { + return true; + } + } + return false; +} + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomDeviceAssertionHost.h b/c10/zoom/ZoomDeviceAssertionHost.h new file mode 100644 index 00000000000000..867c2e626a1370 --- /dev/null +++ b/c10/zoom/ZoomDeviceAssertionHost.h @@ -0,0 +1,164 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#ifdef USE_ZOOM +#define TORCH_USE_ZOOM_DSA +#endif + +/// Number of assertion failure messages we can store. If this is too small +/// threads will fail silently. +constexpr int C10_ZOOM_DSA_ASSERTION_COUNT = 10; +constexpr int C10_ZOOM_DSA_MAX_STR_LEN = 512; + +namespace c10::zoom { + +/// Holds information about any device-side assertions that fail. +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionData { + /// Stringification of the assertion + // NOLINTNEXTLINE(*-c-arrays) + char assertion_msg[C10_ZOOM_DSA_MAX_STR_LEN]{}; + /// File the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char filename[C10_ZOOM_DSA_MAX_STR_LEN]{}; + /// Name of the function the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char function_name[C10_ZOOM_DSA_MAX_STR_LEN]{}; + /// Line number the assertion was at + int line_number{}; + /// Number uniquely identifying the kernel launch that triggered the assertion + uint32_t caller{}; + /// block_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t block_id[3]{}; + /// third_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t thread_id[3]{}; +}; + +/// Used to hold assertions generated by the device +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionsData { + /// Total number of assertions found; a subset of thse will be recorded + /// in `assertions` + int32_t assertion_count{}; + /// An array of assertions that will be written to in a race-free manner + // NOLINTNEXTLINE(*-c-arrays) + DeviceAssertionData assertions[C10_ZOOM_DSA_ASSERTION_COUNT]{}; +}; + +/// Use to hold info about kernel launches so that we can run kernels +/// asynchronously and still associate launches with device-side +/// assertion failures +struct ZoomKernelLaunchInfo { + /// Filename of the code where the kernel was launched from + const char* launch_filename; + /// Function from which the kernel was launched + const char* launch_function; + /// Line number of where the code was launched from + uint32_t launch_linenum; + /// Backtrace of where the kernel was launched from, only populated if + /// ZoomKernelLaunchRegistry::gather_launch_stacktrace is True + std::string launch_stacktrace; + /// Kernel that was launched + const char* kernel_name; + /// Device the kernel was launched on + int device; + /// Stream the kernel was launched on + int32_t stream; + /// A number that uniquely identifies the kernel launch + uint64_t generation_number; +}; + +/// Circular buffer used to hold information about kernel launches +/// this is later used to reconstruct how a device-side kernel assertion failure +/// occurred ZoomKernelLaunchRegistry is used as a singleton +class C10_ZOOM_API ZoomKernelLaunchRegistry { + private: + /// Assume that this is the max number of kernel launches that might ever be + /// enqueued across all streams on a single device + static constexpr int max_kernel_launches = 1024; + /// How many kernel launch infos we've inserted. Used to ensure that circular + /// queue doesn't provide false information by always increasing, but also to + /// mark where we are inserting into the queue +#ifdef TORCH_USE_ZOOM_DSA + uint64_t generation_number = 0; +#endif + /// Shared mutex between writer and accessor to ensure multi-threaded safety. + mutable std::mutex read_write_mutex; + /// Used to ensure prevent race conditions in GPU memory allocation + mutable std::mutex gpu_alloc_mutex; + /// Pointer to managed memory keeping track of device-side assertions. There + /// is one entry for each possible device the process might work with. Unused + /// entries are nullptrs. We could also use an unordered_set here, but this + /// vector design will be faster and the wasted memory is small since we + /// expect the number of GPUs per node will always be small + std::vector< + std::unique_ptr> + uvm_assertions; + /// A single circular buffer holds information about every kernel launch the + /// process makes across all devices. + std::vector kernel_launches; + bool check_env_for_enable_launch_stacktracing() const; + bool check_env_for_dsa_enabled() const; + + public: + ZoomKernelLaunchRegistry(); + /// Register a new kernel launch and obtain a generation number back to be + /// passed to the kernel + uint32_t insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id); + /// Get copies of the kernel launch registry and each device's assertion + /// failure buffer so they can be inspected without raising race conditions + std:: + pair, std::vector> + snapshot() const; + /// Get a pointer to the current device's assertion failure buffer. If no such + /// buffer exists then one is created. This means that the first kernel launch + /// made on each device will be slightly slower because memory allocations are + /// required + DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); + /// Gets the global singleton of the registry + static ZoomKernelLaunchRegistry& get_singleton_ref(); + /// If not all devices support DSA, we disable it + const bool do_all_devices_support_managed_memory = false; + /// Whether or not to gather stack traces when launching kernels + bool gather_launch_stacktrace = false; + /// Whether or not host-side DSA is enabled or disabled at run-time + /// Note: Device-side code cannot be enabled/disabled at run-time + bool enabled_at_runtime = false; + /// Whether or not a device has indicated a failure + bool has_failed() const; +#ifdef TORCH_USE_ZOOM_DSA + const bool enabled_at_compile_time = true; +#else + const bool enabled_at_compile_time = false; +#endif +}; + +std::string c10_retrieve_device_side_assertion_info(); + +} // namespace c10::zoom + +// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH +// requires the same input arguments. We introduce the following macro to +// standardize these. +#define TORCH_DSA_KERNEL_ARGS \ + [[maybe_unused]] c10::zoom::DeviceAssertionsData *const assertions_data, \ + [[maybe_unused]] uint32_t assertion_caller_id + +// This macro can be used to pass the DSA arguments onward to another +// function +#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id \ No newline at end of file diff --git a/c10/zoom/ZoomException.cpp b/c10/zoom/ZoomException.cpp new file mode 100644 index 00000000000000..fe752478067c3b --- /dev/null +++ b/c10/zoom/ZoomException.cpp @@ -0,0 +1,88 @@ +#include +#include + +#include + +namespace c10::zoom { + +void c10_zoom_check_implementation( + const int32_t err, + const char* filename, + const char* function_name, + const int line_number, + const bool include_device_assertions) { + const auto hip_error = static_cast(err); + const auto hip_kernel_failure = include_device_assertions + ? c10::zoom::ZoomKernelLaunchRegistry::get_singleton_ref().has_failed() + : false; + + if (C10_LIKELY(hip_error == hipSuccess && !hip_kernel_failure)) { + return; + } + + auto error_unused C10_UNUSED = hipGetLastError(); + (void)error_unused; + + std::string check_message; +#ifndef STRIP_ERROR_MESSAGES + check_message.append("Zoom error: "); + check_message.append(hipGetErrorString(hip_error)); + // checks if HIP_LAUNCH_BLOCKING in HIP, unimplemented here for now + check_message.append(c10::zoom::get_hip_check_suffix()); + check_message.append("\n"); + if (include_device_assertions) { + check_message.append(c10_retrieve_device_side_assertion_info()); + } else { + check_message.append( + "Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers."); + } + check_message.append( + "Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers."); +#endif + + TORCH_CHECK(false, check_message); +} + +} // namespace c10::zoom + + +namespace at::zoom { + namespace blas { + const char* _hipblasGetErrorEnum(hipblasStatus_t error) { + if (error == HIPBLAS_STATUS_SUCCESS) { + return "HIPBLAS_STATUS_SUCCESS"; + } + if (error == HIPBLAS_STATUS_NOT_INITIALIZED) { + return "HIPBLAS_STATUS_NOT_INITIALIZED"; + } + if (error == HIPBLAS_STATUS_ALLOC_FAILED) { + return "HIPBLAS_STATUS_ALLOC_FAILED"; + } + if (error == HIPBLAS_STATUS_INVALID_VALUE) { + return "HIPBLAS_STATUS_INVALID_VALUE"; + } + if (error == HIPBLAS_STATUS_ARCH_MISMATCH) { + return "HIPBLAS_STATUS_ARCH_MISMATCH"; + } + if (error == HIPBLAS_STATUS_MAPPING_ERROR) { + return "HIPBLAS_STATUS_MAPPING_ERROR"; + } + if (error == HIPBLAS_STATUS_EXECUTION_FAILED) { + return "HIPBLAS_STATUS_EXECUTION_FAILED"; + } + if (error == HIPBLAS_STATUS_INTERNAL_ERROR) { + return "HIPBLAS_STATUS_INTERNAL_ERROR"; + } + if (error == HIPBLAS_STATUS_NOT_SUPPORTED) { + return "HIPBLAS_STATUS_NOT_SUPPORTED"; + } +#ifdef HIPBLAS_STATUS_LICENSE_ERROR + if (error == HIPBLAS_STATUS_LICENSE_ERROR) { + return "HIPBLAS_STATUS_LICENSE_ERROR"; + } +#endif + return ""; +} + + } // namespace blas +} //namespace at::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomException.h b/c10/zoom/ZoomException.h new file mode 100644 index 00000000000000..aaeb140f0d76e3 --- /dev/null +++ b/c10/zoom/ZoomException.h @@ -0,0 +1,185 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// Note [CHECK macro] +// ~~~~~~~~~~~~~~~~~~ +// This is a macro so that AT_ERROR can get accurate __LINE__ +// and __FILE__ information. We could split this into a short +// macro and a function implementation if we pass along __LINE__ +// and __FILE__, but no one has found this worth doing. + +// Used to denote errors from CUDA framework. +// This needs to be declared here instead util/Exception.h for proper conversion +// during hipify. +namespace c10 { +class ZoomError : public c10::Error { + using Error::Error; +}; +} // namespace c10 + +#define C10_ZOOM_CHECK(EXPR) \ + do { \ + const hipError_t __err = EXPR; \ + c10::zoom::c10_zoom_check_implementation( \ + static_cast(__err), \ + __FILE__, \ + __func__, /* Line number data type not well-defined between \ + compilers, so we perform an explicit cast */ \ + static_cast(__LINE__), \ + true); \ + } while (0) + +#define C10_ZOOM_CHECK_WARN(EXPR) \ + do { \ + const hipError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != hipSuccess)) { \ + auto error_unused C10_UNUSED = hipGetLastError(); \ + (void)error_unused; \ + TORCH_WARN("ZOOM warning: ", hipGetErrorString(__err)); \ + } \ + } while (0) + +// Indicates that a CUDA error is handled in a non-standard way +#define C10_ZOOM_ERROR_HANDLED(EXPR) EXPR + +// Intentionally ignore a CUDA error +#define C10_ZOOM_IGNORE_ERROR(EXPR) \ + do { \ + const hipError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != hipSuccess)) { \ + hipError_t error_unused C10_UNUSED = hipGetLastError(); \ + (void)error_unused; \ + } \ + } while (0) + +// Clear the last CUDA error +#define C10_ZOOM_CLEAR_ERROR() \ + do { \ + hipError_t error_unused C10_UNUSED = hipGetLastError(); \ + (void)error_unused; \ + } while (0) + +// This should be used directly after every kernel launch to ensure +// the launch happened correctly and provide an early, close-to-source +// diagnostic if it didn't. +#define C10_ZOOM_KERNEL_LAUNCH_CHECK() C10_ZOOM_CHECK(hipGetLastError()) + +/// Launches a HIP kernel appending to it all the information need to handle +/// device-side assertion failures. Checks that the launch was successful. +#define TORCH_DSA_KERNEL_LAUNCH( \ + kernel, blocks, threads, shared_mem, stream, ...) \ + do { \ + auto& launch_registry = \ + c10::zoom::ZoomKernelLaunchRegistry::get_singleton_ref(); \ + kernel<<>>( \ + __VA_ARGS__, \ + launch_registry.get_uvm_assertions_ptr_for_current_device(), \ + launch_registry.insert( \ + __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } while (0) + +#define HIP_DRIVER_CHECK(EXPR) \ + do { \ + hipError_t __err = EXPR; \ + if (__err != hipSuccess) { \ + AT_ERROR("HIP driver error: ", static_cast(__err), "\nErrorName: ", hipGetErrorName(__err), "\nCause: ", hipGetErrorString(__err)); \ + } \ + } while (0) + +#define ZOOM_HIPRTC_CHECK(EXPR) \ + do { \ + hiprtcResult __err = EXPR; \ + if (__err != HIPRTC_SUCCESS) { \ + if (static_cast(__err) != 7) { \ + AT_ERROR("HIPRTC error: ", hiprtcGetErrorString(__err)); \ + } else { \ + AT_ERROR("HIPRTC error: HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + } \ + } \ + } while (0) + + +#define ZOOM_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } + + +namespace at::zoom::blas { + const char* _hipblasGetErrorEnum(hipblasStatus_t error); +} + + +#define TORCH_HIPBLAS_CHECK(EXPR) \ +do { \ + hipblasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ + "HIP error: ", \ + at::zoom::blas::_hipblasGetErrorEnum(__err), \ + " when calling `" #EXPR "`"); \ +} while (0) + +#define TORCH_WARN_DISABLE_HIPBLASLT TORCH_WARN_ONCE("hipblasLt temporarily disabled in Zoom backend, using hipblas instead") +#define TORCH_CHECK_DISABLE_HIPBLAS_LT TORCH_CHECK(false, "Error: hipblasLt routine called, but hipblasLt is disabled in the Zoom backend") + +const char *hipsparseGetErrorString(hipsparseStatus_t status); + +#define TORCH_HIPSPARSE_CHECK(EXPR) \ + do { \ + hipsparseStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPSPARSE_STATUS_SUCCESS, \ + "HIP error: ", \ + hipsparseGetErrorString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +#ifdef hipsolverVersionMajor + +namespace at::zoom::solver { +C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status); + +constexpr const char* _hipsolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.zoom.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +} // namespace at::zoom::solver + +#define TORCH_HIPSOLVER_CHECK(EXPR) \ + do { \ + hipsolverStatus_t __err = EXPR; \ \ + TORCH_CHECK( \ + __err == CUSOLVER_STATUS_SUCCESS, \ + "hipsolver error: ", \ + at::zoom::solver::hipsolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::zoom::solver::_hipsolver_backend_suggestion); \ \ + } while (0) + +#else +#define TORCH_HIPSOLVER_CHECK(EXPR) EXPR +#endif + +namespace c10::zoom { + +/// In the event of a HIP failure, formats a nice error message about that +/// failure and also checks for device-side assertion failures +void c10_zoom_check_implementation( + const int32_t err, + const char* filename, + const char* function_name, + const int line_number, + const bool include_device_assertions); + +} // namespace c10::zoom diff --git a/c10/zoom/ZoomFunctions.cpp b/c10/zoom/ZoomFunctions.cpp new file mode 100644 index 00000000000000..8169ec38d89e97 --- /dev/null +++ b/c10/zoom/ZoomFunctions.cpp @@ -0,0 +1,294 @@ +#include +#include + +#include + +namespace c10::zoom { + +namespace { +// returns -1 on failure +int32_t driver_version() { + int driver_version = -1; + C10_ZOOM_IGNORE_ERROR(hipDriverGetVersion(&driver_version)); + return driver_version; +} + +int device_count_impl(bool fail_if_no_driver) { + int count = 0; + auto err = C10_ZOOM_ERROR_HANDLED(c10::zoom::GetDeviceCount(&count)); + if (err == hipSuccess) { + return count; + } + // Clear out the error state, so we don't spuriously trigger someone else. + // (This shouldn't really matter, since we won't be running very much CUDA + // code in this regime.) + hipError_t last_err C10_UNUSED = hipGetLastError(); + switch (err) { + case hipErrorNoDevice: + // Zero devices is ok here + count = 0; + break; + case hipErrorInsufficientDriver: { + auto version = driver_version(); + if (version <= 0) { + if (!fail_if_no_driver) { + // No hip driver means no devices + count = 0; + break; + } + TORCH_CHECK( + false, + "Found no ROCm driver on your system. Please check that you " + "have an AMD GPU and installed a driver from " + "https://rocm.docs.amd.com/projects/install-on-linux/en/develop/tutorial/quick-start.html#rocm-install-quick"); + } else { + TORCH_CHECK( + false, + "The ROCm driver on your system is too old (found version ", + version, + "). Please update your GPU driver by downloading and installing " + "a new version from the URL: " + "https://rocm.docs.amd.com/projects/install-on-linux/en/develop/tutorial/quick-start.html#rocm-install-quick"); + } + } break; + case hipErrorInitializationError: + TORCH_CHECK( + false, + "ROCm driver initialization failed, you might not " + "have a ROCm gpu."); + break; + case hipErrorUnknown: + TORCH_CHECK( + false, + "ZOOM unknown error - this may be due to an " + "incorrectly set up environment, e.g. changing env " + "variable ZOOM_VISIBLE_DEVICES after program start. " + "Setting the available devices to be zero."); + break; +#if C10_ASAN_ENABLED + case hipErrorMemoryAllocation: + // In ASAN mode, we know that a hipErrorMemoryAllocation error will + // pop up if compiled with hipcc (clang-hip is fine) + TORCH_CHECK( + false, + "Got 'out of memory' error while trying to initialize ZOOM. " + "ZOOM with hipcc does not work well with ASAN and it's probably " + "the reason. We will simply shut down HIP support. If you " + "would like to use GPUs, turn off ASAN."); + break; +#endif // C10_ASAN_ENABLED + default: + TORCH_CHECK( + false, + "Unexpected error from hipGetDeviceCount(). Did you run " + "some hip functions before calling NumZoomDevices() " + "that might have already set an error? Error ", + err, + ": ", + hipGetErrorString(err)); + } + return count; +} +} // namespace + +DeviceIndex device_count() noexcept { + // initialize number of devices only once + static int count = []() { + try { + auto result = device_count_impl(/*fail_if_no_driver=*/false); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many ROCm devices, DeviceIndex overflowed"); + return result; + } catch (const c10::Error& ex) { + // We don't want to fail, but still log the warning + // msg() returns the message without the stack trace + TORCH_WARN("ZOOM initialization: ", ex.msg()); + return 0; + } + }(); + return static_cast(count); +} + +DeviceIndex device_count_ensure_non_zero() { + // Call the implementation every time to throw the exception + int count = device_count_impl(/*fail_if_no_driver=*/true); + // Zero gpus doesn't produce a warning in `device_count` but we fail here + TORCH_CHECK(count, "No ROCm GPUs are available"); + TORCH_INTERNAL_ASSERT( + count <= std::numeric_limits::max(), + "Too many ROCm devices, DeviceIndex overflowed"); + return static_cast(count); +} + +DeviceIndex current_device() { + DeviceIndex cur_device = -1; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&cur_device)); + return cur_device; +} + +void set_device(DeviceIndex device) { + C10_ZOOM_CHECK(c10::zoom::SetDevice(device)); +} + +void device_synchronize() { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_device_synchronization(c10::DeviceType::PrivateUse1); + } + C10_ZOOM_CHECK(hipDeviceSynchronize()); +} + +// this function has to be called from callers performing cuda synchronizing +// operations, to raise proper error or warning +void warn_or_error_on_sync() { + if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) { + TORCH_CHECK(false, "called a synchronizing HIP operation"); + } else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) { + TORCH_WARN("called a synchronizing HIP operation"); + } +} + +std::optional getDeviceIndexWithPrimaryContext() { + // check current device first + auto current_device_index = current_device(); + if (current_device_index >= 0) { + if (hasPrimaryContext(current_device_index)) { + return current_device_index; + } + } + for (const auto device_index : c10::irange(c10::zoom::device_count())) { + if (device_index == current_device_index) + continue; + if (hasPrimaryContext(device_index)) { + return device_index; + } + } + return c10::nullopt; +} + +namespace _internal { +bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) { + TORCH_CHECK(false, "Should never been called - did you remember to lazyInitPrivateUse1()?"); +} +bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext; + +// Private api to be called from CUDAHooks.cpp +void setHasPrimaryContext(bool (*func)(DeviceIndex)) { + hasPrimaryContext = func ? func : dummyHasPrimaryContext; +} +} // namespace _internal + +bool hasPrimaryContext(DeviceIndex device_index) { + return _internal::hasPrimaryContext(device_index); +} + +// Wrappers for raw CUDA device management functions +hipError_t GetDeviceCount(int* dev_count) { + return hipGetDeviceCount(dev_count); +} + +// This is a codepath for CUDA 12 that comes with a critical change in behavior +// of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context +// lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice` +// is called. This can lead to dramatic consequences and pollute the device +// memory in distributed runs. To avoid unnecessary context creation a new +// function called `MaybeSetDevice` was introduced. This function is to be +// called in device guard destructor and at the exit of torch.cuda.device +// context manager. The behavior of `MaybeSetDevice` is quite simple, it calls +// to `cudaSetDevice` if context already exist or if context was not allocated +// on targeted device it simply saves the device index. This way we can keep +// PyTorch backward compatible for applications like this: +// +// ``` +// import torch +// x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this +// call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0 +// ``` + +thread_local DeviceIndex targetDeviceIndex = -1; + +hipError_t GetDevice(DeviceIndex* device) { + if (targetDeviceIndex >= 0) { + *device = targetDeviceIndex; + return hipSuccess; + } + int tmp_device = -1; + auto err = hipGetDevice(&tmp_device); + if (err == hipSuccess) { + TORCH_INTERNAL_ASSERT( + tmp_device >= 0 && + tmp_device <= std::numeric_limits::max(), + "hipGetDevice returns invalid device ", + tmp_device); + *device = static_cast(tmp_device); + } + return err; +} + +hipError_t SetDevice(DeviceIndex device) { + TORCH_CHECK(device >= 0, "device id must be positive!", device); + targetDeviceIndex = -1; + int cur_device = -1; + C10_ZOOM_CHECK(hipGetDevice(&cur_device)); + if (device == cur_device) { + return hipSuccess; + } + return hipSetDevice(device); +} + +hipError_t MaybeSetDevice(DeviceIndex device) { + if (hasPrimaryContext(device)) { + return c10::zoom::SetDevice(device); + } + targetDeviceIndex = device; + return hipSuccess; +} + +// This function always initializes the CUDA context +// on to_device +DeviceIndex ExchangeDevice(DeviceIndex to_device) { + auto cur_device = targetDeviceIndex; + targetDeviceIndex = -1; + if (cur_device < 0) { + int tmp_device = -1; + C10_ZOOM_CHECK(hipGetDevice(&tmp_device)); + cur_device = static_cast(tmp_device); + if (to_device == cur_device) { + return cur_device; + } + } + C10_ZOOM_CHECK(hipSetDevice(to_device)); + return cur_device; +} + +// This function does not initialize the CUDA context +// on to_device if it does not already exist +DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) { + int tmp_cur_device = -1; + C10_ZOOM_CHECK(hipGetDevice(&tmp_cur_device)); + TORCH_INTERNAL_ASSERT( + tmp_cur_device >= 0 && + tmp_cur_device <= std::numeric_limits::max(), + "hipGetDevice returns invalid device ", + tmp_cur_device); + auto cur_device = static_cast(tmp_cur_device); + if (to_device == tmp_cur_device) { + return cur_device; + } + if (hasPrimaryContext(to_device)) { + C10_ZOOM_CHECK(hipSetDevice(to_device)); + } else { + targetDeviceIndex = to_device; + } + return cur_device; +} + +void SetTargetDevice() { + if (targetDeviceIndex >= 0) { + C10_ZOOM_CHECK(c10::zoom::SetDevice(targetDeviceIndex)); + } +} + + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomFunctions.h b/c10/zoom/ZoomFunctions.h new file mode 100644 index 00000000000000..eb5ffa640967cd --- /dev/null +++ b/c10/zoom/ZoomFunctions.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10::zoom { + +// NB: In the past, we were inconsistent about whether or not this reported +// an error if there were driver problems are not. Based on experience +// interacting with users, it seems that people basically ~never want this +// function to fail; it should just return zero if things are not working. +// Oblige them. +// It still might log a warning for user first time it's invoked +C10_ZOOM_API DeviceIndex device_count() noexcept; + +// Version of device_count that throws is no devices are detected +C10_ZOOM_API DeviceIndex device_count_ensure_non_zero(); + +C10_ZOOM_API DeviceIndex current_device(); + +C10_ZOOM_API void set_device(DeviceIndex device); + +C10_ZOOM_API void device_synchronize(); + +C10_ZOOM_API void warn_or_error_on_sync(); + +// Raw CUDA device management functions +C10_ZOOM_API hipError_t GetDeviceCount(int* dev_count); + +C10_ZOOM_API hipError_t GetDevice(DeviceIndex* device); + +C10_ZOOM_API hipError_t SetDevice(DeviceIndex device); + +C10_ZOOM_API hipError_t MaybeSetDevice(DeviceIndex device); + +C10_ZOOM_API DeviceIndex ExchangeDevice(DeviceIndex device); + +C10_ZOOM_API DeviceIndex MaybeExchangeDevice(DeviceIndex device); + +C10_ZOOM_API void SetTargetDevice(); + +enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; + +// this is a holder for c10 global state (similar to at GlobalContext) +// currently it's used to store cuda synchronization warning state, +// but can be expanded to hold other related global state, e.g. to +// record stream usage +class WarningState { + public: + void set_sync_debug_mode(SyncDebugMode l) { + sync_debug_mode = l; + } + + SyncDebugMode get_sync_debug_mode() { + return sync_debug_mode; + } + + private: + SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; +}; + +C10_ZOOM_API __inline__ WarningState& warning_state() { + static WarningState warning_state_; + return warning_state_; +} +// the subsequent functions are defined in the header because for performance +// reasons we want them to be inline +C10_ZOOM_API void __inline__ memcpy_and_sync( + void* dst, + const void* src, + int64_t nbytes, + hipMemcpyKind kind, + hipStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + c10::DeviceType::PrivateUse1, reinterpret_cast(stream)); + } + + #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) + C10_ZOOM_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); + #else + C10_ZOOM_CHECK(hipMemcpyAsync(dst, src, nbytes, kind, stream)); + C10_ZOOM_CHECK(hipStreamSynchronize(stream)); + #endif + +} + +C10_ZOOM_API void __inline__ stream_synchronize(hipStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + c10::DeviceType::PrivateUse1, reinterpret_cast(stream)); + } + C10_ZOOM_CHECK(hipStreamSynchronize(stream)); +} + +C10_ZOOM_API bool hasPrimaryContext(DeviceIndex device_index); +C10_ZOOM_API std::optional getDeviceIndexWithPrimaryContext(); + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomGuard.h b/c10/zoom/ZoomGuard.h new file mode 100644 index 00000000000000..1f549e9cd59bf2 --- /dev/null +++ b/c10/zoom/ZoomGuard.h @@ -0,0 +1,301 @@ +#pragma once + +#include +#include +#include +// #include +#include + +namespace c10::zoom { + +// This code is kind of boilerplatey. See Note [Whither the DeviceGuard +// boilerplate] + +/// A variant of DeviceGuard that is specialized for HIP. It accepts +/// integer indices (interpreting them as HIP devices) and is a little +/// more efficient than DeviceGuard (it compiles to straight line +/// hipSetDevice/hipGetDevice calls); however, it can only be used +/// from code that links against HIP directly. +struct ZoomGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit ZoomGuard() = delete; + + /// Set the current HIP device to the passed device index. + explicit ZoomGuard(DeviceIndex device_index) : guard_(device_index) {} + + /// Sets the current HIP device to the passed device. Errors if the passed + /// device is not a HIP device. + explicit ZoomGuard(Device device) : guard_(device) {} + + // Copy is not allowed + ZoomGuard(const ZoomGuard&) = delete; + ZoomGuard& operator=(const ZoomGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + ZoomGuard(ZoomGuard&& other) = delete; + ZoomGuard& operator=(ZoomGuard&& other) = delete; + + /// Sets the HIP device to the given device. Errors if the given device + /// is not a HIP device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the HIP device to the given device. Errors if the given device + /// is not a HIP device. (This method is provided for uniformity with + /// DeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the HIP device to the given device index. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set upon construction of the guard + Device original_device() const { + return guard_.original_device(); + } + + /// Returns the last device that was set via `set_device`, if any, otherwise + /// the device passed during construction. + Device current_device() const { + return guard_.current_device(); + } + + private: + /// The guard for the current device. + c10::impl::InlineDeviceGuard guard_; +}; + +/// A variant of OptionalDeviceGuard that is specialized for HIP. See +/// ZoomGuard for when you can use this. +struct OptionalZoomGuard { + /// Create an uninitialized OptionalZoomGuard. + explicit OptionalZoomGuard() : guard_() {} + + /// Set the current HIP device to the passed Device, if it is not nullopt. + explicit OptionalZoomGuard(optional device_opt) + : guard_(device_opt) {} + + /// Set the current HIP device to the passed device index, if it is not + /// nullopt + explicit OptionalZoomGuard(optional device_index_opt) + : guard_(device_index_opt) {} + + // Copy is not allowed + OptionalZoomGuard(const OptionalZoomGuard&) = delete; + OptionalZoomGuard& operator=(const OptionalZoomGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalZoomGuard(OptionalZoomGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalZoomGuard& operator=(OptionalZoomGuard&& other) = delete; + + /// Sets the HIP device to the given device, initializing the guard if it + /// is not already initialized. Errors if the given device is not a HIP + /// device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the HIP device to the given device, initializing the guard if it is + /// not already initialized. Errors if the given device is not a HIP device. + /// (This method is provided for uniformity with OptionalDeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the HIP device to the given device index, initializing the guard if + /// it is not already initialized. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set immediately prior to initialization of the + /// guard, or nullopt if the guard is uninitialized. + optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + optional current_device() const { + return guard_.current_device(); + } + + /// Restore the original HIP device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +/// A variant of StreamGuard that is specialized for HIP. See ZoomGuard +/// for when you can use this. +struct ZoomStreamGuard { + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit ZoomStreamGuard() = delete; + + /// Set the current HIP device to the device associated with the passed + /// stream, and set the current HIP stream on that device to the passed + /// stream. Errors if the Stream is not a HIP stream. + explicit ZoomStreamGuard(Stream stream) : guard_(stream) {} + + /// Copy is disallowed + ZoomStreamGuard(const ZoomStreamGuard&) = delete; + ZoomStreamGuard& operator=(const ZoomStreamGuard&) = delete; + + /// Move is disallowed, as ZoomStreamGuard does not have an uninitialized + /// state, which is required for moves on types with nontrivial destructors. + ZoomStreamGuard(ZoomStreamGuard&& other) = delete; + ZoomStreamGuard& operator=(ZoomStreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Errors if the stream passed is not a HIP stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// on HIP, use ZoomMultiStreamGuard instead. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the HIP stream that was set at the time the guard was + /// constructed. + ZoomStream original_stream() const { + return ZoomStream(ZoomStream::UNCHECKED, guard_.original_stream()); + } + + /// Returns the most recent HIP stream that was set using this device guard, + /// either from construction, or via set_stream. + ZoomStream current_stream() const { + return ZoomStream(ZoomStream::UNCHECKED, guard_.current_stream()); + } + + /// Returns the most recent HIP device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return guard_.current_device(); + } + + /// Returns the HIP device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return guard_.original_device(); + } + + private: + c10::impl::InlineStreamGuard guard_; +}; + +/// A variant of OptionalStreamGuard that is specialized for HIP. See +/// ZoomGuard for when you can use this. +struct OptionalZoomStreamGuard { + /// Create an uninitialized guard. + explicit OptionalZoomStreamGuard() : guard_() {} + + /// Set the current HIP device to the device associated with the passed + /// stream, and set the current HIP stream on that device to the passed + /// stream. Errors if the Stream is not a HIP stream. + explicit OptionalZoomStreamGuard(Stream stream) : guard_(stream) {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit OptionalZoomStreamGuard(optional stream_opt) + : guard_(stream_opt) {} + + /// Copy is disallowed + OptionalZoomStreamGuard(const OptionalZoomStreamGuard&) = delete; + OptionalZoomStreamGuard& operator=(const OptionalZoomStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalZoomStreamGuard(OptionalZoomStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalZoomStreamGuard& operator=(OptionalZoomStreamGuard&& other) = delete; + + /// Resets the currently set HIP stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the guard if it was not previously initialized. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the HIP stream that was set at the time the guard was most + /// recently initialized, or nullopt if the guard is uninitialized. + optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return make_optional(ZoomStream(ZoomStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + /// Returns the most recent HIP stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return make_optional(ZoomStream(ZoomStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + /// Restore the original HIP device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +/// A variant of MultiStreamGuard that is specialized for HIP. +struct ZoomMultiStreamGuard { + explicit ZoomMultiStreamGuard(ArrayRef streams) + : guard_(unwrapStreams(streams)) {} + + /// Copy is disallowed + ZoomMultiStreamGuard(const ZoomMultiStreamGuard&) = delete; + ZoomMultiStreamGuard& operator=(const ZoomMultiStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + ZoomMultiStreamGuard(ZoomMultiStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + ZoomMultiStreamGuard& operator=(ZoomMultiStreamGuard&& other) = delete; + + private: + c10::impl::InlineMultiStreamGuard guard_; + + static std::vector unwrapStreams(ArrayRef zoomStreams) { + std::vector streams; + streams.reserve(zoomStreams.size()); + for (const ZoomStream& zoomStream : zoomStreams) { + streams.push_back(zoomStream); + } + return streams; + } +}; + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomMacros.h b/c10/zoom/ZoomMacros.h new file mode 100644 index 00000000000000..21492ccd4847dc --- /dev/null +++ b/c10/zoom/ZoomMacros.h @@ -0,0 +1,41 @@ +#pragma once + +// See c10/macros/Export.h for a detailed explanation of what the function +// of these macros are. We need one set of macros for every separate library +// we build. + +#ifdef _WIN32 +#if defined(C10_HIP_BUILD_SHARED_LIBS) +#define C10_ZOOM_EXPORT __declspec(dllexport) +#define C10_ZOOM_IMPORT __declspec(dllimport) +#else +#define C10_ZOOM_EXPORT +#define C10_ZOOM_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_ZOOM_EXPORT __attribute__((__visibility__("default"))) +#else // defined(__GNUC__) +#define C10_ZOOM_EXPORT +#endif // defined(__GNUC__) +#define C10_ZOOM_IMPORT C10_ZOOM_EXPORT +#endif // _WIN32 + +// This one is being used by libc10_zoom.so +#ifdef C10_ZOOM_BUILD_MAIN_LIB +#define C10_ZOOM_API C10_ZOOM_EXPORT +#else +#define C10_ZOOM_API C10_ZOOM_IMPORT +#endif + +/** + * The maximum number of GPUs that we recognizes. Increasing this beyond the + * initial limit of 16 broke Caffe2 testing, hence the ifdef guards. + * This value cannot be more than 128 because our DeviceIndex is a uint8_t. +o */ +#ifdef FBCODE_CAFFE2 +// fbcode depends on this value being 16 +#define C10_COMPILE_TIME_MAX_GPUS 16 +#else +#define C10_COMPILE_TIME_MAX_GPUS 120 +#endif \ No newline at end of file diff --git a/c10/zoom/ZoomMallocAsyncAllocator.cpp b/c10/zoom/ZoomMallocAsyncAllocator.cpp new file mode 100644 index 00000000000000..938be05125e8f5 --- /dev/null +++ b/c10/zoom/ZoomMallocAsyncAllocator.cpp @@ -0,0 +1,899 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace c10::zoom::ZoomCachingAllocator::ZoomMallocAsync { + +// CUDA device allocator that uses hipMallocAsync to implement +// the same interface as ZoomCachingAllocator.cpp. + +// Designed to be safe for CUDA graph capture. +// Interactions with CUDA graph capture are mediated by +// notifyCaptureBegin +// notifyCaptureAboutToEnd +// notifyCaptureEnded +// notifyCaptureDestroy + +// Implementation details, not declared in ZoomCachingAllocator.h +namespace { + +// General helpers + +struct UsageStream { + hipStream_t stream; + c10::DeviceIndex device; + UsageStream() = default; + UsageStream(hipStream_t s, c10::DeviceIndex d) : stream(s), device(d) {} + UsageStream(const UsageStream& us) = default; + UsageStream(const UsageStream&& us) noexcept + : stream(us.stream), device(us.device) {} + UsageStream& operator=(UsageStream other) { + stream = other.stream; + device = other.device; + return *this; + } +}; + +bool operator==(const UsageStream& lhs, const UsageStream& rhs) { + return (lhs.stream == rhs.stream) && (lhs.device == rhs.device); +} + +struct UsageStreamHash { + size_t operator()(const UsageStream& us) const noexcept { + return std::hash{}(us.stream) + size_t(us.device); + } +}; + +struct PtrUsage { + // recorded_streams holds side usage streams added by record_stream calls. + // In other words, it does NOT include the original creation stream. + ska::flat_hash_set recorded_streams; + UsageStream creation_stream{}; + uint64_t size; + bool captured; + PtrUsage(uint64_t s, bool c) : size(s), captured(c) {} +}; + +int device_count = 0; +// these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp +// because they'll only be flipped by functions that have locked the mutex. +std::vector devs_initialized_flags; +std::vector dummy_unifying_free_streams; + +// Possible micro-optimization: +// Some accesses to ptr_info are read-only. +// We could let those be concurrent with a shared_mutex and +// have concurrent calls take a shared_lock. +// Keeping it simple with an ordinary mutex for now. +std::mutex general_mutex; + +/** + * Note [Avoid freeing uncaptured ptrs during CUDA graph capture] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * During CUDA graph capture, it's illegal to call hipFreeAsync + * on a pointer that came from a non-captured hipMallocAsync. + * Unfortunately, Python being what it is, it's impossible to be + * sure no uncaptured tensor will ever have its destructor called + * in a capturing region. + * We avoid errors by + * 1. remembering if allocated pointers were captured or uncaptured + * 2. during capture, if we detect an attempt to free an uncaptured + * allocation on a capturing stream, don't free it immediately, + * just remember it and defer its hipFreeAsync call to after + * the end of capture (specifically, to notifyCaptureEnded). + */ + +using PtrInfo = ska::flat_hash_map; +PtrInfo ptr_info; +std::vector ungraphed_ptrs_defer_free_until_no_capture; + +// These two help setMemoryFraction limit the amount of memory +// used by PyTorch in particular (as opposed to other libraries +// in the same process that might be sharing the same hipMemPool_t). +std::vector pytorch_used_bytes; +std::vector pytorch_memory_limits; + +// Graph-specific helpers + +/** + * Note [Avoid dangling free streams during CUDA graph capture] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * During capture, all stream dependencies must branch out from + * the stream on which capture began and rejoin this initial stream + * before capture ends. + * The user rigs desired forking and joining with event waits. + * But it's hard to be sure when tensor destructors get called relative + * to the final joins. + * For example, suppose a user + * forks work stream B from initial capture stream A + * creates a tensor T in B + * joins by syncing A with B + * ends capture. + * All well and good, right? Maybe not: maybe T went out of scope + * and its destructor got called AFTER the rejoin, leaving the graph with + * "unjoined work": a dangling hipFreeAsync node in stream B. + * Ensuring that all tensor destructors for all side stream tensors + * are called before side streams rejoin the main stream is + * difficult. The user might have to add a bunch of explicit + * "del"s at the right spots in code that was fine for ordinary + * eager execution. + * Fortunately, we can spare the user this burden: + * during capture, we remember _all_ free streams, + * and manually rejoin them with the capture stream during + * notifyCaptureAboutToEnd. + * This approach is heavy-handed, but hopefully capture only needs to + * happen once, so we don't mind being heavy-handed. + * + * TODO: If, someday, we augment the graph bindings to support recapture + * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update + * (eg, as a way to accommodate dynamic params) we should think more + * carefully about the CPU overhead of remembering and rejoining + * all free streams during capture. Maybe it's not a big deal. + */ +std::unordered_set capture_free_streams; +bool capture_underway = false; + +// Implementation functions + +// Assumes the caller holds general_mutex +inline void lazy_init_device(c10::DeviceIndex device) { + if (!devs_initialized_flags[device]) { + ZoomGuard g(device); + + // See "Retaining memory in the pool" here: + // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/ + hipMemPool_t mempool = nullptr; + C10_ZOOM_CHECK(hipDeviceGetDefaultMemPool(&mempool, device)); + uint64_t threshold = UINT64_MAX; + C10_ZOOM_CHECK(hipMemPoolSetAttribute( + mempool, hipMemPoolAttrReleaseThreshold, &threshold)); + + // I think all these are on by default, but I want to enable them + // explicitly to ensure awareness. + int enable = 1; + C10_ZOOM_CHECK(hipMemPoolSetAttribute( + mempool, hipMemPoolReuseFollowEventDependencies, &enable)); + C10_ZOOM_CHECK(hipMemPoolSetAttribute( + mempool, hipMemPoolReuseAllowOpportunistic, &enable)); + C10_ZOOM_CHECK(hipMemPoolSetAttribute( + mempool, hipMemPoolReuseAllowInternalDependencies, &enable)); + + // Grabs a stream from the current device to use as the "unifier" free + // stream for allocations that end up used on multiple streams. + const auto dufs = getStreamFromPool(); + dummy_unifying_free_streams[device] = + UsageStream(dufs.stream(), dufs.device_index()); + + pytorch_used_bytes[device] = 0; + pytorch_memory_limits[device] = UINT64_MAX; + + devs_initialized_flags[device] = true; + } +} + +inline void sync_raw(hipStream_t dependency, hipStream_t dependent) { + // ZoomCachingAllocator.cpp uses raw hip events, as do we. + hipEvent_t event = nullptr; + C10_ZOOM_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming)); + C10_ZOOM_CHECK(hipEventRecord(event, dependency)); + C10_ZOOM_CHECK(hipStreamWaitEvent(dependent, event, 0)); + C10_ZOOM_CHECK(hipEventDestroy(event)); +} + +// Assumes the caller holds general_mutex +inline void free_impl(PtrInfo::iterator& it) { + // Possible micro-optimization: If we did a value-copy here, we could move + // ptr_info.erase(it) up here and drop the lock immediately. + const auto& recorded_streams = it->second.recorded_streams; + const auto& creation_stream = it->second.creation_stream; + + // If the usage stream is a null (default) stream, + // hipFreeAsync infers the device from the ambient context, + // so we need to set the right ambient context. + ZoomGuard g(creation_stream.device); + + if (recorded_streams.empty()) { + // ptr was only used on one stream, which must have been + // the original allocation stream. + // Frees ptr in the original allocation stream. + + C10_ZOOM_CHECK(hipFreeAsync(it->first, creation_stream.stream)); + + if (C10_UNLIKELY(capture_underway)) { + // See Note [Avoid dangling free streams during CUDA graph capture] + capture_free_streams.insert(creation_stream); + } + } else { + // ptr was used on many streams. We don't know which was the most recent. + // There could even have been multiple most recent usage streams acting + // on different regions of the memory. + // But hipFreeAsync only accepts a single most recent usage stream. + // We can still safely free ptr with a trick: + // Use a dummy "unifying stream", sync the unifying stream with all of + // ptr's usage streams, and pass the dummy stream to hipFreeAsync. + + // Retrieves the dummy "unifier" stream from the device + // on which the pointer was originally allocated. + auto dummy_unifying_free_stream = + dummy_unifying_free_streams[creation_stream.device]; + TORCH_INTERNAL_ASSERT( + dummy_unifying_free_stream.device == creation_stream.device); + + // we're already on creation_stream.device, no need to re-guard + sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream); + + // The number of usage streams is typically small (low single digits) + for (const auto& recorded_stream : recorded_streams) { + // Logic here accommodates the chance some of the usage streams were on + // other devices, which is possible if some usage kernels accessed the + // memory via p2p. + + // hipEventRecord requires that the input event and stream are on the + // same device. + ZoomGuard g_usage(recorded_stream.device); + + sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream); + } + + // Frees ptr in the dummy "unifier" stream. + C10_ZOOM_CHECK(hipFreeAsync(it->first, dummy_unifying_free_stream.stream)); + // At this point, unless dummy_unifying_free_stream happens to alias some + // future user stream, the allocation is only available for "opportunistic" + // reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the + // point that all events recorded on all usage streams have resolved from + // the CPU's perspective. In theory, we could remove the need for the driver + // to do this tracking by e.g. replacing + // hipStreamWaitEvent(dummy_unifying_free_stream.stream, event); + // with + // hipStreamWaitEvent(creation_stream.stream, event); + // then hipFreeAsyncing straight back into creation_stream.stream, + // but this forces a potentially false dependency of creation_stream.stream + // on all the recorded_streams. + + if (C10_UNLIKELY(capture_underway)) { + // See Note [Avoid dangling free streams during CUDA graph capture] + capture_free_streams.emplace( + dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device); + } + } + + pytorch_used_bytes[creation_stream.device] -= it->second.size; + + ptr_info.erase(it); +} + +void freeAsync(void* ptr) { + std::lock_guard lk(general_mutex); + + auto err = hipGetLastError(); + C10_ZOOM_CHECK(err); + auto it = ptr_info.find(ptr); + TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info"); + + if (C10_UNLIKELY(capture_underway)) { + if (!it->second.captured) { + TORCH_WARN_ONCE( + "freeAsync() was called on an uncaptured allocation during graph capture " + "(address = ", + ptr, + "). This may be benign, for example, a Python tensor in the capture " + "might happen to shadow (use the same name as) an unrelated temporary " + "tensor from somewhere before capture, pushing the earlier tensor " + "out of scope. " + "However, if the tensor we're freeing here IS used by the capture, " + "freeing it is an error, and may cause illegal memory accesses or " + "memory corruption during graph replay."); + // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture] + // Remembers the raw pointer, not the iterator. + // This forces notifyCaptureEnded to do another lookup, + // but avoids the risk the iterator might be invalidated + // between now and then. + ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr); + return; + } + } else if (C10_UNLIKELY(it->second.captured)) { + TORCH_WARN( + "Attempting uncaptured free of a captured allocation with address ", + ptr, + "\nThis is technically allowed, but may indicate you are losing " + "the last user-visible tensor through which the allocation can " + "be accessed, so you'll have no way to view the data after " + "future replays of the owning graph."); + } + + free_impl(it); +} + +// Symmetric with NativeCachingAllocator::malloc for now, +// although I don't think we absolutely need the symmetry. +void mallocAsync( + void** devPtr, + c10::DeviceIndex device, + size_t size, + hipStream_t stream) { + TORCH_INTERNAL_ASSERT( + 0 <= device && device < device_count, + "Invalid device index ", + device, + ": did you call init?"); + + // If stream is a null (default) stream, + // hipMallocAsync infers the device from the ambient context, + // so we need to set the right ambient context. + ZoomGuard g(device); + + std::lock_guard lk(general_mutex); + + if (!capture_underway && + !ungraphed_ptrs_defer_free_until_no_capture.empty()) { + // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture] + for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) { + auto it = ptr_info.find(ptr); + TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info"); + free_impl(it); + } + + ungraphed_ptrs_defer_free_until_no_capture.clear(); + } + + lazy_init_device(device); + + // Defensively checks for preexisting CUDA error state. + auto err = hipGetLastError(); + C10_ZOOM_CHECK(err); + + // TODO: Could we avoid calling hipMallocAsync while holding general_mutex, + // perhaps by letting lazy_init_device use separate once_flags or an internal + // static initializer? + if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) { + err = hipErrorMemoryAllocation; + } else { + err = hipMallocAsync(devPtr, size, stream); + } + + if (err == hipErrorMemoryAllocation) { + // Clears CUDA's internal error state so the user, if desired, can catch the + // OOM exception, free some stuff on the script side, and retry the + // allocation. This aligns with the behavior of alloc_block in + // ZoomCachingAllocator.cpp. + (void)hipGetLastError(); // clear CUDA error + size_t device_free = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&device_free, &device_total)); + TORCH_CHECK_WITH( + OutOfMemoryError, + false, + "Allocation on device ", + device, + " would exceed allowed memory. (out of memory)", + "\nCurrently allocated : ", + format_size(pytorch_used_bytes[device]), + "\nRequested : ", + format_size(size), + "\nDevice limit : ", + format_size(device_total), + "\nFree (according to CUDA): ", + format_size(device_free), + "\nPyTorch limit (set by user-supplied memory fraction)" + "\n : ", + format_size(pytorch_memory_limits[device])); + } else { + C10_ZOOM_CHECK(err); + } + + auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway)); + TORCH_INTERNAL_ASSERT( + inserted.second, + "address returned by hipMallocAsync already exists " + "in ptr_info"); + + inserted.first->second.creation_stream = {stream, device}; + + pytorch_used_bytes[device] += size; +} + +} // anonymous namespace + +void local_raw_delete(void* ptr); + +// Same pattern as ZoomCachingAllocator.cpp. +struct ZoomMallocAsyncAllocator : public ZoomAllocator { + DataPtr allocate(size_t size) override { + constexpr size_t one_exa_bytes = 1152921504606846976ULL; + TORCH_CHECK_WITH( + OutOfMemoryError, + size < one_exa_bytes, + "HIP out of memory. Tried to allocate more than 1EB memory."); + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* r = nullptr; + if (size != 0) { + mallocAsync(&r, device, size, zoom::getCurrentZoomStream(device)); + } + return {r, r, &local_raw_delete, Device(DeviceType::PrivateUse1, device)}; + } + DeleterFnPtr raw_deleter() const override { + return &local_raw_delete; + } + + // This function should not issue any context-creating calls, + // just set up for later calls to init per-device pools based + // on the current device each later call sees. + void init(int dev_count) override { + static bool called = [](int dev_count) { + ; + // Are there external guarantees init will be called before + // any of the allocator's other functions? + // std::lock_guard lk(general_mutex); + device_count = dev_count; + devs_initialized_flags.resize(dev_count, false); + dummy_unifying_free_streams.resize(dev_count); + pytorch_used_bytes.resize(dev_count); + pytorch_memory_limits.resize(dev_count); + return true; + }(dev_count); + (void)called; + } + + bool initialized() override { + return !devs_initialized_flags.empty(); + } + + static inline void assertValidDevice(c10::DeviceIndex device) { + TORCH_CHECK( + 0 <= device && device < device_count, "Invalid device argument."); + } + + void setMemoryFraction(double fraction, c10::DeviceIndex device) override { + TORCH_INTERNAL_ASSERT( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within (0, 1)."); + + std::lock_guard lk(general_mutex); + assertValidDevice(device); + ZoomGuard g(device); + // Should setMemoryFraction be allowed to trigger a full device context and + // pool-creating lazy_init_device, or should we simply assert this device is + // already initialized, ie + // TORCH_CHECK(devs_initialized_flags[device], ...)? + lazy_init_device(device); + + size_t device_free = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&device_free, &device_total)); + pytorch_memory_limits[device] = + static_cast(fraction * static_cast(device_total)); + + // Alternative: Instead of a manual hard limit, we could use + // hipMemPoolSetAttribute(mempool, hipMemPoolAttrReleaseThreshold, + // &threshold); This is a soft hint: The driver allows the pool's reserved + // memory to spike above threshold in regions of high hipMallocAsync + // demand, but opportunistically trims reserved memory back to threshold + // when the memory in use is < threshold. I don't like this because it + // introduces performance nondeterminism. + } + + void emptyCache() override { + std::lock_guard lk(general_mutex); + + for (int dev = 0; dev < device_count; dev++) { + if (devs_initialized_flags[dev]) { + ZoomGuard g(static_cast(dev)); + + hipMemPool_t mempool = nullptr; + hipDeviceGetDefaultMemPool(&mempool, dev); + hipDeviceSynchronize(); + hipMemPoolTrimTo(mempool, 0); + } + } + } + + void cacheInfo(c10::DeviceIndex device, size_t* maxWorkspaceGuess) override { + // The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp. + // Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable + // maximum workspace size to use for an upcoming cudnnFind call. + // + // The native allocator's cacheInfo chooses to return the size of its + // largest unused block (which is the largest allocation the native + // allocator can service immediately and asynchronously without a + // hipMalloc. + // + // Here, we use a different heuristic: figure out the max usable workspace + // size with a bit of educated trial and error. It's ok to be + // perf-inefficient because cacheInfo is a prelude to cudnnFind. + // + // The algo cache then stores the best-performing algo with workspace <= + // maxWorkspaceGuess. Later calls with the same param set hit in cache and + // try to allocate the same workspace. If, in one of those future calls, + // workspace allocation fails (ie because less ambient memory is available), + // the bindings rerun cudnnFind, including calling cacheInfo again + // beforehand to estimate a new (smaller) largest-available workspace. Over + // a few such calls, the cache should settle to the algo with a workspace + // size that's small enough to succeed every time (for that param set). + // + // So the strategy here is to return a rough, largeish guess and let the + // bindings retry to trim as needed over time. + // + // The only caveat is, even if a workspace is allocated without OOM errors + // now and in future calls, it's hard to be sure those later error-free + // hipMallocAsyncs are fast and come straight from the pool (ie, + // hipMallocAsync didn't need to reserve more memory from the system). + // Hopefully, after repeated workspace requests, the pool's reserved memory + // also stabilizes to a point where they all come straight from the pool. + std::lock_guard lk(general_mutex); + assertValidDevice(device); + ZoomGuard g(device); + lazy_init_device(device); + + size_t free_upper_bound = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&free_upper_bound, &device_total)); + TORCH_INTERNAL_ASSERT( + free_upper_bound + pytorch_used_bytes[device] <= device_total); + size_t guess = std::min( + free_upper_bound, + pytorch_memory_limits[device] - pytorch_used_bytes[device]); + auto stream = c10::zoom::getCurrentZoomStream(); + void* dummy = nullptr; + + // Defensively checks for preexisting CUDA error state. + auto err = hipGetLastError(); + C10_ZOOM_CHECK(err); + + while (true) { + // Duplicates some logic from mallocAsync to work with the error state + // directly instead of repeatedly catching an exception thrown by + // mallocAsync. + if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) { + err = hipErrorMemoryAllocation; + } else { + err = hipMallocAsync(&dummy, guess, stream); + } + + if (err == hipSuccess) { + hipFreeAsync(dummy, stream); + *maxWorkspaceGuess = guess; + return; + } else if (err == hipErrorMemoryAllocation) { + (void)hipGetLastError(); // clear CUDA error + guess >>= 1; // quick and dirty: try half the size next iteration + } else { + C10_ZOOM_CHECK(err); + } + } + } + + void* getBaseAllocation(void* ptr, size_t* size) override { + std::lock_guard lk(general_mutex); + + auto it = ptr_info.find(ptr); + TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info"); + + if (size) { + *size = it->second.size; + } + + return ptr; + } + + void recordStream(const DataPtr& ptr, zoom::ZoomStream stream) override { + std::lock_guard lk(general_mutex); + auto ptr_val = ptr.get(); + // Empty tensor's storage().data() might be a null ptr. As there is no + // blocks associated with those tensors, it is fine to do nothing here. + if (!ptr_val) { + return; + } + + // The pointer should exist in the map already. + auto it = ptr_info.find(ptr_val); + TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info"); + + UsageStream to_record{stream.stream(), stream.device_index()}; + if (to_record == it->second.creation_stream) { + TORCH_WARN_ONCE( + "Called record_stream on tensor whose original creation stream " + "matches the recorded stream. This is unnecessary and has no effect."); + } else { + it->second.recorded_streams.insert(to_record); + } + } + + std::shared_ptr getIpcDevPtr(std::string handle) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support getIpcDevPtr. " + "If you need it, please file an issue describing your use case."); + } + + void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); + } + + void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support attachOutOfMemoryObserver. " + "If you need it, please file an issue describing your use case."); + } + + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support attachAllocatorTraceTracker. " + "If you need it, please file an issue describing your use case."); + } + + std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support getCheckpointState. " + "If you need it, please file an issue describing your use case."); + } + + CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) override { + TORCH_CHECK( + false, + "hipMallocAsync does not yet support setCheckpointPoolState. " + "If you need it, please file an issue describing your use case."); + } + + // Collects stats for device. + // If device hasn't been used yet, returns 0s without creating a context. + DeviceStats getDeviceStats(c10::DeviceIndex device) override { + assertValidDevice(device); + + // Memory currently reserved by the mempool + uint64_t reserved_mem_current = 0; + // High-water mark of memory reserved by the mempool since last reset + uint64_t reserved_mem_peak = 0; + // Memory currently in use by the mempool + uint64_t used_mem_current = 0; + // High-water mark of memory + uint64_t used_mem_peak = 0; + + std::lock_guard lk(general_mutex); + + if (devs_initialized_flags[device]) { + ZoomGuard g(device); + + hipMemPool_t mempool = nullptr; + C10_ZOOM_CHECK(hipDeviceGetDefaultMemPool(&mempool, device)); + C10_ZOOM_CHECK(hipMemPoolGetAttribute( + mempool, hipMemPoolAttrReservedMemCurrent, &reserved_mem_current)); + + C10_ZOOM_CHECK(hipMemPoolGetAttribute( + mempool, hipMemPoolAttrReservedMemHigh, &reserved_mem_peak)); + + C10_ZOOM_CHECK(hipMemPoolGetAttribute( + mempool, hipMemPoolAttrUsedMemCurrent, &used_mem_current)); + + C10_ZOOM_CHECK(hipMemPoolGetAttribute( + mempool, hipMemPoolAttrUsedMemHigh, &used_mem_peak)); + } + + // Many stat types are specific to the native allocator. We leave these + // untouched. Their "struct Stat"s will contain zeroed values. + DeviceStats stats; + + // In the native allocator: + // allocated_bytes is the total bytes of blocks that have been malloc()ed + // and not yet free()d. + // active_bytes is the total bytes of blocks that have been malloc()ed but + // not yet released back into a free pool. In other words, it includes all + // allocated_bytes, as well as the bytes of "limbo state" blocks had have + // already been free()ed but not yet free_block()ed back into a pool due to + // outstanding stream_uses. + // + // Here, in the hipMallocAsync allocator: + // We simply ask the driver's opinion about active memory. + // We don't bother distinguishing between allocated_bytes and active_bytes. + stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current = + static_cast(used_mem_current); + stats.allocated_bytes[static_cast(StatType::AGGREGATE)].peak = + static_cast(used_mem_peak); + stats.active_bytes[static_cast(StatType::AGGREGATE)].current = + static_cast(used_mem_current); + stats.active_bytes[static_cast(StatType::AGGREGATE)].peak = + static_cast(used_mem_peak); + stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current = + static_cast(reserved_mem_current); + stats.reserved_bytes[static_cast(StatType::AGGREGATE)].peak = + static_cast(reserved_mem_peak); + + return stats; + } + + void resetAccumulatedStats(c10::DeviceIndex device) override { + assertValidDevice(device); + TORCH_WARN_ONCE( + "For backend:hipMallocAsync, resetAccumulatedStats has no effect."); + } + + void resetPeakStats(c10::DeviceIndex device) override { + assertValidDevice(device); + + ZoomGuard g(device); + hipMemPool_t mempool = nullptr; + C10_ZOOM_CHECK(hipDeviceGetDefaultMemPool(&mempool, device)); + // Using zero as the reset value is the method recommended by Cuda driver + // team. Vivek Kini says: + // "Resetting to zero (which is the only valid value when setting + // ReservedMemHigh) resets it to ReservedMemCurrent inside the driver + // (same goes for UsedMemHigh/UsedMemCurrent)" + uint64_t zero = 0; + C10_ZOOM_CHECK(hipMemPoolSetAttribute( + mempool, hipMemPoolAttrReservedMemHigh, &zero)); + C10_ZOOM_CHECK( + hipMemPoolSetAttribute(mempool, hipMemPoolAttrUsedMemHigh, &zero)); + } + + SnapshotInfo snapshot() override { + TORCH_CHECK( + false, + "Calling snapshot with backend:hipMallocAsync is not meaningful. " + "(For backend:native, snapshot returns a detailed summary of all " + "blocks tracked by the allocator, but the hipMallocAsync backend " + "does not track individual blocks.)"); + // Alternative: TORCH_WARN + return {}; + } + + // CUDAGraph interactions + void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function) override { + std::lock_guard lk(general_mutex); + + TORCH_INTERNAL_ASSERT(capture_free_streams.empty()); + TORCH_CHECK( + !capture_underway, + "Only one capture at a time is allowed in a process.") + capture_underway = true; + } + + void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) + override { + assertValidDevice(device); + + std::lock_guard lk(general_mutex); + + TORCH_CHECK( + capture_underway, + "hipMallocAsync::notifyCaptureAboutToEnd called, " + "but hipMallocAsync::capture_underway is false."); + + auto capture_stream = zoom::getCurrentZoomStream(device); + + // See Note [Avoid dangling free streams during CUDA graph capture] + for (const auto& free_stream : capture_free_streams) { + // hipEventRecord requires that the input event and stream are on the + // same device. + ZoomGuard g(free_stream.device); + + // ZoomCachingAllocator.cpp uses raw hip events, as do we. + hipEvent_t event = nullptr; + C10_ZOOM_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming)); + C10_ZOOM_CHECK(hipEventRecord(event, free_stream.stream)); + C10_ZOOM_CHECK(hipStreamWaitEvent(capture_stream.stream(), event, 0)); + C10_ZOOM_CHECK(hipEventDestroy(event)); + } + + capture_free_streams.clear(); + TORCH_CHECK( + capture_underway, + "hipMallocAsync::notifyCaptureEnded called, " + "but hipMallocAsync::capture_underway is false."); + capture_underway = false; + } + + void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override { + // Q: Do we need to do anything special here, like clear long-lived + // pointers created during the original capture (for example, + // tensors intended as the graph's I/O surface) that might still + // be resident in ptr_info? + // A: I don't think so. + // Those allocations survived capture because the user held + // explicit tensor references to them, + // Those tensors' destructors will call freeAsync() on each pointer + // when the user is done with them. + // The freeAsync()s will probably incur + // TORCH_WARN("Attempting uncaptured free of a captured allocation..." + // but stale ptrs will not permanently leak into ptr_info. + } + + void* raw_alloc(size_t nbytes) override { + if (nbytes == 0) { + return nullptr; + } + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* r = nullptr; + mallocAsync(&r, device, nbytes, zoom::getCurrentZoomStream(device)); + return r; + } + + void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override { + if (nbytes == 0) { + return nullptr; + } + c10::DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + void* r = nullptr; + mallocAsync(&r, device, nbytes, stream); + return r; + } + void raw_delete(void* ptr) override { + freeAsync(ptr); + } + void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) + override { + // Double-checks allocator backend hasn't changed, which would definitely be + // an error. hipMallocAsync pools are unaffected by + // hipDeviceEnablePeerAccess. We need pool-specific enablement. See + // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/ + c10::zoom::ZoomGuard device_guard(dev); + hipMemPool_t mempool = nullptr; + C10_ZOOM_CHECK(hipDeviceGetDefaultMemPool(&mempool, dev_to_access)); + hipMemAccessDesc desc = {}; + desc.location.type = hipMemLocationTypeDevice; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + desc.location.id = dev; + desc.flags = hipMemAccessFlagsProtReadWrite; + C10_ZOOM_CHECK(hipMemPoolSetAccess(mempool, &desc, 1 /* numDescs */)); + } + hipError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) override { + if (p2p_enabled || dstDevice == srcDevice) { + return hipMemcpyAsync(dst, src, count, hipMemcpyDeviceToDevice, stream); + } else { + return hipMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream); + } + } + std::string name() override { + return "hipMallocAsync"; + } + void copy_data(void* dest, const void* src, std::size_t count) const final { + C10_ZOOM_CHECK( + hipMemcpy(dest, src, count, hipMemcpyKind::hipMemcpyDeviceToDevice)); + } +}; + +ZoomMallocAsyncAllocator device_allocator; + +void local_raw_delete(void* ptr) { + freeAsync(ptr); +} +ZoomAllocator* allocator() { + return &device_allocator; +} + + +} // namespace c10::zoom::ZoomCachingAllocator::ZoomMallocAsync \ No newline at end of file diff --git a/c10/zoom/ZoomMiscFunctions.cpp b/c10/zoom/ZoomMiscFunctions.cpp new file mode 100644 index 00000000000000..cba225e314a0f7 --- /dev/null +++ b/c10/zoom/ZoomMiscFunctions.cpp @@ -0,0 +1,23 @@ +#include +#include + +namespace c10::zoom { + +const char* get_hip_check_suffix() noexcept { + static char* device_blocking_flag = getenv("HIP_LAUNCH_BLOCKING"); + static bool blocking_enabled = + (device_blocking_flag && atoi(device_blocking_flag)); + if (blocking_enabled) { + return ""; + } else { + return "\nHIP kernel errors might be asynchronously reported at some" + " other API call, so the stacktrace below might be incorrect." + "\nFor debugging consider passing HIP_LAUNCH_BLOCKING=1"; + } +} +std::mutex* getFreeMutex() { + static std::mutex hip_free_mutex; + return &hip_free_mutex; +} + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomMiscFunctions.h b/c10/zoom/ZoomMiscFunctions.h new file mode 100644 index 00000000000000..8031194734d5e1 --- /dev/null +++ b/c10/zoom/ZoomMiscFunctions.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace c10::zoom { +const char* get_hip_check_suffix() noexcept; +std::mutex* getFreeMutex(); +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomStream.cpp b/c10/zoom/ZoomStream.cpp new file mode 100644 index 00000000000000..4dac263d78db45 --- /dev/null +++ b/c10/zoom/ZoomStream.cpp @@ -0,0 +1,375 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define C10_ZOOM_COMPILE_TIME_MAX_GPUS 16 + +namespace c10::zoom { + +namespace { + +// Global stream state and constants +static c10::once_flag init_flag; +static DeviceIndex num_gpus = -1; +static constexpr int kStreamsPerPoolBits = 5; +static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; +static constexpr unsigned int kDefaultFlags = hipStreamNonBlocking; +static constexpr int kStreamTypeBits = 4; + +static int max_stream_priorities; + +// Non-default streams +// Note: the number of CUDA devices is determined at run time, +// and the low and high priority pools are lazily initialized +// when the first stream is requested for a device. +// The device flags track the initialization of each device, while +// the low and high priority counters track, for each device, the next stream +// in the pool to be returned when a stream is requested (round-robin fashion +// , see the note in ZoomStream.h). +// The streams are "leaked": they are created but never destroyed because the +// destruction of global variables could happen after the CUDA runtime has +// already been destroyed and thus invoking ZoomStreamDestroy could lead to a +// crash. It's likely an issue in CUDA, but to be safe - let's just "forget" +// the destruction. + +static std::array< + std::array, C10_ZOOM_COMPILE_TIME_MAX_GPUS>, + c10::zoom::max_compile_time_stream_priorities> + priority_counters; + +static std::array< + std::array< + std::array, + C10_ZOOM_COMPILE_TIME_MAX_GPUS>, + c10::zoom::max_compile_time_stream_priorities> + streams; + +static c10::once_flag + stream_flags[c10::zoom::max_compile_time_stream_priorities] + [C10_ZOOM_COMPILE_TIME_MAX_GPUS][kStreamsPerPool]; + + +// Note [HIP Lazy Streams] +// ~~~~~~~~~~~~~~~~~~~~~~~ +// For ROCm/HIP, each stream is lazily initialized rather than creating all +// streams when the first stream is requested. HIP streams are not as +// lightweight as CUDA streams; the pooling strategy can affect performance. +// Rather than changing the pooling implementation, ROCm/HIP will lazy init +// each stream when it is first requested. + +// Note [StreamId assignment] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ +// How do we assign stream IDs? +// +// -- 54 bits -- -- 5 bits ----- -- 4 bits -- --1 bit -- +// zeros stream id index StreamIdType Ext/native stream +// ignored for ext ignored for ext +// for external stream, StreamID is a hipStream_t pointer +// this means that last bit will always be 0 +// so when constructing StreamId for a native stream we set last bit to 1 +// to distinguish between native and external streams +// +// +// We are obligated to treat the stream ID 0 as the default stream, per the +// invariant specified in c10::Stream, so this is one exception to +// "last bit = 1 for native streams". However, all other numbers are entirely +// an internal implementation detail, we reserve the right to renumber streams +// however we like. +// +// Note that it is really important that the MSB is zero; StreamId is a +// *signed* integer, and unsigned to signed conversion outside of the +// bounds of signed integer representation is undefined behavior. You +// could work around this with something like +// https://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior +// but it seems a bit overkill for this. +// +// Also, external managed stream pointers (hipStream_t) can be directly stored +// in the Id field so in this case, we need to check the stream alignment. + +class StreamIdType { + // StreamIdType encodes whether this stream is DEFAULT, EXTernal or + // for all other native streams, the stream priority (higher value is higher + // priority) + private: + uint8_t stream_type; + + public: + static const uint8_t DEFAULT = 0x0; + static const uint8_t EXT = 0xF; + + public: + StreamIdType(const uint8_t _stream_type) : stream_type(_stream_type) {} + + bool isExt() const { + return EXT == stream_type; + } + + bool isDefault() const { + return DEFAULT == stream_type; + } + + uint8_t getStreamType() const { + return stream_type; + } +}; + +std::ostream& operator<<(std::ostream& stream, StreamIdType s) { + if (s.isDefault()) { + stream << "DEFAULT"; + } else if (s.isExt()) { + stream << "EXT"; + } else { + stream << "PRIORITY " << int(s.getStreamType()); + } + return stream; +} + +// StreamId is 64-bit, so we can just rely on regular promotion rules. +// We rely on streamIdIndex and streamIdType being non-negative; +// see Note [Hazard when concatenating signed integers] + +static inline StreamIdType streamIdType(StreamId s) { + // Externally allocated streams have their id being the ZoomStream_ptr + // so the last bit will be 0 + if ((!(s & 1)) && s) { + return StreamIdType(StreamIdType::EXT); + } + // last bit is external/internal stream, the mask should start from second + // rightmost bit + int mask_for_type = (1 << kStreamTypeBits) - 1; + auto val = (s >> 1) & mask_for_type; + TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s); + return StreamIdType(val); +} + +static inline size_t streamIdIndex(StreamId s) { + return static_cast( + (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1)); +} + +StreamId makeStreamId(StreamIdType st, size_t si) { + if (st.isDefault()) { + return static_cast(0); + } + return (static_cast(si) << (kStreamTypeBits + 1)) | + static_cast(st.getStreamType() << 1) | 1; +} + +// Thread-local current streams +// NOLINTNEXTLINE(*-arrays) +static thread_local std::unique_ptr current_streams = nullptr; + +// Populates global values. +// Warning: this function must only be called once! +static void initGlobalStreamState() { + num_gpus = device_count(); + // Check if the number of GPUs matches the expected compile-time max number + // of GPUs. + TORCH_CHECK( + num_gpus <= C10_ZOOM_COMPILE_TIME_MAX_GPUS, + "Number of ROCm devices on the machine is larger than the compiled " + "max number of gpus expected (", + C10_ZOOM_COMPILE_TIME_MAX_GPUS, + "). Increase that and recompile."); + // Note [HIP stream priorities] + // HIP stream priorities are 1=low, 0=default, -1=high which differs from CUDA + // which is 0=default, -1=high, -2=higher etc. + // Clamp leastPriority to 0 for HIP. + int leastPriority = 0, greatestPriority = -1; + C10_ZOOM_CHECK( + hipDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); + + + // greatestPriority is negative + auto range = leastPriority - greatestPriority + 1; + max_stream_priorities = range >= c10::zoom::max_compile_time_stream_priorities + ? c10::zoom::max_compile_time_stream_priorities + : range; +} + +// Init a single HIP stream +// See Note [HIP Lazy Streams] +static void initSingleStream(int p, DeviceIndex device_index, int i) { + auto& stream = streams[p][device_index][i]; + auto pri = -p; // lower number is higher priority + + C10_ZOOM_CHECK(hipStreamCreateWithPriority(&stream, kDefaultFlags, pri)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_creation( + c10::DeviceType::PrivateUse1, reinterpret_cast(stream)); + priority_counters[p][device_index] = 0; + } +} + +// Creates the low and high priority stream pools for the specified device +// Warning: only call once per device! +static void initDeviceStreamState(DeviceIndex device_index) { + // Switches to the requested device so streams are properly associated + // with it. + ZoomGuard device_guard{device_index}; + for (const auto i : c10::irange(kStreamsPerPool)) { + for (const auto p : c10::irange(max_stream_priorities)) { + initSingleStream(p, device_index, i); + } + } +} + +// Init front-end to ensure initialization only occurs once +static void initZoomStreamsOnce() { + // Inits default streams (once, globally) + c10::call_once(init_flag, initGlobalStreamState); + + if (current_streams) { + return; + } + + // Inits current streams (thread local) to default streams + // NOLINTNEXTLINE(*-arrays) + current_streams = std::make_unique(num_gpus); + for (const auto i : c10::irange(num_gpus)) { + current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0); + } +} + +// Helper to verify the GPU index is valid +static inline void check_gpu(DeviceIndex device_index) { + TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus); +} + +// Helper to determine the index of the stream to return +// Note: Streams are returned round-robin (see note in ZoomStream.h) +static uint32_t get_idx(std::atomic& counter) { + auto raw_idx = counter++; + return raw_idx % kStreamsPerPool; +} + +ZoomStream ZoomStreamForId(DeviceIndex device_index, StreamId stream_id) { + return ZoomStream( + ZoomStream::UNCHECKED, + Stream( + Stream::UNSAFE, + c10::Device(DeviceType::PrivateUse1, device_index), + stream_id)); +} + +} // anonymous namespace + +// See Note [StreamId assignment] +hipStream_t ZoomStream::stream() const { + c10::DeviceIndex device_index = stream_.device_index(); + StreamId stream_id = stream_.id(); + StreamIdType st = streamIdType(stream_id); + size_t si = streamIdIndex(stream_id); + if (st.isDefault()) { + TORCH_INTERNAL_ASSERT( + si == 0, + "Unrecognized stream ", + stream_, + " (I think this should be the default stream, but I got a non-zero index ", + si, + ").", + " Did you manufacture the StreamId yourself? Don't do that; use the", + " official API like c10::zoom::getStreamFromPool() to get a new stream."); + return nullptr; + } else if (st.isExt()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + return reinterpret_cast(stream_id); + } else { + auto streamType = st.getStreamType(); + TORCH_INTERNAL_ASSERT( + streamType >= 1 && streamType <= max_stream_priorities, + "Unrecognized stream ", + stream_, + " (I didn't recognize the stream type, ", + st, + " with the value ", + streamType, + ")"); + + // See Note [HIP Lazy Streams] + c10::call_once( + stream_flags[st.getStreamType() - 1][device_index][si], + initSingleStream, + st.getStreamType() - 1, + device_index, + si); + + return streams[st.getStreamType() - 1][device_index][si]; + } +} + +// Returns a stream from the requested pool +// Note: when called the first time on a device, this will create the +// stream pools for that device. +ZoomStream getStreamFromPool(const int priority, DeviceIndex device_index) { + initZoomStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + c10::zoom::SetTargetDevice(); + } + TORCH_CHECK( + priority <= 0, + "Expected hip stream priority to be less than or equal to 0, got ", + priority); + check_gpu(device_index); + + auto pri_idx = -priority; + pri_idx = + std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based + const auto idx = get_idx(priority_counters[pri_idx][device_index]); + StreamIdType id_type = StreamIdType(pri_idx + 1); + return ZoomStreamForId(device_index, makeStreamId(id_type, idx)); +} + +ZoomStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) { + initZoomStreamsOnce(); + int priority = isHighPriority ? -max_stream_priorities + 1 : 0; + return getStreamFromPool(priority, device); +} + +ZoomStream getStreamFromExternal( + hipStream_t ext_stream, + DeviceIndex device_index) { + // The stream pointer will be the actual id + return ZoomStreamForId(device_index, reinterpret_cast(ext_stream)); +} + +ZoomStream getDefaultZoomStream(DeviceIndex device_index) { + initZoomStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + c10::zoom::SetTargetDevice(); + } + check_gpu(device_index); + return ZoomStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0)); +} + +ZoomStream getCurrentZoomStream(DeviceIndex device_index) { + initZoomStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + c10::zoom::SetTargetDevice(); + } + check_gpu(device_index); + return ZoomStreamForId(device_index, current_streams[device_index]); +} + +void setCurrentZoomStream(ZoomStream stream) { + initZoomStreamsOnce(); + current_streams[stream.device_index()] = stream.id(); +} + +std::ostream& operator<<(std::ostream& stream, const ZoomStream& s) { + return stream << s.unwrap(); +} + +} // namespace c10::zoom \ No newline at end of file diff --git a/c10/zoom/ZoomStream.h b/c10/zoom/ZoomStream.h new file mode 100644 index 00000000000000..04041318cf1781 --- /dev/null +++ b/c10/zoom/ZoomStream.h @@ -0,0 +1,221 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace c10::zoom { + +static constexpr int max_compile_time_stream_priorities = 4; + +// Value object representing a CUDA stream. This is just a wrapper +// around c10::Stream, but it comes with a little extra CUDA-specific +// functionality (conversion to hipStream_t), and a guarantee that +// the wrapped c10::Stream really is a CUDA stream. +class ZoomStream { + public: + enum Unchecked { UNCHECKED }; + + /// Construct a ZoomStream from a Stream. This construction is checked, + /// and will raise an error if the Stream is not, in fact, a CUDA stream. + explicit ZoomStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::PrivateUse1); + } + + /// Construct a ZoomStream from a Stream with no error checking. + /// This constructor uses the "named" constructor idiom, and can + /// be invoked as: ZoomStream(ZoomStream::UNCHECKED, stream) + explicit ZoomStream(Unchecked, Stream stream) : stream_(stream) {} + + bool operator==(const ZoomStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const ZoomStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + /// Implicit conversion to hipStream_t. + operator hipStream_t() const { + return stream(); + } + + /// Implicit conversion to Stream (a.k.a., forget that the stream is a + /// CUDA stream). + operator Stream() const { + return unwrap(); + } + + /// Used to avoid baking in device type explicitly to Python-side API. + DeviceType device_type() const { + return DeviceType::PrivateUse1; + } + + /// Get the CUDA device index that this stream is associated with. + DeviceIndex device_index() const { + return stream_.device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + Device device() const { + return Device(DeviceType::PrivateUse1, device_index()); + } + + /// Return the stream ID corresponding to this particular stream. + StreamId id() const { + return stream_.id(); + } + + bool query() const { + DeviceGuard guard{stream_.device()}; + hipError_t err = C10_ZOOM_ERROR_HANDLED(hipStreamQuery(stream())); + + if (err == hipSuccess) { + return true; + } else if (err != hipErrorNotReady) { + C10_ZOOM_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)hipGetLastError(); + } + + return false; + } + + void synchronize() const { + DeviceGuard guard{stream_.device()}; + c10::zoom::stream_synchronize(stream()); + } + + int priority() const { + DeviceGuard guard{stream_.device()}; + int priority = 0; + C10_ZOOM_CHECK(hipStreamGetPriority(stream(), &priority)); + return priority; + } + + /// Explicit conversion to hipStream_t. + hipStream_t stream() const; + + /// Explicit conversion to Stream. + Stream unwrap() const { + return stream_; + } + + /// Reversibly pack a ZoomStream into a struct representation. + /// Previously the stream's data was packed into a single int64_t, + /// as it was assumed the fields would not require more than + /// 64 bits of storage in total. + /// See https://github.com/pytorch/pytorch/issues/75854 + /// for more information regarding newer platforms that may violate + /// this assumption. + /// + /// The ZoomStream can be unpacked using unpack(). + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + // Unpack a ZoomStream from the 3 fields generated by pack(). + static ZoomStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return ZoomStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + static std::tuple priority_range() { + // Note: this returns the range of priority **supported by PyTorch**, not + // the range of priority **supported by CUDA**. The former is a subset of + // the latter. + int least_priority = 0, greatest_priority = 0; + C10_ZOOM_CHECK( + hipDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); + + // See Note [HIP stream priorities] + TORCH_INTERNAL_ASSERT( + least_priority == 1, "Unexpected HIP stream priority range"); + least_priority = 0; + + TORCH_INTERNAL_ASSERT( + greatest_priority <= -1, "Unexpected HIP stream priority range"); + greatest_priority = std::max( + -c10::zoom::max_compile_time_stream_priorities + 1, greatest_priority); + return std::make_tuple(least_priority, greatest_priority); + } + + // Deleted for now; use CUDAEvent::block instead + // void synchronize_with(const CUDAEvent& event) const; + + private: + Stream stream_; +}; + +/** + * Get a new stream from the CUDA stream pool. You can think of this + * as "creating" a new stream, but no such creation actually happens; + * instead, streams are preallocated from the pool and returned in a + * round-robin fashion. + * + * You can request a stream from the high priority pool by setting + * isHighPriority to true, or a stream for a specific device by setting device + * (defaulting to the current CUDA stream.) + */ +ZoomStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); +// no default priority to disambiguate overloads +ZoomStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/** + * Get a ZoomStream from a externally allocated one. + * + * This is mainly for interoperability with different libraries where we + * want to operate on a non-torch allocated stream for data exchange or similar + * purposes + */ +ZoomStream +getStreamFromExternal(hipStream_t ext_stream, DeviceIndex device_index); + +/** + * Get the default CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The default stream is + * where most computation occurs when you aren't explicitly using + * streams. + */ +ZoomStream getDefaultZoomStream(DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The current CUDA stream + * will usually be the default CUDA stream for the device, but it may + * be different if someone called 'setCurrentZoomStream' or used 'StreamGuard' + * or 'ZoomStreamGuard'. + */ +ZoomStream getCurrentZoomStream(DeviceIndex device_index = -1); + +/** + * Set the current stream on the device of the passed in stream to be + * the passed in stream. Yes, you read that right: this function + * has *nothing* to do with the current device: it toggles the current + * stream of the device of the passed stream. + * + * Confused? Avoid using this function; prefer using 'ZoomStreamGuard' instead + * (which will switch both your current device and current stream in the way you + * expect, and reset it back to its original state afterwards). + */ +void setCurrentZoomStream(ZoomStream stream); + +std::ostream& operator<<(std::ostream& stream, const ZoomStream& s); + +} // namespace c10::zoom + +namespace std { +template <> +struct hash { + size_t operator()(c10::zoom::ZoomStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std \ No newline at end of file diff --git a/c10/zoom/impl/ZoomGuardImpl.cpp b/c10/zoom/impl/ZoomGuardImpl.cpp new file mode 100644 index 00000000000000..0327253b26d1f0 --- /dev/null +++ b/c10/zoom/impl/ZoomGuardImpl.cpp @@ -0,0 +1,7 @@ +#include + +namespace c10::zoom::impl { + +C10_REGISTER_GUARD_IMPL(PrivateUse1, ZoomGuardImpl); + +} // namespace c10::zoom::impl \ No newline at end of file diff --git a/c10/zoom/impl/ZoomGuardImpl.h b/c10/zoom/impl/ZoomGuardImpl.h new file mode 100644 index 00000000000000..49f0813cf24884 --- /dev/null +++ b/c10/zoom/impl/ZoomGuardImpl.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10::zoom::impl { + +struct ZoomGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::PrivateUse1; + + ZoomGuardImpl() = default; + explicit ZoomGuardImpl(DeviceType t) { + TORCH_INTERNAL_ASSERT(t == DeviceType::PrivateUse1); + } + DeviceType type() const override { + return DeviceType::PrivateUse1; + } + Device exchangeDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_privateuseone()); + auto old_device_index = c10::zoom::ExchangeDevice(d.index()); + return Device(DeviceType::PrivateUse1, old_device_index); + } + Device getDevice() const override { + DeviceIndex device = 0; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + return Device(DeviceType::PrivateUse1, device); + } + std::optional uncheckedGetDevice() const noexcept { + DeviceIndex device{-1}; + const auto err = C10_ZOOM_ERROR_HANDLED(c10::zoom::GetDevice(&device)); + C10_ZOOM_CHECK_WARN(err); + if (err != hipSuccess) { + return c10::nullopt; + } + return Device(DeviceType::PrivateUse1, device); + } + void setDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_privateuseone()); + C10_ZOOM_CHECK(c10::zoom::SetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + C10_ZOOM_CHECK_WARN(c10::zoom::MaybeSetDevice(d.index())); + } + Stream getStream(Device d) const noexcept override { + return getCurrentZoomStream(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultZoomStream(d.index()); + } + Stream getNewStream(Device d, int priority = 0) const override { + return getStreamFromPool(priority, d.index()); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return getStreamFromPool(isHighPriority, d.index()); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const noexcept override { + ZoomStream cs(s); + auto old_stream = getCurrentZoomStream(s.device().index()); + setCurrentZoomStream(cs); + return old_stream.unwrap(); + } + DeviceIndex deviceCount() const noexcept override { + return device_count(); + } + + // Event-related functions + void createEvent(hipEvent_t* zoom_event, const EventFlag flag) const { + // Maps PyTorch's Event::Flag to HIP flag + auto hip_flag = hipEventDefault; + switch (flag) { + case EventFlag::PYTORCH_DEFAULT: + hip_flag = hipEventDisableTiming; + break; + case EventFlag::BACKEND_DEFAULT: + hip_flag = hipEventDefault; + break; + default: + TORCH_CHECK(false, "HIP event received unknown flag"); + } + + C10_ZOOM_CHECK(hipEventCreateWithFlags(zoom_event, hip_flag)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + c10::DeviceType::PrivateUse1, reinterpret_cast(zoom_event)); + } + } + + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + if (!event) + return; + auto zoom_event = static_cast(event); + DeviceIndex orig_device{-1}; + C10_ZOOM_CHECK_WARN(c10::zoom::GetDevice(&orig_device)); + C10_ZOOM_CHECK_WARN(c10::zoom::SetDevice(device_index)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::DeviceType::PrivateUse1, reinterpret_cast(zoom_event)); + } + C10_ZOOM_CHECK_WARN(hipEventDestroy(zoom_event)); + C10_ZOOM_CHECK_WARN(c10::zoom::SetDevice(orig_device)); + } + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + hipEvent_t zoom_event = static_cast(*event); + ZoomStream zoom_stream{stream}; + + // Moves to stream's device to record + const auto orig_device = getDevice(); + setDevice(stream.device()); + + // Creates the event (lazily) + if (!zoom_event) + createEvent(&zoom_event, flag); + C10_ZOOM_CHECK(hipEventRecord(zoom_event, zoom_stream)); + // Makes the void* point to the (possibly just allocated) HIP event + *event = zoom_event; + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::DeviceType::PrivateUse1, + reinterpret_cast(zoom_event), + reinterpret_cast(zoom_stream.stream())); + } + + // Resets device + setDevice(orig_device); + } + + void block(void* event, const Stream& stream) const override { + if (!event) + return; + hipEvent_t zoom_event = static_cast(event); + ZoomStream zoom_stream{stream}; + const auto orig_device = getDevice(); + setDevice(stream.device()); + C10_ZOOM_CHECK(hipStreamWaitEvent( + zoom_stream, + zoom_event, + /*flags (must be zero)=*/0)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::DeviceType::PrivateUse1, + reinterpret_cast(zoom_event), + reinterpret_cast(zoom_stream.stream())); + } + setDevice(orig_device); + } + + // May be called from any device + bool queryEvent(void* event) const override { + if (!event) + return true; + hipEvent_t zoom_event = static_cast(event); + // Note: hipEventQuery can be safely called from any device + const hipError_t err = C10_ZOOM_ERROR_HANDLED(hipEventQuery(zoom_event)); + if (err != hipErrorNotReady) { + C10_ZOOM_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)hipGetLastError(); + } + return (err == hipSuccess); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + ZoomStream zoom_stream{stream}; + return zoom_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + ZoomStream zoom_stream{stream}; + zoom_stream.synchronize(); + } + + void synchronizeEvent(void* event) const override { + if (!event) + return; + hipEvent_t zoom_event = static_cast(event); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::DeviceType::PrivateUse1, reinterpret_cast(zoom_event)); + } + // Note: hipEventSynchronize can be safely called from any device + C10_ZOOM_CHECK(hipEventSynchronize(zoom_event)); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + ZoomStream zoom_stream{stream}; + ZoomCachingAllocator::recordStream(data_ptr, zoom_stream); + } + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override { + TORCH_CHECK( + event1 && event2, + "Both events must be recorded before calculating elapsed time."); + // Even though zoomEventElapsedTime can be safely called from any device, if + // the current device is not initialized, it will create a new zoom context, + // which will consume a lot of memory. + DeviceIndex orig_device{-1}; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&orig_device)); + C10_ZOOM_CHECK(c10::zoom::SetDevice(device_index)); + hipEvent_t zoom_event1 = static_cast(event1); + hipEvent_t zoom_event2 = static_cast(event2); + float time_ms = 0; + // raise hipErrorNotReady if either event is recorded but not yet completed + C10_ZOOM_CHECK(hipEventElapsedTime(&time_ms, zoom_event1, zoom_event2)); + C10_ZOOM_CHECK(c10::zoom::SetDevice(orig_device)); + return static_cast(time_ms); + } +}; + +} // namespace c10::zoom::impl \ No newline at end of file diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 369bb9b106a0db..1a43c7d53aa9fb 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -71,6 +71,7 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND Caffe2_GPU_CU_SRCS ${ATen_CUDA_CU_SRCS}) list(APPEND Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY}) list(APPEND Caffe2_HIP_SRCS ${ATen_HIP_SRCS}) + list(APPEND Caffe2_ZOOM_SRCS ${ATen_ZOOM_SRCS}) list(APPEND Caffe2_MPS_SRCS ${ATen_MPS_SRCS}) list(APPEND Caffe2_XPU_SRCS ${ATen_XPU_SRCS}) list(APPEND Caffe2_HIP_SRCS ${ATen_HIP_SRCS_W_SORT_BY_KEY}) @@ -84,13 +85,16 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND Caffe2_CPU_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND Caffe2_GPU_INCLUDE ${ATen_CUDA_INCLUDE}) list(APPEND Caffe2_HIP_INCLUDE ${ATen_HIP_INCLUDE}) + list(APPEND Caffe2_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE}) list(APPEND Caffe2_XPU_INCLUDE ${ATen_XPU_INCLUDE}) list(APPEND Caffe2_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE}) list(APPEND Caffe2_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS}) list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS}) list(APPEND Caffe2_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS}) + list(APPEND Caffe2_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS}) list(APPEND Caffe2_DEPENDENCY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE}) set(Caffe2_CUDA_DEPENDENCY_LIBS ${Caffe2_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) + set(Caffe2_ZOOM_DEPENDENCY_LIBS ${Caffe2_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE) endif() # ---[ Caffe2 build @@ -128,6 +132,7 @@ if(CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_CU_SRCS CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_HIP_SRCS CAFFE2_ALLOWLISTED_FILES) + caffe2_do_allowlist(Caffe2_ZOOM_SRCS CAFFE2_ALLOWLISTED_FILES) endif() if(PRINT_CMAKE_DEBUG_INFO) @@ -181,6 +186,11 @@ if(PRINT_CMAKE_DEBUG_INFO) message(STATUS " " ${tmp}) endforeach() + message(STATUS "ZOOM sources: ") + foreach(tmp ${Caffe2_ZOOM_SRCS}) + message(STATUS " " ${tmp}) + endforeach() + message(STATUS "MPS sources: ") foreach(tmp ${Caffe2_MPS_SRCS}) message(STATUS " " ${tmp}) @@ -594,6 +604,10 @@ if(USE_CUDA OR USE_ROCM) append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) endif() +# if (USE_ZOOM) +# append_filelist("libtorch_zoom_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) +# endif() + if(USE_CUDA) list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) @@ -675,6 +689,26 @@ if(USE_ROCM) install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() +# if(USE_ZOOM) +# list(APPEND Caffe2_ZOOM_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) +# if(USE_NCCL) +# list(APPEND Caffe2_ZOOM_SRCS +# ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) +# endif() +# if(USE_DISTRIBUTED) +# append_filelist("libtorch_zoom_distributed_base_sources" Caffe2_ZOOM_SRCS) +# if(NOT WIN32) +# append_filelist("libtorch_zoom_distributed_extra_sources" Caffe2_ZOOM_SRCS) +# endif() +# endif() + # See NOTE [ ATen NVRTC Stub and HIP ] +# hip_add_library(caffe2_hiprtc SHARED ${ATen_HIPRTC_STUB_SRCS}) +# target_link_libraries(caffe2_hiprtc ${PYTORCH_HIP_LIBRARIES} ${ROCM_HIPRTC_LIB}) +# target_include_directories(caffe2_hiprtc PRIVATE ${CMAKE_BINARY_DIR} ${ROCM_SOURCE_DIR}/include) +# target_compile_definitions(caffe2_hiprtc PRIVATE USE_ROCM __HIP_PLATFORM_AMD__) +# install(TARGETS caffe2_hiprtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") +# endif() + if(NOT NO_API AND NOT BUILD_LITE_INTERPRETER) list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp @@ -920,6 +954,11 @@ if(USE_ROCM) set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) endif() +if(USE_ZOOM) + filter_list(__caffe2_zoom_hip_srcs_cpp Caffe2_ZOOM_SRCS "\\.(cu|hip)$") + set_source_files_properties(${_caffe2_zoom_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) +endif() + # Compile exposed libraries. if(USE_ROCM) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) @@ -941,6 +980,15 @@ if(USE_ROCM) target_precompile_headers(torch_hip PRIVATE "$<$:ATen/core/ATen_pch.h>") endif() +elseif(USE_ZOOM) + ADD_DEFINITIONS(-DUSE_ZOOM) + set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) + # list(APPEND Caffe2_ZOOM_SRCS ${GENERATED_CXX_TORCH_CUDA}) + + # TODO(Arham): disentangle this and build caffe2_hiprtc instead + hip_add_library(torch_zoom ${Caffe2_ZOOM_SRCS} ${ATen_HIPRTC_STUB_SRCS}) + set(CUDA_LINK_LIBRARIES_KEYWORD) + torch_compile_options(torch_zoom) # see cmake/public/utils.cmake elseif(USE_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA}) @@ -1348,6 +1396,39 @@ if(USE_ROCM) endif() endif() +if(USE_ZOOM) + target_compile_definitions(torch_zoom PRIVATE + USE_ZOOM + __HIP_PLATFORM_AMD__ + ) + # NB: Massive hack. torch/csrc/jit/codegen/fuser/codegen.cpp includes + # torch/csrc/jit/codegen/fuser/cuda/resource_strings.h which changes the + # strings depending on if you're __HIP_PLATFORM_AMD__ or not. + # But that file is in torch_cpu! So, against all odds, this macro + # has to be set on torch_cpu too. I also added it to torch for + # better luck + target_compile_definitions(torch_cpu PRIVATE + USE_ZOOM + __HIP_PLATFORM_AMD__ + ) + target_compile_definitions(torch PRIVATE + USE_ZOOM + __HIP_PLATFORM_AMD__ + ) + + if(NOT ROCM_SOURCE_DIR) + set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}") + endif() + if($ROCM_SOURCE_DIR STREQUAL "") + set(ROCM_SOURCE_DIR "/opt/rocm") + endif() + message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}") + target_include_directories(torch_zoom PRIVATE + ${ROCM_SOURCE_DIR}/include + ${ROCM_SOURCE_DIR}/hcc/include + ) +endif() + if(BUILD_LITE_INTERPRETER) target_compile_definitions(torch_cpu PRIVATE BUILD_LITE_INTERPRETER) # Enable template selective build only when SELECTED_OP_LIST is provided. @@ -1453,6 +1534,8 @@ if(USE_CUDA) target_compile_definitions(torch_cuda PRIVATE TORCH_CUDA_BUILD_MAIN_LIB) elseif(USE_ROCM) target_compile_definitions(torch_hip PRIVATE TORCH_HIP_BUILD_MAIN_LIB) +elseif(USE_ZOOM) + target_compile_definitions(torch_zoom PRIVATE TORCH_HIP_BUILD_MAIN_LIB) endif() if(USE_XPU) @@ -1546,6 +1629,8 @@ if(USE_CUDA) caffe2_interface_library(torch_cuda torch_cuda_library) elseif(USE_ROCM) caffe2_interface_library(torch_hip torch_hip_library) +elseif(USE_ZOOM) + caffe2_interface_library(torch_zoom torch_zoom_library) elseif(USE_XPU) caffe2_interface_library(torch_xpu torch_xpu_library) endif() @@ -1558,6 +1643,8 @@ if(USE_CUDA) install(TARGETS torch_cuda torch_cuda_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") elseif(USE_ROCM) install(TARGETS torch_hip torch_hip_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") +elseif(USE_ZOOM) + install(TARGETS torch_zoom torch_zoom_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") elseif(USE_XPU) install(TARGETS torch_xpu torch_xpu_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() @@ -1570,6 +1657,8 @@ if(USE_CUDA) target_link_libraries(torch PUBLIC torch_cuda_library) elseif(USE_ROCM) target_link_libraries(torch PUBLIC torch_hip_library) +elseif(USE_ZOOM) + target_link_libraries(torch PUBLIC torch_zoom_library) endif() if(USE_XPU) @@ -1715,6 +1804,43 @@ if(USE_ROCM) target_include_directories(torch_hip INTERFACE $) endif() +# ---[ Caffe2 ZOOM HIP sources. +if(USE_ZOOM) + # Call again since Caffe2_ZOOM_INCLUDE is extended with ATen include dirs. + # Get Compile Definitions from the directory (FindHIP.cmake bug) + get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS) + if(MY_DEFINITIONS) + foreach(_item ${MY_DEFINITIONS}) + list(APPEND HIP_CLANG_FLAGS "-D${_item}") + endforeach() + endif() + + # Call again since Caffe2_ZOOM_INCLUDE is extended with ATen include dirs. + hip_include_directories(${Caffe2_ZOOM_INCLUDE}) + # Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added. + target_compile_options(torch_zoom PUBLIC ${HIP_CXX_FLAGS}) # experiment + target_link_libraries(torch_zoom PUBLIC c10_zoom) + # target_link_libraries(torch_zoom PUBLIC c10) + + # this is where lib amdhip64 is actually linked (e.g. HIP symbols) + # should be included in c10_zoom + # target_link_libraries(torch_zoom PUBLIC ${PYTORCH_HIP_LIBRARIES}) + if(NOT INTERN_BUILD_MOBILE) + # TODO: Cut this over to ATEN_HIP_FILES_GEN_LIB. At the moment, we + # only generate CUDA files + # NB: This dependency must be PRIVATE, because we don't install + # ATEN_CUDA_FILES_GEN_LIB (it's a synthetic target just to get the + # correct dependency from generated files.) + #target_link_libraries(torch_zoom PRIVATE ATEN_ZOOM_FILES_GEN_LIB) + endif() + target_link_libraries(torch_zoom PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) + target_link_libraries(torch_zoom PRIVATE ${Caffe2_ZOOM_DEPENDENCY_LIBS}) + + # Since PyTorch files contain HIP headers, this is also needed to capture the includes. + target_include_directories(torch_zoom PRIVATE ${Caffe2_ZOOM_INCLUDE}) + target_include_directories(torch_zoom INTERFACE $) +endif() + if(BUILD_STATIC_RUNTIME_BENCHMARK) add_subdirectory(${TORCH_ROOT}/benchmarks/static_runtime ${PROJECT_BINARY_DIR}/bin) add_executable(static_runtime_bench "${STATIC_RUNTIME_BENCHMARK_SRCS}") diff --git a/cmake/Caffe2Config.cmake.in b/cmake/Caffe2Config.cmake.in index c23b3990aff8a9..67771a17548e80 100644 --- a/cmake/Caffe2Config.cmake.in +++ b/cmake/Caffe2Config.cmake.in @@ -74,6 +74,10 @@ if (@USE_ROCM@) include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake") endif() +if (@USE_ZOOM@) + include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake") +endif() + if(@USE_CUDA@) # The file public/cuda.cmake exclusively uses CAFFE2_USE_*. # If Caffe2 was compiled with the libraries below, they must diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index b478f3cc2e1b08..f022db009f4673 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -239,6 +239,9 @@ if(INTERN_BUILD_ATEN_OPS) add_library(ATEN_CUDA_FILES_GEN_LIB INTERFACE) add_dependencies(ATEN_CPU_FILES_GEN_LIB ATEN_CPU_FILES_GEN_TARGET) add_dependencies(ATEN_CUDA_FILES_GEN_LIB ATEN_CUDA_FILES_GEN_TARGET) + + message(cuda_gen_headers="${cuda_generated_headers}") + message(cuda_gen_sources="${cuda_generated_sources}") if(USE_PER_OPERATOR_HEADERS) target_compile_definitions(ATEN_CPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a7e38ee73bcce5..e29c89479f9dad 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1155,7 +1155,7 @@ if(USE_CUDNN) endif() # ---[ HIP -if(USE_ROCM) +if(USE_ROCM OR USE_ZOOM) # This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo. # Currently only active for Ubuntu 20.04 and greater versions. if(UNIX AND EXISTS "/etc/os-release") @@ -1184,7 +1184,12 @@ if(USE_ROCM) include(${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake) if(PYTORCH_FOUND_HIP) message(INFO "Compiling with HIP for AMD.") - caffe2_update_option(USE_ROCM ON) + if(USE_ROCM) + caffe2_update_option(USE_ROCM ON) + endif() + if(USE_ZOOM) + caffe2_update_option(USE_ZOOM ON) + endif() if(USE_NCCL AND NOT USE_SYSTEM_NCCL) message(INFO "Forcing USE_SYSTEM_NCCL to ON since it's required by using RCCL") @@ -1251,7 +1256,10 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) + if(USE_ROCM) + include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) + endif() + if(USE_CUDA) caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index de64370b37a26f..8b67accd6254ac 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -6,7 +6,10 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") ExternalProject_Add(aotriton_external GIT_REPOSITORY https://github.com/ROCm/aotriton.git - GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27 + # Note (Arham): I changed this commit because the one in nod-ai was old and had some errors, + # in upstream pytorch this commit tag is determined by some CI actions that would be useful to copy + # in order to keep this working + GIT_TAG 04b5df8c8123f90cba3ede7e971e6fbc6040d506 SOURCE_DIR ${__AOTRITON_SOURCE_DIR} BINARY_DIR ${__AOTRITON_BUILD_DIR} PREFIX ${__AOTRITON_INSTALL_DIR} diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 09af98d0bc0666..31f6cbf8838852 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -180,6 +180,8 @@ function(caffe2_print_configuration_summary) message(STATUS " Private Dependencies : ${Caffe2_DEPENDENCY_LIBS}") message(STATUS " Public CUDA Deps. : ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}") message(STATUS " Private CUDA Deps. : ${Caffe2_CUDA_DEPENDENCY_LIBS}") + message(STATUS " Public ZOOM Deps. : ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}") + message(STATUS " Private ZOOM Deps. : ${Caffe2_ZOOM_DEPENDENCY_LIBS}") # coreml message(STATUS " USE_COREML_DELEGATE : ${USE_COREML_DELEGATE}") message(STATUS " BUILD_LAZY_TS_BACKEND : ${BUILD_LAZY_TS_BACKEND}") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index fa39156031ff36..107a6fbc15dac5 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -1,17 +1,20 @@ set(PYTORCH_FOUND_HIP FALSE) if(NOT DEFINED ENV{ROCM_PATH}) + message (WARNING "ROCM_PATH undefined, using ROCM_PATH=/opt/rocm") set(ROCM_PATH /opt/rocm) else() set(ROCM_PATH $ENV{ROCM_PATH}) endif() if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS}) + message (WARNING "ROCM_INCLUDE_DIRS undefined, using ROCM_INCLUDE_DIRS=$ROCM_PATH/include") set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include) else() set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) endif() if(NOT EXISTS ${ROCM_PATH}) + message(WARNING "$ROCM_PATH does not exist, failed to load HIP") return() endif() @@ -39,6 +42,7 @@ endmacro() # Find the HIP Package find_package_and_print_version(HIP 1.0) +message("HIP FOUND? -> " ${HIP_FOUND}) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8f879a8ecc783e..60b6038f7bb9be 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -146,6 +146,17 @@ if(USE_ROCM) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) endif() +if(USE_ZOOM) + append_filelist("libtorch_python_zoom_sources" TORCH_PYTHON_SRCS) + # list(APPEND TORCH_PYTHON_SRCS ${GENERATED_THNN_CXX_CUDA}) + + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS + USE_ZOOM + __HIP_PLATFORM_AMD__ + ) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) +endif() + if(USE_XPU) include(${TORCH_ROOT}/cmake/public/xpu.cmake) append_filelist("libtorch_python_xpu_sources" TORCH_PYTHON_SRCS) @@ -342,6 +353,11 @@ if(USE_ROCM) set_source_files_properties(${TORCH_SRC_DIR}/csrc/cuda/Module.cpp PROPERTIES COMPILE_FLAGS "-DCUDA_ARCH_FLAGS=\"${PYTORCH_ROCM_ARCH_readable}\"") endif() +if(USE_ZOOM) + string(REPLACE ";" " " PYTORCH_ROCM_ARCH_readable "${PYTORCH_ROCM_ARCH}") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/zoom/Module.cpp PROPERTIES COMPILE_FLAGS "-DROCM_ARCH_FLAGS=\"${PYTORCH_ROCM_ARCH_readable}\"") +endif() + target_compile_definitions(torch_python PRIVATE "-DTHP_BUILD_MAIN_LIB") target_link_libraries(torch_python PRIVATE torch_library ${TORCH_PYTHON_LINK_LIBRARIES}) diff --git a/torch/__init__.py b/torch/__init__.py index 1dc0e9c8287fd4..263d9a28bc7802 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1580,6 +1580,7 @@ def _assert(condition, message): # the public API. The "regular" import lines are there solely for the runtime # side effect of adding to the imported module's members for other users. from torch import cuda as cuda +from torch import zoom as zoom from torch import cpu as cpu from torch import mps as mps from torch import xpu as xpu diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 040fbc825becdf..459da2348e2cac 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1459,7 +1459,7 @@ def _addmm_activation( ): out = addmm(self, mat1, mat2, beta, alpha) if use_gelu: - if self.is_cuda: + if self.is_cuda or self.is_zoom: return aten.gelu(out, approximate="tanh") else: return aten.gelu(out) @@ -2608,7 +2608,7 @@ def _index_copy( def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) z = torch.exp(-torch.abs(self)) - if self.is_cuda: + if (self.is_cuda or self.is_zoom): buffer = self.new_zeros((0,)) else: buffer = z @@ -2853,7 +2853,7 @@ def _upsample_nearest( # following "heuristic: only use channels_last path when it's faster than the contiguous path" n_channels = input.shape[1] - if input.device.type == "cuda" and n_channels < 4: + if (input.device.type == "cuda" or input.device.type == "zoom") and n_channels < 4: memory_format = torch.contiguous_format result = result.contiguous(memory_format=memory_format) @@ -3686,7 +3686,7 @@ def get_values(inp_size, out_size, scales, nsqueeze): memory_format = utils.suggest_memory_format(input) # following "heuristic: only use channels_last path when it's faster than the contiguous path" - if input.device.type == "cuda" and n_channels < 16: + if (input.device.type == "cuda" or input.device.type == "zoom") and n_channels < 16: memory_format = torch.contiguous_format assert isinstance(result, torch.Tensor) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 9ff9131435f4cf..308fc9abe994e1 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -109,6 +109,10 @@ #endif #endif +#ifdef USE_ZOOM +#include +#endif + #ifdef USE_DISTRIBUTED #ifdef USE_C10D #include @@ -1528,6 +1532,13 @@ void initModule(PyObject* module); } // namespace torch::cuda #endif +#ifdef USE_ZOOM +PyMethodDef* THCPModule_methods(); +namespace torch::zoom { +void initModule(PyObject* module); +} // namespace torch::zoom +#endif + #ifdef USE_XPU PyMethodDef* THXPModule_methods(); void THXPStream_init(PyObject* module); @@ -1596,6 +1607,9 @@ PyObject* initModule() { #ifdef USE_CUDA THPUtils_addPyMethodDefs(methods, THCPModule_methods()); #endif +#ifdef USE_ZOOM + THPUtils_addPyMethodDefs(methods, THCPModule_methods()); +#endif #ifdef USE_XPU THPUtils_addPyMethodDefs(methods, THXPModule_methods()); #endif @@ -1659,6 +1673,9 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif +#ifdef USE_ZOOM + torch::zoom::initModule(module); +#endif #ifdef USE_XPU torch::xpu::initModule(module); #endif @@ -1677,6 +1694,16 @@ PyObject* initModule() { THCPGraph_init(module); #endif +#ifdef USE_ZOOM + // This will only initialise base classes and attach them to library namespace + // They won't be ready for real usage until importing cuda module, that will + // complete the process (but it defines Python classes before calling back + // into C, so these lines have to execute first).. + THCPStream_init(module); + THCPEvent_init(module); + THCPGraph_init(module); +#endif + #ifdef USE_XPU THXPStream_init(module); THXPEvent_init(module); @@ -1697,7 +1724,7 @@ PyObject* initModule() { return ret == 0; }; -#if defined(USE_CUDNN) || defined(USE_ROCM) +#if defined(USE_CUDNN) || (defined(USE_ROCM) && !defined(USE_ZOOM)) PyObject* has_cudnn = Py_True; #else PyObject* has_cudnn = Py_False; @@ -2067,6 +2094,12 @@ Call this whenever a new thread is created in order to propagate values from PyObject* has_cuda = Py_False; #endif +#ifdef USE_ZOOM + PyObject* has_zoom = Py_True; +#else + PyObject* has_zoom = Py_False; +#endif + #ifdef USE_MPS PyObject* has_mps = Py_True; #else @@ -2080,6 +2113,7 @@ Call this whenever a new thread is created in order to propagate values from #endif ASSERT_TRUE(set_module_attr("_has_cuda", has_cuda)); + ASSERT_TRUE(set_module_attr("_has_zoom", has_zoom)); ASSERT_TRUE( set_module_attr("_has_magma", at::hasMAGMA() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("_has_mps", has_mps)); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index fdcafd6cd70910..09bb02cabeaf48 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -454,10 +454,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { at::Device self_device = self_.device(); Variable value; // TODO: This qint special case looks very suspicious... + // TODO(Arham): exchange keys if (isQIntType(self_.scalar_type())) { value = valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU)); - } else if (self_device.is_cuda()) { + } else if (self_device.is_cuda() || self_device.is_privateuseone()) { value = valueToTensor(self_.options(), py_value, at::Device(kCPU)); } else { value = valueToTensor(self_.options(), py_value, self_device); diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 8d18180ed91955..b1fb4f80ae59c6 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -35,6 +35,7 @@ struct PyTensorType { THPDtype* dtype; THPLayout* layout; bool is_cuda; + bool is_zoom; bool is_xpu; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) char name[64]; @@ -130,6 +131,15 @@ static PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) { } } +static PyObject* Tensor_is_zoom(PyTensorType* self, void* unused) { + if (self->is_zoom) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + + static PyObject* Tensor_is_xpu(PyTensorType* self, void* unused) { if (self->is_xpu) { Py_RETURN_TRUE; @@ -166,6 +176,7 @@ static struct PyGetSetDef metaclass_properties[] = { {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr}, {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr}, {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr}, + {"is_zoom", (getter)Tensor_is_zoom, nullptr, nullptr, nullptr}, {"is_xpu", (getter)Tensor_is_xpu, nullptr, nullptr, nullptr}, {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr}, {"is_sparse_csr", (getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr}, @@ -247,6 +258,9 @@ static void set_type( type_obj.dtype = (THPDtype*)Py_NewRef(torch::getTHPDtype(scalarType)); type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA); + // TODO(Arham): exchange keys + type_obj.is_zoom = + (backend == at::Backend::PrivateUse1 || backend == at::Backend::SparsePrivateUse1); type_obj.is_xpu = (backend == at::Backend::XPU || backend == at::Backend::SparseXPU); } diff --git a/torch/csrc/zoom/Event.cpp b/torch/csrc/zoom/Event.cpp new file mode 100644 index 00000000000000..f07f6e2954c0e3 --- /dev/null +++ b/torch/csrc/zoom/Event.cpp @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +PyObject* THCPEventClass = nullptr; + +static PyObject* THCPEvent_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + unsigned char enable_timing = 0; + unsigned char blocking = 0; + unsigned char interprocess = 0; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + constexpr const char* kwlist[] = { + "enable_timing", "blocking", "interprocess", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "|bbb", + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(kwlist), + &enable_timing, + &blocking, + &interprocess)) { + return nullptr; + } + + THPObjectPtr ptr(type->tp_alloc(type, 0)); + if (!ptr) { + return nullptr; + } + + THCPEvent* self = (THCPEvent*)ptr.get(); + unsigned int flags = (blocking ? hipEventBlockingSync : hipEventDefault) | + (enable_timing ? hipEventDefault : hipEventDisableTiming) | + (interprocess ? hipEventInterprocess : hipEventDefault); + + new (&self->zoom_event) at::zoom::ZoomEvent(flags); + + return (PyObject*)ptr.release(); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_from_ipc_handle( + PyObject* _type, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + auto type = (PyTypeObject*)_type; + + static torch::PythonArgParser parser({ + "from_ipc_handle(Device device, std::string ipc_handle)", + }); + torch::ParsedArgs<2> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + at::Device device = r.device(0); + std::string handle_string = r.string(1); + + TORCH_CHECK( + handle_string.size() == sizeof(hipIpcEventHandle_t), + "hipIpcEventHandle_t expects byte-like object of size ", + sizeof(hipIpcEventHandle_t), + ", but got ", + handle_string.size()); + TORCH_CHECK( + device.type() == at::kPrivateUse1, + "Event can only be created on " + "Zoom devices, but got device type ", + device.type()) + + THPObjectPtr ptr(type->tp_alloc(type, 0)); + if (!ptr) { + return nullptr; + } + THCPEvent* self = (THCPEvent*)ptr.get(); + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + hipIpcEventHandle_t handle; + std::memcpy(&handle, handle_string.c_str(), handle_string.size()); + new (&self->zoom_event) at::zoom::ZoomEvent(device.index(), &handle); + + return (PyObject*)ptr.release(); + END_HANDLE_TH_ERRORS +} + +static void THCPEvent_dealloc(THCPEvent* self) { + { + pybind11::gil_scoped_release no_gil{}; + self->zoom_event.~ZoomEvent(); + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static PyObject* THCPEvent_get_zoom_event(THCPEvent* self, void* unused) { + HANDLE_TH_ERRORS + return PyLong_FromVoidPtr(self->zoom_event.event()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_get_device(THCPEvent* self, void* unused) { + HANDLE_TH_ERRORS + at::optional device = self->zoom_event.device(); + if (!device) { + Py_RETURN_NONE; + } + return THPDevice_New(device.value()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_record(PyObject* _self, PyObject* _stream) { + HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; + self->zoom_event.record(stream->zoom_stream); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_wait(PyObject* _self, PyObject* _stream) { + HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; + pybind11::gil_scoped_release no_gil{}; + self->zoom_event.block(stream->zoom_stream); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_query(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + return PyBool_FromLong(self->zoom_event.query()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_elapsed_time(PyObject* _self, PyObject* _other) { + HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + auto other = (THCPEvent*)_other; + return PyFloat_FromDouble(self->zoom_event.elapsed_time(other->zoom_event)); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_synchronize(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; + pybind11::gil_scoped_release no_gil{}; + self->zoom_event.synchronize(); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPEvent_ipc_handle(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + hipIpcEventHandle_t handle; + self->zoom_event.ipc_handle(&handle); + return PyBytes_FromStringAndSize((const char*)&handle, sizeof(handle)); + END_HANDLE_TH_ERRORS +} + +// NOLINTNEXTLINE(*c-arrays*, *global-variables) +static struct PyGetSetDef THCPEvent_properties[] = { + {"device", (getter)THCPEvent_get_device, nullptr, nullptr, nullptr}, + {"zoom_event", (getter)THCPEvent_get_zoom_event, nullptr, nullptr, nullptr}, + {nullptr}}; + +// NOLINTNEXTLINE(*c-arrays*, *global-variables) +static PyMethodDef THCPEvent_methods[] = { + {(char*)"from_ipc_handle", + castPyCFunctionWithKeywords(THCPEvent_from_ipc_handle), + METH_CLASS | METH_VARARGS | METH_KEYWORDS, + nullptr}, + {(char*)"record", THCPEvent_record, METH_O, nullptr}, + {(char*)"wait", THCPEvent_wait, METH_O, nullptr}, + {(char*)"query", THCPEvent_query, METH_NOARGS, nullptr}, + {(char*)"elapsed_time", THCPEvent_elapsed_time, METH_O, nullptr}, + {(char*)"synchronize", THCPEvent_synchronize, METH_NOARGS, nullptr}, + {(char*)"ipc_handle", THCPEvent_ipc_handle, METH_NOARGS, nullptr}, + {nullptr}}; + +PyTypeObject THCPEventType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._ZoomEventBase", /* tp_name */ + sizeof(THCPEvent), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THCPEvent_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THCPEvent_methods, /* tp_methods */ + nullptr, /* tp_members */ + THCPEvent_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THCPEvent_pynew, /* tp_new */ +}; + +void THCPEvent_init(PyObject* module) { + THCPEventClass = (PyObject*)&THCPEventType; + if (PyType_Ready(&THCPEventType) < 0) { + throw python_error(); + } + Py_INCREF(&THCPEventType); + if (PyModule_AddObject(module, "_ZoomEventBase", (PyObject*)&THCPEventType) < + 0) { + throw python_error(); + } +} diff --git a/torch/csrc/zoom/Event.h b/torch/csrc/zoom/Event.h new file mode 100644 index 00000000000000..6f10c28f86f84d --- /dev/null +++ b/torch/csrc/zoom/Event.h @@ -0,0 +1,18 @@ +#ifndef THCP_EVENT_INC +#define THCP_EVENT_INC + +#include +#include + +struct THCPEvent { + PyObject_HEAD at::zoom::ZoomEvent zoom_event; +}; +extern PyObject* THCPEventClass; + +void THCPEvent_init(PyObject* module); + +inline bool THCPEvent_Check(PyObject* obj) { + return THCPEventClass && PyObject_IsInstance(obj, THCPEventClass); +} + +#endif // THCP_EVENT_INC diff --git a/torch/csrc/zoom/Graph.cpp b/torch/csrc/zoom/Graph.cpp new file mode 100644 index 00000000000000..4d95f871f5af11 --- /dev/null +++ b/torch/csrc/zoom/Graph.cpp @@ -0,0 +1,91 @@ +#include + +#include + +#include +#include + +#include +#include + +// Cargo culted partially from csrc/distributed/c10d/init.cpp +// and partially from csrc/zoom/Stream.cpp. +// THCPStream_init is also declared at global scope. + +// Because THCPGraph_init is forward declared in the only consumer +// (csrc/Module.cpp) I don't think we need a Graph.h. + +template +using shared_ptr_class_ = py::class_>; + +void THCPGraph_init(PyObject* module) { + // Pybind11 patch notes say "py::module_" is more up-to-date syntax, + // but CI linter and some builds prefer "module". + auto torch_C_m = py::handle(module).cast(); + + torch_C_m.def("_graph_pool_handle", &::at::zoom::graph_pool_handle); + + shared_ptr_class_<::at::zoom::HIPGraph>(torch_C_m, "_HIPGraph") + .def(py::init<>()) + .def( + "capture_begin", + [](::at::zoom::HIPGraph& self, + std::optional pool_opt, + std::string capture_error_mode) { + hipStreamCaptureMode capture_mode; + c10::zoom::MempoolId_t pool = pool_opt.has_value() + ? pool_opt.value() + : c10::zoom::MempoolId_t{0, 0}; + if (capture_error_mode == "global") { + capture_mode = hipStreamCaptureModeGlobal; + } else if (capture_error_mode == "thread_local") { + capture_mode = hipStreamCaptureModeThreadLocal; + } else if (capture_error_mode == "relaxed") { + capture_mode = hipStreamCaptureModeRelaxed; + } else { + TORCH_CHECK( + false, + "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ", + capture_error_mode); + } + return self.capture_begin(pool, capture_mode); + }, + py::arg("pool"), + py::arg("capture_error_mode"), + py::call_guard()) + .def( + "capture_end", + torch::wrap_pybind_function_no_gil(&at::zoom::HIPGraph::capture_end)) + .def( + "register_generator_state", + [](::at::zoom::HIPGraph& self, py::handle raw_generator) { + auto generator = THPGenerator_Unwrap(raw_generator.ptr()); + // We've unwrapped Python object to C++ object, + // so we could release GIL before calling into C++ + py::gil_scoped_release release; + return self.register_generator_state(generator); + }, + py::arg("generator")) + .def( + "replay", + torch::wrap_pybind_function_no_gil(&at::zoom::HIPGraph::replay)) + .def( + "reset", + torch::wrap_pybind_function_no_gil(&at::zoom::HIPGraph::reset)) + .def( + "pool", + torch::wrap_pybind_function_no_gil(&at::zoom::HIPGraph::pool)) + .def( + "debug_dump", + torch::wrap_pybind_function_no_gil( + &::at::zoom::HIPGraph::debug_dump)) + .def( + "enable_debug_mode", + torch::wrap_pybind_function_no_gil( + &::at::zoom::HIPGraph::enable_debug_mode)) + .def( + "debug_dump", + torch::wrap_pybind_function_no_gil( + &::at::zoom::HIPGraph::debug_dump), + py::arg("debug_path")); +} diff --git a/torch/csrc/zoom/Module.cpp b/torch/csrc/zoom/Module.cpp new file mode 100644 index 00000000000000..7a0470fad0613e --- /dev/null +++ b/torch/csrc/zoom/Module.cpp @@ -0,0 +1,1533 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +// #include +#include +// #include +#include +#include +#include +#include +#include +#include + +#include +#include + +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef WIN32 +#include +#endif + +using namespace torch; + +static bool in_bad_fork = false; // True for children forked after zoom init + +#ifndef WIN32 +// Called in the forked child if zoom has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kPrivateUse1, true); +} +#endif + +// Should be called before the first zoom call. +// Note: This is distinct from initExtension because a stub zoom implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static c10::once_flag flag; + c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +// Zoom management methods +//////////////////////////////////////////////////////////////////////////////// + +PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); + auto device = THPUtils_unpackLong(arg); + + torch::utils::device_lazy_init(at::kPrivateUse1); + c10::zoom::set_device(static_cast(device)); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_exchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::zoom::ExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_maybeExchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::zoom::MaybeExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getDevice_wrap(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + auto device = static_cast(c10::zoom::current_device()); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* arg1 = nullptr; + PyObject* arg2 = nullptr; + if (!PyArg_ParseTuple(args, "OO", &arg1, &arg2)) { + THPUtils_invalidArguments( + args, + nullptr, + "can_device_peer_access", + 1, + "(int device, int peer_device);"); + return nullptr; + } + TORCH_CHECK( + THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer"); + TORCH_CHECK( + THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer"); + int64_t device = THPUtils_unpackLong(arg1); + int64_t peer_device = THPUtils_unpackLong(arg2); + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto can_access = at::zoom::canDeviceAccessPeer(device, peer_device); + return PyBool_FromLong(can_access); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + poison_fork(); + return THPUtils_packUInt64(c10::zoom::device_count()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + poison_fork(); +#ifdef ROCM_ARCH_FLAGS + static const char* flags = C10_STRINGIZE(ROCM_ARCH_FLAGS); + return THPUtils_packString(flags); +#else + Py_RETURN_NONE; +#endif + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return PyBool_FromLong(in_bad_fork); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getCurrentStream_wrap( + PyObject* /* unused */, + PyObject* device_index) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); + auto c10_device_index = THPUtils_unpackDeviceIndex(device_index); + auto stream = c10::zoom::getCurrentZoomStream(c10_device_index); + PyObject* output_tuple = PyTuple_New(3); + PyTuple_SetItem( + output_tuple, 0, THPUtils_packInt64(static_cast(stream.id()))); + PyTuple_SetItem( + output_tuple, 1, THPUtils_packDeviceIndex(stream.device_index())); + PyTuple_SetItem( + output_tuple, + 2, + THPUtils_packInt64(static_cast(stream.device_type()))); + return output_tuple; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getCurrentStream_raw( + PyObject* /* unused */, + PyObject* device_index) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); + auto c10_device_index = THPUtils_unpackDeviceIndex(device_index); + return PyLong_FromVoidPtr( + c10::zoom::getCurrentZoomStream(c10_device_index).stream()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getDefaultStream_wrap( + PyObject* /* unused */, + PyObject* device_index) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(device_index), "invalid argument to getDefaultStream"); + auto c10_device_index = THPUtils_unpackDeviceIndex(device_index); + auto stream = c10::zoom::getDefaultZoomStream(c10_device_index); + PyObject* output_tuple = PyTuple_New(3); + PyTuple_SetItem( + output_tuple, 0, THPUtils_packInt64(static_cast(stream.id()))); + PyTuple_SetItem( + output_tuple, 1, THPUtils_packDeviceIndex(stream.device_index())); + PyTuple_SetItem( + output_tuple, + 2, + THPUtils_packInt64(static_cast(stream.device_type()))); + return output_tuple; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_setStream_wrap( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + int64_t stream_id = 0; + int64_t device_index = 0; + int64_t device_type = 0; + + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + constexpr const char* kwlist[] = { + "stream_id", "device_index", "device_type", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "|LLL", + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(kwlist), + &stream_id, + &device_index, + &device_type)) { + } + + auto stream = c10::zoom::ZoomStream::unpack3( + stream_id, + static_cast(device_index), + static_cast(device_type)); + + auto device = c10::zoom::current_device(); + if (device != stream.device_index()) { + c10::zoom::set_device(stream.device_index()); + } + c10::zoom::setCurrentZoomStream(stream); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getCompiledVersion(PyObject* self, PyObject* noargs) { + return THPUtils_packInt64((int64_t)ROCM_VERSION); +} + +PyObject* THCPModule_zoomHostAllocator(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + c10::Allocator* allocator = at::zoom::getCachingHostAllocator(); + return PyLong_FromVoidPtr(allocator); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_zoomCachingAllocator_raw_alloc( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* size_o = nullptr; + PyObject* stream_o = nullptr; + if (!PyArg_ParseTuple(args, "OO", &size_o, &stream_o)) { + THPUtils_invalidArguments( + args, + nullptr, + "caching_allocator_alloc", + 1, + "(ssize_t size, intptr_t stream);"); + return nullptr; + } + auto size = PyLong_AsSsize_t(size_o); + hipStream_t stream = static_cast(PyLong_AsVoidPtr(stream_o)); + void* mem = nullptr; + { + pybind11::gil_scoped_release no_gil; + mem = c10::zoom::ZoomCachingAllocator::raw_alloc_with_stream(size, stream); + } + return PyLong_FromVoidPtr(mem); + END_HANDLE_TH_ERRORS +} + +// Unpack a PyObject to at::Scalar, throw an exception if it fails +at::Scalar as_scalar(PyObject* arg) { + // Zero-dim tensors are converted to Scalars as-is. Note this doesn't + // currently handle most NumPy scalar types except np.float64. + if (THPVariable_Check(arg)) { + return THPVariable_Unpack(arg).item(); + } + + if (THPUtils_checkLong(arg)) { + return at::Scalar(static_cast(THPUtils_unpackLong(arg))); + } + + if (PyBool_Check(arg)) { + return at::Scalar(THPUtils_unpackBool(arg)); + } + + if (PyComplex_Check(arg)) { + return at::Scalar(THPUtils_unpackComplexDouble(arg)); + } + return at::Scalar(THPUtils_unpackDouble(arg)); +} + +// Entrypoint for the callable created by torch.zoom.jiterator +// See jiterator.py for more details +// PyObject* THCPModule_zoomJiteratorCompileAndLaunchKernel( +// PyObject* _unused, +// PyObject* args) { +// HANDLE_TH_ERRORS + +// PyObject* code_string_o = nullptr; +// PyObject* kernel_name_o = nullptr; +// PyObject* return_by_ref_o = nullptr; +// PyObject* num_outputs_o = nullptr; +// PyObject* tensors_o = nullptr; +// PyObject* kwargs_o = nullptr; +// if (!PyArg_ParseTuple( +// args, +// "OOOOO|O", +// &code_string_o, +// &kernel_name_o, +// &return_by_ref_o, +// &num_outputs_o, +// &tensors_o, +// &kwargs_o)) { +// return nullptr; +// } + +// const std::string code_string = THPUtils_unpackString(code_string_o); +// const std::string kernel_name = THPUtils_unpackString(kernel_name_o); +// const bool return_by_ref = THPUtils_unpackBool(return_by_ref_o); +// const int num_outputs = static_cast(THPUtils_unpackLong(num_outputs_o)); + +// TORCH_CHECK( +// PyTuple_Check(tensors_o), +// "tensors argument is expected to " +// "be a tuple, but got ", +// THPUtils_typename(tensors_o)); +// Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o); + +// c10::SmallVector tensors; +// for (const auto i : c10::irange(num_tensors)) { +// PyObject* _tensor = PyTuple_GET_ITEM(tensors_o, i); +// TORCH_CHECK( +// THPVariable_Check(_tensor), +// i, +// " of input tensors tuple is not a Tensor"); + +// tensors.emplace_back(THPVariable_Unpack(_tensor)); +// } + +// c10::SmallVector extra_args; +// PyObject* key = nullptr; +// PyObject* value = nullptr; +// Py_ssize_t pos = 0; +// while (PyDict_Next(kwargs_o, &pos, &key, &value)) { +// extra_args.emplace_back(as_scalar(value)); +// } + +// c10::SmallVector outputs = at::zoom::CompileAndLaunchKernel( +// code_string, +// kernel_name, +// num_outputs, +// tensors, +// extra_args, +// return_by_ref); + +// if (num_outputs == 1) { +// return THPVariable_Wrap(outputs[0]); +// } else { +// PyObject* output_tuple = PyTuple_New(num_outputs); +// for (int i = 0; i < num_outputs; ++i) { +// PyTuple_SetItem(output_tuple, i, THPVariable_Wrap(outputs[i])); +// } +// return output_tuple; +// } + +// END_HANDLE_TH_ERRORS +// } + +PyObject* THCPModule_zoomCachingAllocator_raw_delete( + PyObject* _unused, + PyObject* obj) { + HANDLE_TH_ERRORS + void* mem_ptr = PyLong_AsVoidPtr(obj); + { + pybind11::gil_scoped_release no_gil; + c10::zoom::ZoomCachingAllocator::raw_delete(mem_ptr); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_zoomCachingAllocator_set_allocator_settings( + PyObject* _unused, + PyObject* env) { + HANDLE_TH_ERRORS + c10::zoom::ZoomCachingAllocator::setAllocatorSettings( + THPUtils_unpackString(env)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packString(c10::zoom::ZoomCachingAllocator::name()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_zoomSynchronize(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + c10::zoom::device_synchronize(); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// PyObject* THCPModule_zoomIPCCollect(PyObject* _unused, PyObject* noargs) { +// HANDLE_TH_ERRORS +// torch::zoomIPCCollect(); +// Py_RETURN_NONE; +// END_HANDLE_TH_ERRORS +// } + +// PyObject* THCPModule_zoomSleep(PyObject* _unused, PyObject* cycles) { +// HANDLE_TH_ERRORS +// TORCH_CHECK( +// THPUtils_checkLong(cycles), "torch.zoom._sleep(): expected 'int'"); +// int64_t unpacked_cycles = THPUtils_unpackLong(cycles); +// { +// pybind11::gil_scoped_release no_gil; +// at::zoom::sleep(unpacked_cycles); +// } +// Py_RETURN_NONE; +// END_HANDLE_TH_ERRORS +// } + +// We need to ensure that as long as a thread will NEVER loose the GIL as long +// as it holds the CUDA mutex. Otherwise another thread might be scheduled and +// try to e.g. allocate a new tensor which will cause a deadlock. It's enough to +// have a single global, because it can be only set once (zoomMutex is not +// recursive) by the thread that owns the mutex (obviously there can be only one +// such thread). +static PyGILState_STATE zoomMutexGILState; + +PyObject* THCPModule_zoomLockMutex(PyObject* module, PyObject* noargs) { + auto mutex = c10::zoom::getFreeMutex(); + // This has to be a busy loop because we **absolutely need to** hold the GIL + // or it's a recipe for a deadlock otherwise (if we let other Python threads + // run while we have the zoomMutex, but not the GIL, they might try to e.g. + // free a Zoom tensor and acquire the zoomMutex without giving up the GIL, + // because it happens deep within THC). + while (true) { + if (mutex->try_lock()) + break; + { + pybind11::gil_scoped_release no_gil; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + } + + zoomMutexGILState = PyGILState_Ensure(); + Py_RETURN_NONE; +} + +PyObject* THCPModule_zoomUnlockMutex(PyObject* module, PyObject* noargs) { + auto mutex = c10::zoom::getFreeMutex(); + PyGILState_Release(zoomMutexGILState); + mutex->unlock(); + Py_RETURN_NONE; +} + +PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to has_primary_context"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (c10::zoom::hasPrimaryContext(device_index)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* fraction_o = nullptr; + PyObject* device_o = nullptr; + if (!PyArg_ParseTuple(args, "OO", &fraction_o, &device_o)) { + THPUtils_invalidArguments( + args, + nullptr, + "set_memory_fraction", + 1, + "(double fraction, int device);"); + return nullptr; + } + double fraction = PyFloat_AsDouble(fraction_o); + auto device_index = THPUtils_unpackDeviceIndex(device_o); + + c10::zoom::ZoomCachingAllocator::setMemoryFraction(fraction, device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + c10::zoom::ZoomCachingAllocator::emptyCache(); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + + using c10::zoom::ZoomCachingAllocator::DeviceStats; + using c10::zoom::ZoomCachingAllocator::Stat; + using c10::zoom::ZoomCachingAllocator::StatArray; + using c10::zoom::ZoomCachingAllocator::StatType; + + const auto statToDict = [](const Stat& stat) { + py::dict dict; + + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto statArrayToDict = [=](const StatArray& statArray) { + const std::array(StatType::NUM_TYPES)> + statTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(statTypeNames.size())) { + dict[statTypeNames[i]] = statToDict(statArray[i]); + } + return dict; + }; + + const DeviceStats stats = + c10::zoom::ZoomCachingAllocator::getDeviceStats(device_index); + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocation"] = statArrayToDict(stats.allocation); + result["segment"] = statArrayToDict(stats.segment); + result["active"] = statArrayToDict(stats.active); + result["inactive_split"] = statArrayToDict(stats.inactive_split); + result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes); + result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes); + result["active_bytes"] = statArrayToDict(stats.active_bytes); + result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes); + result["requested_bytes"] = statArrayToDict(stats.requested_bytes); + result["oversize_allocations"] = statToDict(stats.oversize_allocations); + result["oversize_segments"] = statToDict(stats.oversize_segments); + + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_resetAccumulatedMemoryStats( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "invalid argument to reset_accumulated_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::zoom::ZoomCachingAllocator::resetAccumulatedStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::zoom::ZoomCachingAllocator::resetPeakStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +CapturedTraceback* getFromContext( + const std::shared_ptr& x) { + if (CapturedTraceback* sc = dynamic_cast(x.get())) { + return sc; + } + TORCH_CHECK( + false, + "attempting to gather stack context from the wrong StackContext type."); +} + +PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS + + using c10::zoom::ZoomCachingAllocator::BlockInfo; + using c10::zoom::ZoomCachingAllocator::SegmentInfo; + + py::str device_s = "device"; + py::str address_s = "address"; + py::str total_size_s = "total_size"; + py::str allocated_size_s = "allocated_size"; + py::str active_size_s = "active_size"; + py::str requested_size_s = "requested_size"; + py::str stream_s = "stream"; + py::str segment_type_s = "segment_type"; + py::str segment_pool_id = "segment_pool_id"; + py::str large_s = "large"; + py::str small_s = "small"; + py::str size_s = "size"; + py::str state_s = "state"; + py::str active_allocated_s = "active_allocated"; + py::str active_pending_free_s = "active_pending_free"; + py::str inactive_s = "inactive"; + py::str addr_s = "addr"; + py::str cpp_frames_s = "cpp_frames"; + py::str blocks_s = "blocks"; + py::str is_expandable_s = "is_expandable"; + py::str frames_s = "frames"; + py::str time_us_s = "time_us"; + + py::list empty_frames; + std::vector to_gather_frames; + std::vector to_gather_dest; + + auto add_frame_key = [&](const py::dict& d, + const std::shared_ptr& ctx) { + if (ctx) { + auto sc = getFromContext(ctx); + to_gather_frames.emplace_back(sc); + to_gather_dest.emplace_back(d); + } else { + d[frames_s] = empty_frames; + } + }; + + const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) { + py::dict segmentDict; + segmentDict[device_s] = segmentInfo.device; + segmentDict[address_s] = segmentInfo.address; + segmentDict[total_size_s] = segmentInfo.total_size; + segmentDict[allocated_size_s] = segmentInfo.allocated_size; + segmentDict[active_size_s] = segmentInfo.active_size; + segmentDict[requested_size_s] = segmentInfo.requested_size; + // we want the python objects to pickle easily so use an int to + // represent the stream rather than a torch.zoom.stream object + segmentDict[stream_s] = int64_t(segmentInfo.stream); + segmentDict[segment_type_s] = (segmentInfo.is_large ? large_s : small_s); + segmentDict[segment_pool_id] = segmentInfo.owner_private_pool_id; + segmentDict[is_expandable_s] = segmentInfo.is_expandable; + add_frame_key(segmentDict, segmentInfo.context_when_allocated); + + auto address = segmentInfo.address; + py::list blocks; + for (const auto& blockInfo : segmentInfo.blocks) { + py::dict blockDict; + blockDict[address_s] = address; + blockDict[size_s] = blockInfo.size; + blockDict[requested_size_s] = blockInfo.requested_size; + blockDict[state_s] = + (blockInfo.allocated + ? active_allocated_s + : (blockInfo.active ? active_pending_free_s : inactive_s)); + add_frame_key(blockDict, blockInfo.context_when_allocated); + blocks.append(blockDict); + address += blockInfo.size; + } + segmentDict[blocks_s] = blocks; + + return segmentDict; + }; + + auto snapshot = c10::zoom::ZoomCachingAllocator::snapshot(); + + py::list segments; + + for (const auto& segmentInfo : snapshot.segments) { + segments.append(segmentInfoToDict(segmentInfo)); + } + + py::list traces; + py::str action_s = "action"; + py::str alloc_s = "alloc"; + py::str free_requested_s = "free_requested"; + py::str free_completed_s = "free_completed"; + py::str segment_alloc_s = "segment_alloc"; + py::str segment_free_s = "segment_free"; + py::str segment_map_s = "segment_map"; + py::str segment_unmap_s = "segment_unmap"; + + py::str snapshot_s = "snapshot"; + py::str oom_s = "oom"; + py::str device_free_s = "device_free"; + + using namespace c10::zoom::ZoomCachingAllocator; + + auto action_to_str = [&](TraceEntry::Action action) { + switch (action) { + case TraceEntry::ALLOC: + return alloc_s; + case TraceEntry::FREE_REQUESTED: + return free_requested_s; + case TraceEntry::FREE_COMPLETED: + return free_completed_s; + case TraceEntry::SEGMENT_ALLOC: + return segment_alloc_s; + case TraceEntry::SEGMENT_FREE: + return segment_free_s; + case TraceEntry::OOM: + return oom_s; + case TraceEntry::SNAPSHOT: + return snapshot_s; + case TraceEntry::SEGMENT_UNMAP: + return segment_unmap_s; + case TraceEntry::SEGMENT_MAP: + return segment_map_s; + } + throw std::runtime_error("unreachable"); + }; + + for (const auto& traceInfo : snapshot.device_traces) { + py::list trace; + for (const auto& te : traceInfo) { + py::dict trace_entry; + if (te.context_) { + // without further compression frames can get really large on dump + auto sc = getFromContext(te.context_); + to_gather_frames.emplace_back(sc); + to_gather_dest.emplace_back(trace_entry); + } + trace_entry[action_s] = action_to_str(te.action_); + trace_entry[TraceEntry::OOM == te.action_ ? device_free_s : addr_s] = + te.addr_; + trace_entry[size_s] = te.size_; + trace_entry[stream_s] = int64_t(te.stream_); + trace_entry[time_us_s] = te.time_.t_; + trace.append(trace_entry); + } + traces.append(trace); + } + + py::dict allocator_settings; + py::str last_allocator_settings_s = "PYTORCH_ZOOM_ALLOC_CONF"; + py::str max_split_size_s = "max_split_size"; + py::str garbage_collection_threshold_s = "garbage_collection_threshold"; + py::str expandable_segments_s = "expandable_segments"; + py::str pinned_num_register_threads_s = "pinned_num_register_threads"; + py::str release_lock_on_malloc_s = "release_lock_on_hipMalloc"; + py::str pinned_use_host_register_s = "pinned_use_zoom_host_register"; + py::str roundup_power2_divisions_s = "roundup_power2_divisions"; + + allocator_settings[last_allocator_settings_s] = + snapshot.config_metadata.last_allocator_settings; + allocator_settings[max_split_size_s] = + int64_t(snapshot.config_metadata.max_split_size); + allocator_settings[garbage_collection_threshold_s] = + snapshot.config_metadata.garbage_collection_threshold; + allocator_settings[expandable_segments_s] = + snapshot.config_metadata.expandable_segments; + allocator_settings[pinned_num_register_threads_s] = + int64_t(snapshot.config_metadata.pinned_num_register_threads); + allocator_settings[release_lock_on_malloc_s] = + snapshot.config_metadata.release_lock_on_malloc; + allocator_settings[pinned_use_host_register_s] = + snapshot.config_metadata.pinned_use_host_register; + unsigned int roundup_key = 1; + py::dict roundup_settings; + for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) { + py::str roundup_key_s = std::to_string(roundup_key); + roundup_settings[roundup_key_s] = int64_t(v); + roundup_key *= 2; + } + allocator_settings[roundup_power2_divisions_s] = roundup_settings; + + py::dict result; + result["segments"] = segments; + result["device_traces"] = traces; + result["allocator_settings"] = allocator_settings; + + auto frames = py_symbolize(to_gather_frames); + for (auto i : c10::irange(frames.size())) { + to_gather_dest.at(i)[frames_s] = frames.at(i); + } + + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_attachOutOfMemoryObserver( + PyObject* _unused, + PyObject* observer) { + HANDLE_TH_ERRORS + Py_XINCREF(observer); + auto obs = [observer]( + int64_t device, + int64_t alloc, + int64_t device_allocated, + int64_t device_free) { + py::gil_scoped_acquire g; + PyObject* result = PyObject_CallFunction( + observer, "LLLL", device, alloc, device_allocated, device_free); + if (!result) { + throw py::error_already_set(); + } + Py_XDECREF(result); + }; + at::globalContext().lazyInitPrivateUse1(); + c10::zoom::ZoomCachingAllocator::attachOutOfMemoryObserver(std::move(obs)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_zoomSetSyncDebugMode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_WARN_ONCE( + "Synchronization debug mode is a prototype feature and does not yet detect all " + "synchronizing operations"); + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode"); + int64_t debug_mode = THPUtils_unpackLong(arg); + TORCH_CHECK( + debug_mode >= 0 && debug_mode <= 2, + "invalid value of debug_mode, expected one of 0,1,2"); + c10::zoom::SyncDebugMode l = c10::zoom::SyncDebugMode::L_DISABLED; + switch (debug_mode) { + case 0: + l = c10::zoom::SyncDebugMode::L_DISABLED; + break; + case 1: + l = c10::zoom::SyncDebugMode::L_WARN; + break; + case 2: + l = c10::zoom::SyncDebugMode::L_ERROR; + break; + default: + break; // can't happen + } + c10::zoom::warning_state().set_sync_debug_mode(l); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_zoomGetSyncDebugMode(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + auto debug_mode = c10::zoom::warning_state().get_sync_debug_mode(); + switch (debug_mode) { + case c10::zoom::SyncDebugMode::L_DISABLED: + return THPUtils_packInt32(0); + case c10::zoom::SyncDebugMode::L_WARN: + return THPUtils_packInt32(1); + case c10::zoom::SyncDebugMode::L_ERROR: + return THPUtils_packInt32(2); + default: + return THPUtils_packInt32(-1); // can't happen + } + END_HANDLE_TH_ERRORS +} + +//////////////////////////////////////////////////////////////////////////////// +// Zoom module initialization +//////////////////////////////////////////////////////////////////////////////// + +static void registerZoomDeviceProperties(PyObject* module) { + // Add _hipDeviceProp_tertires class to torch._C + auto m = py::handle(module).cast(); + py::class_(m, "_ZoomDeviceProperties") + .def_readonly("name", &hipDeviceProp_t::name) + .def_readonly("major", &hipDeviceProp_t::major) + .def_readonly("minor", &hipDeviceProp_t::minor) + .def_readonly("is_multi_gpu_board", &hipDeviceProp_t::isMultiGpuBoard) + .def_readonly("is_integrated", &hipDeviceProp_t::integrated) + .def_readonly( + "multi_processor_count", &hipDeviceProp_t::multiProcessorCount) + .def_readonly("total_memory", &hipDeviceProp_t::totalGlobalMem) + .def_readonly( + "max_threads_per_multi_processor", + &hipDeviceProp_t::maxThreadsPerMultiProcessor) + .def_readonly( + "gcnArchName", + &hipDeviceProp_t::gcnArchName + ) + .def("__repr__", [](const hipDeviceProp_t& prop) { + std::ostringstream stream; + stream << "_ZoomDeviceProperties(name='" << prop.name + << "', major=" << prop.major << ", minor=" << prop.minor + << ", gcnArchName='" << prop.gcnArchName << "'" + << ", total_memory=" << prop.totalGlobalMem / (1024ull * 1024) + << "MB, multi_processor_count=" << prop.multiProcessorCount + << ")"; + return stream.str(); + }); + + // m.def( + // "_zoom_record_memory_history_legacy", + // static_cast( + // torch::zoom::_record_memory_history)); + + // m.def( + // "_zoom_record_memory_history", + // static_cast, + // std::optional, + // const std::string&, + // size_t)>(torch::zoom::_record_memory_history)); + + m.def("_zoom_isHistoryEnabled", []() { + return c10::zoom::ZoomCachingAllocator::isHistoryEnabled(); + }); + + // m.def("_zoom_get_conv_benchmark_empty_cache", []() { + // return at::native::_cudnn_get_conv_benchmark_empty_cache(); + // }); + + // m.def("_cudnn_set_conv_benchmark_empty_cache", [](bool enable) { + // return at::native::_cudnn_set_conv_benchmark_empty_cache(enable); + // }); +} + +// We choose to ignore certain blocks that are currently allocated +// when we set the pool to its checkpoint. For those blocks, we need +// to swap out the deleter function of their corresponding blocks +// so that a deallocation is not triggered when they die. +void removeStorageDeleterFns( + const std::vector& stale_live_storages, + std::unordered_set definitely_stale_pointers) { + for (c10::StorageImpl* stale_storage : stale_live_storages) { + auto ptr = stale_storage->data_ptr().get(); + auto allocated_pointer = definitely_stale_pointers.find(ptr); + TORCH_CHECK(allocated_pointer != definitely_stale_pointers.end()); + auto t = c10::zoom::ZoomCachingAllocator::get(); + bool succeeded = stale_storage->mutable_data_ptr().compare_exchange_deleter( + t->raw_deleter(), &c10::detail::deleteNothing); + + TORCH_CHECK( + succeeded, + "Unexpected deleter function on storage, could not swap function"); + } +} + +void addStorageDeleterFns( + std::vector& storages_to_add_deleters_to, + c10::zoom::ZoomCachingAllocator::CheckpointDelta& delta) { + std::unordered_map storages; + for (auto& storage : storages_to_add_deleters_to) { + storages[storage->data_ptr().get()] = storage; + } + + for (auto& data_ptr : delta.dataptrs_allocd) { + auto storage_pair = storages.find(data_ptr.get()); + if (storage_pair != storages.end()) { + auto ctx = storage_pair->second->data_ptr().get_context(); + TORCH_CHECK(ctx == nullptr, " Not expecting deleter function"); + storage_pair->second->set_data_ptr_noswap(std::move(data_ptr)); + } else { + data_ptr.release_context(); + } + } +} + +static void registerZoomPluggableAllocator(PyObject* module) { + auto m = py::handle(module).cast(); + + // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + c10::zoom::ZoomCachingAllocator::ZoomAllocator, + std::shared_ptr>( + m, "_zoom_ZoomAllocator"); + m.def("_zoom_getAllocator", []() { + return py::cast(torch::zoom::ZoomPluggableAllocator::getCurrentAllocator()); + }); + + m.def( + "_zoom_changeCurrentAllocator", + [](const std::shared_ptr& + allocator) { + torch::zoom::ZoomPluggableAllocator::changeCurrentAllocator(allocator); + }); + py::class_< + torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator, + c10::zoom::ZoomCachingAllocator::ZoomAllocator, + std::shared_ptr< + torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator>>( + m, "_ZoomPluggableAllocator") + .def( + "set_init_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_init_fn(func); + }) + .def( + "set_reset_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_reset_fn(func); + }) + .def( + "set_memory_fraction_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(double, int); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_memory_fraction_fn(func); + }) + .def( + "set_base_alloc_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void*(void*, size_t*); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_base_alloc_fn(func); + }) + .def( + "set_record_stream_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(void*, hipStream_t); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_record_stream_fn(func); + }) + .def( + "set_begin_allocate_to_pool", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void( + int, c10::zoom::MempoolId_t, std::function); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_begin_allocate_to_pool(func); + }) + .def( + "set_end_allocate_to_pool_fn", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int, c10::zoom::MempoolId_t); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_end_allocate_to_pool_fn(func); + }) + .def( + "set_release_pool", + [](torch::zoom::ZoomPluggableAllocator::ZoomPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int, c10::zoom::MempoolId_t); + std::function func = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(func_ptr); + self.set_release_pool(func); + }); + m.def("_zoom_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) { + using MallocFuncType = void*(size_t, int, hipStream_t); + using FreeFuncType = void(void*, size_t, int, hipStream_t); + std::function malloc_fn = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(malloc_ptr); + std::function free_fn = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(free_ptr); + return torch::zoom::ZoomPluggableAllocator::createCustomAllocator( + malloc_fn, free_fn); + }); + + // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + c10::zoom::ZoomCachingAllocator::AllocatorState, + std::shared_ptr>( + m, "_zoom_ZoomAllocator_AllocatorState"); + + m.def( + "_zoom_getCheckpointState", + [](c10::DeviceIndex device, c10::zoom::MempoolId_t id) { + return c10::zoom::ZoomCachingAllocator::getCheckpointState(device, id); + }); + + m.def("_free_And_Remove_DeleterFn", [](size_t storage_impl_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + auto alloc = c10::zoom::ZoomCachingAllocator::get(); + auto data_ptr = storage_impl->data_ptr().get(); + bool succeeded = storage_impl->mutable_data_ptr().compare_exchange_deleter( + alloc->raw_deleter(), c10::detail::deleteNothing); + TORCH_CHECK(succeeded, "Expected standard deleter"); + c10::zoom::ZoomCachingAllocator::raw_delete(data_ptr); + }); + + m.def( + "_set_storage_access_error_msg", [](const at::Tensor& t, std::string s) { + t.unsafeGetTensorImpl() + ->release_storage_and_set_meta_custom_data_ptr_error_msg_(s); + }); + + m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + auto alloc = c10::zoom::ZoomCachingAllocator::get(); + return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter()); + }); + + m.def("_set_cached_tensors_enabled", [](bool enabled) { + at::caching::set_cached_tensors_enabled(enabled); + }); + + m.def("_add_cached_tensor", [](const at::Tensor& t) { + at::caching::add_cached_tensor(t); + }); + + m.def("_remove_cached_tensor", [](const at::Tensor& t) { + at::caching::remove_cached_tensor(t); + }); + + m.def("_is_cached_tensor", [](const at::Tensor& t) { + return at::caching::is_cached_tensor(t); + }); + + m.def("_storage_Use_Count", [](size_t storage_impl_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + return c10::raw::weak_intrusive_ptr::use_count(storage_impl); + }); + + m.def( + "_tensors_data_ptrs_at_indices_equal", + [](py::list& tensors, py::list& data_ptrs, py::list& indices) { + for (size_t i = 0, end = indices.size(); i < end; ++i) { + auto index = indices[i].cast(); + auto t = tensors[index].cast(); + auto data_ptr = data_ptrs[index].cast(); + if (reinterpret_cast(t.data_ptr()) != data_ptr) { + return false; + } + } + return true; + }); + + m.def( + "_construct_Zoom_Tensor_From_Storage_And_Metadata", + [](py::dict& metadata, c10::Storage s) { + auto dtype_arg = metadata["dtype"].ptr(); + auto meta = scalarTypeToTypeMeta(toScalarType(dtype_arg)); + + constexpr c10::DispatchKeySet zoom_dks(c10::DispatchKey::PrivateUse1); + at::Tensor tensor = at::detail::make_tensor_base( + std::move(s), zoom_dks, meta); + + tensor.unsafeGetTensorImpl()->set_sizes_and_strides( + metadata["size"].cast>(), + metadata["stride"].cast>()); + tensor.unsafeGetTensorImpl()->set_storage_offset( + metadata["storage_offset"].cast()); + return tensor; + }); + + m.def( + "_zoom_beginAllocateCurrentStreamToPool", + [](c10::DeviceIndex device, c10::zoom::MempoolId_t mempool_id) { + auto stream = c10::zoom::getCurrentZoomStream(device); + TORCH_CHECK(stream, "Expected stream capture to be under way"); + c10::zoom::ZoomCachingAllocator::beginAllocateToPool( + device, mempool_id, [stream](hipStream_t target) { + return target == stream; + }); + }); + + m.def( + "_zoom_endAllocateCurrentStreamToPool", + [](c10::DeviceIndex device, c10::zoom::MempoolId_t mempool_id) { + c10::zoom::ZoomCachingAllocator::endAllocateToPool(device, mempool_id); + }); + + m.def( + "_zoom_releasePool", + [](c10::DeviceIndex device, c10::zoom::MempoolId_t mempool_id) { + c10::zoom::ZoomCachingAllocator::releasePool(device, mempool_id); + }); + + m.def( + "_zoom_checkPoolLiveAllocations", + [](c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id, + const py::set& expected_live_allocations) { + std::unordered_set allocations; + allocations.reserve(expected_live_allocations.size()); + for (auto& elem : expected_live_allocations) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + allocations.insert(reinterpret_cast(py::cast(elem))); + } + return c10::zoom::ZoomCachingAllocator::checkPoolLiveAllocations( + device, mempool_id, allocations); + }); + + m.def( + "_zoom_setCheckpointPoolState", + [](c10::DeviceIndex device, + std::shared_ptr pps, + const std::vector& stale_storages_ptr, + const std::vector& storages_to_add_deleters_to_ptr = {}) { + std::unordered_set ptr_set; + // iterate on std::vector for determinism + std::vector ptrs; + for (size_t ptr_int : stale_storages_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* ptr = (c10::StorageImpl*)ptr_int; + if (!ptr_set.count(ptr)) { + ptrs.push_back(ptr); + ptr_set.insert(ptr); + } + } + auto delta = c10::zoom::ZoomCachingAllocator::setCheckpointPoolState( + device, std::move(pps)); + auto& freed_pointers = delta.ptrs_freed; + + std::unordered_set allocd_set; + for (auto& data_ptr : delta.dataptrs_allocd) { + allocd_set.insert(data_ptr.get()); + } + std::unordered_set freed_pointer_set; + size_t definite_freed_count = 0; + for (void* ptr : freed_pointers) { + if (!allocd_set.count(ptr)) { + definite_freed_count += 1; + } + freed_pointer_set.insert((ptr)); + } + // that block has already been freed, + // so even those this will error, so too will the allocator + // when the corresponding tensor dies because there is no + // live tensor corresponding to it + TORCH_CHECK( + ptr_set.size() >= definite_freed_count, + "Any stale tensors which are being manually freed" + " must be passed to set checkpoint"); + + removeStorageDeleterFns(ptrs, freed_pointer_set); + std::vector storages_to_add_deleters_to; + storages_to_add_deleters_to.reserve( + storages_to_add_deleters_to_ptr.size()); + for (size_t ptr_int : storages_to_add_deleters_to_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + storages_to_add_deleters_to.push_back((c10::StorageImpl*)ptr_int); + } + + addStorageDeleterFns(storages_to_add_deleters_to, delta); + }); +} + +static void bindGetDeviceProperties(PyObject* module) { + // Add method to torch.zoom + auto m = py::handle(module).cast(); + m.def( + "_get_device_properties", + [](c10::DeviceIndex device) -> hipDeviceProp_t* { + return at::zoom::getDeviceProperties(device); + }, + py::return_value_policy::reference); +} + +// Callback for python part. Used for additional initialization of python +// classes +static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { +#if C10_ASAN_ENABLED + TORCH_WARN( + "torch.zoom: your pytorch binary has address sanitizer (asan) built in, " + "asan is currently not compatible with torch.zoom module, " + "you might get unexpected behavior (eg. out of memory, crash, etc.), " + "please rebuild pytorch without asan if you need to use this module"); +#endif + HANDLE_TH_ERRORS + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); + at::globalContext().lazyInitPrivateUse1(); + + auto m = THPObjectPtr(PyImport_ImportModule("torch.zoom")); + if (!m) + throw python_error(); + + auto set_module_attr = [&](const char* name, PyObject* v) { + // PyObject_SetAttrString doesn't steal reference. So no need to incref. + if (PyObject_SetAttrString(m, name, v) < 0) { + throw python_error(); + } + }; + + auto num_gpus = c10::zoom::device_count(); + auto default_zoom_generators = PyTuple_New(static_cast(num_gpus)); + for (const auto i : c10::irange(num_gpus)) { + auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator( + at::zoom::detail::getDefaultZoomGenerator(i)); + // This reference is meant to be given away, so no need to incref here. + PyTuple_SetItem(default_zoom_generators, i, (PyObject*)cast_gen); + } + set_module_attr("default_generators", default_zoom_generators); + bindGetDeviceProperties(m); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_getCurrentBlasHandle_wrap( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + hipblasHandle_t handle = at::zoom::getCurrentHIPBlasHandle(); + return PyLong_FromVoidPtr(handle); + END_HANDLE_TH_ERRORS +} + + +// PyObject* THCPModule_rocm_is_backward_pass( +// PyObject* _unused, +// PyObject* noargs) { +// HANDLE_TH_ERRORS +// #if USE_ROCM +// if (at::ROCmBackwardPassGuard::is_backward_pass()) { +// Py_RETURN_TRUE; +// } else { +// Py_RETURN_FALSE; +// } +// #else +// Py_RETURN_FALSE; +// #endif +// END_HANDLE_TH_ERRORS +// } + +static PyObject* THCPModule_isCurrentStreamCapturing_wrap( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + // If there's no zoom context, at::zoom::currentStreamCaptureStatus returns + // CaptureStatus::None without initializing a context. + if (at::zoom::currentStreamCaptureStatus() == at::zoom::CaptureStatus::None) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } + END_HANDLE_TH_ERRORS +} + +// NOLINTNEXTLINE(*-c-arrays*, *-global-variables) +static struct PyMethodDef _THCPModule_methods[] = { + {"_zoom_init", THCPModule_initExtension, METH_NOARGS, nullptr}, + {"_zoom_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr}, + {"_zoom_exchangeDevice", THCPModule_exchangeDevice, METH_O, nullptr}, + {"_zoom_maybeExchangeDevice", + THCPModule_maybeExchangeDevice, + METH_O, + nullptr}, + {"_zoom_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, + {"_zoom_getDeviceCount", + THCPModule_getDeviceCount_wrap, + METH_NOARGS, + nullptr}, + {"_zoom_canDeviceAccessPeer", + THCPModule_canDeviceAccessPeer_wrap, + METH_VARARGS, + nullptr}, + {"_zoom_getArchFlags", THCPModule_getArchFlags, METH_NOARGS, nullptr}, + {"_zoom_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr}, + {"_zoom_getCurrentStream", + THCPModule_getCurrentStream_wrap, + METH_O, + nullptr}, + {"_zoom_getCurrentRawStream", + THCPModule_getCurrentStream_raw, + METH_O, + nullptr}, + {"_zoom_getDefaultStream", + THCPModule_getDefaultStream_wrap, + METH_O, + nullptr}, + {"_zoom_getCurrentBlasHandle", + THCPModule_getCurrentBlasHandle_wrap, + METH_NOARGS, + nullptr}, + {"_zoom_isCurrentStreamCapturing", + THCPModule_isCurrentStreamCapturing_wrap, + METH_NOARGS, + nullptr}, + {"_zoom_setStream", + castPyCFunctionWithKeywords(THCPModule_setStream_wrap), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_zoom_getCompiledVersion", + THCPModule_getCompiledVersion, + METH_NOARGS, + nullptr}, + {"_zoom_hasPrimaryContext", THCPModule_hasPrimaryContext, METH_O, nullptr}, + {"_zoom_setMemoryFraction", + THCPModule_setMemoryFraction, + METH_VARARGS, + nullptr}, + {"_zoom_emptyCache", THCPModule_emptyCache, METH_NOARGS, nullptr}, + {"_zoom_memoryStats", THCPModule_memoryStats, METH_O, nullptr}, + {"_zoom_resetAccumulatedMemoryStats", + THCPModule_resetAccumulatedMemoryStats, + METH_O, + nullptr}, + {"_zoom_resetPeakMemoryStats", + THCPModule_resetPeakMemoryStats, + METH_O, + nullptr}, + {"_zoom_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, + {"_zoom_attach_out_of_memory_observer", + THCPModule_attachOutOfMemoryObserver, + METH_O, + nullptr}, + {"_zoom_zoomHostAllocator", + THCPModule_zoomHostAllocator, + METH_NOARGS, + nullptr}, + {"_zoom_zoomCachingAllocator_raw_alloc", + THCPModule_zoomCachingAllocator_raw_alloc, + METH_VARARGS, + nullptr}, + {"_zoom_zoomCachingAllocator_raw_delete", + THCPModule_zoomCachingAllocator_raw_delete, + METH_O, + nullptr}, + {"_zoom_zoomCachingAllocator_set_allocator_settings", + THCPModule_zoomCachingAllocator_set_allocator_settings, + METH_O, + nullptr}, + {"_zoom_getAllocatorBackend", + THCPModule_getAllocatorBackend, + METH_NOARGS, + nullptr}, + {"_zoom_synchronize", THCPModule_zoomSynchronize, METH_NOARGS, nullptr}, + // {"_zoom_ipc_collect", THCPModule_zoomIPCCollect, METH_NOARGS, nullptr}, + // {"_zoom_sleep", THCPModule_zoomSleep, METH_O, nullptr}, + {"_zoom_lock_mutex", THCPModule_zoomLockMutex, METH_NOARGS, nullptr}, + {"_zoom_unlock_mutex", THCPModule_zoomUnlockMutex, METH_NOARGS, nullptr}, + {"_zoom_set_sync_debug_mode", + THCPModule_zoomSetSyncDebugMode, + METH_O, + nullptr}, + {"_zoom_get_sync_debug_mode", + THCPModule_zoomGetSyncDebugMode, + METH_NOARGS, + nullptr}, + // {"_zoom_jiterator_compile_and_launch_kernel", + // THCPModule_zoomJiteratorCompileAndLaunchKernel, + // METH_VARARGS, + // nullptr}, + // {"_rocm_is_backward_pass", + // THCPModule_rocm_is_backward_pass, + // METH_NOARGS, + // nullptr}, + {nullptr}}; + +PyMethodDef* THCPModule_methods() { + return _THCPModule_methods; +} + +namespace torch::zoom { + +namespace shared { + +void initHiprtBindings(PyObject* module); +// void initNvtxBindings(PyObject* module); +// #if defined(USE_CUDNN) || defined(USE_ROCM) +// void initCudnnBindings(PyObject* module); +// #endif + +} // namespace shared + +void initModule(PyObject* module) { +// python::initCommMethods(module); +// // As weird as it seems, this file is also compiled for ROCm, +// // so this condition might not always be true... + shared::initHiprtBindings(module); +// shared::initNvtxBindings(module); +// #if defined(USE_CUDNN) || defined(USE_ROCM) +// shared::initCudnnBindings(module); +// #endif + registerZoomDeviceProperties(module); + registerZoomPluggableAllocator(module); +} + +} // namespace torch::zoom diff --git a/torch/csrc/zoom/Module.h b/torch/csrc/zoom/Module.h new file mode 100644 index 00000000000000..2553dad7c616a8 --- /dev/null +++ b/torch/csrc/zoom/Module.h @@ -0,0 +1,11 @@ +#ifndef THCP_ZOOM_MODULE_INC +#define THCP_ZOOM_MODULE_INC + +PyObject* THCPModule_getDevice_wrap(PyObject* self); +PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg); +PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg); +PyObject* THCPModule_getDriverVersion(PyObject* self); +PyObject* THCPModule_isDriverSufficient(PyObject* self); +PyObject* THCPModule_getCurrentBlasHandle_wrap(PyObject* self); + +#endif diff --git a/torch/csrc/zoom/Stream.cpp b/torch/csrc/zoom/Stream.cpp new file mode 100644 index 00000000000000..bd14fed218e431 --- /dev/null +++ b/torch/csrc/zoom/Stream.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +PyObject* THCPStreamClass = nullptr; + +static PyObject* THCPStream_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + + const auto current_device = c10::zoom::current_device(); + + int priority = 0; + int64_t stream_id = 0; + int64_t device_index = 0; + int64_t device_type = 0; + uint64_t stream_ptr = 0; + + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + constexpr const char* kwlist[] = { + "priority", + "stream_id", + "device_index", + "device_type", + "stream_ptr", + nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "|iLLLK", + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(kwlist), + &priority, + &stream_id, + &device_index, + &device_type, + &stream_ptr)) { + return nullptr; + } + + THPObjectPtr ptr(type->tp_alloc(type, 0)); + if (!ptr) { + return nullptr; + } + + if (stream_ptr) { + TORCH_CHECK( + priority == 0, "Priority was explicitly set for a external stream") + } + c10::zoom::ZoomStream stream = (stream_id || device_index || device_type) + ? c10::zoom::ZoomStream::unpack3( + stream_id, + static_cast(device_index), + static_cast(device_type)) + : stream_ptr ? c10::zoom::getStreamFromExternal( + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(stream_ptr), + current_device) + : c10::zoom::getStreamFromPool(priority); + + THCPStream* self = (THCPStream*)ptr.get(); + self->stream_id = static_cast(stream.id()); + self->device_index = static_cast(stream.device_index()); + self->device_type = static_cast(stream.device_type()); + new (&self->zoom_stream) c10::zoom::ZoomStream(stream); + + return (PyObject*)ptr.release(); + END_HANDLE_TH_ERRORS +} + +static void THCPStream_dealloc(THCPStream* self) { + self->zoom_stream.~ZoomStream(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static PyObject* THCPStream_get_device(THCPStream* self, void* unused) { + HANDLE_TH_ERRORS + return THPDevice_New(self->zoom_stream.device()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_get_zoom_stream(THCPStream* self, void* unused) { + HANDLE_TH_ERRORS + return PyLong_FromVoidPtr(self->zoom_stream.stream()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_get_priority(THCPStream* self, void* unused) { + HANDLE_TH_ERRORS + return THPUtils_packInt64(self->zoom_stream.priority()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_priority_range( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto [least_priority, greatest_priority] = + c10::zoom::ZoomStream::priority_range(); + return Py_BuildValue("(ii)", least_priority, greatest_priority); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_query(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS + auto self = (THCPStream*)_self; + return PyBool_FromLong(self->zoom_stream.query()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_synchronize(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + auto self = (THCPStream*)_self; + self->zoom_stream.synchronize(); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPStream_eq(PyObject* _self, PyObject* _other) { + HANDLE_TH_ERRORS + auto self = (THCPStream*)_self; + auto other = (THCPStream*)_other; + return PyBool_FromLong(self->zoom_stream == other->zoom_stream); + END_HANDLE_TH_ERRORS +} + +// NOLINTNEXTLINE(*-c-arrays*, *-global-variables) +static struct PyMemberDef THCPStream_members[] = {{nullptr}}; + +// NOLINTNEXTLINE(*-c-arrays*, *-global-variables) +static struct PyGetSetDef THCPStream_properties[] = { + {"zoom_stream", + (getter)THCPStream_get_zoom_stream, + nullptr, + nullptr, + nullptr}, + {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr}, + {nullptr}}; + +// NOLINTNEXTLINE(*-c-arrays*, *-global-variables) +static PyMethodDef THCPStream_methods[] = { + {"query", THCPStream_query, METH_NOARGS, nullptr}, + {"synchronize", THCPStream_synchronize, METH_NOARGS, nullptr}, + {"priority_range", + THCPStream_priority_range, + METH_STATIC | METH_NOARGS, + nullptr}, + {"__eq__", THCPStream_eq, METH_O, nullptr}, + {nullptr}}; + +PyTypeObject THCPStreamType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._ZoomStreamBase", /* tp_name */ + sizeof(THCPStream), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THCPStream_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THCPStream_methods, /* tp_methods */ + THCPStream_members, /* tp_members */ + THCPStream_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THCPStream_pynew, /* tp_new */ +}; + +void THCPStream_init(PyObject* module) { + Py_INCREF(THPStreamClass); + THCPStreamType.tp_base = THPStreamClass; + THCPStreamClass = (PyObject*)&THCPStreamType; + if (PyType_Ready(&THCPStreamType) < 0) { + throw python_error(); + } + Py_INCREF(&THCPStreamType); + if (PyModule_AddObject( + module, "_ZoomStreamBase", (PyObject*)&THCPStreamType) < 0) { + throw python_error(); + } +} diff --git a/torch/csrc/zoom/Stream.h b/torch/csrc/zoom/Stream.h new file mode 100644 index 00000000000000..3799abcbe09df9 --- /dev/null +++ b/torch/csrc/zoom/Stream.h @@ -0,0 +1,20 @@ +#ifndef THCP_STREAM_INC +#define THCP_STREAM_INC + +#include +#include +#include + +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +struct THCPStream : THPStream { + c10::zoom::ZoomStream zoom_stream; +}; +extern PyObject* THCPStreamClass; + +void THCPStream_init(PyObject* module); + +inline bool THCPStream_Check(PyObject* obj) { + return THCPStreamClass && PyObject_IsInstance(obj, THCPStreamClass); +} + +#endif // THCP_STREAM_INC diff --git a/torch/csrc/zoom/THCP.h b/torch/csrc/zoom/THCP.h new file mode 100644 index 00000000000000..c66359b3364908 --- /dev/null +++ b/torch/csrc/zoom/THCP.h @@ -0,0 +1,10 @@ +#ifndef THCP_H +#define THCP_H + +#include +#include +#include +#include +#include + +#endif diff --git a/torch/csrc/zoom/Tensor.cpp b/torch/csrc/zoom/Tensor.cpp new file mode 100644 index 00000000000000..ea97dfea288d78 --- /dev/null +++ b/torch/csrc/zoom/Tensor.cpp @@ -0,0 +1,15 @@ +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +// Order of these includes matters, which should be fixed. +// clang-format off +#include +#include + +#include + +#include +#include +#include +// clang-format on diff --git a/torch/csrc/zoom/ZoomPluggableAllocator.cpp b/torch/csrc/zoom/ZoomPluggableAllocator.cpp new file mode 100644 index 00000000000000..c6d31e4ec1ce1d --- /dev/null +++ b/torch/csrc/zoom/ZoomPluggableAllocator.cpp @@ -0,0 +1,373 @@ +#include +#include +#include +#include +#include + +#include + +namespace torch::zoom::ZoomPluggableAllocator { + +int device_count = 0; + +void custom_raw_deleter(void* ptr); + +_AllocationMetadata::_AllocationMetadata() + : size(0), device_idx(-1), stream{} {} + +_AllocationMetadata::_AllocationMetadata( + size_t size, + c10::DeviceIndex device_idx, + hipStream_t stream) + : size(size), device_idx(device_idx), stream(stream) {} + +// This is a fast API to just register allocators +// based on function pointers (ie. external .so libraries) +// This avoids having to link against libtorch for C++ based custom allocators +// And also use this from python +ZoomPluggableAllocator::ZoomPluggableAllocator( + std::function alloc_fn, + std::function free_fn) + : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} + +ZoomPluggableAllocator::ZoomPluggableAllocator(ZoomPluggableAllocator& other) + : alloc_fn_(other.alloc_fn_), + free_fn_(other.free_fn_), + init_fn_(other.init_fn_), + reset_fn_(other.reset_fn_), + memory_fraction_fn_(other.memory_fraction_fn_), + base_alloc_fn_(other.base_alloc_fn_), + record_stream_fn_(other.record_stream_fn_), + begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_), + end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_), + relase_pool_fn_(other.relase_pool_fn_) {} + +void ZoomPluggableAllocator::set_init_fn(std::function init_fn) { + init_fn_ = std::move(init_fn); +} + +void ZoomPluggableAllocator::set_reset_fn(std::function reset_fn) { + reset_fn_ = std::move(reset_fn); +} + +void ZoomPluggableAllocator::set_memory_fraction_fn( + std::function memory_fraction_fn) { + memory_fraction_fn_ = std::move(memory_fraction_fn); +} + +void ZoomPluggableAllocator::set_base_alloc_fn( + std::function base_alloc_fn) { + base_alloc_fn_ = std::move(base_alloc_fn); +} + +void ZoomPluggableAllocator::set_record_stream_fn( + std::function record_stream_fn) { + record_stream_fn_ = std::move(record_stream_fn); +} + +void ZoomPluggableAllocator::set_begin_allocate_to_pool( + std::function< + void(int, c10::zoom::MempoolId_t, std::function)> + capture_begin_fn) { + begin_allocate_to_pool_fn_ = std::move(capture_begin_fn); +} + +void ZoomPluggableAllocator::set_end_allocate_to_pool_fn( + std::function capture_about_to_end_fn) { + end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn); +} + +void ZoomPluggableAllocator::set_release_pool( + std::function capture_destroy_fn) { + relase_pool_fn_ = std::move(capture_destroy_fn); +} + +void* ZoomPluggableAllocator::malloc( + size_t size, + c10::DeviceIndex device, + hipStream_t stream) { + void* r = alloc_fn_(size, device, stream); + { + const std::lock_guard lock(allocator_mutex_); + allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream)); + } + return r; +} + +c10::DataPtr ZoomPluggableAllocator::allocate(size_t size) { + c10::DeviceIndex device = -1; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + hipStream_t stream = c10::zoom::getCurrentZoomStream(device); + void* r = this->malloc(size, device, stream); + c10::DataPtr data_ptr = { + r, r, raw_deleter(), c10::Device(c10::DeviceType::PrivateUse1, device)}; + return data_ptr; +} + +c10::DeleterFnPtr ZoomPluggableAllocator::raw_deleter() const { + return &custom_raw_deleter; +} + +void* ZoomPluggableAllocator::raw_alloc(size_t nbytes) { + c10::DeviceIndex device = -1; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + hipStream_t stream = c10::zoom::getCurrentZoomStream(device); + return malloc(nbytes, device, stream); +} + +void* ZoomPluggableAllocator::raw_alloc_with_stream( + size_t nbytes, + hipStream_t stream) { + c10::DeviceIndex device = -1; + C10_ZOOM_CHECK(c10::zoom::GetDevice(&device)); + return malloc(nbytes, device, stream); +} + +void ZoomPluggableAllocator::raw_delete(void* ptr) { + hipStream_t stream{}; + c10::DeviceIndex device_idx = -1; + size_t size = 0; + { + const std::lock_guard lock(allocator_mutex_); + TORCH_CHECK( + allocation_metadata_.count(ptr), + "Trying to free a pointer not allocated here"); + _AllocationMetadata& metadata = allocation_metadata_[ptr]; + size = metadata.size; + device_idx = metadata.device_idx; + stream = metadata.stream; + allocation_metadata_.erase(ptr); + } + free_fn_(ptr, size, device_idx, stream); +} + +void ZoomPluggableAllocator::init(int device_count) { + if (init_fn_) { + init_fn_(device_count); + } + initialized_ = true; +} + +bool ZoomPluggableAllocator::initialized() { + return initialized_; +} + +void ZoomPluggableAllocator::setMemoryFraction( + double fraction, + c10::DeviceIndex device) { + if (memory_fraction_fn_) { + memory_fraction_fn_(fraction, device); + } +} + +void ZoomPluggableAllocator::emptyCache() { + if (reset_fn_) { + return reset_fn_(); + } +} + +void ZoomPluggableAllocator::cacheInfo( + c10::DeviceIndex device, + size_t* largestBlock) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support cacheInfo. " + "If you need it, please file an issue describing your use case."); +} + +void* ZoomPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) { + if (base_alloc_fn_) { + return base_alloc_fn_(ptr, size); + } else { + return ptr; + } +} + +void ZoomPluggableAllocator::recordStream( + const c10::DataPtr& ptr, + streamType stream) { + if (record_stream_fn_) { + record_stream_fn_(ptr.get(), stream); + } +} + +c10::zoom::ZoomCachingAllocator::DeviceStats ZoomPluggableAllocator:: + getDeviceStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support getDeviceStats. " + "If you need it, please file an issue describing your use case."); +} + +void ZoomPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support resetAccumulatedStats. " + "If you need it, please file an issue describing your use case."); +} + +void ZoomPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support resetPeakStats. " + "If you need it, please file an issue describing your use case."); +} + +c10::zoom::ZoomCachingAllocator::SnapshotInfo ZoomPluggableAllocator:: + snapshot() { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support snapshot. " + "If you need it, please file an issue describing your use case."); +} + +std::shared_ptr ZoomPluggableAllocator::getIpcDevPtr(std::string handle) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support getIpcDevPtr. " + "If you need it, please file an issue describing your use case."); +} + +// HIPGraph interactions +void ZoomPluggableAllocator::beginAllocateToPool( + c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id, + std::function filter) { + if (begin_allocate_to_pool_fn_) { + begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter)); + } +} + +void ZoomPluggableAllocator::endAllocateToPool( + c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id) { + if (end_allocate_to_pool_fn_) { + end_allocate_to_pool_fn_(device, mempool_id); + } +} + +void ZoomPluggableAllocator::releasePool( + c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id) { + if (relase_pool_fn_) { + relase_pool_fn_(device, mempool_id); + } +} + +void ZoomPluggableAllocator::recordHistory( + bool enabled, + c10::zoom::ZoomCachingAllocator::CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + c10::zoom::ZoomCachingAllocator::RecordContext when) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); +} + +void ZoomPluggableAllocator::attachOutOfMemoryObserver( + c10::zoom::ZoomCachingAllocator::OutOfMemoryObserver observer) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support attachOutOfMemoryObserver. " + "If you need it, please file an issue describing your use case."); +} + +void ZoomPluggableAllocator::attachAllocatorTraceTracker( + c10::zoom::ZoomCachingAllocator::AllocatorTraceTracker tracker) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not support attachAllocatorTraceTracker. " + "attachAllocatorTraceTracker is only used inside Pytorch."); +} + +std::shared_ptr +ZoomPluggableAllocator::getCheckpointState( + c10::DeviceIndex device, + c10::zoom::MempoolId_t id) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support getCheckpointState. " + "If you need it, please file an issue describing your use case."); +} + +c10::zoom::ZoomCachingAllocator::CheckpointDelta ZoomPluggableAllocator:: + setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) { + TORCH_CHECK( + false, + "ZoomPluggableAllocator does not yet support setCheckpointPoolState. " + "If you need it, please file an issue describing your use case."); +} + +void ZoomPluggableAllocator::enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) { + c10::zoom::ZoomGuard device_guard(dev); + hipError_t err = hipDeviceEnablePeerAccess(dev_to_access, 0); + if (err == hipErrorPeerAccessAlreadyEnabled) { + // ignore and clear the error if access was already enabled + (void)hipGetLastError(); + } else { + C10_ZOOM_CHECK(err); + } +} + +hipError_t ZoomPluggableAllocator::memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) { + return hipMemcpyAsync(dst, src, count, hipMemcpyDeviceToDevice, stream); +} + +std::string ZoomPluggableAllocator::name() { + return "pluggable"; +} + +void ZoomPluggableAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + C10_ZOOM_CHECK( + hipMemcpy(dest, src, count, hipMemcpyKind::hipMemcpyDeviceToDevice)); +} + +std::shared_ptr + current_custom_allocator; + +std::shared_ptr +getCurrentAllocator() { + return current_custom_allocator; +} + +// TODO: add more functions in the argument +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn) { + std::shared_ptr allocator( + new ZoomPluggableAllocator(std::move(alloc_fn), std::move(free_fn))); + allocator->init(device_count); + return allocator; +} + +void changeCurrentAllocator( + const std::shared_ptr& + allocator) { + TORCH_CHECK( + !c10::zoom::ZoomCachingAllocator::allocator.load()->initialized(), + "Can't swap an already initialized allocator"); + c10::zoom::ZoomCachingAllocator::allocator.store(allocator.get()); + current_custom_allocator = allocator; +} + +void custom_raw_deleter(void* ptr) { + current_custom_allocator->raw_delete(ptr); +} + +} // namespace torch::zoom::ZoomPluggableAllocator diff --git a/torch/csrc/zoom/ZoomPluggableAllocator.h b/torch/csrc/zoom/ZoomPluggableAllocator.h new file mode 100644 index 00000000000000..b2baf8671191c6 --- /dev/null +++ b/torch/csrc/zoom/ZoomPluggableAllocator.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include + +namespace torch::zoom::ZoomPluggableAllocator { +using streamType = c10::zoom::ZoomStream; + +std::shared_ptr +getCurrentAllocator(); +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn); +void changeCurrentAllocator( + const std::shared_ptr& + allocator); + +struct _AllocationMetadata { + _AllocationMetadata(); + _AllocationMetadata( + size_t size, + c10::DeviceIndex device_idx, + hipStream_t stream); + size_t size; + c10::DeviceIndex device_idx; + hipStream_t stream; +}; + +struct ZoomPluggableAllocator + : public c10::zoom::ZoomCachingAllocator::ZoomAllocator { + ZoomPluggableAllocator( + std::function alloc_fn, + std::function free_fn); + + ZoomPluggableAllocator(ZoomPluggableAllocator& other); + + void set_init_fn(std::function init_fn); + + void set_reset_fn(std::function reset_fn); + + void set_memory_fraction_fn( + std::function memory_fraction_fn); + + void set_base_alloc_fn(std::function base_alloc_fn); + + void set_record_stream_fn( + std::function record_stream_fn); + + void set_begin_allocate_to_pool( + std::function< + void(int, c10::zoom::MempoolId_t, std::function)> + capture_begin_fn); + + void set_end_allocate_to_pool_fn( + std::function capture_about_to_end_fn); + + void set_release_pool( + std::function capture_destroy_fn); + + void* malloc(size_t size, c10::DeviceIndex device, hipStream_t stream); + + c10::DataPtr allocate(size_t size) override; + c10::DeleterFnPtr raw_deleter() const override; + + void* raw_alloc(size_t nbytes) override; + void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override; + void raw_delete(void* ptr) override; + void init(int device_count) override; + bool initialized() override; + void setMemoryFraction(double fraction, c10::DeviceIndex device) override; + void emptyCache() override; + void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override; + void* getBaseAllocation(void* ptr, size_t* size) override; + + void recordStream(const c10::DataPtr&, streamType stream) override; + + c10::zoom::ZoomCachingAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + c10::zoom::ZoomCachingAllocator::SnapshotInfo snapshot() override; + void beginAllocateToPool( + c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id, + std::function) override; + void endAllocateToPool( + c10::DeviceIndex device, + c10::zoom::MempoolId_t mempool_id) override; + void releasePool(c10::DeviceIndex device, c10::zoom::MempoolId_t mempool_id) + override; + std::shared_ptr getIpcDevPtr(std::string handle) override; + void recordHistory( + bool enabled, + c10::zoom::ZoomCachingAllocator::CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + c10::zoom::ZoomCachingAllocator::RecordContext when) override; + void attachOutOfMemoryObserver( + c10::zoom::ZoomCachingAllocator::OutOfMemoryObserver observer) override; + void attachAllocatorTraceTracker( + c10::zoom::ZoomCachingAllocator::AllocatorTraceTracker tracker) override; + std::shared_ptr + getCheckpointState(c10::DeviceIndex device, c10::zoom::MempoolId_t id) + override; + c10::zoom::ZoomCachingAllocator::CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) + override; + void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) + override; + hipError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + hipStream_t stream, + bool p2p_enabled) override; + std::string name() override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + protected: + std::function alloc_fn_; + std::function free_fn_; + std::function init_fn_; + std::function reset_fn_; + std::function memory_fraction_fn_; + std::function base_alloc_fn_; + std::function record_stream_fn_; + std::function< + void(int, c10::zoom::MempoolId_t, std::function)> + begin_allocate_to_pool_fn_; + std::function end_allocate_to_pool_fn_; + std::function relase_pool_fn_; + std::mutex allocator_mutex_; + // We do the bookeeping here in order to simplify custom allocators + std::unordered_map allocation_metadata_; + + bool initialized_ = false; +}; +} // namespace torch::zoom::ZoomPluggableAllocator diff --git a/torch/csrc/zoom/comm.cpp b/torch/csrc/zoom/comm.cpp new file mode 100644 index 00000000000000..d66450ab97a8b0 --- /dev/null +++ b/torch/csrc/zoom/comm.cpp @@ -0,0 +1,508 @@ +#include + +#include +#include + +#ifdef USE_NCCL +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch::zoom { +using namespace at; +using namespace torch::autograd; + +// Some operations can be performed more efficiently if we're handling tensors +// of a single type only. Adding this logic directly in the loop makes it a bit +// ugly, so here's a helper for it. +struct unique_type_checker { + void show(size_t type_id) { + if (!unique) { + return; + } + if (!type_id_) { + type_id_ = type_id; + } + + unique = type_id_.value() == type_id; + } + + std::optional type_id_; + bool unique = true; +}; + +// ***************** Broadcast ******************* +// +// Broadcast a source tensor (CPU or Zoom) to a list of Zoom devices, or Zoom +// tensors on one or more devices. + +// no checks +static inline std::vector& _broadcast_out_impl( + const Tensor& tensor, + std::vector& out_tensors) { +#ifdef USE_NCCL + std::vector nccl_list; + nccl_list.reserve(out_tensors.size() + 1); + nccl_list.emplace_back(tensor); + for (auto& out_tensor : out_tensors) { + nccl_list.emplace_back(out_tensor); + } + if (nccl::is_available(nccl_list)) { + nccl::broadcast(nccl_list); + } else { +#else + { +#endif + for (auto& out_tensor : out_tensors) { + out_tensor.copy_(tensor, /*non_blocking=*/true); + } + } + return out_tensors; +} + +std::vector& broadcast_out( + const Tensor& tensor, + std::vector& out_tensors) { + for (const auto i : c10::irange(out_tensors.size())) { + TORCH_CHECK( + out_tensors[i].is_privateuseone(), + "Expected all output tensors to be Zoom tensors, but output tensor at index ", + i, + " has device '", + out_tensors[i].device(), + "'"); + TORCH_CHECK( + out_tensors[i].sizes() == tensor.sizes(), + "Expected all output tensors to have same shape as the source tensor ", + tensor.sizes(), + ", but output tensor at index ", + i, + " has shape ", + out_tensors[i].sizes()); + } + return _broadcast_out_impl(tensor, out_tensors); +} + +std::vector broadcast(const Tensor& tensor, IntArrayRef devices) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector diff_device_dst_tensors; + diff_device_dst_tensors.reserve(devices.size()); + for (auto device : devices) { + TORCH_CHECK( + device >= 0, "Expected non-negative device index, but got ", device); + if (device != tensor.get_device()) { + diff_device_dst_tensors.emplace_back(at::empty( + tensor.sizes(), + tensor.options().device(at::Device( + DeviceType::PrivateUse1, + static_cast(device))))); // preserve memory format + } + } + _broadcast_out_impl(tensor, diff_device_dst_tensors); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector dst_tensors; + dst_tensors.reserve(devices.size()); + auto it = diff_device_dst_tensors.begin(); + for (auto device : devices) { + // NOLINTNEXTLINE(bugprone-branch-clone) + if (device != tensor.get_device()) { + dst_tensors.emplace_back(*it++); + } else { + dst_tensors.emplace_back(tensor); + } + } + TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end()); + return dst_tensors; +} + +// NOTE [ Version Counter in comm.*_coalesced ] +// +// broadcast_coalesced +// ~~~~~~~~~~~~~~~~~~~ +// +// In broadcast_coalesced, multiple variables may be coalesced into a single +// large one, broadcast to other devices, and the get split according to the +// original shapes. +// +// When splitting, the view operations will make all Variables broadcast +// together to share a single version counter, because they are all views of the +// large Variable. However, that large Variable is immediately discarded and all +// these Variables do not share storage at all. +// +// For example, when two buffers are broadcast together in `DataParallel` and +// one of them is modified in-place during `forward` but the other is needed in +// backward, autograd engine will complain. +// +// We thus re-wrap these Variables after broadcasting (i.e., effectively doing +// what is equivalent to .data in Python), and give them individual version +// counters. +// +// NB: Just calling detach() on the variables is not sufficient +// +// NB: For `device[0]` in broadcast_coalesced, the input Variables are always +// returned as-is, so **do not** re-wrap them. +// +// reduce_add_coalesced +// ~~~~~~~~~~~~~~~~~~~~ +// +// Similarly for reduce_add_coalesced, when the output are newly created +// Variables. +tensor_list2d broadcast_coalesced( + TensorList tensors, + IntArrayRef devices, + size_t buffer_size) { + TORCH_CHECK( + std::all_of( + tensors.begin(), + tensors.end(), + [&](const at::Tensor& t) { return t.get_device() == devices[0]; }), + "All tensors must be on devices[0]: ", + devices[0]); +#ifdef USE_NCCL + buffer_size = std::min(torch::zoom::nccl::get_max_count(), buffer_size); +#endif + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tensor_list2d outputs(devices.size()); + outputs[0] = tensors.vec(); + for (auto& o : outputs) + o.reserve(tensors.size()); + + unique_type_checker type_checker; + c10::zoom::ZoomGuard device_guard(static_cast(devices[0])); + for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) { + auto type_id = chunk.type_id(); + type_checker.show(type_id); + std::vector results; + if (chunk.options().is_sparse()) { + auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors); + auto broadcast_indices = broadcast(flat_tuple.first, devices); + auto broadcast_values = broadcast(flat_tuple.second, devices); + results.reserve(devices.size()); + for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { + device_guard.set_index(static_cast(devices[i])); + auto& device_outputs = outputs[i]; + auto& inds = broadcast_indices[i]; + auto& vals = broadcast_values[i]; + for (const auto& var : torch::utils::unflatten_sparse_tensors( + inds, vals, chunk.tensors)) { + // See NOTE [ Version Counter in comm.*_coalesced ] + device_outputs.emplace_back(make_variable(var.tensor_data(), false)); + } + } + } else { + auto results = broadcast( + torch::utils::flatten_dense_tensors(chunk.tensors), devices); + for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { + device_guard.set_index(static_cast(devices[i])); + auto& device_outputs = outputs[i]; + for (auto& var : + torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) { + // See NOTE [ Version Counter in comm.*_coalesced ] + device_outputs.emplace_back(make_variable(var.tensor_data(), false)); + } + } + } + } + + // If we only saw a single tensor type, then we can skip expensive reordering + if (!type_checker.unique) { + for (auto& o : outputs) + torch::utils::reorder_tensors_like(o, tensors); + } + return outputs; +} + +// ***************** Scatter ******************* +// +// Scatter a source tensor (CPU or Zoom) to a list of Zoom tensors on one or +// more devices. + +std::vector& scatter_out( + const at::Tensor& tensor, + std::vector& out_tensors, + int64_t dim, + const std::optional>>& + streams) { + TORCH_CHECK( + !out_tensors.empty(), + "Expected at least one output tensor to scatter to"); + dim = at::maybe_wrap_dim(dim, tensor); + int64_t total_size = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector chunk_sizes; + chunk_sizes.reserve(out_tensors.size()); + for (const auto i : c10::irange(out_tensors.size())) { + TORCH_CHECK( + out_tensors[i].is_privateuseone(), + "Expected all output tensors to be Zoom tensors, but output tensor at index ", + i, + " has device '", + out_tensors[i].device(), + "'"); + auto out_sizes = out_tensors[i].sizes().vec(); + bool same_ndim = out_sizes.size() == static_cast(tensor.dim()); + if (same_ndim) { + total_size += out_sizes[dim]; + chunk_sizes.emplace_back(out_sizes[dim]); + out_sizes[dim] = tensor.size(dim); + } + TORCH_CHECK( + same_ndim && out_sizes == tensor.sizes(), + "Output tensor at index ", + i, + " has incorrect shape: ", + out_tensors[i].sizes(), + ". Expected same " + "shape except for scatter dim ", + dim, + " as the source tensor: ", + at::IntArrayRef(tensor.sizes())); + } + TORCH_CHECK( + total_size == tensor.size(dim), + "Total size for output tensors along scatter dim ", + dim, + " does not match " + "the source tensor size at dim ", + dim, + ". Expected ", + tensor.size(dim), + ", but got total size ", + total_size); + + auto chunks = + tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); + c10::zoom::OptionalZoomStreamGuard zoom_guard; + for (const auto i : c10::irange(chunks.size())) { + if (i < (streams ? streams->size() : 0U) && (*streams)[i]) { + const auto device_index = + static_cast(out_tensors[i].get_device()); + TORCH_CHECK( + (*streams)[i]->device_index() == device_index, + "Expected the device associated with the stream at index ", + i, + " (was ", + (*streams)[i]->device_index(), + ") ", + "to match the device supplied at that index ", + "(expected ", + device_index, + ")"); + zoom_guard.reset_stream(*(*streams)[i]); + } + // NB: We don't detect the case where `out_tensor` is already the correct + // view of `tensor` since that would be nontrivial and involve checking + // ptr, offset, and strides. So `scatter_out(src, src.chunk(...))` does + // more copying than `scatter(src)`. + out_tensors[i].copy_(chunks[i], /*non_blocking=*/true); + } + return out_tensors; +} + +std::vector scatter( + const at::Tensor& tensor, + at::IntArrayRef devices, + const std::optional>& chunk_sizes, + int64_t dim, + const std::optional>>& + streams) { + TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to"); + if (chunk_sizes.has_value()) { + TORCH_CHECK( + chunk_sizes->size() == devices.size(), + "Expected devices and chunk_sizes to be of same length, but got " + "len(devices) = ", + devices.size(), + " and len(chunk_sizes) = ", + chunk_sizes->size()); + } + dim = at::maybe_wrap_dim(dim, tensor); + std::vector chunks = chunk_sizes + ? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim) + : tensor.chunk( + /*chunks=*/static_cast(devices.size()), /*dim=*/dim); + c10::zoom::OptionalZoomStreamGuard zoom_guard; + for (const auto i : c10::irange(chunks.size())) { + const auto device_index = static_cast(devices[i]); + if (device_index != tensor.get_device()) { + if (i < (streams ? streams->size() : 0U) && (*streams)[i]) { + TORCH_CHECK( + (*streams)[i]->device_index() == device_index, + "Expected the device associated with the stream at index ", + i, + " (was ", + (*streams)[i]->device_index(), + ") ", + "to match the device supplied at that index ", + "(expected ", + device_index, + ")"); + zoom_guard.reset_stream(*(*streams)[i]); + } + TORCH_CHECK( + device_index >= 0, + "Expected non-negative device index, but got ", + device_index); + chunks[i] = chunks[i].to( + {DeviceType::PrivateUse1, device_index}, + /*non_blocking=*/true, + /*copy=*/false, + /*memory_format=*/at::MemoryFormat::Preserve); + } + } + return chunks; +} + +// ***************** Gather ******************* +// +// Gather a list of Zoom tensors on one or more devices to a target tensor or +// device, either CPU or Zoom. + +// no checks +static inline at::Tensor& _gather_out_impl( + at::TensorList tensors, + at::Tensor& out_tensor, + int64_t dim) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector chunk_sizes; + chunk_sizes.reserve(tensors.size()); + for (auto& tensor : tensors) { + chunk_sizes.emplace_back(tensor.size(dim)); + } + auto chunks = + out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); + for (const auto i : c10::irange(tensors.size())) { + chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_privateuseone()); + } + return out_tensor; +} + +at::Tensor& gather_out( + at::TensorList tensors, + at::Tensor& out_tensor, + int64_t dim) { + TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from"); + int64_t total_size = 0; + auto& first = tensors.front(); + const auto first_size = first.sizes(); + dim = at::maybe_wrap_dim(dim, first); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector expected_size(first_size.begin(), first_size.end()); + for (const auto i : c10::irange(tensors.size())) { + const auto& tensor = tensors[i]; + TORCH_CHECK( + tensor.is_privateuseone(), + "Expected all input tensors to be Zoom tensors, but " + "tensor at index ", + i, + " has device '", + tensor.device(), + "'"); + TORCH_CHECK( + tensor.ndimension() == static_cast(expected_size.size()), + "Expected all input tensors to have the same number of dimensions, but ", + "tensor at index ", + i, + "has ", + tensor.ndimension(), + " dimensions, (expected ", + expected_size.size(), + ")"); + expected_size[dim] = tensor.size(dim); + for (const auto dimension : c10::irange(expected_size.size())) { + TORCH_CHECK( + expected_size[dimension] == tensor.size(dimension), + "Input tensor at index ", + i, + " has invalid shape ", + tensor.sizes(), + ", but expected ", + at::IntArrayRef(expected_size)); + } + total_size += tensor.size(dim); + } + expected_size[dim] = total_size; + TORCH_CHECK( + out_tensor.sizes() == expected_size, + "Expected out tensor to have shape ", + at::IntArrayRef(expected_size), + ", but got ", + out_tensor.sizes()) + + return _gather_out_impl(tensors, out_tensor, dim); +} + +at::Tensor gather( + at::TensorList tensors, + int64_t dim, + std::optional destination_index) { + TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from"); + int64_t total_size = 0; + auto& first = tensors.front(); + const auto first_size = first.sizes(); + dim = at::maybe_wrap_dim(dim, first); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::vector expected_size(first_size.begin(), first_size.end()); + auto memory_format = first.suggest_memory_format(); + for (const auto i : c10::irange(tensors.size())) { + const auto& tensor = tensors[i]; + TORCH_CHECK( + tensor.is_privateuseone(), + "Expected all input tensors to be Zoom tensors, but " + "tensor at index ", + i, + " has device ", + tensor.device()); + TORCH_CHECK( + tensor.ndimension() == static_cast(expected_size.size()), + "Expected all input tensors to have the same number of dimensions, but ", + "tensor at index ", + i, + "has ", + tensor.ndimension(), + " dimensions, (expected ", + expected_size.size(), + ")"); + expected_size[dim] = tensor.size(dim); + for (const auto dimension : c10::irange(expected_size.size())) { + TORCH_CHECK( + expected_size[dimension] == tensor.size(dimension), + "Input tensor at index ", + i, + " has invalid shape ", + tensor.sizes(), + ", but expected ", + at::IntArrayRef(expected_size)); + } + total_size += tensor.size(dim); + if (memory_format != MemoryFormat::Contiguous && + tensor.suggest_memory_format() != memory_format) { + memory_format = MemoryFormat::Contiguous; + } + } + expected_size[dim] = total_size; + at::Device device(DeviceType::CPU); + if (!destination_index || *destination_index != -1) { + device = at::Device( + DeviceType::PrivateUse1, + destination_index ? static_cast(*destination_index) + : DeviceIndex(-1)); + } + + at::Tensor result = + at::empty(expected_size, first.options().device(device), memory_format); + return _gather_out_impl(tensors, result, dim); +} + +} // namespace torch::zoom diff --git a/torch/csrc/zoom/comm.h b/torch/csrc/zoom/comm.h new file mode 100644 index 00000000000000..27229ef3169f0a --- /dev/null +++ b/torch/csrc/zoom/comm.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch::zoom { + +using tensor_list2d = std::vector>; + +TORCH_ZOOM_API std::vector& broadcast_out( + const at::Tensor& tensor, + std::vector& out_tensors); +TORCH_ZOOM_API std::vector broadcast( + const at::Tensor& tensor, + at::IntArrayRef devices); +TORCH_ZOOM_API tensor_list2d broadcast_coalesced( + at::TensorList tensors, + at::IntArrayRef devices, + size_t buffer_size); + +TORCH_ZOOM_API std::vector& scatter_out( + const at::Tensor& tensor, + std::vector& out_tensors, + int64_t dim = 0, + const std::optional>>& + streams = c10::nullopt); + +TORCH_ZOOM_API std::vector scatter( + const at::Tensor& tensor, + at::IntArrayRef devices, + const std::optional>& chunk_sizes = c10::nullopt, + int64_t dim = 0, + const std::optional>>& + streams = c10::nullopt); + +TORCH_ZOOM_API at::Tensor& gather_out( + at::TensorList tensors, + at::Tensor& out_tensor, + int64_t dim); + +TORCH_ZOOM_API at::Tensor gather( + at::TensorList tensors, + int64_t dim, + std::optional destination_index); + +} // namespace torch::zoom diff --git a/torch/csrc/zoom/device_set.h b/torch/csrc/zoom/device_set.h new file mode 100644 index 00000000000000..14226ef2e1c92f --- /dev/null +++ b/torch/csrc/zoom/device_set.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include +#include + +namespace torch { + +using device_set = std::bitset; + +} // namespace torch diff --git a/torch/csrc/zoom/memory_snapshot.cpp b/torch/csrc/zoom/memory_snapshot.cpp new file mode 100644 index 00000000000000..79ed6cc9d19ced --- /dev/null +++ b/torch/csrc/zoom/memory_snapshot.cpp @@ -0,0 +1,376 @@ +#include +#include +#include +#include +#include +#include + +namespace torch::zoom { + +using c10::Dict; +using c10::IValue; +using torch::jit::Pickler; + +using c10::zoom::ZoomCachingAllocator::SegmentInfo; + +namespace { +std::string write_pickle(const IValue& v) { + std::vector result; + { + auto writer = [&](const char* data, size_t size) { + result.insert(result.end(), data, data + size); + }; + Pickler pickler(writer, nullptr, nullptr, nullptr, nullptr, false); + pickler.protocol(); + pickler.pushIValue(v); + pickler.stop(); + } + return std::string(result.begin(), result.end()); +} +Dict new_dict() { + return Dict(c10::AnyType::get(), c10::AnyType::get()); +} +c10::List new_list() { + return List(c10::AnyType::get()); +} + +std::vector ivalue_symbolize( + std::vector& to_symbolize) { + // we dedup repeated to_symbolize objects to prevent + // creating a bunch of duplicated frame objects + std::unordered_map cached_frames; + std::vector unique_frames; + for (const auto& sc : to_symbolize) { + auto it = cached_frames.find(sc); + if (it == cached_frames.end()) { + cached_frames.insert({sc, unique_frames.size()}); + unique_frames.push_back(sc); + } + } + auto s = symbolize(unique_frames); + + IValue line_s = "line"; + IValue name_s = "name"; + IValue filename_s = "filename"; + std::vector all_frames; + for (const auto& f : s.all_frames) { + auto d = new_dict(); + d.insert(name_s, f.funcname); + d.insert(filename_s, f.filename); + d.insert(line_s, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + + std::vector py_unique_frames; + for (const auto& t : s.tracebacks) { + auto l = new_list(); + for (const auto& e : t) { + l.push_back(all_frames.at(e)); + } + py_unique_frames.emplace_back(std::move(l)); + } + + std::vector result; + result.reserve(to_symbolize.size()); + for (const auto& sc : to_symbolize) { + result.push_back(py_unique_frames.at(cached_frames.at(sc))); + } + return result; +} + +std::shared_ptr gather() { + return CapturedTraceback::gather(true, true, false); +} + +std::shared_ptr gather_with_cpp() { + return CapturedTraceback::gather(true, true, true); +} + +CapturedTraceback* getFromContext( + const std::shared_ptr& x) { + if (CapturedTraceback* sc = dynamic_cast(x.get())) { + return sc; + } + TORCH_CHECK( + false, + "attempting to gather stack context from the wrong StackContext type."); +} + +} // namespace + +void _record_memory_history( + bool enabled, + bool record_context, + int64_t trace_alloc_max_entries, + bool trace_alloc_record_context, + bool record_cpp_context) { + c10::zoom::ZoomCachingAllocator::CreateContextFn recorder = gather; + if (enabled && record_cpp_context) { + recorder = gather_with_cpp; + // warm up C++ stack unwinding + unwind::unwind(); + } + auto when = c10::zoom::ZoomCachingAllocator::RecordContext::NEVER; + if (trace_alloc_record_context) { + when = c10::zoom::ZoomCachingAllocator::RecordContext::ALLOC; + } else if (record_context) { + when = c10::zoom::ZoomCachingAllocator::RecordContext::STATE; + } + at::globalContext().lazyInitPrivateUse1(); + c10::zoom::ZoomCachingAllocator::recordHistory( + enabled, recorder, trace_alloc_max_entries, when); +} + +static void checkOptionIn( + const std::string& option, + std::initializer_list valid, + const char* error) { + TORCH_CHECK( + valid.end() != std::find(valid.begin(), valid.end(), option), error); +} + +void _record_memory_history( + std::optional enabled, + std::optional context, + const std::string& stacks, + size_t max_entries) { + if (enabled) { + checkOptionIn( + *enabled, + {"state", "all"}, + "expected state to be 'state', 'all', or None"); + } + if (context) { + checkOptionIn( + *context, + {"state", "alloc", "all"}, + "expected context to be 'state', 'alloc', 'all', or None"); + } + checkOptionIn( + stacks, {"python", "all"}, "expected stacks to be 'python', or 'all'"); + + c10::zoom::ZoomCachingAllocator::CreateContextFn recorder = gather; + if (enabled && stacks == "all") { + recorder = gather_with_cpp; + // warm up C++ stack unwinding + unwind::unwind(); + } + max_entries = (enabled && *enabled == "all") ? max_entries : 1; + auto when = c10::zoom::ZoomCachingAllocator::RecordContext::NEVER; + if (context) { + if (context == "all") { + when = c10::zoom::ZoomCachingAllocator::RecordContext::ALL; + } else if (context == "alloc") { + when = c10::zoom::ZoomCachingAllocator::RecordContext::ALLOC; + } else if (context == "state") { + when = c10::zoom::ZoomCachingAllocator::RecordContext::STATE; + } + } + at::globalContext().lazyInitPrivateUse1(); + c10::zoom::ZoomCachingAllocator::recordHistory( + enabled.has_value(), recorder, max_entries, when); +} + +std::string _memory_snapshot_pickled() { + IValue device_s = "device"; + IValue address_s = "address"; + IValue total_size_s = "total_size"; + IValue allocated_size_s = "allocated_size"; + IValue active_size_s = "active_size"; + IValue requested_size_s = "requested_size"; + IValue stream_s = "stream"; + IValue segment_type_s = "segment_type"; + IValue segment_pool_id = "segment_pool_id"; + IValue large_s = "large"; + IValue small_s = "small"; + IValue size_s = "size"; + IValue state_s = "state"; + IValue active_allocated_s = "active_allocated"; + IValue active_pending_free_s = "active_pending_free"; + IValue inactive_s = "inactive"; + IValue addr_s = "addr"; + IValue filename_s = "filename"; + IValue name_s = "name"; + IValue line_s = "line"; + IValue frames_s = "frames"; + IValue blocks_s = "blocks"; + IValue is_expandable_s = "is_expandable"; + IValue time_us_s = "time_us"; + + auto empty_frames = new_list(); + + std::vector frame_tracebacks; + std::vector> frame_dict; + + auto add_frame_key = [&](const c10::Dict& d, + const std::shared_ptr& ctx) { + if (ctx) { + frame_tracebacks.push_back(getFromContext(ctx)); + frame_dict.push_back(d); + } else { + d.insert(frames_s, empty_frames); + } + }; + + const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) { + auto segmentDict = new_dict(); + segmentDict.insert(device_s, segmentInfo.device); + segmentDict.insert(address_s, static_cast(segmentInfo.address)); + segmentDict.insert( + total_size_s, static_cast(segmentInfo.total_size)); + segmentDict.insert( + allocated_size_s, static_cast(segmentInfo.allocated_size)); + segmentDict.insert( + active_size_s, static_cast(segmentInfo.active_size)); + segmentDict.insert( + requested_size_s, static_cast(segmentInfo.requested_size)); + segmentDict.insert(stream_s, int64_t(segmentInfo.stream)); + segmentDict.insert( + segment_type_s, (segmentInfo.is_large ? large_s : small_s)); + segmentDict.insert( + segment_pool_id, + std::tuple(segmentInfo.owner_private_pool_id)); + segmentDict.insert(is_expandable_s, segmentInfo.is_expandable); + + add_frame_key(segmentDict, segmentInfo.context_when_allocated); + + auto address = segmentInfo.address; + auto blocks = new_list(); + for (const auto& blockInfo : segmentInfo.blocks) { + auto blockDict = new_dict(); + blockDict.insert(address_s, static_cast(address)); + blockDict.insert(size_s, static_cast(blockInfo.size)); + blockDict.insert( + requested_size_s, static_cast(blockInfo.requested_size)); + blockDict.insert( + state_s, + (blockInfo.allocated + ? active_allocated_s + : (blockInfo.active ? active_pending_free_s : inactive_s))); + add_frame_key(blockDict, blockInfo.context_when_allocated); + address += blockInfo.size; + blocks.push_back(blockDict); + } + segmentDict.insert(blocks_s, blocks); + + return segmentDict; + }; + + auto snapshot = c10::zoom::ZoomCachingAllocator::snapshot(); + + auto segments = new_list(); + for (const auto& segmentInfo : snapshot.segments) { + segments.push_back(segmentInfoToDict(segmentInfo)); + } + + auto traces = new_list(); + IValue action_s = "action"; + IValue alloc_s = "alloc"; + IValue free_requested_s = "free_requested"; + IValue free_completed_s = "free_completed"; + IValue segment_alloc_s = "segment_alloc"; + IValue segment_free_s = "segment_free"; + IValue segment_map_s = "segment_map"; + IValue segment_unmap_s = "segment_unmap"; + IValue snapshot_s = "snapshot"; + IValue oom_s = "oom"; + IValue device_free_s = "device_free"; + + using namespace c10::zoom::ZoomCachingAllocator; + + auto action_to_str = [&](TraceEntry::Action action) { + switch (action) { + case TraceEntry::ALLOC: + return alloc_s; + case TraceEntry::FREE_REQUESTED: + return free_requested_s; + case TraceEntry::FREE_COMPLETED: + return free_completed_s; + case TraceEntry::SEGMENT_ALLOC: + return segment_alloc_s; + case TraceEntry::SEGMENT_FREE: + return segment_free_s; + case TraceEntry::OOM: + return oom_s; + case TraceEntry::SNAPSHOT: + return snapshot_s; + case TraceEntry::SEGMENT_UNMAP: + return segment_unmap_s; + case TraceEntry::SEGMENT_MAP: + return segment_map_s; + } + throw std::runtime_error("unreachable"); + }; + + for (const auto& traceInfo : snapshot.device_traces) { + auto trace = new_list(); + for (const auto& te : traceInfo) { + auto trace_entry = new_dict(); + trace_entry.insert(action_s, action_to_str(te.action_)); + trace_entry.insert( + TraceEntry::OOM == te.action_ ? device_free_s : addr_s, + static_cast(te.addr_)); + trace_entry.insert(size_s, (int64_t)te.size_); + trace_entry.insert(stream_s, int64_t(te.stream_)); + if (te.context_) { + auto sc = getFromContext(te.context_); + frame_tracebacks.push_back(sc); + frame_dict.push_back(trace_entry); + } + trace_entry.insert(time_us_s, te.time_.t_); + trace.push_back(trace_entry); + } + traces.push_back(trace); + } + + auto allocator_settings = new_dict(); + IValue last_allocator_settings_s = "PYTORCH_ZOOM_ALLOC_CONF"; + IValue max_split_size_s = "max_split_size"; + IValue garbage_collection_threshold_s = "garbage_collection_threshold"; + IValue expandable_segments_s = "expandable_segments"; + IValue pinned_num_register_threads_s = "pinned_num_register_threads"; + IValue release_lock_on_malloc_s = "release_lock_on_hipmalloc"; + IValue pinned_use_host_register_s = "pinned_use_zoom_host_register"; + IValue roundup_power2_divisions_s = "roundup_power2_divisions"; + + allocator_settings.insert( + last_allocator_settings_s, + snapshot.config_metadata.last_allocator_settings); + allocator_settings.insert( + max_split_size_s, int64_t(snapshot.config_metadata.max_split_size)); + allocator_settings.insert( + garbage_collection_threshold_s, + snapshot.config_metadata.garbage_collection_threshold); + allocator_settings.insert( + expandable_segments_s, snapshot.config_metadata.expandable_segments); + allocator_settings.insert( + pinned_num_register_threads_s, + int64_t(snapshot.config_metadata.pinned_num_register_threads)); + allocator_settings.insert( + release_lock_on_malloc_s, + snapshot.config_metadata.release_lock_on_malloc); + allocator_settings.insert( + pinned_use_host_register_s, + snapshot.config_metadata.pinned_use_host_register); + unsigned int roundup_key = 1; + auto roundup_settings = new_dict(); + for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) { + IValue roundup_key_s = std::to_string(roundup_key); + roundup_settings.insert(roundup_key_s, int64_t(v)); + roundup_key *= 2; + } + allocator_settings.insert(roundup_power2_divisions_s, roundup_settings); + + auto result = new_dict(); + result.insert("segments", segments); + result.insert("device_traces", traces); + result.insert("allocator_settings", allocator_settings); + + auto frames = ivalue_symbolize(frame_tracebacks); + for (auto i : c10::irange(frames.size())) { + frame_dict.at(i).insert(frames_s, frames.at(i)); + } + + return write_pickle(result); +} +} // namespace torch::zoom diff --git a/torch/csrc/zoom/memory_snapshot.h b/torch/csrc/zoom/memory_snapshot.h new file mode 100644 index 00000000000000..bacf3cf0ebafb9 --- /dev/null +++ b/torch/csrc/zoom/memory_snapshot.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::zoom { + +// C++-only versions of these, for python use +// those defined in zoom/Module.cpp which also record python state. +TORCH_ZOOM_API void _record_memory_history( + bool enabled, + bool record_context = true, + int64_t trace_alloc_max_entries = 1, + bool trace_alloc_record_context = false, + bool record_cpp_context = false); + +TORCH_ZOOM_API void _record_memory_history( + std::optional enabled = "all", + std::optional context = "all", + const std::string& stacks = "all", + size_t max_entries = SIZE_MAX); + +TORCH_ZOOM_API std::string _memory_snapshot_pickled(); + +} // namespace torch::zoom diff --git a/torch/csrc/zoom/python_comm.cpp b/torch/csrc/zoom/python_comm.cpp new file mode 100644 index 00000000000000..07e84b914a07b2 --- /dev/null +++ b/torch/csrc/zoom/python_comm.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include + +namespace torch::zoom::python { +void initCommMethods(PyObject* module) { + auto m = py::cast(module); + m.def( + "_broadcast_coalesced", + [](std::vector& tensors, + const std::vector& devices, + size_t buffer_size) { + return broadcast_coalesced(tensors, devices, buffer_size); + }, + py::arg("tensors"), + py::arg("devices"), + py::arg("buffer_size"), + py::call_guard()) + .def( + "_broadcast", + [](at::Tensor& tensor, std::vector devices) { + return broadcast(tensor, devices); + }, + py::call_guard(), + py::arg("tensor"), + py::arg("devices")) + .def( + "_broadcast_out", + [](at::Tensor& tensor, std::vector& out_tensors) { + return broadcast_out(tensor, out_tensors); + }, + py::call_guard(), + py::arg("tensor"), + py::arg("out")) + .def( + "_scatter", + [](at::Tensor& tensor, + std::vector& devices, + std::optional> chunk_sizes, + int64_t dim, + std::optional py_streams) { + std::optional>> + streams; + if (py_streams) { + py::handle handle = *py_streams; + streams = THPUtils_PySequence_to_ZoomStreamList(handle.ptr()); + } + // Note: We're holding the GIL up to here. + pybind11::gil_scoped_release no_gil; + return scatter(tensor, devices, chunk_sizes, dim, streams); + }, + py::arg("tensor"), + py::arg("devices"), + py::arg("chunk_sizes"), + py::arg("dim"), + py::arg("streams")) + .def( + "_scatter_out", + [](at::Tensor& tensor, + std::vector& out_tensors, + int64_t dim, + std::optional py_streams) { + std::optional>> + streams; + if (py_streams) { + py::handle handle = *py_streams; + streams = THPUtils_PySequence_to_ZoomStreamList(handle.ptr()); + } + // Note: We're holding the GIL up to here. + pybind11::gil_scoped_release no_gil; + return scatter_out(tensor, out_tensors, dim, streams); + }, + py::arg("tensor"), + py::arg("out"), + py::arg("dim"), + py::arg("streams")) + .def( + "_gather", + [](std::vector& tensors, + int64_t dim, + std::optional destination_index) { + return gather(tensors, dim, destination_index); + }, + py::arg("tensors"), + py::arg("dim"), + py::arg("destination_index"), + py::call_guard()) + .def( + "_gather_out", + [](std::vector& tensors, + at::Tensor& out_tensor, + int64_t dim) { return gather_out(tensors, out_tensor, dim); }, + py::arg("tensors"), + py::arg("out"), + py::arg("dim"), + py::call_guard()); +} +} // namespace torch::zoom::python diff --git a/torch/csrc/zoom/python_comm.h b/torch/csrc/zoom/python_comm.h new file mode 100644 index 00000000000000..de5af273adc0a3 --- /dev/null +++ b/torch/csrc/zoom/python_comm.h @@ -0,0 +1,7 @@ +#pragma once + +namespace torch::zoom::python { + +void initCommMethods(PyObject* module); + +} // namespace torch::zoom::python diff --git a/torch/csrc/zoom/shared/hiprt.cpp b/torch/csrc/zoom/shared/hiprt.cpp new file mode 100644 index 00000000000000..823806750f4eef --- /dev/null +++ b/torch/csrc/zoom/shared/hiprt.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +namespace torch::zoom::shared { + +namespace { +hipError_t hipReturnSuccess() { + return hipSuccess; +} +} // namespace + +void initHiprtBindings(PyObject* module) { + auto m = py::handle(module).cast(); + + auto hiprt = m.def_submodule("_hiprt", "hip runtime bindings"); + + py::enum_( + hiprt, + "hip" + "Error") + .value("success", hipSuccess); + + hiprt.def( + "hip" + "GetErrorString", + hipGetErrorString); + hiprt.def( + "hip" + "ProfilerStart", + hipReturnSuccess + ); + hiprt.def( + "hip" + "ProfilerStop", + hipReturnSuccess + ); + hiprt.def( + "hip" + "HostRegister", + [](uintptr_t ptr, size_t size, unsigned int flags) -> hipError_t { + return C10_ZOOM_ERROR_HANDLED( + hipHostRegister((void*)ptr, size, flags)); + }); + hiprt.def( + "hip" + "HostUnregister", + [](uintptr_t ptr) -> hipError_t { + return C10_ZOOM_ERROR_HANDLED(hipHostUnregister((void*)ptr)); + }); + hiprt.def( + "hip" + "StreamCreate", + [](uintptr_t ptr) -> hipError_t { + return C10_ZOOM_ERROR_HANDLED(hipStreamCreate((hipStream_t*)ptr)); + }); + hiprt.def( + "hip" + "StreamDestroy", + [](uintptr_t ptr) -> hipError_t { + return C10_ZOOM_ERROR_HANDLED(hipStreamDestroy((hipStream_t)ptr)); + }); + hiprt.def( + "hip" + "MemGetInfo", + [](c10::DeviceIndex device) -> std::pair { + c10::zoom::ZoomGuard guard(device); + size_t device_free = 0; + size_t device_total = 0; + C10_ZOOM_CHECK(hipMemGetInfo(&device_free, &device_total)); + return {device_free, device_total}; + }); +} + +} // namespace torch::zoom::shared diff --git a/torch/csrc/zoom/utils.cpp b/torch/csrc/zoom/utils.cpp new file mode 100644 index 00000000000000..e04d93ae4b99d2 --- /dev/null +++ b/torch/csrc/zoom/utils.cpp @@ -0,0 +1,41 @@ +#include +#include +#include +#include + +// NB: It's a list of *optional* ZoomStream; when nullopt, that means to use +// whatever the current stream of the device the input is associated with was. +std::vector> +THPUtils_PySequence_to_ZoomStreamList(PyObject* obj) { + if (!PySequence_Check(obj)) { + throw std::runtime_error( + "Expected a sequence in THPUtils_PySequence_to_ZoomStreamList"); + } + THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); + if (seq.get() == nullptr) { + throw std::runtime_error( + "expected PySequence, but got " + std::string(THPUtils_typename(obj))); + } + + std::vector> streams; + Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); + for (Py_ssize_t i = 0; i < length; i++) { + PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); + + if (PyObject_IsInstance(stream, THCPStreamClass)) { + // Spicy hot reinterpret cast!! + streams.emplace_back(c10::zoom::ZoomStream::unpack3( + (reinterpret_cast(stream))->stream_id, + (reinterpret_cast(stream))->device_index, + static_cast( + (reinterpret_cast(stream))->device_type))); + } else if (stream == Py_None) { + streams.emplace_back(); + } else { + // NOLINTNEXTLINE(bugprone-throw-keyword-missing) + std::runtime_error( + "Unknown data type found in stream list. Need torch.cuda.Stream or None"); + } + } + return streams; +} \ No newline at end of file diff --git a/torch/csrc/zoom/utils.h b/torch/csrc/zoom/utils.h new file mode 100644 index 00000000000000..39b9c17b60459c --- /dev/null +++ b/torch/csrc/zoom/utils.h @@ -0,0 +1,4 @@ +#include + +std::vector> +THPUtils_PySequence_to_ZoomStreamList(PyObject* obj); \ No newline at end of file diff --git a/torch/nn/functional.py b/torch/nn/functional.py index a1d2a846e75e02..844bd5ebc30f00 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4073,7 +4073,7 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. if not torch.jit.is_scripting(): - if torch.are_deterministic_algorithms_enabled() and input.is_cuda: + if torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_zoom): # Use slow decomp whose backward will be in terms of index_put # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) @@ -4528,7 +4528,7 @@ def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[f return handle_torch_function( torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value) if not torch.jit.is_scripting(): - if torch.are_deterministic_algorithms_enabled() and input.is_cuda: + if torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_zoom): if mode == 'replicate': # Use slow decomp whose backward will be in terms of index_put. # importlib is required because the import cannot be top level diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 07caa0ac3eee35..26594a33ca9d2e 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1083,6 +1083,16 @@ def _has_sufficient_memory(device, size): device = 'cuda:0' return torch.cuda.memory.mem_get_info(device)[0] >= size + if torch.device(device).type == 'zoom': + if not torch.zoom.is_available(): + return False + gc.collect() + torch.zoom.empty_cache() + # torch.zoom.mem_get_info, aka hipMemGetInfo, returns a tuple of (free memory, total memory) of a GPU + if device == 'zoom': + device = 'zoom:0' + return torch.zoom.memory.mem_get_info(device)[0] >= size + if device == 'xla': raise unittest.SkipTest('TODO: Memory availability checks for XLA?') @@ -1318,6 +1328,12 @@ class dtypesIfCUDA(dtypes): def __init__(self, *args): super().__init__(*args, device_type='cuda') +# Overrides specified dtypes on Zoom. +class dtypesIfZoom(dtypes): + + def __init__(self, *args): + super().__init__(*args, device_type='zoom') + class dtypesIfMPS(dtypes): def __init__(self, *args): @@ -1335,6 +1351,8 @@ def onlyCPU(fn): def onlyCUDA(fn): return onlyOn('cuda')(fn) +def onlyZOOM(fn): + return onlyOn('zoom')(fn) def onlyMPS(fn): return onlyOn('mps')(fn) @@ -1362,6 +1380,17 @@ def only_fn(self, *args, **kwargs): return only_fn +def onlyCUDAAndZOOM(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if self.device_type not in ('cuda', 'privateuseone'): + reason = f"onlyCUDAAndZOOM: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + def disablecuDNN(fn): @wraps(fn) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index af5dcf35b4a377..93988d025f3779 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -69,6 +69,7 @@ import torch.backends.mps import torch.backends.xnnpack import torch.cuda +import torch.zoom from torch import Tensor from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined] from torch._utils_internal import get_writable_path @@ -1234,6 +1235,7 @@ def TemporaryDirectoryName(suffix=None): TEST_MPS = torch.backends.mps.is_available() TEST_XPU = torch.xpu.is_available() TEST_CUDA = torch.cuda.is_available() +TEST_ZOOM = torch.zoom.is_available() custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available() TEST_PRIVATEUSE1 = True if custom_device_is_available else False @@ -1596,6 +1598,21 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def skipIfZoom(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipIfZoom: {msg}" + + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_ZOOM: # noqa: F821 + raise unittest.SkipTest(reason) + else: + return fn(*args, **kwargs) + return wrapper + if func: + return dec_fn(func) + return dec_fn + # Skips a test on CUDA if ROCm is available and its version is lower than requested. def skipIfRocmVersionLessThan(version=None): def dec_fn(fn): @@ -1698,6 +1715,17 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): torch.cuda.set_sync_debug_mode(self.debug_mode_restore) +class ZoomSyncGuard: + def __init__(self, sync_debug_mode): + self.mode = sync_debug_mode + + def __enter__(self): + self.debug_mode_restore = torch.zoom.get_sync_debug_mode() + torch.zoom.set_sync_debug_mode(self.mode) + + def __exit__(self, exception_type, exception_value, traceback): + torch.zoom.set_sync_debug_mode(self.debug_mode_restore) + # Context manager for setting torch.__future__.set_swap_module_params_on_conversion # and automatically resetting it to its original value class SwapTensorsGuard: diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 0b17a4af0eaca3..8bdbde217683ba 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -1344,7 +1344,7 @@ def supported_dtypes(self, device_type): if device_type == "privateuse1": device_type = torch._C._get_privateuse1_backend_name() device_type = torch.device(device_type).type - if device_type == "cuda": + if device_type == "cuda" or device_type == "zoom": return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA return self.dtypes @@ -1356,7 +1356,7 @@ def supported_backward_dtypes(self, device_type): device_type = torch._C._get_privateuse1_backend_name() device_type = torch.device(device_type).type backward_dtypes = None - if device_type == "cuda": + if device_type == "cuda" or device_type == "zoom": backward_dtypes = ( self.backward_dtypesIfROCM if TEST_WITH_ROCM diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 960aa1e79d7395..50a753225e565e 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1078,10 +1078,15 @@ def CUDAExtension(name, sources, *args, **kwargs): libraries.append('torch_cpu') libraries.append('torch_python') if IS_HIP_EXTENSION: + print("IS_HIP_EXTENSION") libraries.append('amdhip64') - libraries.append('c10_hip') - libraries.append('torch_hip') + libraries.append('rocblas') + libraries.append('hipblas') + # (Arham): commented out for zoom development + # libraries.append('c10_hip') + # libraries.append('torch_hip') else: + print("LOADING CUDA") libraries.append('cudart') libraries.append('c10_cuda') libraries.append('torch_cuda') @@ -1089,7 +1094,8 @@ def CUDAExtension(name, sources, *args, **kwargs): include_dirs = kwargs.get('include_dirs', []) - if IS_HIP_EXTENSION: + # (Arham): disable hipify + if False and IS_HIP_EXTENSION: build_dir = os.getcwd() hipify_result = hipify_python.hipify( project_directory=build_dir, @@ -1690,7 +1696,8 @@ def _jit_compile(name, try: if version != old_version: with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx: - if IS_HIP_EXTENSION and (with_cuda or with_cudnn): + # (Arham): to disable hipifying for testing the zoom extension + if False and IS_HIP_EXTENSION and (with_cuda or with_cudnn): hipify_result = hipify_python.hipify( project_directory=build_directory, output_directory=build_directory, @@ -1866,11 +1873,12 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): else: extra_ldflags.append(f'-L{TORCH_LIB_PATH}') extra_ldflags.append('-lc10') - if with_cuda: - extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') + # (Arham): commented out to develop zoom + # if with_cuda: + # extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') extra_ldflags.append('-ltorch_cpu') - if with_cuda: - extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') + # if with_cuda: + # extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') extra_ldflags.append('-ltorch') if not is_standalone: extra_ldflags.append('-ltorch_python') diff --git a/torch/zoom/__init__.py b/torch/zoom/__init__.py new file mode 100644 index 00000000000000..7b5a757d08520c --- /dev/null +++ b/torch/zoom/__init__.py @@ -0,0 +1,577 @@ +import importlib +import os +import threading +import traceback +import warnings +from functools import lru_cache +from typing import Any, Callable, cast, List, Optional, Tuple, Union + +import torch +import torch._C +from torch.types import Device +from .. import device as _device +from .._utils import _dummy_type, _LazySeedTracker, classproperty +from ._utils import _get_device_index +from .streams import Event, ExternalStream, Stream + + +try: + from torch._C import _hiprt # type: ignore[attr-defined] +except ImportError: + _hiprt = None + + +# Define dummy _ZoomDeviceProperties type if PyTorch was compiled without Zoom +if hasattr(torch._C, "_ZoomDeviceProperties"): + _ZoomDeviceProperties = torch._C._ZoomDeviceProperties +else: + _ZoomDeviceProperties = _dummy_type("_ZoomDeviceProperties") # type: ignore[assignment, misc] + +if hasattr(torch._C, "_zoom_exchangeDevice"): + _exchange_device = torch._C._zoom_exchangeDevice +else: + def _exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without Zoom support") + + +if hasattr(torch._C, "_zoom_maybeExchangeDevice"): + _maybe_exchange_device = torch._C._zoom_maybeExchangeDevice +else: + def _maybe_exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without Zoom support") + + + +_initialized = False +_tls = threading.local() +_initialization_lock = threading.Lock() +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] +] = [] # don't invoke these until initialization occurs +_is_in_bad_fork = getattr(torch._C, "_zoom_isInBadFork", lambda: False) +_device_t = Union[_device, str, int, None] +_lazy_seed_tracker = _LazySeedTracker() +_cached_device_count: Optional[int] = None + +class DeferredZoomCallError(Exception): + pass + +def get_amp_supported_dtype() -> List[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + +def _is_compiled() -> bool: + r"""Return true if compile with Zoom support.""" + return hasattr(torch._C, "_zoom_getDeviceCount") + +def is_available() -> bool: + r"""Return a bool indicating if Zoom is currently available.""" + if not _is_compiled(): + return False + return torch._C._zoom_getDeviceCount() > 0 + +def is_bf16_supported(): + r"""bfloat16 is supported on AMD GPU Archs""" + return True + +def is_initialized(): + r"""Return whether PyTorch's HIP state has been initialized.""" + return _initialized and not _is_in_bad_fork() + +def init(): + r"""Initialize PyTorch's HIP state. + + You may need to call this explicitly if you are interacting with + PyTorch via its C API, as Python bindings for Zoom functionality + will not be available until this initialization takes place. + + No-op if Zoom is already initialized. + """ + _lazy_init() + + +def _lazy_init(): + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checked locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize Zoom in forked subprocess. To use Zoom with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not hasattr(torch._C, "_zoom_getDeviceCount"): + raise AssertionError("Torch not compiled with Zoom enabled") + if _hiprt is None: + raise AssertionError( + "HIP runtime functions unavailable. It looks like you have a broken build?" + ) + # This function throws if there's a driver initialization error, no GPUs + # are found or any other error occurs + # if "CUDA_MODULE_LOADING" not in os.environ: + # os.environ["CUDA_MODULE_LOADING"] = "LAZY" + torch._C._zoom_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + for calls in _lazy_seed_tracker.get_calls(): + if calls: + _queued_calls.append(calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"Zoom call failed lazily at initialization with error: {str(e)}\n\n" + f"Zoom call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredZoomCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + +def hiprt(): + _lazy_init() + return _hiprt + +class hipStatus: + SUCCESS: int = 0 + ERROR_NOT_READY: int = 34 + + +class ZoomError(RuntimeError): + def __init__(self, code: int) -> None: + msg = _hiprt.hipGetErrorString(_hiprt.hipError(code)) + super().__init__(f"{msg} ({code})") + + +def check_error(res: int) -> None: + if res != _hiprt.hipError.success: + raise ZoomError(res) + + +class _DeviceGuard: + def __init__(self, index: int): + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.zoom._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.zoom._maybe_exchange_device(self.prev_idx) + return False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.zoom._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.zoom._maybe_exchange_device(self.prev_idx) + return False + + +class device_of(device): + r"""Context-manager that changes the current device to that of given object. + + You can use both tensors and storages as arguments. If a given object is + not allocated on a GPU, this is a no-op. + + Args: + obj (Tensor or Storage): object allocated on the selected device. + """ + + def __init__(self, obj): + idx = obj.get_device() if obj.is_zoom else -1 + super().__init__(idx) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Usage of this function is discouraged in favor of :any:`device`. In most + cases it's better to use ``ZOOM_VISIBLE_DEVICES`` environmental variable. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._zoom_setDevice(device) + + +def get_device_name(device: Optional[_device_t] = None) -> str: + r"""Get the name of a device. + + Args: + device (torch.device or int, optional): device for which to return the + name. This function is a no-op if this argument is a negative + integer. It uses the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + Returns: + str: the name of the device + """ + return get_device_properties(device).name + + +def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: + r"""Get the HIP capability of a device. + + Args: + device (torch.device or int, optional): device for which to return the + device capability. This function is a no-op if this argument is + a negative integer. It uses the current device, given by + :func:`~torch.zoom.current_device`, if :attr:`device` is ``None`` + (default). + + Returns: + tuple(int, int): the major and minor HIP capability of the device + """ + prop = get_device_properties(device) + return prop.major, prop.minor + + +def get_device_properties(device: _device_t) -> _ZoomDeviceProperties: + r"""Get the properties of a device. + + Args: + device (torch.device or int or str): device for which to return the + properties of the device. + + Returns: + _ZoomDeviceProperties: the properties of the device + """ + _lazy_init() # will define _get_device_properties + device = _get_device_index(device, optional=True) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + return _get_device_properties(device) # type: ignore[name-defined] + + +def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: + r"""Check if peer access between two devices is possible.""" + _lazy_init() + device = _get_device_index(device, optional=True) + peer_device = _get_device_index(peer_device) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + if peer_device < 0 or peer_device >= device_count(): + raise AssertionError("Invalid peer device id") + return torch._C._zoom_canDeviceAccessPeer(device, peer_device) + + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + _lazy_init() + return torch._C._zoom_getDevice() + +def synchronize(device: _device_t = None) -> None: + r"""Wait for all kernels in all streams on a Zoom device to complete. + + Args: + device (torch.device or int, optional): device for which to synchronize. + It uses the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + """ + _lazy_init() + with torch.zoom.device(device): + return torch._C._zoom_synchronize() + +def device_count() -> int: + r"""Return the number of GPUs available.""" + global _cached_device_count + if not _is_compiled(): + return 0 + if _cached_device_count is not None: + return _cached_device_count + r = torch._C._zoom_getDeviceCount() + # NB: Do not cache the device count prior to Zoom initialization, because + # the number of devices can change due to changes to ZOOM_VISIBLE_DEVICES + # setting prior to Zoom initialization. + if _initialized: + _cached_device_count = r + return r + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.zoom.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._zoom_getCurrentStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def current_blas_handle(): + r"""Return cublasHandle_t pointer to current cuBLAS handle""" + _lazy_init() + return torch._C._zoom_getCurrentBlasHandle() + + +def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: + r"""Set the debug mode for zoom synchronizing operations. + + Args: + debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations, + if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations. + + Warning: + This is an experimental feature, and not all synchronizing operations will trigger warning or error. In + particular, operations in torch.distributed and torch.sparse namespaces are not covered yet. + """ + _lazy_init() + if isinstance(debug_mode, str): + if debug_mode == "default": + debug_mode = 0 + elif debug_mode == "warn": + debug_mode = 1 + elif debug_mode == "error": + debug_mode = 2 + else: + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, `warn`, `error`" + ) + + torch._C._zoom_set_sync_debug_mode(debug_mode) + + +def get_sync_debug_mode() -> int: + r"""Return current value of debug mode for zoom synchronizing operations.""" + _lazy_init() + return torch._C._zoom_get_sync_debug_mode() + + +################################################################################ +# Define Storage and Tensor classes +################################################################################ + + +@staticmethod # type: ignore[misc] +def _lazy_new(cls, *args, **kwargs): + _lazy_init() + # We may need to call lazy init again if we are a forked child + # del _ZoomBase.__new__ + return super(_ZoomBase, cls).__new__(cls, *args, **kwargs) + + +class _ZoomBase: + is_zoom = True + is_sparse = False + + def type(self, *args, **kwargs): + # We could use a Protocol here to tell mypy that self has `get_device` method + # but it is only available in the typing module on Python >= 3.8 + # or on typing_extensions module on Python >= 3.6 + with device(self.get_device()): # type: ignore[attr-defined] + return super().type(*args, **kwargs) # type: ignore[misc] + + __new__ = _lazy_new + + +from torch.storage import _LegacyStorage, _warn_typed_storage_removal + + +class _ZoomLegacyStorage(_LegacyStorage): + @classmethod + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + raise RuntimeError("from_buffer: Not available for Zoom storage") + + @classmethod + def _new_with_weak_ptr(cls, *args, **kwargs): + raise RuntimeError("_new_with_weak_ptr: Not available for Zoom storage") + + @classmethod + def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): + raise RuntimeError("_new_shared_filename: Not available for Zoom storage") + + +class ByteStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.uint8 + + +class DoubleStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.double + + +class FloatStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.float + + +class HalfStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.half + + +class LongStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.long + + +class IntStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int + + +class ShortStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.short + + +class CharStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int8 + + +class BoolStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bool + + +class BFloat16Storage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bfloat16 + + +class ComplexDoubleStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cdouble + + +class ComplexFloatStorage(_ZoomLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cfloat + + +del _LegacyStorage +del _ZoomLegacyStorage + +torch._storage_classes.add(DoubleStorage) +torch._storage_classes.add(FloatStorage) +torch._storage_classes.add(LongStorage) +torch._storage_classes.add(IntStorage) +torch._storage_classes.add(ShortStorage) +torch._storage_classes.add(CharStorage) +torch._storage_classes.add(ByteStorage) +torch._storage_classes.add(HalfStorage) +torch._storage_classes.add(BoolStorage) +torch._storage_classes.add(BFloat16Storage) +torch._storage_classes.add(ComplexDoubleStorage) +torch._storage_classes.add(ComplexFloatStorage) + +from .memory import * # noqa: F403 \ No newline at end of file diff --git a/torch/zoom/_memory_viz.py b/torch/zoom/_memory_viz.py new file mode 100644 index 00000000000000..8b39bebf35637d --- /dev/null +++ b/torch/zoom/_memory_viz.py @@ -0,0 +1,627 @@ +import pickle +import sys +import os +import io +import subprocess +import json +from functools import lru_cache +from typing import Any +from itertools import groupby +import base64 +import warnings +import operator + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + +def _frame_fmt(f, full_filename=False): + i = f['line'] + fname = f['filename'] + if not full_filename: + fname = fname.split('/')[-1] + func = f['name'] + return f'{fname}:{i}:{func}' + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])] + +def _block_extra_legacy(b): + if 'history' in b: + frames = b['history'][0].get('frames', []) + real_size = b['history'][0]['real_size'] + else: + real_size = b.get('requested_size', b['size']) + frames = [] + return frames, real_size + +def _block_extra(b): + if 'frames' not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b['frames'], b['requested_size'] + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl' + if not os.path.exists(flamegraph_script): + import urllib.request + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + urllib.request.urlretrieve( + 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script) + subprocess.check_call(['chmod', '+x', flamegraph_script]) + args = [flamegraph_script, '--countname', 'bytes'] + p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8') + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ';'.join(_frames_fmt(frames, reverse=True)) + for b in blocks: + if 'history' not in b: + frames, accounted_for_size = _block_extra(b) + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n') + else: + accounted_for_size = 0 + for h in b['history']: + sz = h['real_size'] + accounted_for_size += sz + if 'frames' in h: + frames = h['frames'] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b['size'] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg['address'], seg['total_size']) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}') + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks']) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks']) + + return format_flamegraph(f.getvalue()) + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + +def calc_active(seg): + return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated') + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = '' + if total != 0: + pct = (free_internal / total) * 100 + suffix = f' ({pct:.1f}% internal)' + return f'{Bytes(total)}{suffix}' + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + segments = [] + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))): + total_reserved += seg['total_size'] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg['blocks']: + active = b['state'] == 'active_allocated' + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b['size'] - allocated_size + else: + seg_free_external += b['size'] + + boffset += b['size'] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1 + occupied = [' ' for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = (start_ + size) + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord('a' if active else 'A') + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != ' ': + occupied[j] = '0123456789*'[int(frac[j] * 10)] + else: + occupied[j] = m + stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}' + body = ''.join(occupied) + assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size'] + stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else '' + if seg['total_size'] >= PAGE_SIZE: + out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n') + out.write(f'segments: {len(data["segments"])}\n') + out.write(f'total_reserved: {Bytes(total_reserved)}\n') + out.write(f'total_allocated: {Bytes(total_allocated)}\n') + internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else '' + out.write(f'total_free: {_report_free(free_external, free_internal)}\n') + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals : list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names : list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data['segments']): + saddr = seg['address'] + size = seg['allocated_size'] + if addr >= saddr and addr < saddr + size: + return f'seg_{i}', saddr + return None, None + count = 0 + out.write(f'{len(entries)} entries\n') + + + total_reserved = 0 + for seg in data['segments']: + total_reserved += seg['total_size'] + + for count, e in enumerate(entries): + if e['action'] == 'alloc': + addr, size = e['addr'], e['size'] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n') + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e['action'] == 'free_requested': + addr, size = e['addr'], e['size'] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'del {name} # {Bytes(size)}\n') + elif e['action'] == 'free_completed': + addr, size = e['addr'], e['size'] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'# free completed for {name} {Bytes(size)}\n') + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e['action'] == 'segment_alloc': + addr, size = e['addr'], e['size'] + name = _name() + out.write(f'{name} = hipMalloc({addr}, {Bytes(size)})\n') + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e['action'] == 'segment_free': + addr, size = e['addr'], e['size'] + name = segment_addr_to_name.get(addr, addr) + out.write(f'hipFree({name}) # {Bytes(size)}\n') + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e['action'] == 'oom': + size = e['size'] + free = e['device_free'] + out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in Zoom\n') + else: + out.write(f'{e}\n') + out.write(f"TOTAL MEM: {Bytes(count)}") + for i, d in enumerate(data['device_traces']): + if d: + out.write(f'Device {i} ----------------\n') + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn('device argument is deprecated, plots now contain all device') + buffer = pickle.dumps(data) + buffer += b'\x00' * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode('utf-8') + + json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}]) + return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \ + .replace('$SNAPSHOT', json_format) + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.zoom.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from hipMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device) + + +def _profile_to_snapshot(profile): + import torch + from torch.profiler._memory_profiler import Action, TensorKey + from torch._C._profiler import _EventType + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + + device_count = torch.zoom.device_count() + snapshot = { + 'device_traces': [[] for _ in range(device_count + 1)], + 'segments': [{'device': device, + 'address': None, + 'total_size': 0, + 'stream': 0, + 'blocks': []} for device in range(device_count + 1)] + } + + def to_device(device): + if device.type == 'zoom': + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot['segments'][device] # type: ignore[index] + if seg['address'] is None or seg['address'] > addr: + seg['address'] = addr + seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack] + r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category} + if during_trace: + snapshot['device_traces'][device].append(r) # type: ignore[index] + return r + + def free(alloc, device): + for e in ('free_requested', 'free_completed'): + snapshot['device_traces'][device].append({'action': e, # type: ignore[index] + 'addr': alloc['addr'], + 'size': alloc['size'], + 'stream': 0, + 'frames': alloc['frames']}) + + kv_to_elem = {} + + + + # create the device trace + for time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False) + + + # create the final snapshot state + blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) + for (tensor_key, version), event in kv_to_elem.items()] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot['segments'][device] # type: ignore[index] + last_addr = seg['address'] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'}) + seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames}) + last_addr = addr + size + if last_addr < seg['total_size']: + seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'}) + + snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined] + for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef] + seg['total_size'] -= seg['address'] + if not seg['blocks']: + seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'}) + + return snapshot + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, 'Active Memory Timeline', device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, 'Allocator State History', device) + +if __name__ == "__main__": + import os.path + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find zoom/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = 'torch.zoom.memory._snapshot()' + pickled = f'pickled memory statistics from {fn_name}' + parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}') + + subparsers = parser.add_subparsers(dest='action') + + def _output(p): + p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)') + + description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.' + stats_a = subparsers.add_parser('stats', description=description) + stats_a.add_argument('input', help=pickled) + + description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.' + trace_a = subparsers.add_parser('trace', description=description) + trace_a.add_argument('input', help=pickled) + + description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)' + segments_a = subparsers.add_parser('segments', description=description) + segments_a.add_argument('input', help=pickled) + _output(segments_a) + + description = "Generate a flamegraph the program locations contributing to Zoom memory usage." + memory_a = subparsers.add_parser('memory', description=description) + memory_a.add_argument('input', help=pickled) + _output(memory_a) + + description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \ + 'or removed between two different memorys snapshots.' + compare_a = subparsers.add_parser('compare', description=description) + compare_a.add_argument('before', help=pickled) + compare_a.add_argument('after', help=pickled) + _output(compare_a) + + plots = ( + ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."), + ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.") + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument('input', help=pickled) + help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' + trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) + help = 'path to save the visualization(default: output.html)' + trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) + if cmd == "trace_plot": + help = 'visualize change to segments rather than individual allocations' + trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) + + + args = parser.parse_args() + + def _read(name): + if name == '-': + f = sys.stdin.buffer + else: + f = open(name, 'rb') + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {'segments': data, 'traces': []} + return data + + def _write(name, data): + with open(name, 'w') as f: + f.write(data) + + if args.action == 'segments': + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == 'memory': + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == 'stats': + data = _read(args.input) + print(segsum(data)) + elif args.action == 'trace': + data = _read(args.input) + print(trace(data)) + elif args.action == 'compare': + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == 'trace_plot': + data = _read(args.input) + _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments)) + elif args.action == 'segment_plot': + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/torch/zoom/_utils.py b/torch/zoom/_utils.py new file mode 100644 index 00000000000000..a6a0b4fa39f13c --- /dev/null +++ b/torch/zoom/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a Zoom device. Note that for a Zoom device without a specified index, + i.e., ``torch.device('zoom')``, this will return the current default zoom + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default zoom + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["zoom", "cpu"]: + raise ValueError(f"Expected a zoom or cpu device, but got: {device}") + elif device.type != "zoom": + raise ValueError(f"Expected a zoom device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.zoom.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/torch/zoom/graphs.py b/torch/zoom/graphs.py new file mode 100644 index 00000000000000..c418abd3e6ef7a --- /dev/null +++ b/torch/zoom/graphs.py @@ -0,0 +1,479 @@ +import gc +import typing + +import torch +from torch.utils import _pytree +from .._utils import _dummy_type + +if not hasattr(torch._C, "_ZoomStreamBase"): + # Define dummy base classes + torch._C.__dict__["_HIPGraph"] = _dummy_type("_HIPGraph") + torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") + torch._C.__dict__["_zoom_isCurrentStreamCapturing"] = _dummy_type( + "_zoom_isCurrentStreamCapturing" + ) + +from torch._C import ( # noqa: F401 + _zoom_isCurrentStreamCapturing, + _HIPGraph, + _graph_pool_handle, +) + + +def is_current_stream_capturing(): + r"""Return True if CUDA graph capture is underway on the current Zoom stream, False otherwise. + + If a Zoom context does not exist on the current device, returns False without initializing the context. + """ + return _zoom_isCurrentStreamCapturing() + + +# Python shim helps Sphinx process docstrings more reliably. +def graph_pool_handle(): + r"""Return an opaque token representing the id of a graph memory pool. + + See :ref:`Graph memory management`. + + .. warning:: + This API is in beta and may change in future releases. + """ + return _graph_pool_handle() + + +# Python shim helps Sphinx process docstrings more reliably. +class HIPGraph(torch._C._HIPGraph): + r"""Wrapper around a HIP graph. + + .. warning:: + This API is in beta and may change in future releases. + """ + + def __new__(cls): + return super().__new__(cls) + + def capture_begin(self, pool=None, capture_error_mode="global"): + r"""Begin capturing Zoom work on the current stream. + + Typically, you shouldn't call ``capture_begin`` yourself. + Use :class:`~torch.zoom.graph` or :func:`~torch.zoom.make_graphed_callables`, + which call ``capture_begin`` internally. + + Arguments: + pool (optional): Token (returned by :func:`~torch.zoom.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + capture_error_mode (str, optional): specifies the hipStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as hipMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting + unless you're familiar with `hipStreamCaptureMode `_ + """ # noqa: B950 + super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) + + def capture_end(self): + r"""End HIP graph capture on the current stream. + + After ``capture_end``, ``replay`` may be called on this instance. + + Typically, you shouldn't call ``capture_end`` yourself. + Use :class:`~torch.zoom.graph` or :func:`~torch.zoom.make_graphed_callables`, + which call ``capture_end`` internally. + """ + super().capture_end() + + def replay(self): + r"""Replay the HIP work captured by this graph.""" + super().replay() + + def reset(self): + r"""Delete the graph currently held by this instance.""" + super().reset() + + def pool(self): + r"""Return an opaque token representing the id of this graph's memory pool. + + This id can optionally be passed to another graph's ``capture_begin``, + which hints the other graph may share the same memory pool. + """ + return super().pool() + + def enable_debug_mode(self): + r"""Enable debugging mode for HIPGraph.debug_dump.""" + return super().enable_debug_mode() + + def debug_dump(self, debug_path): + r""" + Arguments: + debug_path (required): Path to dump the graph to. + + Calls a debugging function to dump the graph if the debugging is + enabled via HIPGraph.enable_debug_mode() + """ + return super().debug_dump(debug_path) + + +class graph: + r"""Context-manager that captures HIP work into a :class:`torch.zoom.HIPGraph` object for later replay. + + See :ref:`CUDA Graphs ` for a general introduction, + detailed use, and constraints. + + Arguments: + hip_graph (torch.zoom.HIPGraph): Graph object used for capture. + pool (optional): Opaque token (returned by a call to :func:`~torch.zoom.graph_pool_handle()` or + :meth:`other_Graph_instance.pool()`) hinting this graph's capture + may share memory from the specified pool. See :ref:`Graph memory management`. + stream (torch.zoom.Stream, optional): If supplied, will be set as the current stream in the context. + If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. + capture_error_mode (str, optional): specifies the hipStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During hip graph capture, some actions, such as hipMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting + unless you're familiar with `hipStreamCaptureMode `_ + + .. note:: + For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture + used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. + + .. warning:: + This API is in beta and may change in future releases. + + .. _hipStreamCaptureMode: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + """ # noqa: B950 + + default_capture_stream: typing.Optional["torch.zoom.Stream"] = None + + def __init__( + self, + hip_graph, + pool=None, + stream=None, + capture_error_mode: str = "global", + ): + # Lazy-init of default_capture_stream helps avoid circular-import errors. + # Not thread safe, but graphs already have the general (explicitly documented) + # restriction that only one capture may be underway at a time in the process. + if self.__class__.default_capture_stream is None: + self.__class__.default_capture_stream = torch.zoom.Stream() + + self.pool = () if pool is None else (pool,) + self.capture_stream = ( + stream if stream is not None else self.__class__.default_capture_stream + ) + assert self.capture_stream is not None + self.stream_ctx = torch.zoom.stream(self.capture_stream) + self.hip_graph = hip_graph + self.capture_error_mode = capture_error_mode + + def __enter__(self): + # Free as much memory as we can for the graph + torch.zoom.synchronize() + gc.collect() + torch.zoom.empty_cache() + + # Stackoverflow seems comfortable with this pattern + # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 + self.stream_ctx.__enter__() + + self.hip_graph.capture_begin( + *self.pool, capture_error_mode=self.capture_error_mode + ) + + def __exit__(self, exc_type, exc_value, traceback): + self.hip_graph.capture_end() + self.stream_ctx.__exit__(exc_type, exc_value, traceback) + # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() + + +def make_graphed_callables( + callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None +): + r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. + + Each graphed callable's forward pass runs its source callable's + forward CUDA work as a CUDA graph inside a single autograd node. + + The graphed callable's forward pass also appends + a backward node to the autograd graph. During backward, this node runs the + callable's backward work as a CUDA graph. + + Therefore, each graphed callable should be a drop-in replacement for its source callable + in an autograd-enabled training loop. + + See :ref:`Partial-network capture` for detailed use and constraints. + + If you pass a tuple of several callables, their captures will use the same memory pool. + See :ref:`Graph memory management` for when this is appropriate. + + Arguments: + callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. + See :ref:`Graph memory management` for when passing a tuple of callables + is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order + they'll run in the live workload. + sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. + If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. + If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. + num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs + 11 iterations for warm up. Default: ``3``. + allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs + (and therefore their grad is always zero) is an error. Defaults to False. + pool (optional): Token (returned by :func:`~torch.zoom.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + .. note:: + The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state + that's expected for the corresponding real input in the training loop. + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + ``sample_args`` for each callable must contain only Tensors. Other types are not allowed. + + .. warning:: + Returned callables do not support higher order differentiation (e.g., double backward). + + .. warning:: + In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters + may be trainable. Buffers must have ``requires_grad=False``. + + .. warning:: + After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, + you may not add or remove any of that Module's parameters or buffers. + + .. warning:: + :class:`torch.nn.Module`\s passed to :func:`~torch.zoom.make_graphed_callables` must not have module hooks + registered on them at the time they are passed. However, registering hooks on modules *after* passing them + through :func:`~torch.zoom.make_graphed_callables` is allowed. + + .. warning:: + When running a graphed callable, you must pass its arguments in the same order and format + they appeared in that callable's ``sample_args``. + + .. warning:: + The automatic mixed precision is supported in :func:`~torch.zoom.make_graphed_callables` only with disabled + caching. The context manager `torch.zoom.amp.autocast()` must have `cache_enabled=False`. + """ + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + + for c, args in zip(callables, sample_args): + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. However, registering hooks " + + "on modules after passing them through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + + "``requires_grad=False``." + ) + flatten_arg = _pytree.arg_tree_leaves(*args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + + fwd_graphs = [torch.zoom.HIPGraph() for _ in range(len(callables))] + bwd_graphs = [torch.zoom.HIPGraph() for _ in range(len(callables))] + + mempool = graph_pool_handle() if pool is None else pool + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.zoom.synchronize() + with torch.zoom.stream(torch.zoom.Stream()): + for func, args, static_input_surface in zip( + callables, sample_args, per_callable_static_input_surfaces + ): + for _ in range(num_warmup_iters): + outputs = _pytree.tree_leaves(func(*args)) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + del outputs, grad_inputs # type: ignore[possibly-undefined] + torch.zoom.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.zoom.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + flatten_outputs, spec = _pytree.tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph, module_params in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + reversed(per_callable_module_params), + ): + # For now, assumes all static_outputs require grad + # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + + with torch.zoom.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. + # I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs.reverse() + per_callable_static_grad_inputs.reverse() + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args = _pytree.arg_tree_leaves(*user_args) + out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) + return _pytree.tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i, func in enumerate(callables): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + else: + return orig_fwd(*user_args) + + return new_fwd + + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] + ret.append(func) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) diff --git a/torch/zoom/memory.py b/torch/zoom/memory.py new file mode 100644 index 00000000000000..e910e6271fc8ff --- /dev/null +++ b/torch/zoom/memory.py @@ -0,0 +1,910 @@ +r"""This package adds support for device memory management implemented in Zoom.""" + +import collections +import contextlib +import ctypes +import pickle +import sys +import warnings +from inspect import signature + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import _C + +from torch.types import Device +from .._utils import _dummy_type +# from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized +from . import _get_device_index, _lazy_init, is_initialized + +from ._memory_viz import memory as _memory, segments as _segments + +__all__ = [ + "caching_allocator_alloc", + "caching_allocator_delete", + "set_per_process_memory_fraction", + "empty_cache", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "memory_cached", + "max_memory_cached", + "memory_snapshot", + "memory_summary", + # "list_gpu_processes", + "mem_get_info", + "get_allocator_backend", + "ZoomPluggableAllocator", + "change_current_allocator", +] + + +if not hasattr(torch._C, "_zoom_ZoomAllocator"): + # Define dummy base classes + torch._C.__dict__["_zoom_ZoomAllocator"] = _dummy_type("_zoom_ZoomAllocator") + + +def _host_allocator(): + _lazy_init() + return torch._C._zoom_zoomHostAllocator() + + +@contextlib.contextmanager +def _free_mutex(): + torch._C._zoom_lock_mutex() + try: + yield + finally: + torch._C._zoom_unlock_mutex() + + +def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None): + r"""Perform a memory allocation using the Zoom memory allocator. + + Memory is allocated for a given device and a stream, this + function is intended to be used for interoperability with other + frameworks. Allocated memory is released through + :func:`~torch.zoom.caching_allocator_delete`. + + Args: + size (int): number of bytes to be allocated. + device (torch.device or int, optional): selected device. If it is + ``None`` the default Zoom device is used. + stream (torch.zoom.Stream or int, optional): selected stream. If is ``None`` then + the default stream for the selected device is used. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + if device is None: + device = torch.zoom.current_device() + device = _get_device_index(device) + if stream is None: + stream = torch.zoom.current_stream(device) + if isinstance(stream, torch.zoom.streams.Stream): + stream = stream.zoom_stream + if not isinstance(stream, int): + raise TypeError( + "Invalid type for stream argument, must be " + "`torch.zoom.Stream` or `int` representing a pointer " + "to a existing stream" + ) + with torch.zoom.device(device): + return torch._C._zoom_zoomCachingAllocator_raw_alloc(size, stream) + + +def caching_allocator_delete(mem_ptr): + r"""Delete memory allocated using the Zoom memory allocator. + + Memory allocated with :func:`~torch.zoom.caching_allocator_alloc`. + is freed here. The associated device and stream are tracked inside + the allocator. + + Args: + mem_ptr (int): memory address to be freed by the allocator. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + torch._C._zoom_zoomCachingAllocator_raw_delete(mem_ptr) + + +def set_per_process_memory_fraction( + fraction, device: Union[Device, int] = None +) -> None: + r"""Set memory fraction for a process. + + The fraction is used to limit an caching allocator to allocated memory on a Zoom device. + The allowed value equals the total visible memory multiplied fraction. + If trying to allocate more than the allowed value in a process, will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. + device (torch.device or int, optional): selected device. If it is + ``None`` the default Zoom device is used. + .. note:: + In general, the total available free memory is less than the total capacity. + """ + _lazy_init() + if device is None: + device = torch.zoom.current_device() + device = _get_device_index(device) + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 1: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1") + + torch._C._zoom_setMemoryFraction(fraction, device) + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU application and visible in + `nvidia-smi`. + + .. note:: + :func:`~torch.zoom.empty_cache` doesn't increase the amount of GPU + memory available for PyTorch. However, it may help reduce fragmentation + of GPU memory in certain cases. See :ref:`cuda-memory-management` for + more details about GPU memory management. + """ + if is_initialized(): + torch._C._zoom_emptyCache() + + +def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return a dictionary of Zoom memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from ``hipMalloc()``. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of October 2019, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of October 2019, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed ``hipMalloc`` calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of Zoom allocation calls. This includes both + cuMemMap and hipMalloc. + - ``"num_device_free"``: number of Zoom free calls. This includes both cuMemUnmap + and hipFree. + + The caching allocator can be configured via ENV to not split blocks larger than a + defined size (see Memory Management section of the Cuda Semantics documentation). + This helps avoid memory fragmentation but may have a performance + penalty. Additional outputs to assist with tuning and evaluating impact: + + - ``"max_split_size"``: blocks above this size will not be split. + - ``"oversize_allocations.{current,peak,allocated,freed}"``: + number of over-size allocation requests received by the memory allocator. + - ``"oversize_segments.{current,peak,allocated,freed}"``: + number of over-size reserved segments from ``hipMalloc()``. + + The caching allocator can be configured via ENV to round memory allocations in order + to reduce fragmentation. Sometimes the overhead from rounding can be higher than + the fragmentation it helps reduce. The following stat can be used to check if + rounding adds too much overhead: + + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + Args: + device (torch.device or int, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + + .. note:: + With :ref:`backend:hipMallocAsync`, some stats are not + meaningful, and are always reported as zero. + """ + result = [] + + def _recurse_add_to_result(prefix, obj): + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return the result of :func:`~torch.zoom.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._zoom_memoryStats(device) + + +def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the Zoom memory allocator. + + See :func:`~torch.zoom.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict, as well as + `"num_alloc_retries"` and `"num_ooms"`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._zoom_resetAccumulatedMemoryStats(device) + + +def reset_peak_memory_stats(device: Union[Device, int] = None) -> None: + r"""Reset the "peak" stats tracked by the Zoom memory allocator. + + See :func:`~torch.zoom.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._zoom_resetPeakMemoryStats(device) + + +def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device. + + See :func:`~torch.zoom.max_memory_allocated` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.zoom.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.zoom.reset_max_memory_allocated now calls torch.zoom.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def reset_max_memory_cached(device: Union[Device, int] = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + + See :func:`~torch.zoom.max_memory_cached` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.zoom.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.zoom.reset_max_memory_cached now calls torch.zoom.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def memory_allocated(device: Union[Device, int] = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `nvidia-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: Union[Device, int] = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.zoom.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: Union[Device, int] = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: Union[Device, int] = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.zoom.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +def memory_cached(device: Union[Device, int] = None) -> int: + r"""Deprecated; see :func:`~torch.zoom.memory_reserved`.""" + warnings.warn( + "torch.zoom.memory_cached has been renamed to torch.zoom.memory_reserved", + FutureWarning, + ) + return memory_reserved(device=device) + + +def max_memory_cached(device: Union[Device, int] = None) -> int: + r"""Deprecated; see :func:`~torch.zoom.max_memory_reserved`.""" + warnings.warn( + "torch.zoom.max_memory_cached has been renamed to torch.zoom.max_memory_reserved", + FutureWarning, + ) + return max_memory_reserved(device=device) + + +def memory_snapshot(): + r"""Return a snapshot of the Zoom memory allocator state across all devices. + + Interpreting the output of this function requires familiarity with the + memory allocator internals. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return torch._C._zoom_memorySnapshot()["segments"] + + +def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str: + r"""Return a human-readable printout of the current memory allocator statistics for a given device. + + This can be useful to display periodically during training, or when + handling out-of-memory exceptions. + + Args: + device (torch.device or int, optional): selected device. Returns + printout for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + abbreviated (bool, optional): whether to return an abbreviated summary + (default: False). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + stats = memory_stats(device=device) + + def _format_size(sz, pref_sz): + prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_sz < 768 * 1024: + break + prefix = new_prefix + sz //= 1024 + pref_sz /= 1024 + return f"{sz:6d} {prefix}" + + def _format_count(cnt, pref_cnt): + prefixes = [" ", "K", "M"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_cnt < 750 * 1000: + break + prefix = new_prefix + cnt //= 1000 + pref_cnt /= 1000 + return f"{cnt:7d} {prefix} " + + metrics_to_display = [ + ("allocated_bytes", "Allocated memory", _format_size), + ("active_bytes", "Active memory", _format_size), + ("requested_bytes", "Requested memory", _format_size), + ("reserved_bytes", "GPU reserved memory", _format_size), + ("inactive_split_bytes", "Non-releasable memory", _format_size), + ("allocation", "Allocations", _format_count), + ("active", "Active allocs", _format_count), + ("segment", "GPU reserved segments", _format_count), + ("inactive_split", "Non-releasable allocs", _format_count), + ] + + lines = [] + lines.append("=" * 75) + lines.append(" {_:16} PyTorch Zoom memory summary, device ID {device:<17d} ") + lines.append("-" * 75) + lines.append( + " {_:9} HIP OOMs: {num_ooms:<12d} | {_:6} hipMalloc retries: {num_alloc_retries:<8d} " + ) + lines.append("=" * 75) + lines.append( + " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed " + ) + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + submetrics = [("all", metric_name)] + if not abbreviated: + submetrics.append(("large_pool", " from large pool")) + submetrics.append(("small_pool", " from small pool")) + + current_prefval, peak_prefval, allocated_prefval, freed_prefval = ( + None, + None, + None, + None, + ) + + for submetric_key, submetric_name in submetrics: + prefix = metric_key + "." + submetric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + if current_prefval is None: + current_prefval = current + peak_prefval = peak + allocated_prefval = allocated + freed_prefval = freed + + lines.append( + f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | " + f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ", + ) + + metrics_to_display = [ + ("oversize_allocations", "Oversize allocations", _format_count), + ("oversize_segments", "Oversize GPU segments", _format_count), + ] + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + + prefix = metric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + lines.append( + f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | " + f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ", + ) + + lines.append("=" * 75) + + fmt_dict = {"_": "", "device": device} + for k, v in stats.items(): + fmt_dict[k.replace(".", "-")] = v + return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n" + + +# def list_gpu_processes(device: Union[Device, int] = None) -> str: +# r"""Return a human-readable printout of the running processes and their GPU memory use for a given device. + +# This can be useful to display periodically during training, or when +# handling out-of-memory exceptions. + +# Args: +# device (torch.device or int, optional): selected device. Returns +# printout for the current device, given by :func:`~torch.zoom.current_device`, +# if :attr:`device` is ``None`` (default). +# """ +# # try: +# # import pynvml # type: ignore[import] +# # except ModuleNotFoundError: +# # return "pynvml module not found, please install pynvml" +# # from pynvml import NVMLError_DriverNotLoaded + +# try: +# pynvml.nvmlInit() +# except NVMLError_DriverNotLoaded: +# return "cuda driver can't be loaded, is cuda enabled?" +# device = _get_nvml_device_index(device) +# handle = pynvml.nvmlDeviceGetHandleByIndex(device) +# procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) +# lines = [] +# lines.append(f"GPU:{device}") +# if len(procs) == 0: +# lines.append("no processes are running") +# for p in procs: +# mem = p.usedGpuMemory / (1024 * 1024) +# lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") +# return "\n".join(lines) + + +def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: + r"""Return the global free and total GPU memory for a given device using hipMemGetInfo. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.zoom.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = torch.zoom.current_device() + device = _get_device_index(device) + return torch.zoom.hiprt().hipMemGetInfo(device) + + +def _record_memory_history_legacy( + enabled: bool, + record_context=True, + trace_alloc_max_entries=1, + trace_alloc_record_context=False, + device: Union[Device, int] = None, + record_context_cpp=False, +): + _C._zoom_record_memory_history_legacy( + enabled, + record_context, + trace_alloc_max_entries, + trace_alloc_record_context, + record_context_cpp, + ) + + +def _record_memory_history(enabled="all", *args, **kwargs): + """Enable recording of stack traces associated with memory + allocations, so you can tell what allocated any piece of memory in + :func:`torch.zoom.memory._snapshot()`. + + In addition too keeping stack traces with each current allocation and free, + this will also enable recording of a history of all alloc/free events. + + Use :func:`torch.zoom.memory._snapshot()` to retrieve this information, + and the tools in `_memory_viz.py` to visualize snapshots. + + The Python trace collection is fast (2us per trace), so you may consider + enabling this on production jobs if you anticipate ever having to debug + memory issues. + + C++ trace collection is also fast (~50ns/frame), which for many typical programs + works out to ~2us per trace, but can vary depending on stack depth. + + Args: + enabled (Literal[None, "state", "all"], optional): + `None`, disable recording memory history. + `"state"`, keep information for currenly allocated memory. + `"all"`, additionally keep a history of all alloc/free calls. + Defaults to "all". + context (Literal[None, "state", "alloc", "all"], optional): + `None`, Do not record any tracebacks. + `"state"`, Record tracebacks for currently allocated memory. + `"alloc"`, additionally keep tracebacks for alloc calls. + `"all"`, additionally keep tracebacks for free calls. + Defaults to "all". + stacks (Literal["python", "all"], optional): + `"python"`, include Python, TorchScript, and inductor frames in tracebacks + `"all"`, additionally include C++ frames + Defaults to "all". + max_entries (int, optional): Keep a maximum of `max_entries` + alloc/free events in the recorded history recorded. + """ + if isinstance(enabled, bool): + return _record_memory_history_legacy(enabled, *args, **kwargs) + else: + return _record_memory_history_impl(enabled, *args, **kwargs) + + +def _record_memory_history_impl( + enabled: Optional[str] = "all", + context: Optional[str] = "all", + stacks: str = "all", + max_entries: int = sys.maxsize, + device: Union[Device, int] = None, +): + _C._zoom_record_memory_history(enabled, context, stacks, max_entries) + + +_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] + + +def _snapshot(device: Union[Device, int] = None): + """Save a snapshot of HIP memory state at the time it was called. + + The state is represented as a dictionary with the following structure. + + .. code-block:: python + + class Snapshot(TypedDict): + segments : List[Segment] + device_traces: List[List[TraceEntry]] + + class Segment(TypedDict): + # Segments are memory returned from a hipMalloc call. + # The size of reserved memory is the sum of all Segments. + # Segments are cached and reused for future allocations. + # If the reuse is smaller than the segment, the segment + # is split into more then one Block. + # empty_cache() frees Segments that are entirely inactive. + address: int + total_size: int # hipMalloc'd size of segment + stream: int + segment_type: Literal['small', 'large'] # 'large' (>1MB) + allocated_size: int # size of memory in use + active_size: int # size of memory in use or in active_awaiting_free state + blocks : List[Block] + + class Block(TypedDict): + # A piece of memory returned from the allocator, or + # current cached but inactive. + size: int + requested_size: int # size requested during malloc, may be smaller than + # size due to rounding + address: int + state: Literal['active_allocated', # used by a tensor + 'active_awaiting_free', # waiting for another stream to finish using + # this, then it will become free + 'inactive',] # free for reuse + frames: List[Frame] # stack trace from where the allocation occurred + + class Frame(TypedDict): + filename: str + line: int + name: str + + class TraceEntry(TypedDict): + # When `torch.zoom.memory._record_memory_history()` is enabled, + # the snapshot will contain TraceEntry objects that record each + # action the allocator took. + action: Literal[ + 'alloc' # memory allocated + 'free_requested', # the allocated received a call to free memory + 'free_completed', # the memory that was requested to be freed is now + # able to be used in future allocation calls + 'segment_alloc', # the caching allocator ask hipMalloc for more memory + # and added it as a segment in its cache + 'segment_free', # the caching allocator called hipFree to return memory + # to hip possibly trying free up memory to + # allocate more segments or because empty_caches was called + 'oom', # the allocator threw an OOM exception. 'size' is + # the requested number of bytes that did not succeed + 'snapshot' # the allocator generated a memory snapshot + # useful to coorelate a previously taken + # snapshot with this trace + ] + addr: int # not present for OOM + frames: List[Frame] + size: int + stream: int + device_free: int # only present for OOM, the amount of + # memory hip still reports to be free + + Returns: + The Snapshot dictionary object + """ + return _C._zoom_memorySnapshot() + + +def _dump_snapshot(filename="dump_snapshot.pickle"): + """ + Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. + + This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz + + Args: + filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + """ + s = _snapshot() + with open(filename, "wb") as f: + pickle.dump(s, f) + + +def _save_segment_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_segments(snapshot)) + + +def _save_memory_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_memory(snapshot)) + + +def _set_allocator_settings(env: str): + return torch._C._zoom_zoomCachingAllocator_set_allocator_settings(env) + + +def get_allocator_backend() -> str: + r"""Return a string describing the active allocator backend as set by + ``PYTORCH_ZOOM_ALLOC_CONF``. Currently available backends are + ``native`` (PyTorch's native caching allocator) and `hipMallocAsync`` + (HIP's built-in asynchronous allocator). + + .. note:: + See :ref:`cuda-memory-management` for details on choosing the allocator backend. + """ + return torch._C._zoom_getAllocatorBackend() + + +class _ZoomAllocator: + r"""Wrapper over internal Zoom memory allocators.""" + + def __init__(self, allocator: torch._C._zoom_ZoomAllocator): + self._allocator = allocator + + def allocator(self): + return self._allocator + + +class ZoomPluggableAllocator(_ZoomAllocator): + r"""Zoom memory allocator loaded from a so file.""" + + def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): + r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes. + + To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function. + + Args: + path_to_so_file(str): Path in the filesystem to the `.so` file containing + the allocator functions + alloc_fn_name(str): Name of the function to perform the memory allocation + in the so file. The signature must be: + void* alloc_fn_name(ssize_t size, int device, hipStream_t stream); + free_fn_name(str): Name of the function to perform the memory release + in the so file. The signature must be: + void free_fn_name(void* ptr, size_t size, hipStream_t stream); + + .. warning:: + This is currently supported only in unix OSs + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + allocator = ctypes.CDLL(path_to_so_file) + alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value + free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value + assert alloc_fn is not None + assert free_fn is not None + self._allocator = torch._C._zoom_customAllocator(alloc_fn, free_fn) + + +def change_current_allocator(allocator: _ZoomAllocator) -> None: + r"""Change the currently used memory allocator to be the one provided. + + If the current allocator has already been used/initialized, this function will error. + + + Args: + allocator (torch.zoom.memory._ZoomAllocator): allocator to be set as the active one. + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + torch._C._zoom_changeCurrentAllocator(allocator.allocator()) + + +def _get_current_allocator() -> _ZoomAllocator: + r"""Return the allocator being currently used. + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + return _ZoomAllocator(torch._C._zoom_getAllocator()) diff --git a/torch/zoom/random.py b/torch/zoom/random.py new file mode 100644 index 00000000000000..30c906063698bb --- /dev/null +++ b/torch/zoom/random.py @@ -0,0 +1,179 @@ +from typing import Iterable, List, Union + +import torch +from .. import Tensor +from . import _lazy_call, _lazy_init, current_device, device_count + +__all__ = [ + "get_rng_state", + "get_rng_state_all", + "set_rng_state", + "set_rng_state_all", + "manual_seed", + "manual_seed_all", + "seed", + "seed_all", + "initial_seed", +] + + +def get_rng_state(device: Union[int, str, torch.device] = "zoom") -> Tensor: + r"""Return the random number generator state of the specified GPU as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'zoom'`` (i.e., ``torch.device('zoom')``, the current Zoom device). + + .. warning:: + This function eagerly initializes Zoom. + """ + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("zoom", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.zoom.default_generators[idx] + return default_generator.get_state() + + +def get_rng_state_all() -> List[Tensor]: + r"""Return a list of ByteTensor representing the random number states of all devices.""" + results = [] + for i in range(device_count()): + results.append(get_rng_state(i)) + return results + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "zoom" +) -> None: + r"""Set the random number generator state of the specified GPU. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'zoom'`` (i.e., ``torch.device('zoom')``, the current Zoom device). + """ + with torch._C._DisableFuncTorch(): + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("zoom", device) + + def cb(): + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.zoom.default_generators[idx] + default_generator.set_state(new_state_copy) + + _lazy_call(cb) + + +def set_rng_state_all(new_states: Iterable[Tensor]) -> None: + r"""Set the random number generator state of all devices. + + Args: + new_states (Iterable of torch.ByteTensor): The desired state for each device. + """ + for i, state in enumerate(new_states): + set_rng_state(state, i) + + +def manual_seed(seed: int) -> None: + r"""Set the seed for generating random numbers for the current GPU. + + It's safe to call this function if Zoom is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + + .. warning:: + If you are working with a multi-GPU model, this function is insufficient + to get determinism. To seed all GPUs, use :func:`manual_seed_all`. + """ + seed = int(seed) + + def cb(): + idx = current_device() + default_generator = torch.zoom.default_generators[idx] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed=True) + + +def manual_seed_all(seed: int) -> None: + r"""Set the seed for generating random numbers on all GPUs. + + It's safe to call this function if Zoom is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + """ + seed = int(seed) + + def cb(): + for i in range(device_count()): + default_generator = torch.zoom.default_generators[i] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed_all=True) + + +def seed() -> None: + r"""Set the seed for generating random numbers to a random number for the current GPU. + + It's safe to call this function if Zoom is not available; in that + case, it is silently ignored. + + .. warning:: + If you are working with a multi-GPU model, this function will only initialize + the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. + """ + + def cb(): + idx = current_device() + default_generator = torch.zoom.default_generators[idx] + default_generator.seed() + + _lazy_call(cb) + + +def seed_all() -> None: + r"""Set the seed for generating random numbers to a random number on all GPUs. + + It's safe to call this function if Zoom is not available; in that + case, it is silently ignored. + """ + + def cb(): + random_seed = 0 + seeded = False + for i in range(device_count()): + default_generator = torch.zoom.default_generators[i] + if not seeded: + default_generator.seed() + random_seed = default_generator.initial_seed() + seeded = True + else: + default_generator.manual_seed(random_seed) + + _lazy_call(cb) + + +def initial_seed() -> int: + r"""Return the current random seed of the current GPU. + + .. warning:: + This function eagerly initializes Zoom. + """ + _lazy_init() + idx = current_device() + default_generator = torch.zoom.default_generators[idx] + return default_generator.initial_seed() diff --git a/torch/zoom/streams.py b/torch/zoom/streams.py new file mode 100644 index 00000000000000..29a69fbb9d8ba6 --- /dev/null +++ b/torch/zoom/streams.py @@ -0,0 +1,241 @@ +import ctypes + +import torch +from torch._streambase import _EventBase, _StreamBase +from .._utils import _dummy_type + + +if not hasattr(torch._C, "_ZoomStreamBase"): + # Define dummy base classes + torch._C.__dict__["_ZoomStreamBase"] = _dummy_type("_ZoomStreamBase") + torch._C.__dict__["_ZoomEventBase"] = _dummy_type("_ZoomEventBase") + + +class Stream(torch._C._ZoomStreamBase, _StreamBase): + r"""Wrapper around a Zoom stream. + + A Zoom stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. See :ref:`cuda-semantics` for + details. + + Args: + device(torch.device or int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream, should be 0 or + negative, where negative numbers indicate higher priority. By default, + streams have priority 0. + + """ + + def __new__(cls, device=None, priority=0, **kwargs): + # setting device manager is expensive, so we avoid it unless necessary + if device is None or ("stream_id" in kwargs and "device_index" in kwargs): + return super().__new__(cls, priority=priority, **kwargs) + else: + with torch.zoom.device(device): + return super().__new__(cls, priority=priority, **kwargs) + + def wait_event(self, event) -> None: + r"""Make all future work submitted to the stream wait for an event. + + Args: + event (torch.zoom.Event): an event to wait for. + + .. note:: This is a wrapper around ``hipStreamWaitEvent()``: see + `CUDA Stream documentation`_ for more info. + + This function returns without waiting for :attr:`event`: only future + operations are affected. + + .. _CUDA Stream documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html + """ + event.wait(self) + + def wait_stream(self, stream) -> None: + r"""Synchronize with another stream. + + All future work submitted to this stream will wait until all kernels + submitted to a given stream at the time of call complete. + + Args: + stream (Stream): a stream to synchronize. + + .. note:: This function returns without waiting for currently enqueued + kernels in :attr:`stream`: only future operations are affected. + """ + self.wait_event(stream.record_event()) + + def record_event(self, event=None): + r"""Record an event. + + Args: + event (torch.zoom.Event, optional): event to record. If not given, a new one + will be allocated. + + Returns: + Recorded event. + """ + if event is None: + event = Event() + event.record(self) + return event + + def query(self) -> bool: + r"""Check if all the work submitted has been completed. + + Returns: + A boolean indicating if all kernels in this stream are completed. + """ + return super().query() + + def synchronize(self) -> None: + r"""Wait for all the kernels in this stream to complete. + + .. note:: This is a wrapper around ``hipStreamSynchronize()``: see + `CUDA Stream documentation`_ for more info. + """ + super().synchronize() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.zoom_stream) + + def __eq__(self, o) -> bool: + if isinstance(o, Stream): + return super().__eq__(o) + return False + + def __hash__(self): + return hash((self.zoom_stream, self.device)) + + def __repr__(self): + return f"" + + +class ExternalStream(Stream): + r"""Wrapper around an externally allocated Zoom stream. + + This class is used to wrap streams allocated in other libraries in order + to facilitate data exchange and multi-library interactions. + + .. note:: This class doesn't manage the stream life-cycle, it is the user + responsibility to keep the referenced stream alive while this class is + being used. + + Args: + stream_ptr(int): Integer representation of the `hipStream_t` value. + allocated externally. + device(torch.device or int, optional): the device where the stream + was originally allocated. If device is specified incorrectly, + subsequent launches using this stream may fail. + """ + + def __new__(cls, stream_ptr, device=None, **kwargs): + with torch.zoom.device(device): + return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) + + +class Event(torch._C._ZoomEventBase, _EventBase): + r"""Wrapper around a Zoom event. + + Zoom events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize Zoom + streams. + + The underlying Zoom events are lazily initialized when the event is first + recorded or exported to another process. After creation, only streams on the + same device may record the event. However, streams on any device can wait on + the event. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes + (default: ``False``) + + .. _CUDA Event Documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + """ + + def __new__(cls, enable_timing=False, blocking=False, interprocess=False): + return super().__new__( + cls, + enable_timing=enable_timing, + blocking=blocking, + interprocess=interprocess, + ) + + @classmethod + def from_ipc_handle(cls, device, handle): + r"""Reconstruct an event from an IPC handle on the given device.""" + return super().from_ipc_handle(device, handle) + + def record(self, stream=None): + r"""Record the event in a given stream. + + Uses ``torch.zoom.current_stream()`` if no stream is specified. The + stream's device must match the event's device. + """ + if stream is None: + stream = torch.zoom.current_stream() + super().record(stream) + + def wait(self, stream=None) -> None: + r"""Make all future work submitted to the given stream wait for this event. + + Use ``torch.zoom.current_stream()`` if no stream is specified. + + .. note:: This is a wrapper around ``hipStreamWaitEvent()``: see + `CUDA Event documentation`_ for more info. + """ + if stream is None: + stream = torch.zoom.current_stream() + super().wait(stream) + + def query(self): + r"""Check if all work currently captured by event has completed. + + Returns: + A boolean indicating if all work currently captured by event has + completed. + """ + return super().query() + + def elapsed_time(self, end_event): + r"""Return the time elapsed. + + Time reported in milliseconds after the event was recorded and + before the end_event was recorded. + """ + return super().elapsed_time(end_event) + + def synchronize(self) -> None: + r"""Wait for the event to complete. + + Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + + .. note:: This is a wrapper around ``hipEventSynchronize()``: see + `CUDA Event documentation`_ for more info. + """ + super().synchronize() + + def ipc_handle(self): + r"""Return an IPC handle of this event. + + If not recorded yet, the event will use the current device. + """ + return super().ipc_handle() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.zoom_event) + + def __repr__(self) -> str: + if self.zoom_event: + return f"" + else: + return "" From 16d3bea4ca71b9b2d88deadd9b31b593e7468299 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Mon, 23 Dec 2024 00:06:14 +0000 Subject: [PATCH 02/23] resolve some build deps --- aten/CMakeLists.txt | 2 + aten/src/ATen/CMakeLists.txt | 6 +- aten/src/ATen/native/zoom/ForeachFunctors.cuh | 681 ++++++++++++++++++ .../src/ATen/native/zoom/MultiTensorApply.cuh | 379 ++++++++++ aten/src/ATen/native/zoom/Pow.cuh | 58 ++ aten/src/ATen/native/zoom/PowKernel.cu | 209 ++++++ 6 files changed, 1332 insertions(+), 3 deletions(-) create mode 100644 aten/src/ATen/native/zoom/ForeachFunctors.cuh create mode 100644 aten/src/ATen/native/zoom/MultiTensorApply.cuh create mode 100644 aten/src/ATen/native/zoom/Pow.cuh create mode 100644 aten/src/ATen/native/zoom/PowKernel.cu diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index bda6aea327062f..d1459366a2e945 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -34,6 +34,7 @@ set(ATen_HIP_SRCS) set(ATen_HIP_SRCS_W_SORT_BY_KEY) set(ATen_HIP_TEST_SRCS) set(ATen_HIP_INCLUDE) +set(ATen_ZOOM_SRCS) set(ATen_MPS_SRCS) set(ATen_MPS_TEST_SRCS) set(ATen_XPU_SRCS) @@ -116,6 +117,7 @@ set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_ZOOM_SRCS ${ATen_ZOOM_SRCS} PARENT_SCOPE) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE) set(ATen_MPS_TEST_SRCS ${ATen_MPS_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS_W_SORT_BY_KEY ${ATen_HIP_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 1cd471cee47bc0..42ca9254a64885 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -82,9 +82,9 @@ file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") file(GLOB miopen_h "miopen/*.h") file(GLOB miopen_cpp "miopen/*.cpp") -file(GLOB zoom_h "zoom/*.h" "zoom/detail/*.h" "zoom/*.cuh" "zoom/detail/*.cuh" "zoom/tunable/*.cuh" "zoom/tunable/*.h" "zoom/jit/*.cuh" "zoom/jit/*.h") -file(GLOB zoom_cpp "zoom/*.cpp" "zoom/detail/*.cpp" "zoom/tunable/*.cpp" "zoom/jit/*.cpp") -file(GLOB zoom_hip "zoom/*.cu" "zoom/detail/*.cu" "zoom/impl/*.cu" "zoom/tunable/*.cu") +file(GLOB zoom_h "zoom/*.h" "zoom/detail/*.h" "zoom/*.cuh" "zoom/detail/*.cuh" "zoom/tunable/*.h" "zoom/jit/*.cuh" "zoom/jit/*.h") +file(GLOB zoom_cpp "zoom/*.cpp" "zoom/detail/*.cpp" "zoom/jit/*.cpp") +file(GLOB zoom_hip "zoom/*.cu" "zoom/detail/*.cu") file(GLOB zoom_hiprtc_stub_h "zoom/hiprtc_stub/*.h") file(GLOB zoom_hiprtc_stub_cpp "zoom/hiprtc_stub/*.cpp") diff --git a/aten/src/ATen/native/zoom/ForeachFunctors.cuh b/aten/src/ATen/native/zoom/ForeachFunctors.cuh new file mode 100644 index 00000000000000..869e6fa3fd4389 --- /dev/null +++ b/aten/src/ATen/native/zoom/ForeachFunctors.cuh @@ -0,0 +1,681 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// TODO(crcrpar): Handle version bump in codegen. +// rel: +// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 +inline void increment_version(TensorList tensors) { + for (const auto& t : tensors) { + t.unsafeGetTensorImpl()->bump_version(); + } +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListScalarListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ bool init_args( + T** args, + FusedOptimizerTensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ void load_args( + T r_args[][kILP], + T** args, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + for (int r_index = 0; r_index < depth; r_index++) { + r_args[r_index][ii] = 0; + if (i < n && i < chunk_size) { + r_args[r_index][ii] = args[r_index][i]; + } + } + } +} + +template +__device__ void store_args( + T* dst, + T* src, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const int64_t i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + dst[i] = src[ii]; + } +} + +template +__device__ __forceinline__ void binary_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args + // has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +template +__device__ __forceinline__ void pointwise_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args + // has depth 3 + load_args<3>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +// +// Binary Functors +// +template +struct BinaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpListAlphaFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct BinaryOpScalarTensorFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + T* scalar, + opmath_t alpha) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), + // r_args has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Unary Functors +// + +template +struct ZeroFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata<1>& tl) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const auto all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + // store + load_store(args[0], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct UnaryOpFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Pointwise Functors +// + +template +struct PointwiseOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[depth - 1][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + // store + load_store(args[2], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + store_args(args[2], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + static_assert(depth == 3 || depth == 4, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(a, b); + } +}; + +template +struct reverse_power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(b, a); + } +}; + +} // namespace +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/MultiTensorApply.cuh b/aten/src/ATen/native/zoom/MultiTensorApply.cuh new file mode 100644 index 00000000000000..9efa863f49ceaf --- /dev/null +++ b/aten/src/ATen/native/zoom/MultiTensorApply.cuh @@ -0,0 +1,379 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +static constexpr int64_t kILP = 4; +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kBlockSize = 512; + +// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` +// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; + +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (kILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store( + T* dst, + T* src, + int64_t dst_offset, + int64_t src_offset) { + using LT = at::native::memory::aligned_vector; + ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +} + +template +struct TensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +struct TensorListScalarListMetadata { + const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; +}; + +// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of +// 4kb with `c10::complex` +template <> +struct TensorListScalarListMetadata, 1> { + const void* addresses[1] + [depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + int block_to_chunk[depth_to_max_blocks[1 - 1]]; +}; + +template <> +struct TensorListScalarListMetadata, 2> { + const void* addresses[2] + [depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + int block_to_chunk[depth_to_max_blocks[2 - 1]]; +}; + +// NOTE(crcrpar): This is a conservative resolution to handle `state_steps` +// whose each element is `at::Tensor` of 1 element representing the number of +// `step`s called so far. +template +struct FusedOptimizerTensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ void multi_tensor_apply_kernel( + T tensorListMeta, + U callable, + ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(kChunkSize, tensorListMeta, args...); +} + +} // namespace + +// multi_tensor_apply enables horizontal fusion across lists of tensors. +// For example, whereas you once had a for-loop of a + b = c, where a, b, +// and c are individual tensors in lists as, bs, and cs, you can now with +// fewer kernel launches compute as + bs = cs. +// +// You can also imagine bs to be a scalar list vs a tensor list. +// +// The function below takes in tensor lists, scalars, and a callable and +// chunks up the computation to launch as few kernels as possible by iterating +// through every "chunk" in every tensor (thus the nested for loops). In the +// simplest case, everything gets bundled into just one kernel launch, but +// due to blocksize constraints, we may need to launch multiple kernels. +// Each kernel launch is defined by one tensorListMeta construct, which we +// use to track and reset the necessary metadata for each launch. +template +void multi_tensor_apply( + std::vector>& tensor_lists, + at::ArrayRef scalars, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + using scalar_vals_t = typename T::opmath_t; + TensorListScalarListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // now we enter [chunking territory]. + // we will launch a kernel when EITHER the blocks get filled up OR + // the tensors get filled up. There will always be at least one block + // per tensor since the zero-sized ones will not enter the loop, so + // the nested forloop within represents iterating through the chunks + // of a single tensor. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + // a tensor is not considered full unless all its chunks have been + // processed + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>( + tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + // all chunks have already been handled in the kernel + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { // blocks were full and tensor chunks remain + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.scalar_vals[0] = + tensorListMeta.scalar_vals[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // note: [finishing what we started] + // if there's remaining work to be done but the tensors/blocks aren't full + // yet we are at the end, submit the kernel to do the work! + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply( + std::vector>& tensor_lists, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + TensorListMetadata tensorListMeta; + tensorListMeta.start_tensor_this_launch = 0; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // see note: [chunking territory]. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>( + tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + tensorListMeta.start_tensor_this_launch = t + 1; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tensorListMeta.start_tensor_this_launch = t; + } + } + } + } + + // see note: [finishing what we started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply_for_fused_optimizer( + std::vector>& tensor_lists, + at::TensorList state_steps, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + const auto num_tensors = tensor_lists[0].size(); + FusedOptimizerTensorListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (const auto& tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + tensorListMeta.state_steps_addresses[loc_tensor_info] = + state_steps[tensor_index].const_data_ptr(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][tensor_index].numel(); + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][tensor_index].const_data_ptr(); + } + loc_tensor_info++; + + // see above note: [chunking territory] + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + for (const auto& chunk : c10::irange(chunks)) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const auto tensor_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; + + if (tensor_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>( + tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.state_steps_addresses[0] = + tensorListMeta.state_steps_addresses[loc_tensor_info - 1]; + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // see above note: [finishing what we've started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Pow.cuh b/aten/src/ATen/native/zoom/Pow.cuh new file mode 100644 index 00000000000000..eee86031f8d932 --- /dev/null +++ b/aten/src/ATen/native/zoom/Pow.cuh @@ -0,0 +1,58 @@ +#pragma once +#include +#include + +namespace at { namespace native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#ifdef _MSC_VER +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t::value, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/PowKernel.cu b/aten/src/ATen/native/zoom/PowKernel.cu new file mode 100644 index 00000000000000..e67e47201687ad --- /dev/null +++ b/aten/src/ATen/native/zoom/PowKernel.cu @@ -0,0 +1,209 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Forward declare some unary kernels +void rsqrt_kernel_zoom(TensorIteratorBase& iter); +void sqrt_kernel_zoom(TensorIteratorBase& iter); +void reciprocal_kernel_zoom(TensorIteratorBase& iter); + +namespace { + +void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar); + +template +void pow_scalar_tensor_impl(TensorIteratorBase& iter, scalar_t base) { + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t exp) -> scalar_t { + return pow_(base, exp); + }); +} + +template +void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { + // For complex, thrust::pow uses the identity + // pow(a, b) = exp(log(a) * b) + const auto fct = std::log(base); + gpu_kernel(iter, [=]GPU_LAMBDA(c10::complex exp) -> c10::complex { + return std::exp(fct * exp); + }); +} + +/* complex support impl */ +CONSTEXPR_EXCEPT_WIN_CUDA char pow_scalar_base_name[] = "pow_scalar_base_kernel"; +template <> +void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { + using scalar_t = c10::complex; + using opmath_t = at::opmath_type; + // For complex, thrust::pow uses the identity + // pow(a, b) = exp(log(a) * b) + const auto fct = std::log(opmath_t{base}); +#if AT_USE_JITERATOR() + static const auto pow_kernel_string = + jiterator_stringify(template T pow_scalar_base_kernel(T exp, T fct) { + return std::exp(fct * exp); + }); + jitted_gpu_kernel( + iter, + pow_kernel_string, + /*scalar_pos=*/at::zoom::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/0, + /*extra_args=*/std::make_tuple(fct)); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t { + return std::exp(fct * opmath_t{exp}); + }); +#endif +} + +namespace { + +#if AT_USE_JITERATOR() +/* complex support impl */ +CONSTEXPR_EXCEPT_WIN_CUDA char pow_name[] = "pow_kernel"; +static const auto pow_kernel_string = + jiterator_stringify(template T pow_kernel(T base, T exp) { + return std::pow(base, exp); + }); +#endif + +/* complex support impl */ +void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) { + using scalar_t = c10::complex; + using opmath_t = at::opmath_type; + auto exp = exp_scalar.to(); +#if AT_USE_JITERATOR() + jitted_gpu_kernel( + iter, + pow_kernel_string, + /*scalar_pos=*/at::zoom::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/0, + /*extra_args=*/std::make_tuple(exp)); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t { + return std::pow(opmath_t{base}, exp); + }); +#endif +} + +} // anonymous namespace + +void pow_tensor_tensor_kernel(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; + if (iter.is_cpu_scalar(1)) { + const auto base = iter.scalar_value(1); + iter.remove_operand(1); + pow_scalar_tensor_impl(iter, base); + } else if (iter.is_cpu_scalar(2)) { + const auto exp = iter.scalar_value(2); + iter.remove_operand(2); + pow_chalf_tensor_scalar_impl(iter, exp); + } else { + using opmath_t = at::opmath_type; + TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2)); +#if AT_USE_JITERATOR() + jitted_gpu_kernel( + iter, pow_kernel_string); +#else + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + using opmath_t = at::opmath_type; + return pow_(opmath_t{base}, opmath_t{exp}); + }); +#endif + } + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kHalf, kBFloat16, iter.common_dtype(), "pow_zoom", [&] { + if (iter.is_cpu_scalar(1)) { + const auto base = iter.scalar_value(1); + iter.remove_operand(1); + pow_scalar_tensor_impl(iter, base); + } else if (iter.is_cpu_scalar(2)) { + const auto exp = iter.scalar_value(2); + iter.remove_operand(2); + pow_tensor_scalar_kernel(iter, exp); + } else { + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + return pow_(base, exp); + }); + } + }); + } +} + + +template +void pow_tensor_scalar_kernel_impl(TensorIteratorBase& iter, + Exp_type exp) { + const auto d_exp = static_cast(exp); + // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled + // in pow_tensor_scalar_kernel + if (d_exp == 2) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return base * base; + }); + } else if (d_exp == 3) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return base * base * base; + }); + } else if (d_exp == -2) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return 1.0 / (base * base); + }); + } else { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return pow_(base, exp); + }); + } +} + +void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) { + // Dispatch to fast specialization for sqrt, rsqrt and reciprocal + if (!exp_scalar.isComplex()) { + if (exp_scalar.equal(.5)) { + return sqrt_kernel_zoom(iter); + } else if (exp_scalar.equal(-0.5)) { + return rsqrt_kernel_zoom(iter); + } else if (exp_scalar.equal(-1.0)) { + return reciprocal_kernel_zoom(iter); + } + } + if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) { + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; + pow_chalf_tensor_scalar_impl(iter, exp_scalar); + return; + } + AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_zoom", [&]() { + const auto exp = exp_scalar.to(); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { + return pow_(base, exp); + }); + }); + } else if (isFloatingType(iter.common_dtype()) || exp_scalar.isIntegral(false)) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "pow_zoom", [&]() { + const auto exp = exp_scalar.to(); + pow_tensor_scalar_kernel_impl(iter, exp); + }); + } else { + TORCH_INTERNAL_ASSERT(false, "invalid combination of type in Pow function, common dtype:", iter.common_dtype(), + "exp is integral?", exp_scalar.isIntegral(false)); + } +} + +} // anonymous namespace + +REGISTER_PRIVATEUSE1_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel); + +} // namespace at::native \ No newline at end of file From 53deb9560b64702c9d95329f88e00052d9c3b0f4 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Wed, 25 Dec 2024 23:18:04 +0000 Subject: [PATCH 03/23] minimize, fix build, torchgen logic --- BUILD.bazel | 11 +- CMakeLists.txt | 1 + aten/CMakeLists.txt | 19 +- aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/Context.cpp | 5 - aten/src/ATen/Context.h | 3 - aten/src/ATen/detail/ZoomHooksInterface.h | 4 - aten/src/ATen/native/native_functions.yaml | 32 +- aten/src/ATen/native/zoom/AbsKernel.cu | 42 + aten/src/ATen/native/zoom/AmpKernels.cu | 252 ------ aten/src/ATen/native/zoom/CompareKernels.cu | 103 --- aten/src/ATen/native/zoom/Copy.cu | 63 +- aten/src/ATen/native/zoom/ForeachFunctors.cuh | 681 -------------- aten/src/ATen/native/zoom/MiscUtils.h | 32 - .../src/ATen/native/zoom/MultiTensorApply.cuh | 379 -------- aten/src/ATen/native/zoom/Nonzero.cu | 130 --- aten/src/ATen/native/zoom/Pow.cuh | 58 -- aten/src/ATen/native/zoom/PowKernel.cu | 209 ----- aten/src/ATen/native/zoom/TensorCompare.cu | 133 --- aten/src/ATen/native/zoom/TensorShape.cu | 833 ------------------ .../ATen/native/zoom/TensorTransformations.cu | 154 ---- aten/src/ATen/native/zoom/ZoomScalar.cu | 38 + .../ATen/native/zoom/reduction_template.cuh | 680 ++++++++++++++ aten/src/ATen/templates/UfuncZoom.cu | 17 + aten/src/ATen/zoom/ZoomContext.cpp | 1 - aten/src/ATen/zoom/ZoomContextLight.h | 50 +- aten/src/ATen/zoom/detail/ZoomHooks.cpp | 32 - aten/src/ATen/zoom/detail/ZoomHooks.h | 1 - buckbuild.bzl | 4 + build.bzl | 17 +- build.sh | 130 +++ caffe2/CMakeLists.txt | 7 +- cmake/Codegen.cmake | 11 + torch/csrc/zoom/Module.cpp | 162 ---- torchgen/dest/__init__.py | 1 + torchgen/dest/register_dispatch_key.py | 15 + torchgen/dest/ufunc.py | 33 + torchgen/gen.py | 40 +- torchgen/model.py | 12 +- ufunc_defs.bzl | 6 + 40 files changed, 1168 insertions(+), 3235 deletions(-) create mode 100644 aten/src/ATen/native/zoom/AbsKernel.cu delete mode 100644 aten/src/ATen/native/zoom/AmpKernels.cu delete mode 100644 aten/src/ATen/native/zoom/CompareKernels.cu delete mode 100644 aten/src/ATen/native/zoom/ForeachFunctors.cuh delete mode 100644 aten/src/ATen/native/zoom/MiscUtils.h delete mode 100644 aten/src/ATen/native/zoom/MultiTensorApply.cuh delete mode 100644 aten/src/ATen/native/zoom/Nonzero.cu delete mode 100644 aten/src/ATen/native/zoom/Pow.cuh delete mode 100644 aten/src/ATen/native/zoom/PowKernel.cu delete mode 100644 aten/src/ATen/native/zoom/TensorCompare.cu delete mode 100644 aten/src/ATen/native/zoom/TensorShape.cu delete mode 100644 aten/src/ATen/native/zoom/TensorTransformations.cu create mode 100644 aten/src/ATen/native/zoom/ZoomScalar.cu create mode 100644 aten/src/ATen/native/zoom/reduction_template.cuh create mode 100644 aten/src/ATen/templates/UfuncZoom.cu create mode 100644 build.sh diff --git a/BUILD.bazel b/BUILD.bazel index 3f7e6327452c09..c30d8c3df92327 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -9,7 +9,7 @@ load("@pytorch//tools/config:defs.bzl", "if_cuda") load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops") load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets") load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources") -load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources") +load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources", "aten_ufunc_generated_zoom_sources") load("//:tools/bazel.bzl", "rules") define_targets(rules = rules) @@ -104,6 +104,12 @@ generated_cuda_cpp = [ "aten/src/ATen/RegisterSparseCsrCUDA.cpp", ] +generated_zoom_cpp = [ + "aten/src/ATen/ZoomFunctions.h", + "aten/src/ATen/ZoomFunctions_inl.h", + "aten/src/ATen/RegisterPrivateUse1.cpp", +] + generate_aten( name = "generated_aten_cpp", srcs = aten_generation_srcs, @@ -112,7 +118,8 @@ generate_aten( generated_cuda_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") + aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") + - aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [ + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + + aten_ufunc_generated_zoom_sources("aten/src/ATen/{}") + [ "aten/src/ATen/Declarations.yaml", ] ), diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c6320e68d3903..528ebfb8f55a47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ option(USE_CPP_CODE_COVERAGE "Compile C/C++ with code coverage flags" OFF) option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) +option(USE_ZOOM "Use ZOOM HIP Backend" OFF) option(USE_CUDA "Use CUDA" ON) cmake_dependent_option( USE_XPU "Use XPU. Only available on Linux." ON diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index d1459366a2e945..f1753f50c32fdc 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -30,11 +30,13 @@ set(ATen_CUDA_SRCS_W_SORT_BY_KEY) set(ATen_CUDA_TEST_SRCS) set(ATen_CUDA_INCLUDE) set(ATen_NVRTC_STUB_SRCS) +set(ATen_HIPRTC_STUB_SRCS) set(ATen_HIP_SRCS) +set(ATen_ZOOM_SRCS) set(ATen_HIP_SRCS_W_SORT_BY_KEY) set(ATen_HIP_TEST_SRCS) set(ATen_HIP_INCLUDE) -set(ATen_ZOOM_SRCS) +set(ATen_ZOOM_INCLUDE) set(ATen_MPS_SRCS) set(ATen_MPS_TEST_SRCS) set(ATen_XPU_SRCS) @@ -45,6 +47,7 @@ set(ATen_CPU_DEPENDENCY_LIBS) set(ATen_XPU_DEPENDENCY_LIBS) set(ATen_CUDA_DEPENDENCY_LIBS) set(ATen_HIP_DEPENDENCY_LIBS) +set(ATen_ZOOM_DEPENDENCY_LIBS) set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS) set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS) set(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory") @@ -71,6 +74,17 @@ if(USE_ROCM) endif() endif() +if(USE_ZOOM) + include(LoadHIP) + if(NOT PYTORCH_FOUND_HIP) + message(WARNING "Could not load HIP, setting USE_ZOOM = OFF") + set(USE_ZOOM OFF) + else() + message(STATUS "Loaded HIP, Zoom Enabled") + endif() +endif() + + # Both CUDA and ROCM are enabled and found. Report an error. if(USE_CUDA AND USE_ROCM) message(FATAL_ERROR "Both CUDA and ROCm are enabled and found. PyTorch can only be built with either of them. Please turn one off by using either USE_CUDA=OFF or USE_ROCM=OFF.") @@ -124,6 +138,7 @@ set(ATen_HIP_SRCS_W_SORT_BY_KEY ${ATen_HIP_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_XPU_TEST_SRCS ${ATen_XPU_TEST_SRCS} PARENT_SCOPE) set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) @@ -134,12 +149,14 @@ set(ATen_VEC_TEST_SRCS ${ATen_VEC_TEST_SRCS} PARENT_SCOPE) set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) +set(ATen_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 42ca9254a64885..684b2c4cdeb905 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -613,7 +613,7 @@ endif() if(USE_ZOOM) set(ATen_ZOOM_SRCS ${all_zoom_cpp}) set(ATen_HIPRTC_STUB_SRCS ${zoom_hiprtc_stub_cpp}) - # list(APPEND ATen_ZOOM_DEPENDENCY_LIBS ATEN_ZOOM_FILES_GEN_LIB) + list(APPEND ATen_ZOOM_DEPENDENCY_LIBS ATEN_ZOOM_FILES_GEN_LIB) endif() set(ATEN_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/${AT_INSTALL_INCLUDE_DIR}") diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 1136b05b265491..20679ab7ff5afa 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -153,7 +153,6 @@ static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" } bool Context::checkCuBLASConfigDeterministic() { bool cublas_config_deterministic = true; - #ifndef USE_ZOOM // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config // is set to deterministic setting if (hasCUDART() && (versionCUDART() >= 10020)) { @@ -164,10 +163,6 @@ bool Context::checkCuBLASConfigDeterministic() { ); } return cublas_config_deterministic; - #else - // Zoom uses hipBLAS with the rocBLAS backend - this is only deterministic if atomics are disabled - return checkHIPBlasDeterministic(); - #endif } void Context::alertCuBLASConfigNotDeterministic() const { diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index f241e91be6f731..4b71d3813353cd 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -127,9 +127,6 @@ class TORCH_API Context { static bool hasCuBLASLt() { return detail::getCUDAHooks().hasCuBLASLt(); } - static bool checkHIPBlasDeterministic() { - return detail::getZoomHooks().checkHIPBlasDeterministic(); - } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } diff --git a/aten/src/ATen/detail/ZoomHooksInterface.h b/aten/src/ATen/detail/ZoomHooksInterface.h index 0e971a17e5a9c9..02bdd94ff1dada 100644 --- a/aten/src/ATen/detail/ZoomHooksInterface.h +++ b/aten/src/ATen/detail/ZoomHooksInterface.h @@ -91,10 +91,6 @@ struct TORCH_API ZoomHooksInterface : PrivateUse1HooksInterface { return false; } - virtual bool checkHIPBlasDeterministic() const { - TORCH_CHECK(false, "Cannot call checkHIPBlasDeterministic without torch_zoom library", ZOOM_HELP); - } - virtual const at::zoom::HIPRTC& hiprtc() const { TORCH_CHECK(false, "HIPRTC requires Zoom. ", ZOOM_HELP); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 10d8b1ad79cadf..b28fcfbfc2732e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -354,7 +354,7 @@ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: abs_out + CPU, CUDA, PrivateUse1: abs_out MPS: abs_out_mps SparseCPU, SparseCUDA: abs_sparse_out SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr_out @@ -413,12 +413,12 @@ - func: view_as_real(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA, MPS, Meta: view_as_real + CPU, CUDA, PrivateUse1, MPS, Meta: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA, MPS, Meta: view_as_complex + CPU, CUDA, PrivateUse1, MPS, Meta: view_as_complex - func: sgn(Tensor self) -> Tensor variants: function, method @@ -931,7 +931,7 @@ - func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) variants: function, method dispatch: - ZeroTensor, CPU, CUDA: as_strided_tensorimpl + ZeroTensor, CPU, CUDA, PrivateUse1: as_strided_tensorimpl Meta: as_strided_tensorimpl_meta_symint MPS: as_strided_tensorimpl_mps QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl @@ -2367,6 +2367,7 @@ dispatch: CPU: empty_cpu CUDA: empty_cuda + PrivateUse1: empty_zoom MPS: empty_mps Meta: empty_meta_symint MkldnnCPU: empty_mkldnn @@ -2444,6 +2445,7 @@ Meta: resize__symint CPU: resize_ CUDA: resize_cuda_ + PrivateUse1: resize_zoom_ MPS: resize_mps_ QuantizedCPU: quantized_resize_cpu_ SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_ @@ -2485,6 +2487,7 @@ dispatch: CPU: empty_strided_cpu CUDA: empty_strided_cuda + PrivateUse1: empty_strided_zoom MPS: empty_strided_mps Meta: empty_strided_meta_symint QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized @@ -2634,12 +2637,14 @@ dispatch: CPU, Meta: eye_out_cpu CUDA: eye_out_cuda + PrivateUse1: eye_out_zoom MPS: eye_out_mps - func: eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, Meta: eye_out_cpu CUDA: eye_out_cuda + PrivateUse1: eye_out_zoom MPS: eye_out_mps - func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) @@ -2679,7 +2684,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ + CPU, CUDA, PrivateUse1: fill_ MPS: fill_scalar_mps QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ @@ -2691,7 +2696,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ + CPU, CUDA, PrivateUse1: fill_ MPS: fill_tensor_mps_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ @@ -6501,6 +6506,7 @@ dispatch: CPU: _efficientzerotensor CUDA: _efficientzerotensor_cuda + PrivateUse1: _efficientzerotensor_zoom MPS: _efficientzerotensor_mps Meta: _efficientzerotensor_meta_symint autogen: _efficientzerotensor.out @@ -7726,6 +7732,7 @@ dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda + PrivateUse1: _local_scalar_dense_zoom MPS: _local_scalar_dense_mps variants: function @@ -7863,6 +7870,7 @@ CPU: set_storage_cpu_ Meta: set_storage_meta__symint CUDA: set_storage_cuda_ + PrivateUse1: set_storage_zoom_ MPS: set_storage_mps_ QuantizedCPU, QuantizedCUDA: set_storage_quantized_ autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out @@ -7890,6 +7898,7 @@ dispatch: CPU: set_cpu_ CUDA: set_cuda_ + PrivateUse1: set_zoom_ Meta: set_meta_ MPS: set_mps_ autogen: set, set.out @@ -7998,7 +8007,7 @@ device_check: NoCheck device_guard: False dispatch: - ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view + ZeroTensor, Meta, CPU, CUDA, PrivateUse1, QuantizedCPU, QuantizedCUDA, MPS: view MkldnnCPU: mkldnn_view NestedTensorCPU, NestedTensorCUDA: view_nested tags: core @@ -8765,7 +8774,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Scalar_out + CPU, CUDA, PrivateUse1: ne_Scalar_out MPS: ne_scalar_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8783,7 +8792,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Tensor_out + CPU, CUDA, PrivateUse1: ne_Tensor_out MPS: ne_tensor_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8828,7 +8837,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Scalar_out + CPU, CUDA, PrivateUse1: eq_Scalar_out MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -8847,7 +8856,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Tensor_out + CPU, CUDA, PrivateUse1: eq_Tensor_out MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -10123,6 +10132,7 @@ dispatch: CPU: cpu_equal CUDA: cuda_equal + PrivateUse1: zoom_equal MPS: mps_equal QuantizedCPU: equal_quantized_cpu diff --git a/aten/src/ATen/native/zoom/AbsKernel.cu b/aten/src/ATen/native/zoom/AbsKernel.cu new file mode 100644 index 00000000000000..dd6dc56f646bf9 --- /dev/null +++ b/aten/src/ATen/native/zoom/AbsKernel.cu @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +namespace at::native { + + +CONSTEXPR_EXCEPT_WIN_CUDA constexpr char abs_name[] = "abs_kernel"; +void abs_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + static const auto abs_string = jiterator_stringify( + template T abs_kernel(T x) { return std::abs(x); }); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/abs_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, abs_string); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, + ScalarType::BFloat16, + ScalarType::Bool, + iter.dtype(), + "abs_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/abs_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, abs_string); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/AmpKernels.cu b/aten/src/ATen/native/zoom/AmpKernels.cu deleted file mode 100644 index 14fa799fd6d283..00000000000000 --- a/aten/src/ATen/native/zoom/AmpKernels.cu +++ /dev/null @@ -1,252 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#define _USE_MATH_DEFINES - -#include - -#include -#include -#include -#include -#include -#include -#include - - -namespace { -// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e, -// to ensure the Cuda math library's isfinite is actually what gets called in -// _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda. -// -// isfinite_ensure_cuda_math is defined outside at::native because: -// - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated -// Tensor at::native::isfinite(const Tensor&), resulting in an error: -// "no suitable constructor exists to convert from "float" to "at::Tensor"" -// - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path -// to ensure that its version of a particular function is invoked. It only shows bare (not-namespaced) -// calls to its routines inside kernel or device functions. -// - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows. -// -// isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs -// suggest and allows disambiguated usage in the lambda within the at::native region. -// GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that -// "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed." -static __host__ __device__ __forceinline__ int isfinite_ensure_zoom_math(float val) { - return isfinite(val); -} -} - -namespace at::native { - -namespace { -// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_zoom_. -// Handles individual tensors that are acceptable to unscale but not MTA-safe. -void _amp_non_finite_check_and_unscale_zoom_(Tensor& scaled_grad, - Tensor& found_inf, - const Tensor& inv_scale) -{ - // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_zoom_, so no input checks. - - // It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case. - const OptionalDeviceGuard device_guard(device_of(scaled_grad)); - - // Acts on scaled_grad in place. - auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - iter.dtype(), - "_amp_non_finite_check_and_unscale_zoom", - [&iter, &found_inf, &inv_scale] { - auto* found_inf_ptr = found_inf.mutable_data_ptr(); - auto* inv_scale_ptr = inv_scale.const_data_ptr(); - - using opmath_t = at::opmath_type; - - gpu_kernel(iter, - [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { - auto val = static_cast(val_in); - if (!isfinite_ensure_zoom_math(val)) { - *found_inf_ptr = 1.f; - } - // Every thread accesses inv_scale, but it will hit in cache. - const auto inv_scale_val = *inv_scale_ptr; - return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); - }); - }); -} -} // anonymous namespace - - -// Multiplies each tensor in scaled_grads by inv_scale in-place. -// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0. -// Uses multi tensor apply (MTA) to process all MTA-safe tensors. -// -// Args: -// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs. -// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans. -// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. -// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied. -void _amp_foreach_non_finite_check_and_unscale_zoom_(TensorList scaled_grads, - Tensor& found_inf, - const Tensor& inv_scale) -{ - if (scaled_grads.size() == 0) { - return; - } - - TORCH_CHECK(inv_scale.is_privateuseone(), "inv_scale must be a Zoom tensor."); - TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); - TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); - TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); - TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); - TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); - - // Ensures client code (GradScaler) filtered scaled_grads by dtype. - check_foreach_api_restrictions(scaled_grads); - - std::vector> tensor_lists; - - // is_non_overlapping_and_dense() is not available in Python. - // GradScaler can't filter for it. We need to filter here. - if (can_use_fast_route(scaled_grads)) { - // Hopefully common case. - // can_use_fast_route is true, which confirms: - // - all scaled_grads are strided - // - all scaled_grads are non overlapping and dense - // - all scaled_grads are on the same device - // - all scaled_grads are of the same dtype - TORCH_CHECK(scaled_grads[0].is_privateuseone(), "scaled_grads must be Zoom tensors."); - // Sets up MTA launch to use scaled_grads as-is. - tensor_lists.emplace_back(scaled_grads.vec()); - } else { - // Hopefully uncommon case. - // can_use_fast_route is an all-or-nothing check. In this path it was false, - // so any of the above confirmations could have gone wrong. - // We filter MTA-safe tensors into an MTA-able list. - // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel. - // If a tensor is unacceptable, we throw an error to blame GradScaler. - tensor_lists.resize(1); - tensor_lists[0].reserve(scaled_grads.size()); - auto expected_device = scaled_grads[0].device(); - const auto expected_dtype = scaled_grads[0].scalar_type(); - for (const Tensor& t : scaled_grads) { - // Ensures GradScaler filtered scaled_grads by device. - TORCH_CHECK(t.is_privateuseone(), "one of scaled_grads was not a Zoom tensor."); - TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); - TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); - if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { - // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. - _amp_non_finite_check_and_unscale_zoom_(const_cast(t), - found_inf, - inv_scale); - } else { - tensor_lists[0].push_back(t); - } - } - if (tensor_lists[0].size() == 0) { - return; - } - } - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - tensor_lists[0][0].scalar_type(), - "_amp_foreach_non_finite_check_and_unscale_zoom", - [&tensor_lists, &found_inf, &inv_scale] { - auto* found_inf_ptr = found_inf.mutable_data_ptr(); - auto* inv_scale_ptr = inv_scale.const_data_ptr(); - - using opmath_t = at::opmath_type; - - // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. - multi_tensor_apply<1>(tensor_lists, - UnaryOpFunctor(), - [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { - // There is a slight asymmetry here with the TensorIterator kernel above. - // MTA Functors ensure val comes in as opmath_t rather than scalar_t. - if (!isfinite_ensure_zoom_math(val)) { - *found_inf_ptr = 1.f; - } - // Every thread accesses inv_scale, but it will hit in cache. - const auto inv_scale_val = *inv_scale_ptr; - return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); - }); - }); -} - - -// amp_update_scale_zoom_kernel is launched with a single thread to compute the new scale. -// The scale factor is maintained and updated on the GPU to avoid synchronization. -__global__ void amp_update_scale_zoom_kernel(float* current_scale, - int* growth_tracker, - const float* found_inf, - double growth_factor, - double backoff_factor, - int growth_interval) -{ - if (*found_inf) { - *current_scale = (*current_scale)*backoff_factor; - *growth_tracker = 0; - } else { - // Entering this branch means we just carried out a successful step, - // so growth_tracker is incremented before comparing to growth_interval. - auto successful = (*growth_tracker) + 1; - if (successful == growth_interval) { - auto new_scale = static_cast((*current_scale)*growth_factor); - // Do not grow the scale past fp32 bounds to inf. - if (isfinite_ensure_zoom_math(new_scale)) { - *current_scale = new_scale; - } - *growth_tracker = 0; - } else { - *growth_tracker = successful; - } - } -} - - -// _amp_update_scale_zoom asynchronously updates the scale tensor in place. -// -// Args: -// current_scale: A one-element zoom float tensor containing the scale value. -// growth_tracker: A one-element torch.zoom.IntTensor containing the number of recent consecutive unskipped steps. -// found_inf: A one-element zoom float tensor. If > 0, indicates that infs/nans were found by the relevant -// prior _amp_non_finite_check_and_unscale_zoom call, and 0 if no infs/nans were found. -// growth_factor: Multiplier if no infs/NaNs were found (typically slightly > 1). -// backoff_factor: Multiplier if infs/NaNs were found (typically 0.5). -// growth_interval: Number of consecutive unskipped steps that must occur for current_scale to be multiplied by -// growth_factor. -// -// Returns: -// current_scale -Tensor& _amp_update_scale_zoom_(Tensor& current_scale, - Tensor& growth_tracker, - const Tensor& found_inf, - double growth_factor, - double backoff_factor, - int64_t growth_interval) -{ - TORCH_CHECK(growth_tracker.is_privateuseone(), "growth_tracker must be a Zoom tensor."); - TORCH_CHECK(current_scale.is_privateuseone(), "current_scale must be a Zoom tensor."); - TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); - TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor."); - TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor."); - TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); - TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor."); - TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor."); - TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); - - amp_update_scale_zoom_kernel<<<1, 1, 0, c10::zoom::getCurrentZoomStream()>>>( - current_scale.mutable_data_ptr(), - growth_tracker.mutable_data_ptr(), - found_inf.const_data_ptr(), - growth_factor, - backoff_factor, - growth_interval); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - - return current_scale; -} - -} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CompareKernels.cu b/aten/src/ATen/native/zoom/CompareKernels.cu deleted file mode 100644 index 21da608a35fc94..00000000000000 --- a/aten/src/ATen/native/zoom/CompareKernels.cu +++ /dev/null @@ -1,103 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include - - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. - -namespace at::native { namespace { - -enum class OpType {GE, GT, LE, LT}; - -template -struct CompareFunctor{ - constexpr CompareFunctor(OpType op): op_(op) {}; - OpType op_; - __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { - if (op_ == OpType::GE) { - return a >= b; - } else if (op_ == OpType::GT) { - return a > b; - } else if (op_ == OpType::LE) { - return a <= b; - } else { //LT - return a < b; - } - } -}; - -// Reflects the comparison operator, so reflect(op)(a, b) == op(b, a) -OpType reflect(OpType x) { - switch (x) { - case OpType::GE: return OpType::LE; - case OpType::GT: return OpType::LT; - case OpType::LE: return OpType::GE; - case OpType::LT: return OpType::GT; - } - TORCH_INTERNAL_ASSERT(false, "Invalid OpType"); -} - -} // namespace (anonymous) - -template -void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) { - CompareFunctor f(op); - gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool { - return f(lhs, rhs); - }); -} - -template -void compare_kernel_impl(TensorIteratorBase &iter, OpType op) { - // If either input is a cpu scalar, perform the equivalent comparison - // where the scalar is on the right hand side. This saves us from - // generating two otherwise identical kernels with mirrored - // arguments. - if (iter.is_cpu_scalar(1)) { - const scalar_t lhs = iter.scalar_value(1); - iter.remove_operand(1); - const DeviceGuard device_guard(iter.device(1)); - compare_scalar_kernel(iter, reflect(op), lhs); - } else if (iter.is_cpu_scalar(2)) { - const scalar_t rhs = iter.scalar_value(2); - iter.remove_operand(2); - compare_scalar_kernel(iter, op, rhs); - } else { - CompareFunctor f(op); - gpu_kernel(iter, f); - } -} - -C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) { - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_zoom", [&]() { - compare_kernel_impl(iter, op); - }); -} - - -void ge_kernel_zoom(TensorIteratorBase& iter) { - compare_kernel_with_scalars(iter, OpType::GE); -} - -void gt_kernel_zoom(TensorIteratorBase& iter) { - compare_kernel_with_scalars(iter, OpType::GT); -} - -void le_kernel_zoom(TensorIteratorBase& iter) { - compare_kernel_with_scalars(iter, OpType::LE); -} - -void lt_kernel_zoom(TensorIteratorBase& iter) { - compare_kernel_with_scalars(iter, OpType::LT); -} - -REGISTER_PRIVATEUSE1_DISPATCH(ge_stub, &ge_kernel_zoom); -REGISTER_PRIVATEUSE1_DISPATCH(gt_stub, >_kernel_zoom); -REGISTER_PRIVATEUSE1_DISPATCH(le_stub, &le_kernel_zoom); -REGISTER_PRIVATEUSE1_DISPATCH(lt_stub, <_kernel_zoom); - -} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Copy.cu b/aten/src/ATen/native/zoom/Copy.cu index 3415806851f9fd..57436f844beedc 100644 --- a/aten/src/ATen/native/zoom/Copy.cu +++ b/aten/src/ATen/native/zoom/Copy.cu @@ -11,6 +11,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -23,8 +24,66 @@ namespace at::native { -void neg_kernel_zoom(TensorIteratorBase &iter); -void conj_kernel_zoom(TensorIteratorBase &iter); +// forward decl, defined below +void direct_copy_kernel_zoom(TensorIteratorBase &iter); + +// NB: Ignores the negative bit on tensors +CONSTEXPR_EXCEPT_WIN_CUDA char neg_name[] = "neg_kernel"; +void neg_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + if (at::isComplexType(dtype)) { + static const auto neg_string = jiterator_stringify( + template + T neg_kernel(T a) { + return -a; + } + ); // neg_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ neg_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, neg_string); + }); + + } else { + AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, dtype, "neg_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return -a; + }); + }); + } +} + +// NB: Ignores the negative bit on tensors +CONSTEXPR_EXCEPT_WIN_CUDA char conj_name[] = "conj_kernel"; +void conj_kernel_zoom(TensorIteratorBase& iter) { + auto conj_chalf = [&] { + using scalar_t = c10::complex; + + static const auto conj_string = jiterator_stringify( + template + T conj_kernel(T z) { + return std::conj(z); + } + ); + jitted_gpu_kernel(iter, conj_string); + + }; + + AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_zoom", + AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] { + // Conj is a no-op for non-complex types + direct_copy_kernel_zoom(iter); + }) + AT_DISPATCH_CASE_COMPLEX_TYPES([&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::conj(a); + }); + }) + AT_DISPATCH_CASE(kComplexHalf, conj_chalf) + ); +} void float8_copy_kernel_zoom(TensorIteratorBase &iter) { ScalarType dtype = iter.dtype(0); diff --git a/aten/src/ATen/native/zoom/ForeachFunctors.cuh b/aten/src/ATen/native/zoom/ForeachFunctors.cuh deleted file mode 100644 index 869e6fa3fd4389..00000000000000 --- a/aten/src/ATen/native/zoom/ForeachFunctors.cuh +++ /dev/null @@ -1,681 +0,0 @@ -#pragma once -#include -#include -#include -#include - -namespace at::native { - -namespace { - -// TODO(crcrpar): Handle version bump in codegen. -// rel: -// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 -inline void increment_version(TensorList tensors) { - for (const auto& t : tensors) { - t.unsafeGetTensorImpl()->bump_version(); - } -} - -// Initializes args and checks if all args are aligned -template -__device__ bool init_args( - T** args, - TensorListMetadata& tl, - const int64_t chunk_idx, - const int64_t chunk_size, - const int64_t tensor_loc) { - bool all_aligned = true; - for (int i = 0; i < depth; i++) { - args[i] = (T*)tl.addresses[i][tensor_loc]; - args[i] += chunk_idx * chunk_size; - - if (!is_aligned(args[i])) { - all_aligned = false; - } - } - return all_aligned; -} - -// Initializes args and checks if all args are aligned -template -__device__ bool init_args( - T** args, - TensorListScalarListMetadata& tl, - const int64_t chunk_idx, - const int64_t chunk_size, - const int64_t tensor_loc) { - bool all_aligned = true; - for (int i = 0; i < depth; i++) { - args[i] = (T*)tl.addresses[i][tensor_loc]; - args[i] += chunk_idx * chunk_size; - - if (!is_aligned(args[i])) { - all_aligned = false; - } - } - return all_aligned; -} - -template -__device__ bool init_args( - T** args, - FusedOptimizerTensorListMetadata& tl, - const int64_t chunk_idx, - const int64_t chunk_size, - const int64_t tensor_loc) { - bool all_aligned = true; - for (int i = 0; i < depth; i++) { - args[i] = (T*)tl.addresses[i][tensor_loc]; - args[i] += chunk_idx * chunk_size; - - if (!is_aligned(args[i])) { - all_aligned = false; - } - } - return all_aligned; -} - -template -__device__ void load_args( - T r_args[][kILP], - T** args, - const int64_t i_start, - const int64_t chunk_size, - const int64_t n) { -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - const auto i = i_start + threadIdx.x + ii * blockDim.x; - for (int r_index = 0; r_index < depth; r_index++) { - r_args[r_index][ii] = 0; - if (i < n && i < chunk_size) { - r_args[r_index][ii] = args[r_index][i]; - } - } - } -} - -template -__device__ void store_args( - T* dst, - T* src, - const int64_t i_start, - const int64_t chunk_size, - const int64_t n) { -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - const int64_t i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) - dst[i] = src[ii]; - } -} - -template -__device__ __forceinline__ void binary_op_scalar( - T r_args[][kILP], - T** args, - opmath_t scalar, - const int64_t n, - const int64_t chunk_size, - const bool all_aligned, - Op op) { - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - static_cast(scalar))); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args - // has depth 1 - load_args<1>(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - static_cast(scalar))); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } -} - -template -__device__ __forceinline__ void pointwise_op_scalar( - T r_args[][kILP], - T** args, - opmath_t scalar, - const int64_t n, - const int64_t chunk_size, - const bool all_aligned, - Op op) { - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); - load_store(r_args[1], args[1], 0, i_start); - load_store(r_args[2], args[2], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - static_cast(r_args[0][ii]) + - scalar * - op(static_cast(r_args[1][ii]), - static_cast(r_args[2][ii]))); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args - // has depth 3 - load_args<3>(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - static_cast(r_args[0][ii]) + - scalar * - op(static_cast(r_args[1][ii]), - static_cast(r_args[2][ii]))); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } -} - -// -// Binary Functors -// -template -struct BinaryOpScalarFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op, - opmath_t scalar) { - const int tensor_loc = tl.block_to_tensor[blockIdx.x]; - const int chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - binary_op_scalar( - r_args, args, scalar, n, chunk_size, all_aligned, op); - } -}; - -template -struct BinaryOpScalarListFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListScalarListMetadata& tl, - Op op) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - opmath_t scalar = tl.scalar_vals[tensor_loc]; - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - binary_op_scalar( - r_args, args, scalar, n, chunk_size, all_aligned, op); - } -}; - -template -struct BinaryOpListAlphaFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op, - opmath_t alpha) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); - load_store(r_args[1], args[1], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - alpha * static_cast(r_args[1][ii]))); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - load_args(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - alpha * static_cast(r_args[1][ii]))); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -template -struct BinaryOpScalarTensorFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op, - T* scalar, - opmath_t alpha) { - const int tensor_loc = tl.block_to_tensor[blockIdx.x]; - const int chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast(op( - static_cast(r_args[0][ii]), - static_cast(alpha) * static_cast(*scalar))); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - // Regardless if depth is 1 (for inplace) or 2 (for out of place), - // r_args has depth 1 - load_args<1>(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast(op( - static_cast(r_args[0][ii]), - static_cast(alpha) * static_cast(*scalar))); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -// -// Unary Functors -// - -template -struct ZeroFunctor { - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata<1>& tl) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const auto all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = 0; - } - // store - load_store(args[0], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = 0; - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -template -struct UnaryOpFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - static_cast(op(static_cast(r_args[0][ii]))); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - load_args(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - static_cast(op(static_cast(r_args[0][ii]))); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -// -// Pointwise Functors -// - -template -struct PointwiseOpScalarFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op, - opmath_t scalar) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - pointwise_op_scalar( - r_args, args, scalar, n, chunk_size, all_aligned, op); - } -}; - -template -struct PointwiseOpScalarListFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListScalarListMetadata& tl, - Op op) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - opmath_t scalar = tl.scalar_vals[tensor_loc]; - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - pointwise_op_scalar( - r_args, args, scalar, n, chunk_size, all_aligned, op); - } -}; - -template -struct PointwiseOpListFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op) { - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[depth - 1][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); - load_store(r_args[1], args[1], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]))); - } - // store - load_store(args[2], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - load_args(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = static_cast( - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]))); - } - store_args(args[2], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -template -struct TernaryOpListFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op) { - static_assert(depth == 3 || depth == 4, ""); - static_assert(depth >= r_args_depth, ""); - static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - load_store(r_args[0], args[0], 0, i_start); - load_store(r_args[1], args[1], 0, i_start); - load_store(r_args[2], args[2], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]), - static_cast(r_args[2][ii])); - } - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - load_args(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]), - static_cast(r_args[2][ii])); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -template -struct TernaryOpScalarFunctor { - using opmath_t = at::opmath_type; - template - __device__ __forceinline__ void operator()( - int chunk_size, - TensorListMetadata& tl, - Op op, - opmath_t alpha) { - static_assert(depth == 2 || depth == 3, ""); - static_assert(depth >= r_args_depth, ""); - static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); - const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; - const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; - auto n = tl.numel_for_tensor[tensor_loc]; - - T* args[depth]; - const bool all_aligned = - init_args(args, tl, chunk_idx, chunk_size, tensor_loc); - n -= chunk_idx * chunk_size; - T r_args[r_args_depth][kILP]; - - // to make things simple, we put aligned case in a different code path - if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { - for (int64_t i_start = threadIdx.x; - i_start * kILP < n && i_start * kILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_args[0], args[0], 0, i_start); - load_store(r_args[1], args[1], 0, i_start); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]), - alpha); - } - // store - load_store(args[res_arg_index], r_args[0], i_start, 0); - } - } else { - for (int64_t i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * kILP) { - load_args(r_args, args, i_start, chunk_size, n); -#pragma unroll - for (int ii = 0; ii < kILP; ii++) { - r_args[0][ii] = - op(static_cast(r_args[0][ii]), - static_cast(r_args[1][ii]), - alpha); - } - store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); - } - } - } -}; - -template -struct power_functor { - C10_DEVICE T operator()(const T& a, const T& b) const { - return at::native::pow_(a, b); - } -}; - -template -struct reverse_power_functor { - C10_DEVICE T operator()(const T& a, const T& b) const { - return at::native::pow_(b, a); - } -}; - -} // namespace -} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/MiscUtils.h b/aten/src/ATen/native/zoom/MiscUtils.h deleted file mode 100644 index 257c488bd7e98e..00000000000000 --- a/aten/src/ATen/native/zoom/MiscUtils.h +++ /dev/null @@ -1,32 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#pragma once -#include -#include -#include - -namespace at { -namespace native { - -static inline int zoom_int_cast(int64_t value, const char* varname) { - auto result = static_cast(value); - TORCH_CHECK(static_cast(result) == value, - "zoom_int_cast: The value of ", varname, "(", (long long)value, - ") is too large to fit into a int (", sizeof(int), " bytes)"); - return result; -} - -// Creates an array of size elements of type T, backed by pinned memory -// wrapped in a Storage -template -static inline Storage pin_memory(int64_t size) { - auto* allocator = zoom::getPinnedMemoryAllocator(); - int64_t adjusted_size = size * sizeof(T); - return Storage( - Storage::use_byte_size_t(), - adjusted_size, - allocator, - /*resizable=*/false); -} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/zoom/MultiTensorApply.cuh b/aten/src/ATen/native/zoom/MultiTensorApply.cuh deleted file mode 100644 index 9efa863f49ceaf..00000000000000 --- a/aten/src/ATen/native/zoom/MultiTensorApply.cuh +++ /dev/null @@ -1,379 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include - -namespace at::native { - -namespace { - -static constexpr int64_t kILP = 4; -static constexpr int64_t kChunkSize = 65536; -static constexpr int64_t kBlockSize = 512; - -// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` -// TensorListMetadata has to be < 4KB - the limit for kernel launch argument -static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; -static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; -static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { - 72, - 60}; - -template -__device__ __forceinline__ bool is_aligned(T* p) { - return ((uint64_t)p) % (kILP * sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store( - T* dst, - T* src, - int64_t dst_offset, - int64_t src_offset) { - using LT = at::native::memory::aligned_vector; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct TensorListMetadata { - const void* addresses[n][depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; - int start_tensor_this_launch; -}; - -template -struct TensorListScalarListMetadata { - const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; - scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; -}; - -// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of -// 4kb with `c10::complex` -template <> -struct TensorListScalarListMetadata, 1> { - const void* addresses[1] - [depth_to_max_tensors_scalarlist_of_complex_double[0]]; - int64_t - numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; - c10::complex - scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; - unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; - int block_to_chunk[depth_to_max_blocks[1 - 1]]; -}; - -template <> -struct TensorListScalarListMetadata, 2> { - const void* addresses[2] - [depth_to_max_tensors_scalarlist_of_complex_double[1]]; - int64_t - numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; - c10::complex - scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; - unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; - int block_to_chunk[depth_to_max_blocks[2 - 1]]; -}; - -// NOTE(crcrpar): This is a conservative resolution to handle `state_steps` -// whose each element is `at::Tensor` of 1 element representing the number of -// `step`s called so far. -template -struct FusedOptimizerTensorListMetadata { - const void* addresses[n][depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; - int start_tensor_this_launch; -}; - -template -C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ void multi_tensor_apply_kernel( - T tensorListMeta, - U callable, - ArgTypes... args) { - // Hand the chunk information to the user-supplied functor to process however - // it likes. - callable(kChunkSize, tensorListMeta, args...); -} - -} // namespace - -// multi_tensor_apply enables horizontal fusion across lists of tensors. -// For example, whereas you once had a for-loop of a + b = c, where a, b, -// and c are individual tensors in lists as, bs, and cs, you can now with -// fewer kernel launches compute as + bs = cs. -// -// You can also imagine bs to be a scalar list vs a tensor list. -// -// The function below takes in tensor lists, scalars, and a callable and -// chunks up the computation to launch as few kernels as possible by iterating -// through every "chunk" in every tensor (thus the nested for loops). In the -// simplest case, everything gets bundled into just one kernel launch, but -// due to blocksize constraints, we may need to launch multiple kernels. -// Each kernel launch is defined by one tensorListMeta construct, which we -// use to track and reset the necessary metadata for each launch. -template -void multi_tensor_apply( - std::vector>& tensor_lists, - at::ArrayRef scalars, - T callable, - ArgTypes... args) { - TORCH_CHECK( - tensor_lists.size() == depth, - "Number of tensor lists has to match the depth."); - const size_t n_tensors = tensor_lists[0].size(); - using scalar_vals_t = typename T::opmath_t; - TensorListScalarListMetadata tensorListMeta; - - int loc_block_info = 0; - int loc_tensor_info = 0; - for (size_t t = 0; t < n_tensors; t++) { - // short-circuit to avoid adding empty tensors to tensorListMeta - if (tensor_lists[0][t].numel() == 0) { - continue; - } - tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to(); - tensorListMeta.numel_for_tensor[loc_tensor_info] = - tensor_lists[0][t].numel(); - for (int d = 0; d < depth; d++) { - tensorListMeta.addresses[d][loc_tensor_info] = - tensor_lists[d][t].const_data_ptr(); - } - loc_tensor_info++; - - // now we enter [chunking territory]. - // we will launch a kernel when EITHER the blocks get filled up OR - // the tensors get filled up. There will always be at least one block - // per tensor since the zero-sized ones will not enter the loop, so - // the nested forloop within represents iterating through the chunks - // of a single tensor. - const auto numel = tensor_lists[0][t].numel(); - const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); - for (auto chunk = 0; chunk < chunks; chunk++) { - tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tensorListMeta.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - // a tensor is not considered full unless all its chunks have been - // processed - const bool tensors_full = - (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && - chunk == chunks - 1); - const bool blocks_full = - (loc_block_info == depth_to_max_blocks[depth - 1]); - - if (tensors_full || blocks_full) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>( - tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - - // Reset. - loc_block_info = 0; - // all chunks have already been handled in the kernel - if (chunk == chunks - 1) { - loc_tensor_info = 0; - } else { // blocks were full and tensor chunks remain - tensorListMeta.numel_for_tensor[0] = - tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; - tensorListMeta.scalar_vals[0] = - tensorListMeta.scalar_vals[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) { - tensorListMeta.addresses[d][0] = - tensorListMeta.addresses[d][loc_tensor_info - 1]; - } - loc_tensor_info = 1; - } - } - } - } - - // note: [finishing what we started] - // if there's remaining work to be done but the tensors/blocks aren't full - // yet we are at the end, submit the kernel to do the work! - if (loc_block_info != 0) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } -} - -template -void multi_tensor_apply( - std::vector>& tensor_lists, - T callable, - ArgTypes... args) { - TORCH_CHECK( - tensor_lists.size() == depth, - "Number of tensor lists has to match the depth."); - const size_t n_tensors = tensor_lists[0].size(); - TensorListMetadata tensorListMeta; - tensorListMeta.start_tensor_this_launch = 0; - - int loc_block_info = 0; - int loc_tensor_info = 0; - for (size_t t = 0; t < n_tensors; t++) { - // short-circuit to avoid adding empty tensors to tensorListMeta - if (tensor_lists[0][t].numel() == 0) { - continue; - } - tensorListMeta.numel_for_tensor[loc_tensor_info] = - tensor_lists[0][t].numel(); - for (int d = 0; d < depth; d++) { - tensorListMeta.addresses[d][loc_tensor_info] = - tensor_lists[d][t].const_data_ptr(); - } - loc_tensor_info++; - - // see note: [chunking territory]. - const auto numel = tensor_lists[0][t].numel(); - const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); - for (auto chunk = 0; chunk < chunks; chunk++) { - tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tensorListMeta.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - const bool tensors_full = - (loc_tensor_info == depth_to_max_tensors[depth - 1] && - chunk == chunks - 1); - const bool blocks_full = - (loc_block_info == depth_to_max_blocks[depth - 1]); - - if (tensors_full || blocks_full) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>( - tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - - // Reset. - loc_block_info = 0; - if (chunk == chunks - 1) { - loc_tensor_info = 0; - tensorListMeta.start_tensor_this_launch = t + 1; - } else { - tensorListMeta.numel_for_tensor[0] = - tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) { - tensorListMeta.addresses[d][0] = - tensorListMeta.addresses[d][loc_tensor_info - 1]; - } - loc_tensor_info = 1; - tensorListMeta.start_tensor_this_launch = t; - } - } - } - } - - // see note: [finishing what we started] - if (loc_block_info != 0) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } -} - -template -void multi_tensor_apply_for_fused_optimizer( - std::vector>& tensor_lists, - at::TensorList state_steps, - T callable, - ArgTypes... args) { - TORCH_CHECK( - tensor_lists.size() == depth, - "Number of tensor lists has to match the depth"); - const auto num_tensors = tensor_lists[0].size(); - FusedOptimizerTensorListMetadata tensorListMeta; - - int loc_block_info = 0; - int loc_tensor_info = 0; - for (const auto& tensor_index : c10::irange(num_tensors)) { - // short-circuit to avoid adding empty tensors to tensorListMeta - if (tensor_lists[0][tensor_index].numel() == 0) { - continue; - } - tensorListMeta.state_steps_addresses[loc_tensor_info] = - state_steps[tensor_index].const_data_ptr(); - tensorListMeta.numel_for_tensor[loc_tensor_info] = - tensor_lists[0][tensor_index].numel(); - for (const auto& d : c10::irange(depth)) { - tensorListMeta.addresses[d][loc_tensor_info] = - tensor_lists[d][tensor_index].const_data_ptr(); - } - loc_tensor_info++; - - // see above note: [chunking territory] - const auto numel = tensor_lists[0][tensor_index].numel(); - const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); - TORCH_CHECK(chunks > -1); - for (const auto& chunk : c10::irange(chunks)) { - tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tensorListMeta.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - const auto tensor_full = - (loc_tensor_info == depth_to_max_tensors[depth - 1] && - chunk == chunks - 1); - const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; - - if (tensor_full || blocks_full) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>( - tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - - // Reset. - loc_block_info = 0; - if (chunk == chunks - 1) { - loc_tensor_info = 0; - } else { - tensorListMeta.numel_for_tensor[0] = - tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; - tensorListMeta.state_steps_addresses[0] = - tensorListMeta.state_steps_addresses[loc_tensor_info - 1]; - for (const auto& d : c10::irange(depth)) { - tensorListMeta.addresses[d][0] = - tensorListMeta.addresses[d][loc_tensor_info - 1]; - } - loc_tensor_info = 1; - } - } - } - } - - // see above note: [finishing what we've started] - if (loc_block_info != 0) { - multi_tensor_apply_kernel<<< - loc_block_info, - kBlockSize, - 0, - c10::zoom::getCurrentZoomStream()>>>(tensorListMeta, callable, args...); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } -} - -} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Nonzero.cu b/aten/src/ATen/native/zoom/Nonzero.cu deleted file mode 100644 index d735795bcc1720..00000000000000 --- a/aten/src/ATen/native/zoom/Nonzero.cu +++ /dev/null @@ -1,130 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include -#include -#include -#include -#include //for MAX_DIMS -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#include -#endif - - -namespace at::native { - -namespace{ -template -struct NonZeroOp -{ - __host__ __device__ __forceinline__ bool operator()(const T& a) const { - return (a!=T(0)); - } -}; - -//TODO: actually support int64_t index_t -template -struct TensorDims { - index_t sizes[MAX_DIMS]; -}; - -template -__global__ void write_indices( - int64_t* inp, - TensorDims dims, - int ndim, - index_t n) { - auto index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < n) { - index_t div = 1; - int64_t idx_flat = inp[index]; -#pragma unroll - for (int dim = MAX_DIMS; dim >= 0; dim--) { - if (dim > ndim - 1) - continue; - auto dim_size = dims.sizes[dim]; - inp[index + dim * n] = (idx_flat / div) % dim_size; - div *= dim_size; - } - } -} - -} //anonymous namespace - -template -void nonzero_zoom_out_impl(const Tensor& self, Tensor& out){ - Tensor self_ = self.contiguous(); - int N = self_.numel(); - const hipStream_t stream = c10::zoom::getCurrentZoomStream(); -// compute number of nonzero elements - size_t temp_storage_bytes=0; - auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); - auto num_nonzeros = allocator.allocate(sizeof(int)); - hipcub::TransformInputIterator, const scalar_t*> itr(self_.const_data_ptr(), NonZeroOp()); - hipcub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); - auto temp_storage = allocator.allocate(temp_storage_bytes); - hipcub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); - int num_nonzeros_h; - c10::zoom::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), hipMemcpyDeviceToHost, stream); - //expected output size is num_nonzeros x ndim - //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) - //we are able to directly use passed output with this size and strides, and we can also (per contract) - //resize passed output with incorrect sizes anyway we want. - //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. - bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); - at::Tensor out_temp = need_to_copy ? - Tensor(at::detail::empty_zoom({self.dim(), num_nonzeros_h}, out.options())) : - out.resize_({self.dim(), num_nonzeros_h}); - //Scalars are expected to produce output of size (1,0), so we can't write to it - if (self.dim() > 0) { - hipcub::CountingInputIterator counting_itr(0); - temp_storage_bytes = 0; - hipcub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); - temp_storage = allocator.allocate(temp_storage_bytes); - hipcub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr, - out_temp.mutable_data_ptr(), (int*)num_nonzeros.get(), N, stream); - if (num_nonzeros_h > 0 && self.dim() > 1){ - TensorDims dims; - for (int i=0; i>>(out_temp.mutable_data_ptr(), - dims, self.dim(), num_nonzeros_h); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - } - if (need_to_copy) { - out.copy_(out_temp.t()); - } else { - //transpose out so it is correct size - Tensor out_ = out_temp.t(); - out.set_(out_); - } -} - -Tensor& nonzero_out_zoom(const Tensor& self, Tensor& out){ - TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ - See https://github.com/pytorch/pytorch/issues/51871"); - TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); - TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", - out.device(), " and self on ", self.device()); - TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, - self.scalar_type(), "nonzero_zoom", - [&] {nonzero_zoom_out_impl(self, out);}); - return out; -} - -Tensor nonzero_zoom(const Tensor& self){ - Tensor out = at::detail::empty_zoom({0}, self.options().dtype(kLong)); - return at::native::nonzero_out_zoom(self, out); -} -} //namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Pow.cuh b/aten/src/ATen/native/zoom/Pow.cuh deleted file mode 100644 index eee86031f8d932..00000000000000 --- a/aten/src/ATen/native/zoom/Pow.cuh +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once -#include -#include - -namespace at { namespace native { - -namespace { - - -// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. -// So we need to define the functions with the explicit function signatures. -// As for pow, the following signatures are defined as the device function: -// pow(float, int) -// pow(double, int) -// pow(float, float) -// pow(double, double) -#ifdef _MSC_VER -// Functions for pow -// pow for at::Half -static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { - return static_cast(std::pow(static_cast(base), static_cast(exp))); -} -// pow for at::BFloat16 -static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { - return static_cast(std::pow(static_cast(base), static_cast(exp))); -} -// pow (floating, floating/int) -template -static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type - pow_(Base_type base, Exp_type exp) { - return std::pow(base, exp); -} -// pow (Otherwise) -template -static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type - pow_(Base_type base, Exp_type exp) { - return static_cast(std::pow(static_cast(base), static_cast(exp))); -} -#else -template -static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { - return ::pow(base, exp); -} -#endif - -template -static inline __host__ __device__ std::enable_if_t::value, T> pow_( - T base, T exp) { - return at::native::powi(base, exp); -} - -template -static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { - return c10_complex_math::pow(base, exp); -} - -} // namespace -}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/PowKernel.cu b/aten/src/ATen/native/zoom/PowKernel.cu deleted file mode 100644 index e67e47201687ad..00000000000000 --- a/aten/src/ATen/native/zoom/PowKernel.cu +++ /dev/null @@ -1,209 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at::native { - -// Forward declare some unary kernels -void rsqrt_kernel_zoom(TensorIteratorBase& iter); -void sqrt_kernel_zoom(TensorIteratorBase& iter); -void reciprocal_kernel_zoom(TensorIteratorBase& iter); - -namespace { - -void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar); - -template -void pow_scalar_tensor_impl(TensorIteratorBase& iter, scalar_t base) { - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t exp) -> scalar_t { - return pow_(base, exp); - }); -} - -template -void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { - // For complex, thrust::pow uses the identity - // pow(a, b) = exp(log(a) * b) - const auto fct = std::log(base); - gpu_kernel(iter, [=]GPU_LAMBDA(c10::complex exp) -> c10::complex { - return std::exp(fct * exp); - }); -} - -/* complex support impl */ -CONSTEXPR_EXCEPT_WIN_CUDA char pow_scalar_base_name[] = "pow_scalar_base_kernel"; -template <> -void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { - using scalar_t = c10::complex; - using opmath_t = at::opmath_type; - // For complex, thrust::pow uses the identity - // pow(a, b) = exp(log(a) * b) - const auto fct = std::log(opmath_t{base}); -#if AT_USE_JITERATOR() - static const auto pow_kernel_string = - jiterator_stringify(template T pow_scalar_base_kernel(T exp, T fct) { - return std::exp(fct * exp); - }); - jitted_gpu_kernel( - iter, - pow_kernel_string, - /*scalar_pos=*/at::zoom::jit::BinaryFuncVariant::NoScalar, - /*scalar_val=*/0, - /*extra_args=*/std::make_tuple(fct)); -#else - gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t { - return std::exp(fct * opmath_t{exp}); - }); -#endif -} - -namespace { - -#if AT_USE_JITERATOR() -/* complex support impl */ -CONSTEXPR_EXCEPT_WIN_CUDA char pow_name[] = "pow_kernel"; -static const auto pow_kernel_string = - jiterator_stringify(template T pow_kernel(T base, T exp) { - return std::pow(base, exp); - }); -#endif - -/* complex support impl */ -void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) { - using scalar_t = c10::complex; - using opmath_t = at::opmath_type; - auto exp = exp_scalar.to(); -#if AT_USE_JITERATOR() - jitted_gpu_kernel( - iter, - pow_kernel_string, - /*scalar_pos=*/at::zoom::jit::BinaryFuncVariant::NoScalar, - /*scalar_val=*/0, - /*extra_args=*/std::make_tuple(exp)); -#else - gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t { - return std::pow(opmath_t{base}, exp); - }); -#endif -} - -} // anonymous namespace - -void pow_tensor_tensor_kernel(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (common_dtype == kComplexHalf) { - using scalar_t = c10::complex; - if (iter.is_cpu_scalar(1)) { - const auto base = iter.scalar_value(1); - iter.remove_operand(1); - pow_scalar_tensor_impl(iter, base); - } else if (iter.is_cpu_scalar(2)) { - const auto exp = iter.scalar_value(2); - iter.remove_operand(2); - pow_chalf_tensor_scalar_impl(iter, exp); - } else { - using opmath_t = at::opmath_type; - TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2)); -#if AT_USE_JITERATOR() - jitted_gpu_kernel( - iter, pow_kernel_string); -#else - gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { - using opmath_t = at::opmath_type; - return pow_(opmath_t{base}, opmath_t{exp}); - }); -#endif - } - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - kHalf, kBFloat16, iter.common_dtype(), "pow_zoom", [&] { - if (iter.is_cpu_scalar(1)) { - const auto base = iter.scalar_value(1); - iter.remove_operand(1); - pow_scalar_tensor_impl(iter, base); - } else if (iter.is_cpu_scalar(2)) { - const auto exp = iter.scalar_value(2); - iter.remove_operand(2); - pow_tensor_scalar_kernel(iter, exp); - } else { - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { - return pow_(base, exp); - }); - } - }); - } -} - - -template -void pow_tensor_scalar_kernel_impl(TensorIteratorBase& iter, - Exp_type exp) { - const auto d_exp = static_cast(exp); - // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled - // in pow_tensor_scalar_kernel - if (d_exp == 2) { - gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { - return base * base; - }); - } else if (d_exp == 3) { - gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { - return base * base * base; - }); - } else if (d_exp == -2) { - gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { - return 1.0 / (base * base); - }); - } else { - gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { - return pow_(base, exp); - }); - } -} - -void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) { - // Dispatch to fast specialization for sqrt, rsqrt and reciprocal - if (!exp_scalar.isComplex()) { - if (exp_scalar.equal(.5)) { - return sqrt_kernel_zoom(iter); - } else if (exp_scalar.equal(-0.5)) { - return rsqrt_kernel_zoom(iter); - } else if (exp_scalar.equal(-1.0)) { - return reciprocal_kernel_zoom(iter); - } - } - if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) { - if (iter.common_dtype() == kComplexHalf) { - using scalar_t = c10::complex; - pow_chalf_tensor_scalar_impl(iter, exp_scalar); - return; - } - AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_zoom", [&]() { - const auto exp = exp_scalar.to(); - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { - return pow_(base, exp); - }); - }); - } else if (isFloatingType(iter.common_dtype()) || exp_scalar.isIntegral(false)) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "pow_zoom", [&]() { - const auto exp = exp_scalar.to(); - pow_tensor_scalar_kernel_impl(iter, exp); - }); - } else { - TORCH_INTERNAL_ASSERT(false, "invalid combination of type in Pow function, common dtype:", iter.common_dtype(), - "exp is integral?", exp_scalar.isIntegral(false)); - } -} - -} // anonymous namespace - -REGISTER_PRIVATEUSE1_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel); -REGISTER_PRIVATEUSE1_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel); - -} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorCompare.cu b/aten/src/ATen/native/zoom/TensorCompare.cu deleted file mode 100644 index e92d058c9b7222..00000000000000 --- a/aten/src/ATen/native/zoom/TensorCompare.cu +++ /dev/null @@ -1,133 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include - - -namespace at::native { - -namespace { - -void where_kernel_impl(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_zoom", [&] { - gpu_kernel( - iter, - [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { - return cond_val ? self_val : other_val; - }); - }); -} - -void isposinf_kernel_impl(TensorIteratorBase &iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_zoom", [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits::infinity(); } - ); - }); -} - -void isneginf_kernel_impl(TensorIteratorBase &iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_zoom", [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits::infinity(); } - ); - }); -} - -void clamp_kernel_impl(TensorIteratorBase& iter) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_zoom", [&] { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { - // Propagate nan, which doesn't propagate automatically for ROCm - if (at::_isnan(v)) { - return v; - } if (at::_isnan(lower)) { - return lower; - } if (at::_isnan(upper)) { - return upper; - } else { - return ::min(::max(v, lower), upper); - } - }); - }); -} - -void inline launch_clamp_scalar(TensorIteratorBase& iter, Scalar lim0, Scalar lim1, at::native::detail::ClampLimits minmax){ - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_zoom", [&] { - using opmath_t = at::opmath_type; - auto lim0_val = lim0.to(); - auto lim1_val = lim1.to(); - - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - // Propagate nan, which doesn't propagate automatically for ROCm - if (_isnan(static_cast(v))) { - return v; - } else if (minmax==at::native::detail::ClampLimits::Min){ - return ::max(static_cast(v), lim0_val); - } else if (minmax==at::native::detail::ClampLimits::Max){ - return ::min(static_cast(v), lim0_val); - } else { - return ::min(::max(static_cast(v), lim0_val), lim1_val); - } - }); - }); -} - - -void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min, const Scalar& max) { - launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax); -} - -void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min) { - launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min); -} - -void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max) { - launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); -} - -} // anonymous namespace - - -REGISTER_PRIVATEUSE1_DISPATCH(where_kernel, &where_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(isposinf_stub, &isposinf_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(isneginf_stub, &isneginf_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(clamp_stub, &clamp_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); -REGISTER_PRIVATEUSE1_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); - -template -__global__ void _assert_async_zoom_kernel(const scalar_t* input) { - ZOOM_KERNEL_ASSERT(input[0] != 0); -} - -__global__ void _assert_async_zoom_kernel(const c10::complex* input) { - ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); -} -__global__ void _assert_async_zoom_kernel(const c10::complex* input) { - ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); -} - -void _assert_async_zoom(const Tensor& self_tensor) { - const TensorBase &self = get_tensor_base(self_tensor); - auto n = self.numel(); - TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); - TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); - auto stream = c10::zoom::getCurrentZoomStream(); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_zoom", [&] { - _assert_async_zoom_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - }); -} - -// TODO (tmanlaibaatar) Ignore assert msg for now -void _assert_async_msg_zoom(const Tensor& self_tensor, c10::string_view assert_msg) { - _assert_async_zoom(self_tensor); -} - -} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorShape.cu b/aten/src/ATen/native/zoom/TensorShape.cu deleted file mode 100644 index 5fad25d8a76179..00000000000000 --- a/aten/src/ATen/native/zoom/TensorShape.cu +++ /dev/null @@ -1,833 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#include -#endif - -namespace at::native { - -namespace detail { - -// NOTE [CUDA fast path for split_with_sizes_copy.out] -// split_with_sizes_copy.out for contiguous operands has the following -// properties: -// - Each src split consists of multiple chunks that are separated by a fixed -// stride. The number of chunks and the strides are the same across all src -// splits. -// - Each dst split is the concatenation of the chunks in its corresponding src -// splits. -// - The sizes of chunks vary across splits. -// - A (src, dst) chunk pair is not guaranteed to have the -// same alignment. -// -// The following strategies are employed to optimize for this workload: -// - The entire workload is fused into a single kernel to maximize I/O -// throughput and minimize wave quantization. -// - To account for both small and large chunk sizes, a "jagged grid" is used. -// Each chunk is processed by one or more blocks depending on its size. -// - Within each chunk, the region in which writes can be vectorized is -// identified. Within this region, writes are always vectorized and reads are -// oppurtunistically vectorized. -static constexpr int64_t BLOCK_SIZE = 128; -static constexpr int64_t BYTES_PER_THREAD = 16; -static constexpr int64_t BYTES_PER_BLOCK = BYTES_PER_THREAD * BLOCK_SIZE; - -static __host__ __device__ inline int64_t div_up(int64_t a, int64_t b) { - return (a + b - 1) / b; -} - -template -__device__ inline void stream_load128(uint4& val, const T* addr) { - uint64_t low, high; - low = reinterpret_cast(addr)[0]; - high = reinterpret_cast(addr)[1]; - reinterpret_cast(&val)[0] = low; - reinterpret_cast(&val)[1] = high; -} - -template -__device__ inline void stream_store128(T* addr, const uint4& val) { - uint64_t low, high; - low = reinterpret_cast(&val)[0]; - high = reinterpret_cast(&val)[1]; - reinterpret_cast(addr)[0] = low; - reinterpret_cast(addr)[1] = high; -} - -template -static __device__ inline bool is_aligned(const void* addr) { - return reinterpret_cast(addr) % sizeof(T) == 0; -} - -template -static __device__ inline void load128(uint4& val, const char* addr) { - for (size_t i = 0; i < detail::BYTES_PER_THREAD / sizeof(T); ++i) { - reinterpret_cast(&val)[i] = reinterpret_cast(addr)[i]; - } -} - -template <> -__device__ inline void load128(uint4& val, const char* addr) { - stream_load128(val, addr); -} - -static __device__ inline void load128(uint4& val, const char* addr) { - if (is_aligned(addr)) { - load128(val, addr); - } else if (is_aligned(addr)) { - load128(val, addr); - } else if (is_aligned(addr)) { - load128(val, addr); - } else { - load128(val, addr); - } -} - -static __device__ __inline__ void get_aligned_region( - char* ptr, - const int64_t chunk_size, - const int64_t alignment, - int64_t& align_off, - int64_t& aligned_size) { - const int64_t ptr_val = reinterpret_cast(ptr); - align_off = detail::div_up(ptr_val, alignment) * alignment - ptr_val; - aligned_size = (chunk_size - align_off) / alignment * alignment; -} - -static __device__ __inline__ void copy_chunk( - char* dst, - const char* src, - int64_t chunk_size, - int64_t thread_idx, - int64_t num_threads) { - if (chunk_size < num_threads) { - if (thread_idx < chunk_size) { - dst[thread_idx] = src[thread_idx]; - } - return; - } - - // Identify the region in which writes are guaranteed to be 128-bit aligned - int64_t align_off, aligned_size; - get_aligned_region( - dst, chunk_size, detail::BYTES_PER_THREAD, align_off, aligned_size); - - for (int64_t off = align_off + thread_idx * detail::BYTES_PER_THREAD; - off < align_off + aligned_size; - off += num_threads * detail::BYTES_PER_THREAD) { - uint4 val; - // Oppurtunistically vectorize reads - load128(val, &src[off]); - stream_store128(&dst[off], val); - } - - // Handle unaligned regions - if (thread_idx < align_off && thread_idx < chunk_size) { - dst[thread_idx] = src[thread_idx]; - } - if (align_off + aligned_size + thread_idx < chunk_size) { - dst[align_off + aligned_size + thread_idx] = - src[align_off + aligned_size + thread_idx]; - } -} - -static __global__ void split_with_sizes_copy_out_contiguous_no_cast_kernel( - char** dst_base_addrs, - char** src_base_addrs, - int64_t* split_chunk_sizes, - int64_t* block_idx_to_split_idx, - int64_t* blocks_cumsums, - int64_t src_stride, - int64_t num_chunks) { - const int64_t split_idx = block_idx_to_split_idx[blockIdx.x]; - const int64_t split_blocks = - blocks_cumsums[split_idx + 1] - blocks_cumsums[split_idx]; - const int64_t split_threads = split_blocks * blockDim.x; - const int64_t split_thread_idx = - (blockIdx.x - blocks_cumsums[split_idx]) * blockDim.x + threadIdx.x; - const int64_t split_chunk_size = split_chunk_sizes[split_idx]; - - char* dst_base_addr = dst_base_addrs[split_idx]; - char* src_base_addr = src_base_addrs[split_idx]; - - for (int64_t i = blockIdx.y; i < num_chunks; i += gridDim.y) { - copy_chunk( - dst_base_addr + i * split_chunk_size, - src_base_addr + i * src_stride, - split_chunk_size, - split_thread_idx, - split_threads); - } -} - -// Calculate the base addr for each split. -static inline std::vector get_split_base_addrs( - const at::Tensor& tensor, - at::IntArrayRef split_sizes, - int64_t dim) { - const auto* data_ptr = static_cast(tensor.const_data_ptr()); - const auto strides = tensor.strides(); - const auto element_sz = tensor.element_size(); - int64_t off = 0; - std::vector split_base_addrs; - split_base_addrs.reserve(split_sizes.size()); - for (const auto& split_size : split_sizes) { - split_base_addrs.push_back(reinterpret_cast(data_ptr + off)); - off += split_size * strides[dim] * element_sz; - } - return split_base_addrs; -} - -static inline std::vector get_dst_addrs(at::TensorList out) { - std::vector addrs; - addrs.reserve(out.size()); - for (const auto& tensor : out) { - addrs.push_back(reinterpret_cast(tensor.data_ptr())); - } - return addrs; -} - -// Calculate the chunk size for each split in bytes. -static inline std::vector get_split_chunk_sizes( - const at::Tensor& tensor, - at::IntArrayRef split_sizes, - int64_t dim) { - const auto stride = tensor.stride(dim); - const auto element_sz = tensor.element_size(); - std::vector split_chunk_sizes; - split_chunk_sizes.reserve(split_sizes.size()); - for (const auto& split_size : split_sizes) { - split_chunk_sizes.push_back(split_size * stride * element_sz); - } - return split_chunk_sizes; -} - -// Calculate the chunk stride in bytes. This is the same for all splits. -static inline int64_t get_chunk_stride(const at::Tensor& tensor, int64_t dim) { - int64_t stride = 1; - for (int64_t d = dim; d < tensor.dim(); ++d) { - stride *= tensor.sizes()[d]; - } - return stride * tensor.element_size(); -} - -// Calculate the number of chunks. This is the same for all splits. -static inline int64_t get_num_chunks(const at::Tensor& tensor, int64_t dim) { - int64_t num_chunks = tensor.numel(); - for (int64_t d = dim; d < tensor.dim(); ++d) { - num_chunks /= tensor.sizes()[d]; - } - return num_chunks; -} - -// Pack multiple std::vector into a single zoom tensor. -std::pair> pack_vecs( - std::vector*> vecs, - const at::Device& device) { - int64_t numel = 0; - for (const auto* vec : vecs) { - numel += vec->size(); - } - - auto packed = at::empty( - {numel}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - size_t offset = 0; - for (const auto* vec : vecs) { - memcpy( - packed.data_ptr() + offset, - vec->data(), - sizeof(int64_t) * vec->size()); - offset += vec->size(); - } - packed = packed.to(device, /*non_blocking=*/true); - - std::vector ptrs; - ptrs.reserve(vecs.size()); - offset = 0; - for (const auto* vec : vecs) { - ptrs.push_back(packed.data_ptr() + offset); - offset += vec->size(); - } - return std::make_pair(std::move(packed), std::move(ptrs)); -} - -static inline std::vector get_chunk_cat_out_sizes( - IntArrayRef input_tensor_sizes, - int64_t dim, - int64_t num_chunks, - int64_t chunk_size, - int64_t out_element_size) { - std::vector view_sizes = std::vector( - input_tensor_sizes.begin(), input_tensor_sizes.begin() + dim); - view_sizes.insert( - view_sizes.end(), {num_chunks, chunk_size / out_element_size}); - return view_sizes; -} - -// Copy `max_chunk_size` bytes from `src` to `dst` by `num_threads`, and pad -// zero when `src` size (i.e., actual_chunk_size) is less than `max_chunk_size`. -// Assume elements of src and dst have the same data type. -template -__device__ __inline__ void copy_chunk_with_pad( - dst_t* dst_ptr, - src_t* src_ptr, - int64_t max_chunk_size, - int64_t actual_chunk_size, - int64_t thread_idx, - int64_t num_threads) { - // Supports type cast - if (!std::is_same_v) { - const int64_t max_num_elems = max_chunk_size / sizeof(dst_t); - const int64_t actual_num_elems = actual_chunk_size / sizeof(src_t); - int64_t elem_index = thread_idx; - while (elem_index < actual_num_elems) { - dst_ptr[elem_index] = - static_cast_with_inter_type::apply(src_ptr[elem_index]); - elem_index += num_threads; - } - while (elem_index < max_num_elems) { - dst_ptr[elem_index] = static_cast_with_inter_type::apply(0); - elem_index += num_threads; - } - return; - } - char* dst = reinterpret_cast(dst_ptr); - char* src = reinterpret_cast(src_ptr); - // Fast path when the number of threads is larger than the number of bytes to - // be copied (i.e., max_chunk_size). In this case, each thread only copies 1 - // byte. For 0 <= thread_idx < actual_chunk_size, the thread copies data from - // `src`. For actual_chunk_size <= thread_idx < max_chunk_size, the thread set - // the val=0 for padding. - if (max_chunk_size < num_threads) { - char val = static_cast(0); - if (thread_idx < actual_chunk_size) { - val = src[thread_idx]; - } - if (thread_idx < max_chunk_size) { - dst[thread_idx] = val; - } - return; - } - // Split dst array into three parts: - // [dst, dst+align_off), [dst+align_off, dst+align_end), [dst+align_end, - // dst+max_chunk_size) The second part is aligned with BYTES_PER_THREAD(=16 - // bytes) to enable `stream_store128`. - int64_t align_off, aligned_size; - get_aligned_region( - dst, actual_chunk_size, BYTES_PER_THREAD, align_off, aligned_size); - int64_t align_end = align_off + aligned_size; - for (int64_t i = align_off + thread_idx * BYTES_PER_THREAD; i < align_end; - i += num_threads * BYTES_PER_THREAD) { - uint4 val; - if (is_aligned(src + i)) { - stream_load128(val, src + i); - } else { - for (size_t j = 0; j < BYTES_PER_THREAD; ++j) { - reinterpret_cast(&val)[j] = src[i + j]; - } - } - stream_store128(&dst[i], val); - } - // Copy data for the first part of dst array [dst, dst+align_off). - // Check `thread_idx -static __global__ void chunk_cat_zoom_kernel( - src_t** src, - dst_t* dst, - int64_t* block_idx_to_tensor_idx, - int64_t* tensor_idx_to_start_tensor_bytes, - int64_t* start_block_idx_per_tensor_chunk, - int64_t* actual_tensor_sizes, - int64_t* pad_tensor_chunk_sizes, - int64_t* num_blocks_per_tensor_chunk, - int64_t slice_size, - int64_t chunk_size, - int64_t dst_to_src_ratio) { - const int64_t slice_idx = blockIdx.z; - const int64_t chunk_idx = blockIdx.y; - const int64_t tensor_idx = block_idx_to_tensor_idx[blockIdx.x]; - const int64_t tile_idx = - blockIdx.x - start_block_idx_per_tensor_chunk[tensor_idx]; - // Number of threads for the `tensor_idx`-th tensor chunk. - const int64_t num_threads = - num_blocks_per_tensor_chunk[tensor_idx] * BLOCK_SIZE; - const int64_t thread_idx = tile_idx * BLOCK_SIZE + threadIdx.x; - char* src_addr = reinterpret_cast(src)[tensor_idx] + - slice_idx * actual_tensor_sizes[tensor_idx] + - chunk_idx * pad_tensor_chunk_sizes[tensor_idx] / dst_to_src_ratio; - char* dst_addr = reinterpret_cast(dst) + slice_idx * slice_size + - chunk_idx * chunk_size + tensor_idx_to_start_tensor_bytes[tensor_idx]; - // Compute the actual number of bytes to copy from src. - const int64_t actual_copy_size = ::min( - pad_tensor_chunk_sizes[tensor_idx] / dst_to_src_ratio, - ::max( - (int64_t)0, - actual_tensor_sizes[tensor_idx] - - chunk_idx * pad_tensor_chunk_sizes[tensor_idx] / - dst_to_src_ratio)); - copy_chunk_with_pad( - reinterpret_cast(dst_addr), - reinterpret_cast(src_addr), - pad_tensor_chunk_sizes[tensor_idx], - actual_copy_size, - thread_idx, - num_threads); -} - -bool all_contiguous(TensorList tensors) { - bool contiguous = true; - for (const auto& t : tensors) { - contiguous &= t.is_non_overlapping_and_dense(); - } - return contiguous; -} - -// Get leading dimensions before `dim`-th dimension. -static inline int64_t get_leading_dim(at::IntArrayRef sizes, int64_t dim) { - int64_t leading_dim = 1; - if (dim > 0) { - leading_dim = c10::multiply_integers(sizes.slice(0, dim)); - } - return leading_dim; -} - -// Get trailing dimensions after `dim`-th dimension and padded size along -// `dim`-th dimension. -static inline std::pair get_pad_size( - at::IntArrayRef sizes, - int64_t dim, - int64_t num_chunks) { - int64_t trailing_numel = 1; - if (sizes.size() > (uint64_t)dim + 1) { - trailing_numel = - c10::multiply_integers(sizes.slice(dim + 1, sizes.size() - dim - 1)); - } - int64_t pad_size_along_dim = - detail::div_up(sizes[dim], num_chunks) * num_chunks; - return std::make_pair(pad_size_along_dim, trailing_numel); -} - -// Get the padded chunk size. -static inline int64_t get_chunk_size( - TensorList tensors, - int64_t dim, - int64_t num_chunks, - int64_t elem_size) { - auto num_tensors = tensors.size(); - int64_t chunk_size = 0; - for (const auto i : c10::irange(num_tensors)) { - auto [pad_size_along_dim, trailing_numel] = - get_pad_size(tensors[i].sizes(), dim, num_chunks); - const int64_t pad_tensor_chunk_size = - pad_size_along_dim * trailing_numel * elem_size / num_chunks; - chunk_size += pad_tensor_chunk_size; - } - return chunk_size; -} - -// Get metadata for chunk_cat. -std::tuple< - int64_t, - int64_t, - int64_t, - int64_t, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector> -get_chunk_cat_metadata( - TensorList tensors, - int64_t dim, - int64_t num_chunks, - int64_t dst_elem_size, - int64_t src_elem_size) { - TORCH_CHECK( - dst_elem_size % src_elem_size == 0, - "get_chunk_cat_metadata error: only support dst_elem_size % src_elem_size == 0"); - auto num_tensors = tensors.size(); - int64_t leading_dim = get_leading_dim(tensors[0].sizes(), dim); - std::vector pad_tensor_chunk_sizes; - std::vector num_blocks_per_tensor_chunk; - std::vector start_block_idx_per_tensor_chunk{0}; - std::vector actual_tensor_sizes; - std::vector tensor_idx_to_start_tensor_bytes{0}; - std::vector srcs; - pad_tensor_chunk_sizes.reserve(num_tensors); - num_blocks_per_tensor_chunk.reserve(num_tensors); - start_block_idx_per_tensor_chunk.reserve(num_tensors + 1); - actual_tensor_sizes.reserve(num_tensors); - tensor_idx_to_start_tensor_bytes.reserve(num_tensors + 1); - srcs.reserve(num_tensors); - // block_idx_to_tensor_idx cannot be reserved since the number of blocks is - // data dependent - std::vector block_idx_to_tensor_idx; - // Inline computing `chunk_size` to avoid redundant computation - int64_t chunk_size = 0; - for (const auto i : c10::irange(num_tensors)) { - at::Tensor tensor = tensors[i]; - srcs.push_back(reinterpret_cast(tensor.data_ptr())); - auto sizes = tensor.sizes(); - auto [pad_size_along_dim, trailing_numel] = - get_pad_size(sizes, dim, num_chunks); - const int64_t pad_tensor_chunk_size = - pad_size_along_dim * trailing_numel * dst_elem_size / num_chunks; - pad_tensor_chunk_sizes.push_back(pad_tensor_chunk_size); - chunk_size += pad_tensor_chunk_size; - // Number of blocks required to process this tensor chunk. - const int64_t num_blocks = - detail::div_up(pad_tensor_chunk_size, detail::BYTES_PER_BLOCK); - num_blocks_per_tensor_chunk.push_back(num_blocks); - start_block_idx_per_tensor_chunk.push_back( - start_block_idx_per_tensor_chunk.back() + num_blocks); - block_idx_to_tensor_idx.insert( - block_idx_to_tensor_idx.end(), num_blocks, i); - tensor_idx_to_start_tensor_bytes.push_back( - tensor_idx_to_start_tensor_bytes.back() + pad_tensor_chunk_size); - actual_tensor_sizes.push_back(sizes[dim] * trailing_numel * src_elem_size); - } - const int64_t num_blocks_per_chunk = start_block_idx_per_tensor_chunk.back(); - const int64_t slice_size = num_chunks * chunk_size; - return std::make_tuple( - chunk_size, - leading_dim, - num_blocks_per_chunk, - slice_size, - srcs, - block_idx_to_tensor_idx, - tensor_idx_to_start_tensor_bytes, - start_block_idx_per_tensor_chunk, - actual_tensor_sizes, - pad_tensor_chunk_sizes, - num_blocks_per_tensor_chunk); -} - -// See [CUDA kernel for chunk_cat_cuda] -template -void _chunk_cat_out_zoom_contiguous( - TensorList tensors, - int64_t dim, - int64_t num_chunks, - Tensor& out, - int64_t dst_elem_size, - int64_t src_elem_size) { - const auto device = tensors[0].device(); - // `get_chunk_cat_metadata` must return vectors and `pack_vecs` cannot be - // moved into `get_chunk_cat_metadata`. Otherwise `packed` would point to - // vectors allocated inside `get_chunk_cat_metadata` which become out of local - // scope. - auto - [chunk_size, - leading_dim, - num_blocks_per_chunk, - slice_size, - srcs, - block_idx_to_tensor_idx, - tensor_idx_to_start_tensor_bytes, - start_block_idx_per_tensor_chunk, - actual_tensor_sizes, - pad_tensor_chunk_sizes, - num_blocks_per_tensor_chunk] = - get_chunk_cat_metadata( - tensors, dim, num_chunks, dst_elem_size, src_elem_size); - auto packed = pack_vecs( - {&srcs, - &block_idx_to_tensor_idx, - &tensor_idx_to_start_tensor_bytes, - &start_block_idx_per_tensor_chunk, - &actual_tensor_sizes, - &pad_tensor_chunk_sizes, - &num_blocks_per_tensor_chunk}, - device); - std::vector view_sizes = get_chunk_cat_out_sizes( - tensors[0].sizes(), dim, num_chunks, chunk_size, dst_elem_size); - at::native::resize_output(out, view_sizes); - dim3 blocks(num_blocks_per_chunk, num_chunks, leading_dim); - dim3 threads(detail::BLOCK_SIZE, 1, 1); - hipLaunchKernelGGL(( detail::chunk_cat_zoom_kernel), - dim3(blocks), - dim3(threads), - 0, - c10::zoom::getCurrentZoomStream(), - /*srcs=*/reinterpret_cast(packed.second[0]), - reinterpret_cast(out.data_ptr()), - /*block_idx_to_tensor_idx=*/packed.second[1], - /*tensor_idx_to_start_tensor_bytes=*/packed.second[2], - /*start_block_idx_per_tensor_chunk=*/packed.second[3], - /*actual_tensor_sizes=*/packed.second[4], - /*pad_tensor_chunk_sizes=*/packed.second[5], - /*num_blocks_per_tensor_chunk=*/packed.second[6], - slice_size, - chunk_size, - dst_elem_size / src_elem_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); -} - -} // namespace detail - -// See [CUDA fast path for split_with_sizes_copy.out] -void split_with_sizes_copy_out_zoom_contiguous_no_cast( - const at::Tensor& self, - at::IntArrayRef split_sizes, - int64_t dim, - at::TensorList out) { - const auto device = self.device(); - const auto src_base_addrs = - detail::get_split_base_addrs(self, split_sizes, dim); - const auto dst_base_addrs = detail::get_dst_addrs(out); - const auto src_stride = detail::get_chunk_stride(self, dim); - const auto split_chunk_sizes = - detail::get_split_chunk_sizes(self, split_sizes, dim); - const auto num_chunks = detail::get_num_chunks(self, dim); - - // Calculate the number of blocks required for the first chunk across all - // splits, assuming each thread only processes BYTES_PER_THREAD bytes. - int64_t num_blocks = 0; - for (const auto& split_chunk_size : split_chunk_sizes) { - num_blocks += detail::div_up( - split_chunk_size, detail::BLOCK_SIZE * detail::BYTES_PER_THREAD); - } - - // Calculate the maximum number of blocks to launch. Only consider - // maxThreadsPerMultiProcessor as a limiting factor as the kernel uses no - // shared memory and little registers. Over-subscribe the SMs to hide I/O - // latency. - const auto num_sms = - at::zoom::getCurrentDeviceProperties()->multiProcessorCount; - const auto max_threads_per_sm = - at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; - const int64_t max_blocks = - num_sms * max_threads_per_sm / detail::BLOCK_SIZE * 2.0; - - // Make each thread process BYTES_PER_THREAD * iter_factor bytes to regulate - // block size. Spread iter_factor evenly between chunks_per_block and - // iters_per_chunk. - int64_t iter_factor = detail::div_up(num_blocks * num_chunks, max_blocks); - int64_t chunks_per_block = ::ceil(std::sqrt(iter_factor)); - chunks_per_block = ::min(chunks_per_block, num_chunks); - const int64_t iters_per_chunk = detail::div_up(iter_factor, chunks_per_block); - - // Launch a logically jagged grid of shape - // (chunk_size*, num_splits, num_chunks / chunks_per_block) - // backed by a physical grid of shape - // (sum(chunk_size), num_chunks / chunks_per_block). - // A block can find its split_idx via block_idx_to_split_idx. - std::vector block_idx_to_split_idx; - std::vector blocks_cumsums{0}; - block_idx_to_split_idx.reserve(num_blocks); - for (size_t split_idx = 0; split_idx < split_sizes.size(); ++split_idx) { - const auto blocks = detail::div_up( - split_chunk_sizes[split_idx], - detail::BLOCK_SIZE * detail::BYTES_PER_THREAD * iters_per_chunk); - block_idx_to_split_idx.insert( - block_idx_to_split_idx.end(), blocks, split_idx); - blocks_cumsums.push_back(blocks_cumsums.back() + blocks); - } - - dim3 blocks(blocks_cumsums.back(), num_chunks / chunks_per_block, 1); - dim3 threads(detail::BLOCK_SIZE, 1, 1); - - auto [_, ptrs] = detail::pack_vecs( - {&dst_base_addrs, - &src_base_addrs, - &split_chunk_sizes, - &block_idx_to_split_idx, - &blocks_cumsums}, - device); - - hipLaunchKernelGGL(( detail::split_with_sizes_copy_out_contiguous_no_cast_kernel), - dim3(blocks), - dim3(threads), - 0, - c10::zoom::getCurrentZoomStream(), - /*dst_base_addrs=*/reinterpret_cast(ptrs[0]), - /*src_base_addrs=*/reinterpret_cast(ptrs[1]), - /*split_chunk_sizes=*/ptrs[2], - /*block_idx_to_split_idx=*/ptrs[3], - /*blocks_cumsums=*/ptrs[4], - src_stride, - num_chunks); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); -} - -void split_with_sizes_copy_out_zoom( - const Tensor& self, - IntArrayRef split_sizes, - int64_t dim, - TensorList out) { - const bool is_capturing = c10::zoom::currentStreamCaptureStatusMayInitCtx() != - c10::zoom::CaptureStatus::None; - bool contiguous_no_cast = self.is_non_overlapping_and_dense(); - for (const auto& t : out) { - contiguous_no_cast &= t.is_non_overlapping_and_dense(); - contiguous_no_cast &= (t.dtype() == self.dtype()); - } - // TODO(yifu): make the fast path work for CUDA graph - if (!is_capturing && contiguous_no_cast) { - // Perform equivalent checks performed by the composite impl - if (dim < 0) { - dim = at::maybe_wrap_dim(dim, self.dim()); - } - TORCH_CHECK( - self.dim() != 0, "split expects at least a 1-dimensional tensor") - - const int64_t dim_size = self.size(dim); - int64_t split_sizes_sum = 0; - for (const auto i : c10::irange(split_sizes.size())) { - TORCH_CHECK( - split_sizes[i] >= 0, - "split_with_sizes expects split_sizes have only non-negative ", - "entries, but got split_sizes=", - split_sizes[i]); - split_sizes_sum += split_sizes[i]; - } - TORCH_CHECK( - split_sizes_sum == dim_size, - "split_with_sizes expects split_sizes to sum exactly to ", - dim_size, - " (input tensor's size at dimension ", - dim, - "), ", - "but got split_sizes=", - split_sizes); - - TORCH_CHECK( - out.size() == split_sizes.size(), - "split_with_sizes_copy_out() expected an out= argument of size ", - split_sizes.size(), - ", got size ", - out.size()); - - auto out_shape = self.sizes().vec(); - for (const auto i : c10::irange(split_sizes.size())) { - out_shape[dim] = split_sizes[i]; - if (resize_output_check(out[i], out_shape)) { - out[i].resize_(out_shape); - } - TORCH_CHECK( - out[i].dtype() == self.dtype(), - "Expected out tensor to have dtype ", - self.dtype(), - ", but got ", - out[i].dtype(), - " instead"); - TORCH_CHECK( - out[i].device() == self.device(), - "Expected out tensor to have device ", - self.device(), - ", but got ", - out[i].device(), - " instead"); - } - split_with_sizes_copy_out_zoom_contiguous_no_cast( - self, split_sizes, dim, out); - } else { - at::native::split_with_sizes_copy_out(self, split_sizes, dim, out); - } -} - -Tensor _chunk_cat_zoom(TensorList tensors, int64_t dim, int64_t num_chunks) { - dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks); - if (detail::all_contiguous(tensors)) { - // Return a tensor with the same dtype as input tensors - int64_t elem_size = tensors[0].element_size(); - int64_t chunk_size = - detail::get_chunk_size(tensors, dim, num_chunks, elem_size); - int64_t leading_dim = detail::get_leading_dim(tensors[0].sizes(), dim); - auto view_sizes = detail::get_chunk_cat_out_sizes( - tensors[0].sizes(), dim, num_chunks, chunk_size, elem_size); - Tensor out = - tensors[0] - .new_empty(chunk_size * num_chunks * leading_dim / elem_size) - .view(view_sizes); - // Type-agnostic copy since out and input tensors have the same type. - detail::_chunk_cat_out_zoom_contiguous( - tensors, dim, num_chunks, out, elem_size, elem_size); - return out; - } else { - return at::native::_chunk_cat(tensors, dim, num_chunks); - } -} - -Tensor& _chunk_cat_out_zoom( - TensorList tensors, - int64_t dim, - int64_t num_chunks, - Tensor& out) { - dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks); - TORCH_CHECK( - tensors[0].device() == out.device(), - "_chunk_cat_out_zoom: mismatch between input and out tensor devices"); - bool both_input_output_contiguous = - detail::all_contiguous(tensors) && out.is_non_overlapping_and_dense(); - if (both_input_output_contiguous && - (tensors[0].dtype() == at::ScalarType::BFloat16) && - (out.dtype() == at::ScalarType::Float)) { - // _chunk_cat_out_zoom_contiguous should also support other types, thanks to - // static_cast_with_inter_type. Here, we dispatch to BFloat16 in and float32 - // out since it is the only known use case. - detail::_chunk_cat_out_zoom_contiguous( - tensors, - dim, - num_chunks, - out, - out.element_size(), - tensors[0].element_size()); - } else if ( - both_input_output_contiguous && tensors[0].dtype() == out.dtype()) { - // Type-agnostic copy since out and input tensors have the same type. - detail::_chunk_cat_out_zoom_contiguous( - tensors, - dim, - num_chunks, - out, - out.element_size(), - tensors[0].element_size()); - } else { - at::native::_chunk_cat_out(tensors, dim, num_chunks, out); - } - return out; -} - -} // namespace at::native diff --git a/aten/src/ATen/native/zoom/TensorTransformations.cu b/aten/src/ATen/native/zoom/TensorTransformations.cu deleted file mode 100644 index fd84d2cb79a1bc..00000000000000 --- a/aten/src/ATen/native/zoom/TensorTransformations.cu +++ /dev/null @@ -1,154 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include - -#include -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - -#include -#include - -namespace at::native { - -template -C10_LAUNCH_BOUNDS_2(zoom::getApplyBlockSize(), zoom::getApplyBlocksPerSM()) -__global__ void kernel_pointwise_flip_apply2( - const zoom::detail::TensorInfo in_tensor_info, - zoom::detail::TensorInfo out_tensor_info, - IndexType N, - int flip_dim, - IndexType total_dims) { - for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) { - IndexType dst_offset = 0; - if (flip_dim == 0) { - // flip 1st dim - dst_offset = (in_tensor_info.sizes[0] - 1 - linear_index / in_tensor_info.strides[0]) * in_tensor_info.strides[0] + linear_index % in_tensor_info.strides[0]; - } - else { - // flip last dim - IndexType i = total_dims - 1; - dst_offset = linear_index / in_tensor_info.strides[0] * in_tensor_info.strides[0] + (in_tensor_info.sizes[i] - 1 - linear_index % in_tensor_info.strides[0]); - } - out_tensor_info.data[dst_offset] = in_tensor_info.data[linear_index]; - } -} - -template -C10_LAUNCH_BOUNDS_1(zoom::getApplyBlockSize()) -__global__ void flip_zoom_kernel( - scalar_t* in_tensor, - scalar_t* out_tensor, - int64_t N, - int64_t* flip_dims, - int64_t flip_dims_size, - int64_t* strides, - int64_t* strides_contiguous, - int64_t* shape, - int64_t total_dims) { - int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; - if (linear_index >= N) { - return; - } - - int64_t cur_indices = linear_index, rem = 0, dst_offset = 0; - for (int64_t i = 0; i < total_dims; i++) { - int64_t temp = cur_indices; - cur_indices = cur_indices / strides_contiguous[i]; - rem = temp - cur_indices * strides_contiguous[i]; - // flip the indices if it is in flip_dims - for (int64_t j = 0; j < flip_dims_size; j++) { - if (i == flip_dims[j]) { - cur_indices = shape[i] - 1 - cur_indices; - } - } - dst_offset += cur_indices * strides[i]; - cur_indices = rem; - } - out_tensor[linear_index] = in_tensor[dst_offset]; -} - -template -C10_LAUNCH_BOUNDS_1(zoom::getApplyBlockSize()) -__global__ void roll_zoom_kernel( - const scalar_t* in_tensor, - scalar_t* out_tensor, - int64_t N, - int64_t roll_dim, - int64_t start, - int64_t size, - int64_t stride, - int64_t total_dims) { - int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; - if (linear_index >= N) { - return; - } - // roll dim idx is the index of linear_index along the rolling dimension. - int64_t roll_dim_idx = linear_index % (stride * size) / stride; - // index into the source data to find appropriate value. - int64_t source_idx = 0; - if( roll_dim_idx >= (size - start) ) { - source_idx = linear_index - ((size - start) * stride); - } else { - source_idx = linear_index + (start * stride); - } - out_tensor[linear_index] = in_tensor[source_idx]; -} - -// Roll a tensor along a dimension -Tensor roll_zoom(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { - if (dims.size() != 1 || shifts.size() != 1) { - return roll_common(self, shifts, dims); - } - - auto in_tensor = self; - if(!self.is_contiguous()) { - in_tensor = self.contiguous(); - } - auto out_tensor = at::empty_like(in_tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - if (out_tensor.numel() == 0) { - return out_tensor; - } - const int64_t N = in_tensor.numel(); - const int64_t dim = dims[0]; - const int64_t size = in_tensor.size(dim); - int64_t start = (size - shifts[0]) % size; - // Behavior of % is different in C++ vs Python for negative numbers. This - // corrects the difference. - if( start < 0 ) start = start + size; - - dim3 dim_block = zoom::getApplyBlock(); - dim3 dim_grid; - TORCH_CHECK(zoom::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid"); - - auto total_dims = in_tensor.dim(); - - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, - at::ScalarType::ComplexHalf, - in_tensor.scalar_type(), "roll_zoom", - [&] { - hipLaunchKernelGGL(( roll_zoom_kernel), dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), - in_tensor.const_data_ptr(), out_tensor.mutable_data_ptr(), N, - dim, start, - size, - in_tensor.stride(dim), - total_dims); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - }); - - return out_tensor; -} - -} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ZoomScalar.cu b/aten/src/ATen/native/zoom/ZoomScalar.cu new file mode 100644 index 00000000000000..370c8a28b3ebed --- /dev/null +++ b/aten/src/ATen/native/zoom/ZoomScalar.cu @@ -0,0 +1,38 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +namespace at::native { + +Scalar _local_scalar_dense_zoom(const Tensor& self) { + Scalar r; + AT_DISPATCH_V2( + self.scalar_type(), "_local_scalar_dense_zoom", AT_WRAP([&] { + // Create pinned memory for the scalar value to avoid implicit + // locking/sync in cuda library due to pageable memory + auto value = at::detail::empty_cpu( + {1}, /* size */ + c10::CppTypeToScalarType(), /* dtype */ + std::nullopt, /* layout */ + std::nullopt, /* device */ + true, /* pin_memory */ + std::nullopt /* memory format */ + ); + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + c10::zoom::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), hipMemcpyDeviceToHost, stream); + r = Scalar(*value.const_data_ptr()); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + return r; +} + +} // at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/reduction_template.cuh b/aten/src/ATen/native/zoom/reduction_template.cuh new file mode 100644 index 00000000000000..f868c450614b38 --- /dev/null +++ b/aten/src/ATen/native/zoom/reduction_template.cuh @@ -0,0 +1,680 @@ +namespace at { +namespace zoom { +//windows doesn't like large string literals, so split in two +const std::string reduction_template_0 = R"ESCAPE( + #define C10_HOST_DEVICE __host__ __device__ + #define C10_DEVICE __device__ + #if defined(__clang__) && defined(__HIP__) + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + // until ROCm support for kernel asserts is restored + #define assert(expr) (static_cast(0)) + #endif + + template + __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + #if defined(__clang__) && defined(__HIP__) + return __shfl_down(value, delta, width); + #else + return __shfl_down_sync(mask, value, delta, width); + #endif + } + + + #if ${complex} + template + __device__ __forceinline__ std::complex WARP_SHFL_DOWN(std::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + return std::complex( + #if defined(__clang__) && defined(__HIP__) + __shfl_down(value.real(), delta, width), + __shfl_down(value.imag(), delta, width)); + #else + __shfl_down_sync(mask, value.real(), delta, width), + __shfl_down_sync(mask, value.imag(), delta, width)); + #endif + } + #endif + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + + C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; + } + + + + + struct ReduceConfig { + //has to match host-side ReduceConfig in the eager code + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + + }; + + +//TODO this will need to be different for more generic reduction functions +namespace reducer { + + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + + inline __device__ ${functor} + + inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + + inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + // wrap a normal reduction that ignores the index + inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) { + return combine(acc, val); + } +} + + +struct ReduceJitOp { + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + using InputCalculator = OffsetCalculator<1>; + using OutputCalculator = OffsetCalculator<2>; + +// static constexpr bool can_accumulate_in_output = +// std::is_convertible::value +// && std::is_convertible::value; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + uint32_t output_idx = config.output_idx<${output_vec_size}>(); + uint32_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + + value = thread_reduce<${output_vec_size}>(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce<${output_vec_size}>(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce<${output_vec_size}>(value, shared_memory); + } + + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce<${output_vec_size}>(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output<${output_vec_size}>(out, value); + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + assert(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](uint32_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](uint32_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + uint32_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = aligned_vector; + + uint32_t idx = config.input_idx(); + const uint32_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + scalar_t values[input_vec_size]; + + load_t *values_vector = reinterpret_cast(&values[0]); + + while (idx * input_vec_size + input_vec_size - 1 < end) { + *values_vector = reinterpret_cast(data)[idx]; + #pragma unroll + for (uint32_t i = 0; i < input_vec_size; i++) { + value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + uint32_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = reducer::combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + uint32_t idx = config.input_idx(); + const uint32_t end = config.num_inputs; + const uint32_t stride = config.step_input; + const int vt0=${vt0}; + + using arg_vec_t = Array; + using load_t = aligned_vector; + const load_t* data = reinterpret_cast(data_); + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + values[i] = data[calc(idx + i * stride) / output_vec_size]; + } + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + values[i] = data[calc(idx) / output_vec_size]; + idx += stride; + } + idx = idx_; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + template + C10_DEVICE Array block_x_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = reducer::warp_shfl_down(value[i], offset); + value[i] = reducer::combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE Array block_y_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + )ESCAPE"; + + const std::string reduction_template_1 = R"ESCAPE( + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE Array accumulate_in_output( + Array out, + Array value + ) const { + Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = reducer::combine(*(out[i]), value[i]); + } + return ret; + } + + + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value + ) const { + assert(!final_output); + return (out_scalar_t)value; + } + + template + C10_DEVICE void set_results(const T x, const uint32_t base_offset) const { + assert(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + +//TODO - multi-output reduction - we won't be able to use thrust::pair +//just explicitly specify typed output reads/writes +//Currently implemented for max of two outputs +// template +// C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { +// if (noutputs >= 1) { +// auto res0 = (T1*)((char*)dst[0] + base_offset); +// *res0 = x.first; +// } +// if (noutputs >= 2) { +// // base offset is computed assuming element size being sizeof(T1), so we need to make a +// // correction to obtain the correct base offset +// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); +// *res1 = x.second; +// } +// } + + template + C10_DEVICE void set_results_to_output(Array value, Array base_offset) const { + assert(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(reducer::project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE Array global_reduce(Array value, Array *acc, char* shared_memory) const { + using arg_vec_t = Array; + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + uint32_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + uint32_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + value = ident; + if (config.should_block_x_reduce()) { + uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + uint32_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } else { + uint32_t input_offset = threadIdx.y; + uint32_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +extern "C" +__launch_bounds__(${max_threads_lb}, 4) +__global__ void reduction_${name}_kernel(ReduceJitOp r){ + r.run(); +} +)ESCAPE"; + +const std::string reduction_template = reduction_template_0 + reduction_template_1; + + +const std::string &get_reduction_template() { + return reduction_template; +} + +}} \ No newline at end of file diff --git a/aten/src/ATen/templates/UfuncZoom.cu b/aten/src/ATen/templates/UfuncZoom.cu new file mode 100644 index 00000000000000..689a78b42f9102 --- /dev/null +++ b/aten/src/ATen/templates/UfuncZoom.cu @@ -0,0 +1,17 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +${zoom_headers} +namespace at { +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/zoom/ZoomContext.cpp b/aten/src/ATen/zoom/ZoomContext.cpp index 3182fafed7493f..30bb6d79d53e0f 100644 --- a/aten/src/ATen/zoom/ZoomContext.cpp +++ b/aten/src/ATen/zoom/ZoomContext.cpp @@ -2,7 +2,6 @@ #include #include -// #include #include #include #include diff --git a/aten/src/ATen/zoom/ZoomContextLight.h b/aten/src/ATen/zoom/ZoomContextLight.h index 44a82879f05267..93ad2791cd4a85 100644 --- a/aten/src/ATen/zoom/ZoomContextLight.h +++ b/aten/src/ATen/zoom/ZoomContextLight.h @@ -1,21 +1,10 @@ #pragma once // Light-weight version of ZoomContext.h with fewer transitive includes -#define DISABLE_HIPBLASLT - #include - #include #include #include -#include -#include -#include -#ifndef DISABLE_HIPBLASLT -#include -#include -#endif - namespace c10 { struct Allocator; } @@ -23,24 +12,24 @@ struct Allocator; namespace at::zoom { /* -A common CUDA interface for ATen. +A common Zoom interface for ATen. -This interface is distinct from CUDAHooks, which defines an interface that links -to both CPU-only and CUDA builds. That interface is intended for runtime +This interface is distinct from ZoomHooks, which defines an interface that links +to both CPU-only and Zoom builds. That interface is intended for runtime dispatch and should be used from files that are included in both CPU-only and -CUDA builds. +Zoom builds. -CUDAContext, on the other hand, should be preferred by files only included in -CUDA builds. It is intended to expose CUDA functionality in a consistent +ZoomContext, on the other hand, should be preferred by files only included in +Zoom builds. It is intended to expose Zoom functionality in a consistent manner. -This means there is some overlap between the CUDAContext and CUDAHooks, but -the choice of which to use is simple: use CUDAContext when in a CUDA-only file, -use CUDAHooks otherwise. +This means there is some overlap between the ZoomContext and ZoomHooks, but +the choice of which to use is simple: use ZoomContext when in a Zoom-only file, +use ZoomHooks otherwise. -Note that CUDAContext simply defines an interface with no associated class. +Note that ZoomContext simply defines an interface with no associated class. It is expected that the modules whose functions compose this interface will -manage their own state. There is only a single CUDA context/state. +manage their own state. There is only a single Zoom context/state. */ /** @@ -51,9 +40,9 @@ inline int64_t getNumGPUs() { } /** - * CUDA is available if we compiled with CUDA, and there are one or more - * devices. If we compiled with CUDA but there is a driver problem, etc., - * this function will report CUDA is not available (rather than raise an error.) + * Zoom is available if we compiled with Zoom, and there are one or more + * devices. If we compiled with Zoom but there is a driver problem, etc., + * this function will report Zoom is not available (rather than raise an error.) */ inline bool is_available() { return c10::zoom::device_count() > 0; @@ -71,15 +60,4 @@ TORCH_ZOOM_API bool canDeviceAccessPeer( TORCH_ZOOM_API c10::Allocator* getZoomDeviceAllocator(); -TORCH_ZOOM_API hipsparseHandle_t getCurrentHIPSparseHandle(); -TORCH_ZOOM_API hipblasHandle_t getCurrentHIPBlasHandle(); -#ifndef DISABLE_HIPBLASLT -TORCH_ZOOM_API hipblasLtHandle_t getCurrentHIPBlasLtHandle(); -#endif - - -#if defined(hipsolverVersionMajor) -TORCH_ZOOM_API hipsolverDnHandle_t getCurrentHIPSolverDnHandle(); -#endif - } // namespace at::zoom \ No newline at end of file diff --git a/aten/src/ATen/zoom/detail/ZoomHooks.cpp b/aten/src/ATen/zoom/detail/ZoomHooks.cpp index 828ef6993c45b7..51ba8ae7be3f7d 100644 --- a/aten/src/ATen/zoom/detail/ZoomHooks.cpp +++ b/aten/src/ATen/zoom/detail/ZoomHooks.cpp @@ -3,31 +3,17 @@ #include #include #include -// #include #include #include #include #include #include #include -// #include #include #include #include #include -// #if AT_CUDNN_ENABLED() -// #include -// #endif - -// #if AT_MAGMA_ENABLED() -// #include -// #endif - -// #if defined(USE_ROCM) -// #include -// #endif - #include #include #include @@ -39,22 +25,11 @@ namespace c10::zoom::_internal { void setHasPrimaryContext(bool (*func)(DeviceIndex)); } -// defined in Aten/zoom/HIPblasHandlePool.cpp -namespace at::zoom { - bool getHIPBlasAtomicsEnabled(); -} - namespace at::zoom::detail { const at::zoom::HIPRTC& hiprtc(); DeviceIndex current_device(); -// static void (*magma_init_fn)() = nullptr; - -// void set_magma_init_fn(void (*fn)()) { -// magma_init_fn = fn; -// } - namespace { bool _hasPrimaryContext(DeviceIndex device_index) { TORCH_CHECK(device_index >= 0 && device_index < c10::zoom::device_count(), @@ -149,13 +124,6 @@ bool ZoomHooks::hasROCM() const { return at::zoom::is_available(); } -// rocBLAS is deterministic if atomic operations are disabled -// for details on when rocBLAS is guaranteed to be bitwise deterministic see below: -// https://github.com/ROCm/rocBLAS/issues/1459#issuecomment-2272082035 -bool ZoomHooks::checkHIPBlasDeterministic() const { - return !at::zoom::getHIPBlasAtomicsEnabled(); -} - // #if defined(USE_DIRECT_NVRTC) || defined(USE_DIRECT_HIPRTC) static std::pair, at::zoom::HIPRTC*> load_hiprtc() { return std::make_pair(nullptr, at::zoom::load_hiprtc()); diff --git a/aten/src/ATen/zoom/detail/ZoomHooks.h b/aten/src/ATen/zoom/detail/ZoomHooks.h index 51cabb8bde377f..d5d813c9dbb87a 100644 --- a/aten/src/ATen/zoom/detail/ZoomHooks.h +++ b/aten/src/ATen/zoom/detail/ZoomHooks.h @@ -20,7 +20,6 @@ struct ZoomHooks : public ZoomHooksInterface { bool isPinnedPtr(const void* data) const override; const Generator& getDefaultZoomGenerator(DeviceIndex device_index = -1) const override; bool hasROCM() const override; - bool checkHIPBlasDeterministic() const override; const at::zoom::HIPRTC& hiprtc() const override; DeviceIndex current_device() const override; bool hasPrimaryContext(DeviceIndex device_index) const override; diff --git a/buckbuild.bzl b/buckbuild.bzl index 4c4fc9a89a280d..9ee843ef74aac0 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -50,6 +50,7 @@ load( "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources", + "aten_ufunc_generated_zoom_sources", ) def read_bool(section, field, default, required = True): @@ -398,6 +399,9 @@ def get_aten_generated_files(enabled_backends): # skipped src_files.extend(aten_ufunc_generated_cuda_sources()) + # TODO(Arham): redo logic once we have a zoom key and backend name + src_files.extend(aten_ufunc_generated_zoom_sources()) + res = {} for file_name in src_files: res[file_name] = [file_name] diff --git a/build.bzl b/build.bzl index 5ab9f92acecca0..299307478a0fb5 100644 --- a/build.bzl +++ b/build.bzl @@ -3,6 +3,7 @@ load( "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources", + "aten_ufunc_generated_zoom_sources", ) def define_targets(rules): @@ -80,13 +81,18 @@ def define_targets(rules): aten_ufunc_generated_cuda_sources() ) + gen_aten_outs_cuda = ( + GENERATED_H_ZOOM + GENERATED_CPP_ZOOM + + aten_ufunc_generated_zoom_sources() + ) + gen_aten_outs = ( GENERATED_H + GENERATED_H_CORE + GENERATED_CPP + GENERATED_CPP_CORE + aten_ufunc_generated_cpu_sources() + aten_ufunc_generated_cpu_kernel_sources() + [ "Declarations.yaml", - ] + gen_aten_outs_cuda + ] + gen_aten_outs_cuda + gen_aten_outs_zoom ) rules.genrule( @@ -208,6 +214,15 @@ GENERATED_CPP_CUDA = [ "RegisterQuantizedCUDA.cpp", ] +GENERATED_H_ZOOM = [ + "ZoomFunctions.h", + "ZoomFunctions_inl.h", +] + +GENERATED_CPP_ZOOM = [ + "RegisterPrivateUse1.cpp", +] + GENERATED_CPP = [ "Functions.cpp", "RegisterBackendSelect.cpp", diff --git a/build.sh b/build.sh new file mode 100644 index 00000000000000..74897f8830e56a --- /dev/null +++ b/build.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +rm -rf build +git clean -fdx -e .idea +git clean -fdX -e .idea + + +export USE_ZOOM=1 +export USE_ROCM=0 +export USE_CUDA=0 +#export USE_PER_OPERATOR_HEADERS=1 +export USE_CCACHE=1 +export BUILD_PYTHON=1 +export USE_NUMPY=1 +export USE_FLASH_ATTENTION=0 +#export BUILD_SHARED_LIBS=ON + +export BUILD_AOT_INDUCTOR_TEST=0 +#export BUILD_BINARY=0 +#export BUILD_CUSTOM_PROTOBUF=1 +export BUILD_DOCS=0 +export BUILD_EXECUTORCH=0 +export BUILD_FUNCTORCH=0 +export BUILD_JNI=0 +#export BUILD_LAZY_TS_BACKEND=1 +#export BUILD_LIBTORCH_CPU_WITH_DEBUG=0 +export BUILD_LITE_INTERPRETER=0 +export BUILD_MOBILE_AUTOGRAD=0 +export BUILD_MOBILE_BENCHMARK=0 +export BUILD_MOBILE_TEST=0 +export BUILD_ONNX_PYTHON=0 +export BUILD_STATIC_RUNTIME_BENCHMARK=0 +export BUILD_TEST=0 +export USE_ASAN=0 +export USE_C10D_GLOO=0 +export USE_C10D_MPI=0 +export USE_C10D_NCCL=0 +export USE_COLORIZE_OUTPUT=0 +export USE_COREML_DELEGATE=0 +export USE_CPP_CODE_COVERAGE=0 +export USE_CUDA=0 +export USE_CUDNN=0 +export USE_CUPTI_SO=0 +export USE_CUSPARSELT=0 +export USE_DISTRIBUTED=1 +export USE_FAKELOWP=0 +export USE_FBGEMM=0 +export USE_FLASH_ATTENTI0=0 +export USE_GFLAGS=0 +export USE_GLOG=0 +export USE_GLOO=0 +export USE_GLOO_WITH_OPENSSL=0 +export USE_GNU_SOURCE=0 +export USE_GOLD_LINKER=0 +export USE_IBVERBS=0 +export USE_INTERNAL_PTHREADPOOL_IMPL=0 +export USE_ITT=0 +export USE_KINETO=0 +export USE_LIBUV=0 +export USE_LIGHTWEIGHT_DISPATCH=0 +export USE_LITE_INTERPRETER_PROFILER=0 +export USE_LITE_PROTO=0 +export USE_MAGMA=0 +export USE_MIMALLOC=0 +export USE_MKLDNN=0 +export USE_MKLDNN_CBLAS=0 +export USE_MPI=0 +export USE_NATIVE_ARCH=0 +export USE_NCCL=0 +export USE_NNAPI=0 +export USE_NNPACK=0 +export USE_NUMA=0 +export USE_NVRTC=0 +export USE_OBSERVERS=0 +export USE_OPENCL=0 +export USE_OPENMP=0 +export USE_PRECOMPILED_HEADERS=0 +export USE_PROF=0 +export USE_PTHREADPOOL=0 +export USE_PYTORCH_METAL=0 +export USE_PYTORCH_METAL_EXPORT=0 +export USE_PYTORCH_QNNPACK=0 +export USE_QNNPACK=0 +#export USE_RCCL=0 +export USE_REDIS=0 +#export USE_ROCM_KERNEL_ASSERT=0 +export USE_SANITIZER=0 +export USE_SLEEF_FOR_ARM_VEC256=0 +export USE_SNPE=0 +export USE_SOURCE_DEBUG_0_MOBILE=0 +export USE_STATIC_CUDNN=0 +export USE_STATIC_MKL=0 +export USE_STATIC_NCCL=0 +export USE_SYSTEM_BENCHMARK=0 +export USE_SYSTEM_CPUINFO=0 +export USE_SYSTEM_EIGEN_INSTALL=0 +export USE_SYSTEM_FP16=0 +export USE_SYSTEM_FXDIV=0 +export USE_SYSTEM_GLOO=0 +export USE_SYSTEM_GOOGLEBENCHMARK=0 +export USE_SYSTEM_GOOGLETEST=0 +export USE_SYSTEM_LIBS=0 +export USE_SYSTEM_NCCL=0 +export USE_SYSTEM_0NX=0 +export USE_SYSTEM_PSIMD=0 +export USE_SYSTEM_PTHREADPOOL=0 +export USE_SYSTEM_PYBIND11=0 +export USE_SYSTEM_SLEEF=0 +export USE_SYSTEM_XNNPACK=0 +export USE_TBB=0 +export USE_TCP_OPENSSL_LINK=0 +export USE_TCP_OPENSSL_LOAD=0 +export USE_TENSORPIPE=1 +export USE_TSAN=0 +export USE_UCC=0 +export USE_VALGRIND=0 +export USE_VULKAN_FP16_INFERENCE=0 +export USE_VULKAN_RELAXED_PRECISI0=0 +export USE_XNNPACK=0 +export USE_XPU=0 + +# for the ligerllama example we need distributed and tensorpipe, only because +# huggingface model.generate insists on querying torch.distributed and distributed relies on tensorpipe +# this could be a factor of nod-pytorch being out of date with upstream: +# https://github.com/pytorch/pytorch/issues/97397 + +python setup.py develop +python zoom_extension/examples/test.py +PYTORCH_TEST_WITH_SLOW=1 TORCH_TEST_DEVICES=zoom_extension/test/pytorch_test_base.py ./test.sh +python setup.py bdist_wheel \ No newline at end of file diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 1a43c7d53aa9fb..82d8b6b3372135 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1677,6 +1677,8 @@ if(MSVC AND BUILD_SHARED_LIBS) install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) elseif(USE_ROCM) install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) + elseif(USE_ZOOM) + install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) endif() endif() @@ -1822,16 +1824,13 @@ if(USE_ZOOM) target_link_libraries(torch_zoom PUBLIC c10_zoom) # target_link_libraries(torch_zoom PUBLIC c10) - # this is where lib amdhip64 is actually linked (e.g. HIP symbols) - # should be included in c10_zoom - # target_link_libraries(torch_zoom PUBLIC ${PYTORCH_HIP_LIBRARIES}) if(NOT INTERN_BUILD_MOBILE) # TODO: Cut this over to ATEN_HIP_FILES_GEN_LIB. At the moment, we # only generate CUDA files # NB: This dependency must be PRIVATE, because we don't install # ATEN_CUDA_FILES_GEN_LIB (it's a synthetic target just to get the # correct dependency from generated files.) - #target_link_libraries(torch_zoom PRIVATE ATEN_ZOOM_FILES_GEN_LIB) + target_link_libraries(torch_zoom PRIVATE ATEN_ZOOM_FILES_GEN_LIB) endif() target_link_libraries(torch_zoom PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_zoom PRIVATE ${Caffe2_ZOOM_DEPENDENCY_LIBS}) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index f022db009f4673..ce4762860bb524 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -201,6 +201,7 @@ if(INTERN_BUILD_ATEN_OPS) include("${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake") + include("${CMAKE_BINARY_DIR}/aten/src/ATen/zoom_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake") message(STATUS "${gen_type} outputs: ${gen_outputs}") @@ -210,6 +211,7 @@ if(INTERN_BUILD_ATEN_OPS) OUTPUT ${generated_${gen_type}} ${cuda_generated_${gen_type}} + ${zoom_generated_${gen_type}} ${core_generated_${gen_type}} ${cpu_vec_generated_${gen_type}} ${ops_generated_${gen_type}} @@ -218,6 +220,7 @@ if(INTERN_BUILD_ATEN_OPS) ${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake + ${CMAKE_BINARY_DIR}/aten/src/ATen/zoom_generated_${gen_type}.cmake COMMAND ${GEN_COMMAND_${gen_type}} DEPENDS ${all_python} ${${gen_type}_templates} ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml @@ -235,17 +238,25 @@ if(INTERN_BUILD_ATEN_OPS) ${generated_declarations_yaml} ${generated_unboxing_sources}) add_custom_target(ATEN_CUDA_FILES_GEN_TARGET DEPENDS ${cuda_generated_headers} ${cuda_generated_sources}) + add_custom_target(ATEN_ZOOM_FILES_GEN_TARGET DEPENDS + ${zoom_generated_headers} ${zoom_generated_sources}) add_library(ATEN_CPU_FILES_GEN_LIB INTERFACE) add_library(ATEN_CUDA_FILES_GEN_LIB INTERFACE) + add_library(ATEN_ZOOM_FILES_GEN_LIB INTERFACE) add_dependencies(ATEN_CPU_FILES_GEN_LIB ATEN_CPU_FILES_GEN_TARGET) add_dependencies(ATEN_CUDA_FILES_GEN_LIB ATEN_CUDA_FILES_GEN_TARGET) + add_dependencies(ATEN_ZOOM_FILES_GEN_LIB ATEN_ZOOM_FILES_GEN_TARGET) + message(zoom_gen_headers="${zoom_generated_headers}") + message(zoom_gen_sources="${zoom_generated_sources}") + message(cuda_gen_headers="${cuda_generated_headers}") message(cuda_gen_sources="${cuda_generated_sources}") if(USE_PER_OPERATOR_HEADERS) target_compile_definitions(ATEN_CPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS) target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS) + target_compile_definitions(ATEN_ZOOM_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS) endif() # Handle source files that need to be compiled multiple times for diff --git a/torch/csrc/zoom/Module.cpp b/torch/csrc/zoom/Module.cpp index 7a0470fad0613e..341f8484b30679 100644 --- a/torch/csrc/zoom/Module.cpp +++ b/torch/csrc/zoom/Module.cpp @@ -329,83 +329,6 @@ at::Scalar as_scalar(PyObject* arg) { return at::Scalar(THPUtils_unpackDouble(arg)); } -// Entrypoint for the callable created by torch.zoom.jiterator -// See jiterator.py for more details -// PyObject* THCPModule_zoomJiteratorCompileAndLaunchKernel( -// PyObject* _unused, -// PyObject* args) { -// HANDLE_TH_ERRORS - -// PyObject* code_string_o = nullptr; -// PyObject* kernel_name_o = nullptr; -// PyObject* return_by_ref_o = nullptr; -// PyObject* num_outputs_o = nullptr; -// PyObject* tensors_o = nullptr; -// PyObject* kwargs_o = nullptr; -// if (!PyArg_ParseTuple( -// args, -// "OOOOO|O", -// &code_string_o, -// &kernel_name_o, -// &return_by_ref_o, -// &num_outputs_o, -// &tensors_o, -// &kwargs_o)) { -// return nullptr; -// } - -// const std::string code_string = THPUtils_unpackString(code_string_o); -// const std::string kernel_name = THPUtils_unpackString(kernel_name_o); -// const bool return_by_ref = THPUtils_unpackBool(return_by_ref_o); -// const int num_outputs = static_cast(THPUtils_unpackLong(num_outputs_o)); - -// TORCH_CHECK( -// PyTuple_Check(tensors_o), -// "tensors argument is expected to " -// "be a tuple, but got ", -// THPUtils_typename(tensors_o)); -// Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o); - -// c10::SmallVector tensors; -// for (const auto i : c10::irange(num_tensors)) { -// PyObject* _tensor = PyTuple_GET_ITEM(tensors_o, i); -// TORCH_CHECK( -// THPVariable_Check(_tensor), -// i, -// " of input tensors tuple is not a Tensor"); - -// tensors.emplace_back(THPVariable_Unpack(_tensor)); -// } - -// c10::SmallVector extra_args; -// PyObject* key = nullptr; -// PyObject* value = nullptr; -// Py_ssize_t pos = 0; -// while (PyDict_Next(kwargs_o, &pos, &key, &value)) { -// extra_args.emplace_back(as_scalar(value)); -// } - -// c10::SmallVector outputs = at::zoom::CompileAndLaunchKernel( -// code_string, -// kernel_name, -// num_outputs, -// tensors, -// extra_args, -// return_by_ref); - -// if (num_outputs == 1) { -// return THPVariable_Wrap(outputs[0]); -// } else { -// PyObject* output_tuple = PyTuple_New(num_outputs); -// for (int i = 0; i < num_outputs; ++i) { -// PyTuple_SetItem(output_tuple, i, THPVariable_Wrap(outputs[i])); -// } -// return output_tuple; -// } - -// END_HANDLE_TH_ERRORS -// } - PyObject* THCPModule_zoomCachingAllocator_raw_delete( PyObject* _unused, PyObject* obj) { @@ -444,26 +367,6 @@ PyObject* THCPModule_zoomSynchronize(PyObject* _unused, PyObject* noargs) { END_HANDLE_TH_ERRORS } -// PyObject* THCPModule_zoomIPCCollect(PyObject* _unused, PyObject* noargs) { -// HANDLE_TH_ERRORS -// torch::zoomIPCCollect(); -// Py_RETURN_NONE; -// END_HANDLE_TH_ERRORS -// } - -// PyObject* THCPModule_zoomSleep(PyObject* _unused, PyObject* cycles) { -// HANDLE_TH_ERRORS -// TORCH_CHECK( -// THPUtils_checkLong(cycles), "torch.zoom._sleep(): expected 'int'"); -// int64_t unpacked_cycles = THPUtils_unpackLong(cycles); -// { -// pybind11::gil_scoped_release no_gil; -// at::zoom::sleep(unpacked_cycles); -// } -// Py_RETURN_NONE; -// END_HANDLE_TH_ERRORS -// } - // We need to ensure that as long as a thread will NEVER loose the GIL as long // as it holds the CUDA mutex. Otherwise another thread might be scheduled and // try to e.g. allocate a new tensor which will cause a deadlock. It's enough to @@ -929,30 +832,10 @@ static void registerZoomDeviceProperties(PyObject* module) { return stream.str(); }); - // m.def( - // "_zoom_record_memory_history_legacy", - // static_cast( - // torch::zoom::_record_memory_history)); - - // m.def( - // "_zoom_record_memory_history", - // static_cast, - // std::optional, - // const std::string&, - // size_t)>(torch::zoom::_record_memory_history)); - m.def("_zoom_isHistoryEnabled", []() { return c10::zoom::ZoomCachingAllocator::isHistoryEnabled(); }); - // m.def("_zoom_get_conv_benchmark_empty_cache", []() { - // return at::native::_cudnn_get_conv_benchmark_empty_cache(); - // }); - - // m.def("_cudnn_set_conv_benchmark_empty_cache", [](bool enable) { - // return at::native::_cudnn_set_conv_benchmark_empty_cache(enable); - // }); } // We choose to ignore certain blocks that are currently allocated @@ -1349,33 +1232,6 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { END_HANDLE_TH_ERRORS } -PyObject* THCPModule_getCurrentBlasHandle_wrap( - PyObject* self, - PyObject* noargs) { - HANDLE_TH_ERRORS - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - hipblasHandle_t handle = at::zoom::getCurrentHIPBlasHandle(); - return PyLong_FromVoidPtr(handle); - END_HANDLE_TH_ERRORS -} - - -// PyObject* THCPModule_rocm_is_backward_pass( -// PyObject* _unused, -// PyObject* noargs) { -// HANDLE_TH_ERRORS -// #if USE_ROCM -// if (at::ROCmBackwardPassGuard::is_backward_pass()) { -// Py_RETURN_TRUE; -// } else { -// Py_RETURN_FALSE; -// } -// #else -// Py_RETURN_FALSE; -// #endif -// END_HANDLE_TH_ERRORS -// } - static PyObject* THCPModule_isCurrentStreamCapturing_wrap( PyObject* self, PyObject* noargs) { @@ -1422,10 +1278,6 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_getDefaultStream_wrap, METH_O, nullptr}, - {"_zoom_getCurrentBlasHandle", - THCPModule_getCurrentBlasHandle_wrap, - METH_NOARGS, - nullptr}, {"_zoom_isCurrentStreamCapturing", THCPModule_isCurrentStreamCapturing_wrap, METH_NOARGS, @@ -1491,14 +1343,6 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_zoomGetSyncDebugMode, METH_NOARGS, nullptr}, - // {"_zoom_jiterator_compile_and_launch_kernel", - // THCPModule_zoomJiteratorCompileAndLaunchKernel, - // METH_VARARGS, - // nullptr}, - // {"_rocm_is_backward_pass", - // THCPModule_rocm_is_backward_pass, - // METH_NOARGS, - // nullptr}, {nullptr}}; PyMethodDef* THCPModule_methods() { @@ -1519,13 +1363,7 @@ void initHiprtBindings(PyObject* module); void initModule(PyObject* module) { // python::initCommMethods(module); -// // As weird as it seems, this file is also compiled for ROCm, -// // so this condition might not always be true... shared::initHiprtBindings(module); -// shared::initNvtxBindings(module); -// #if defined(USE_CUDNN) || defined(USE_ROCM) -// shared::initCudnnBindings(module); -// #endif registerZoomDeviceProperties(module); registerZoomPluggableAllocator(module); } diff --git a/torchgen/dest/__init__.py b/torchgen/dest/__init__.py index 0c684fc1915cb9..2c304b3188c407 100644 --- a/torchgen/dest/__init__.py +++ b/torchgen/dest/__init__.py @@ -16,4 +16,5 @@ compute_ufunc_cpu as compute_ufunc_cpu, compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, compute_ufunc_cuda as compute_ufunc_cuda, + compute_ufunc_zoom as compute_ufunc_zoom, ) diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index fced019cc4e308..ac1f4c60d74429 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -30,6 +30,7 @@ DispatchKey, gets_generated_out_inplace_wrapper, is_cuda_dispatch_key, + is_zoom_dispatch_key, NativeFunction, NativeFunctionsGroup, SchemaKind, @@ -56,6 +57,8 @@ def gen_registration_headers( headers.append("#include ") else: headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.PrivateUse1: #TODO(Arham): remove once we have a zoom key + headers.append("#include ") elif backend_index.dispatch_key == DispatchKey.MPS: headers.append("#include ") elif per_operator_headers: @@ -81,9 +84,12 @@ def gen_empty_impl_names( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, + DispatchKey.PrivateUse1, # TODO (Arham) change keys DispatchKey.MPS, ): dispatch = str(backend_index.dispatch_key).lower() + if backend_index.dispatch_key == DispatchKey.PrivateUse1: + dispatch = "zoom" empty_impl = f"at::detail::empty_{dispatch}" empty_strided_impl = f"at::detail::empty_strided_{dispatch}" elif backend_index.dispatch_key in ( @@ -506,6 +512,10 @@ def generate_defn(cpp_sig: CppSignature) -> str: device_guard = ( f"globalContext().lazyInitCUDA();\n{device_guard}" ) + if is_zoom_dispatch_key(self.backend_index.dispatch_key): + device_guard = ( + f"globalContext().lazyInitPrivateUse1();\n{device_guard}" + ) else: # kernel is operating on existing tensors @@ -600,6 +610,7 @@ def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, + DispatchKey.PrivateUse1, # TODO (Arham): change keys DispatchKey.MPS, DispatchKey.CompositeExplicitAutogradNonFunctional, ]: @@ -631,6 +642,7 @@ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, + DispatchKey.PrivateUse1, # TODO (Arham): change keys DispatchKey.MPS, DispatchKey.CompositeExplicitAutogradNonFunctional, ) @@ -699,6 +711,9 @@ def gen_class( guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" else: guard_field = "c10::cuda::OptionalCUDAGuard guard_;" + # TODO (Arham): change keys + elif self.backend_index.dispatch_key == DispatchKey.PrivateUse1: + guard_field = "c10::OptionalDeviceGuard guard_;" elif ( self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index ffc879afb6cdba..999f7489a8ff66 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -321,6 +321,39 @@ def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: }} """ +@with_native_function +def compute_ufunc_zoom(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.PrivateUse1)) + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} + }} +) +""" + ) + dtype_cases_str = "\n".join(dtype_cases) + stub_sig = StubSignature(g) + return f""" +{ufunctors} +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", + {dtype_cases_str} + ); +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # diff --git a/torchgen/gen.py b/torchgen/gen.py index d715361146ea0e..057e12111f2ebe 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -69,6 +69,7 @@ DispatchKey, FRAGMENT_NAMESPACES, FunctionSchema, + is_zoom_dispatch_key, is_cuda_dispatch_key, is_generic_dispatch_key, is_ufunc_dispatch_key, @@ -194,7 +195,7 @@ def parse_native_yaml_struct( use_out_as_primary=True, external=False, # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), + device_guard=is_cuda_dispatch_key(k) or is_zoom_dispatch_key(k), index=v, ) return ParsedYaml(rs, indices) @@ -1729,6 +1730,7 @@ def gen_aggregated_headers( selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], cpu_fm: FileManager, + zoom_fm: FileManager, cuda_fm: FileManager, functions_keys: Set[DispatchKey], dispatch_keys: Sequence[DispatchKey], @@ -1810,6 +1812,7 @@ def gen_aggregated_headers( for dispatch_key in dispatch_keys: fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + fm = zoom_fm if is_zoom_dispatch_key(dispatch_key) else fm if dispatch_key in functions_keys: inl_headers = f"#include " @@ -1849,6 +1852,7 @@ def gen_per_operator_headers( selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], cpu_fm: FileManager, + zoom_fm: FileManager, cuda_fm: FileManager, ops_fm: FileManager, functions_keys: Set[DispatchKey], @@ -1998,6 +2002,7 @@ def gen_per_operator_headers( ) fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + fm = zoom_fm if is_zoom_dispatch_key(dispatch_key) else fm inl_headers = f"#include " fm.write_with_template( @@ -2046,6 +2051,7 @@ def gen_headers( backend_indices: Dict[DispatchKey, BackendIndex], core_fm: FileManager, cpu_fm: FileManager, + zoom_fm: FileManager, cuda_fm: FileManager, ops_fm: FileManager, dispatch_keys: Sequence[DispatchKey], @@ -2061,6 +2067,7 @@ def gen_headers( selector=selector, backend_indices=backend_indices, cpu_fm=cpu_fm, + zoom_fm=zoom_fm, cuda_fm=cuda_fm, ops_fm=ops_fm, dispatch_keys=dispatch_keys, @@ -2076,6 +2083,7 @@ def gen_headers( selector=selector, backend_indices=backend_indices, cpu_fm=cpu_fm, + zoom_fm=zoom_fm, cuda_fm=cuda_fm, dispatch_keys=dispatch_keys, functions_keys=functions_keys, @@ -2186,6 +2194,7 @@ def gen_source_files( core_fm: FileManager, cpu_fm: FileManager, cpu_vec_fm: FileManager, + zoom_fm: FileManager, cuda_fm: FileManager, dispatch_keys: Sequence[DispatchKey], functions_keys: Set[DispatchKey], @@ -2209,6 +2218,13 @@ def gen_source_files( for dispatch_key in dispatch_keys: fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + if is_zoom_dispatch_key(dispatch_key): + fm = zoom_fm + extra_cuda_headers = """\ + #include + #include + #include + #include """ if per_operator_headers: @@ -2296,7 +2312,7 @@ def operator_headers() -> List[str]: "RegisterDispatchKey.cpp", lambda: { "extra_cuda_headers": extra_cuda_headers - if is_cuda_dispatch_key(dispatch_key) + if is_cuda_dispatch_key(dispatch_key) or is_zoom_dispatch_key(dispatch_key) else "", "external_backend_headers": "", "dispatch_headers": dest.gen_registration_headers( @@ -2350,6 +2366,21 @@ def operator_headers() -> List[str]: "native_definitions": dest.compute_ufunc_cuda(g), }, ) + elif dispatch_key is DispatchKey.PrivateUse1: # TODO(Arham): change keys + zoom_headers = "#include " + fm.write_with_template( + f"UfuncZoom_{name}.cu", + "UfuncZoom.cu", + lambda: { + "name": name, + "zoom_headers": zoom_headers, + "meta_declaration": compute_meta_function_declaration(g), + "native_declaration": dest.compute_native_function_declaration( + g, backend_indices[dispatch_key] + ), + "native_definitions": dest.compute_ufunc_zoom(g), + }, + ) else: raise AssertionError(f"unrecognized {dispatch_key} for ufunc") @@ -2887,6 +2918,7 @@ def main() -> None: core_fm = make_file_manager(options=options, install_dir=core_install_dir) cpu_fm = make_file_manager(options=options) cpu_vec_fm = make_file_manager(options=options) + zoom_fm = make_file_manager(options=options) cuda_fm = make_file_manager(options=options) ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) @@ -2896,6 +2928,7 @@ def main() -> None: functions_keys = { DispatchKey.CPU, DispatchKey.CUDA, + DispatchKey.PrivateUse1, # TODO(Arham): change keys DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeImplicitAutogradNestedTensor, DispatchKey.CompositeExplicitAutograd, @@ -2936,6 +2969,7 @@ def main() -> None: core_fm=core_fm, cpu_fm=cpu_fm, cpu_vec_fm=cpu_vec_fm, + zoom_fm=zoom_fm, cuda_fm=cuda_fm, dispatch_keys=dispatch_keys, functions_keys=functions_keys, @@ -2957,6 +2991,7 @@ def main() -> None: backend_indices=backend_indices, core_fm=core_fm, cpu_fm=cpu_fm, + zoom_fm=zoom_fm, cuda_fm=cuda_fm, ops_fm=ops_fm, dispatch_keys=dispatch_keys, @@ -2977,6 +3012,7 @@ def main() -> None: (cpu_fm, ""), (cpu_vec_fm, "cpu_vec_"), (core_fm, "core_"), + (zoom_fm, "zoom_"), (cuda_fm, "cuda_"), (ops_fm, "ops_"), ]: diff --git a/torchgen/model.py b/torchgen/model.py index 2706f234c56b0a..40b81a69f18f87 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -259,9 +259,9 @@ def codegen_per_backend_entries() -> str: f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" ) - -STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU} -UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} +# TODO(Arham): change keys +STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU, DispatchKey.PrivateUse1} +UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU, DispatchKey.PrivateUse1} # Set of supported dispatch keys dispatch_keys = [ @@ -270,6 +270,7 @@ def codegen_per_backend_entries() -> str: DispatchKey.SparseCsrCPU, DispatchKey.MkldnnCPU, DispatchKey.CUDA, + DispatchKey.PrivateUse1, # TODO(Arham): replace with zoom key DispatchKey.MPS, DispatchKey.SparseCUDA, DispatchKey.SparseCsrCUDA, @@ -314,6 +315,11 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: DispatchKey.AutogradCUDA, } +def is_zoom_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.PrivateUse1 + } + # Structured kernel generation is only supported for certain key types; # otherwise use old-style diff --git a/ufunc_defs.bzl b/ufunc_defs.bzl index 4490f05be01519..f94b9e866765bb 100644 --- a/ufunc_defs.bzl +++ b/ufunc_defs.bzl @@ -23,3 +23,9 @@ def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"): "UfuncCUDA_{}.cu".format(n) for n in aten_ufunc_names ]] + +def aten_ufunc_generated_zoom_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncZoom_{}.cu".format(n) + for n in aten_ufunc_names + ]] \ No newline at end of file From 0b7cc75307b9149241a2a9c055dd1e479d69afe1 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Sat, 28 Dec 2024 00:25:57 +0000 Subject: [PATCH 04/23] add kernel deps for llama3 --- aten/src/ATen/native/SharedReduceOps.h | 15 +- aten/src/ATen/native/SoftMax.cpp | 8 +- aten/src/ATen/native/native_functions.yaml | 99 +- .../ATen/native/zoom/BinaryDivFloorKernel.cu | 83 + .../ATen/native/zoom/BinaryDivTrueKernel.cu | 61 + .../ATen/native/zoom/BinaryDivTruncKernel.cu | 53 + aten/src/ATen/native/zoom/BinaryInternal.h | 48 + aten/src/ATen/native/zoom/BinaryMulKernel.cu | 48 + aten/src/ATen/native/zoom/Bmm.cpp | 121 + aten/src/ATen/native/zoom/CompareKernels.cu | 103 + aten/src/ATen/native/zoom/Copy.cu | 29 +- aten/src/ATen/native/zoom/CumminmaxKernel.cu | 29 + aten/src/ATen/native/zoom/CumprodKernel.cu | 23 + aten/src/ATen/native/zoom/CumsumKernel.cu | 25 + aten/src/ATen/native/zoom/DeviceSqrt.cuh | 18 + aten/src/ATen/native/zoom/HIPbmm.cu | 126 + aten/src/ATen/native/zoom/Indexing.cu | 1798 ++++++++++ aten/src/ATen/native/zoom/KernelUtils.cuh | 97 + .../ATen/native/zoom/LegacyThrustHelpers.cu | 113 + .../ATen/native/zoom/LogcumsumexpKernel.cu | 124 + aten/src/ATen/native/zoom/Math.cuh | 3026 +++++++++++++++++ .../ATen/native/zoom/PersistentSoftmax.cuh | 402 +++ aten/src/ATen/native/zoom/Reduce.cuh | 1354 ++++++++ .../src/ATen/native/zoom/ReduceLogicKernel.cu | 38 + aten/src/ATen/native/zoom/ScanKernels.cpp | 115 + aten/src/ATen/native/zoom/ScanKernels.h | 18 + aten/src/ATen/native/zoom/ScanUtils.cuh | 459 +++ aten/src/ATen/native/zoom/Shape.cu | 521 +++ aten/src/ATen/native/zoom/SoftMax.cu | 1272 +++++++ aten/src/ATen/native/zoom/Sort.cpp | 128 + aten/src/ATen/native/zoom/Sort.cu | 384 +++ aten/src/ATen/native/zoom/Sort.h | 17 + aten/src/ATen/native/zoom/SortImpl.cu | 37 + aten/src/ATen/native/zoom/SortStable.cu | 286 ++ aten/src/ATen/native/zoom/SortStable.h | 19 + aten/src/ATen/native/zoom/SortUtils.cuh | 333 ++ aten/src/ATen/native/zoom/Sorting.cpp | 208 ++ aten/src/ATen/native/zoom/Sorting.cu | 282 ++ aten/src/ATen/native/zoom/Sorting.h | 18 + aten/src/ATen/native/zoom/SortingCommon.cuh | 188 + .../ATen/native/zoom/SortingRadixSelect.cuh | 410 +++ aten/src/ATen/native/zoom/TensorTopK.cpp | 96 + aten/src/ATen/native/zoom/TensorTopK.cu | 895 +++++ aten/src/ATen/native/zoom/TensorTopK.h | 14 + aten/src/ATen/native/zoom/TriangularOps.cu | 165 + .../native/zoom/UnaryGeometricCosKernel.cu | 58 + .../native/zoom/UnaryGeometricSinKernel.cu | 58 + aten/src/ATen/native/zoom/UnarySignKernels.cu | 121 + aten/src/ATen/native/zoom/block_reduce.cuh | 143 + torchgen/dest/ufunc.py | 2 +- 50 files changed, 14018 insertions(+), 70 deletions(-) create mode 100644 aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu create mode 100644 aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu create mode 100644 aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu create mode 100644 aten/src/ATen/native/zoom/BinaryInternal.h create mode 100644 aten/src/ATen/native/zoom/BinaryMulKernel.cu create mode 100644 aten/src/ATen/native/zoom/Bmm.cpp create mode 100644 aten/src/ATen/native/zoom/CompareKernels.cu create mode 100644 aten/src/ATen/native/zoom/CumminmaxKernel.cu create mode 100644 aten/src/ATen/native/zoom/CumprodKernel.cu create mode 100644 aten/src/ATen/native/zoom/CumsumKernel.cu create mode 100644 aten/src/ATen/native/zoom/DeviceSqrt.cuh create mode 100644 aten/src/ATen/native/zoom/HIPbmm.cu create mode 100644 aten/src/ATen/native/zoom/Indexing.cu create mode 100644 aten/src/ATen/native/zoom/KernelUtils.cuh create mode 100644 aten/src/ATen/native/zoom/LegacyThrustHelpers.cu create mode 100644 aten/src/ATen/native/zoom/LogcumsumexpKernel.cu create mode 100644 aten/src/ATen/native/zoom/Math.cuh create mode 100644 aten/src/ATen/native/zoom/PersistentSoftmax.cuh create mode 100644 aten/src/ATen/native/zoom/Reduce.cuh create mode 100644 aten/src/ATen/native/zoom/ReduceLogicKernel.cu create mode 100644 aten/src/ATen/native/zoom/ScanKernels.cpp create mode 100644 aten/src/ATen/native/zoom/ScanKernels.h create mode 100644 aten/src/ATen/native/zoom/ScanUtils.cuh create mode 100644 aten/src/ATen/native/zoom/Shape.cu create mode 100644 aten/src/ATen/native/zoom/SoftMax.cu create mode 100644 aten/src/ATen/native/zoom/Sort.cpp create mode 100644 aten/src/ATen/native/zoom/Sort.cu create mode 100644 aten/src/ATen/native/zoom/Sort.h create mode 100644 aten/src/ATen/native/zoom/SortImpl.cu create mode 100644 aten/src/ATen/native/zoom/SortStable.cu create mode 100644 aten/src/ATen/native/zoom/SortStable.h create mode 100644 aten/src/ATen/native/zoom/SortUtils.cuh create mode 100644 aten/src/ATen/native/zoom/Sorting.cpp create mode 100644 aten/src/ATen/native/zoom/Sorting.cu create mode 100644 aten/src/ATen/native/zoom/Sorting.h create mode 100644 aten/src/ATen/native/zoom/SortingCommon.cuh create mode 100644 aten/src/ATen/native/zoom/SortingRadixSelect.cuh create mode 100644 aten/src/ATen/native/zoom/TensorTopK.cpp create mode 100644 aten/src/ATen/native/zoom/TensorTopK.cu create mode 100644 aten/src/ATen/native/zoom/TensorTopK.h create mode 100644 aten/src/ATen/native/zoom/TriangularOps.cu create mode 100644 aten/src/ATen/native/zoom/UnaryGeometricCosKernel.cu create mode 100644 aten/src/ATen/native/zoom/UnaryGeometricSinKernel.cu create mode 100644 aten/src/ATen/native/zoom/UnarySignKernels.cu create mode 100644 aten/src/ATen/native/zoom/block_reduce.cuh diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 5b7167ee93dd29..9cdf5df112d716 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -11,8 +11,13 @@ #include #include #elif defined(__HIPCC__) -#include -#include + #ifdef USE_ZOOM + #include + #include + #else + #include + #include + #endif #endif #if defined(__CUDACC__) || defined(__HIPCC__) #include @@ -56,7 +61,11 @@ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) { #include #define compat_pow c10::cuda::compat::pow #elif defined(__HIPCC__) -#include +#ifdef USE_ZOOM + #include + #else + #include + #endif #define compat_pow c10::hip::compat::pow #else #define compat_pow std::pow diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 3188479b931f3b..fd2e8e282ad1d2 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -452,7 +452,7 @@ static Tensor softmax(const Tensor& input_, const int64_t dim_) { Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ return at::_softmax(input_, dim_, true); } else { Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_; @@ -469,7 +469,7 @@ Tensor& softmax_out( std::optional dtype, Tensor& output_) { Tensor output_temp; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) { if (!output_.is_contiguous()) { auto options = @@ -517,7 +517,7 @@ static Tensor log_softmax(const Tensor& input_, const int64_t dim_) { Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ return at::_log_softmax(input_, dim_, true); } else { Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_; @@ -534,7 +534,7 @@ Tensor& log_softmax_out( std::optional dtype, Tensor& output_) { Tensor output_temp; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && + if (((input_.is_cuda() || input_.is_privateuseone())) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) { if (!output_.is_contiguous()) { auto options = diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b28fcfbfc2732e..a5876201f7e9c6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -442,7 +442,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sgn_out + CPU, CUDA, PrivateUse1: sgn_out MPS: sgn_out_mps SparseCPU, SparseCUDA: sgn_sparse_out SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out @@ -707,14 +707,14 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: all_out + CPU, CUDA, PrivateUse1: all_out MPS: all_out_mps - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: all_dims_out + CPU, CUDA, PrivateUse1: all_dims_out CompositeExplicitAutograd: all_dims_out_default cpp_no_default_args: ['dim'] @@ -750,14 +750,14 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: any_out + CPU, CUDA, PrivateUse1: any_out MPS: any_out_mps - func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: any_dims_out + CPU, CUDA, PrivateUse1: any_dims_out CompositeExplicitAutograd: any_dims_out_default cpp_no_default_args: ['dim'] @@ -1259,7 +1259,7 @@ - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_not_out + CPU, CUDA, PrivateUse1: logical_not_out MPS: logical_not_out_mps tags: pointwise @@ -1352,6 +1352,7 @@ dispatch: CPU: bmm_out_cpu CUDA: bmm_out_cuda + PrivateUse1: bmm_out_zoom MPS: bmm_out_mps SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda @@ -1386,6 +1387,7 @@ dispatch: CPU: cat_out_cpu CUDA: cat_out_cuda + PrivateUse1: cat_out_zoom MPS: cat_out_mps QuantizedCPU: cat_out_quantized_cpu @@ -1797,7 +1799,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: cos_out + CPU, CUDA, PrivateUse1: cos_out MPS: cos_out_mps tags: pointwise @@ -1933,6 +1935,7 @@ dispatch: CPU: cummax_helper_cpu CUDA: cummax_helper_cuda + PrivateUse1: cummax_helper_zoom - func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) device_check: NoCheck # TensorIterator @@ -1957,6 +1960,7 @@ dispatch: CPU: cummin_helper_cpu CUDA: cummin_helper_cuda + PrivateUse1: cummin_helper_zoom - func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor variants: function @@ -1976,7 +1980,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: cumprod_out + CPU, CUDA, PrivateUse1: cumprod_out MPS: cumprod_out_mps - func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor @@ -2008,7 +2012,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: cumsum_out + CPU, CUDA, PrivateUse1: cumsum_out MPS: cumsum_out_mps - func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor @@ -2137,7 +2141,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: div_out + CPU, CUDA, PrivateUse1: div_out MPS: div_out_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2163,7 +2167,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: div_out_mode + CPU, CUDA, PrivateUse1: div_out_mode MPS: div_out_mode_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2736,7 +2740,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: floor_divide + CPU, CUDA, PrivateUse1: floor_divide MPS: floor_divide_mps SparseCPU, SparseCUDA: floor_divide_sparse @@ -2744,14 +2748,14 @@ device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: floor_divide_ + CPU, CUDA, PrivateUse1: floor_divide_ MPS: floor_divide_mps_ SparseCPU, SparseCUDA: floor_divide_sparse_ - func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: floor_divide_out + CPU, CUDA, PrivateUse1: floor_divide_out MPS: floor_divide_out_mps SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim @@ -3132,7 +3136,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Tensor_Tensor_out + CPU, CUDA, PrivateUse1: isin_Tensor_Tensor_out MPS: isin_Tensor_Tensor_out_mps - func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor @@ -3143,7 +3147,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Tensor_Scalar_out + CPU, CUDA, PrivateUse1: isin_Tensor_Scalar_out - func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor variants: function @@ -3153,7 +3157,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Scalar_Tensor_out + CPU, CUDA, PrivateUse1: isin_Scalar_Tensor_out - func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor variants: function @@ -3246,6 +3250,7 @@ dispatch: CPU: kthvalue_out_cpu CUDA: kthvalue_out_cuda + PrivateUse1: kthvalue_out_zoom - func: kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -3680,6 +3685,7 @@ dispatch: CPU: log_softmax_cpu_out CUDA: log_softmax_cuda_out + PrivateUse1: log_softmax_zoom_out MPS: log_softmax_mps_out - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor @@ -3690,17 +3696,20 @@ dispatch: CPU: log_softmax_backward_cpu_out CUDA: log_softmax_backward_cuda_out + PrivateUse1: log_softmax_backward_zoom_out MPS: log_softmax_backward_mps_out - func: _logcumsumexp(Tensor self, int dim) -> Tensor dispatch: CPU: _logcumsumexp_cpu CUDA: _logcumsumexp_cuda + PrivateUse1: _logcumsumexp_zoom - func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _logcumsumexp_out_cpu CUDA: _logcumsumexp_out_cuda + PrivateUse1: _logcumsumexp_out_zoom - func: logcumsumexp(Tensor self, int dim) -> Tensor variants: function, method @@ -3945,6 +3954,7 @@ dispatch: CPU: median_cpu CUDA: median_cuda + PrivateUse1: median_zoom MPS: median_mps autogen: median.out @@ -3957,6 +3967,7 @@ dispatch: CPU: median_out_cpu CUDA: median_out_cuda + PrivateUse1: median_out_zoom MPS: median_out_mps - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3969,6 +3980,7 @@ dispatch: CPU: nanmedian_cpu CUDA: nanmedian_cuda + PrivateUse1: nanmedian_zoom autogen: nanmedian.out - func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3980,6 +3992,7 @@ dispatch: CPU: nanmedian_out_cpu CUDA: nanmedian_out_cuda + PrivateUse1: nanmedian_out_zoom - func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -4108,6 +4121,7 @@ dispatch: CPU: mm_out_cpu CUDA: mm_out_cuda + PrivateUse1: mm_out_zoom MPS: mm_out_mps SparseCPU, SparseCUDA: _sparse_mm_out SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out @@ -4192,7 +4206,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: mul_out + CPU, CUDA, PrivateUse1: mul_out MPS: mul_out_mps SparseCPU: mul_out_sparse_cpu SparseCUDA: mul_out_sparse_cuda @@ -4835,7 +4849,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: neg_out + CPU, CUDA, PrivateUse1: neg_out MPS: neg_out_mps SparseCPU, SparseCUDA: neg_out_sparse SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_out @@ -4898,7 +4912,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias + CPU, CUDA, PrivateUse1, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias # We don't need to support mkldnn since this is handled explicitly by the reshape operator. - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor @@ -5301,7 +5315,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sin_out + CPU, CUDA, PrivateUse1: sin_out MPS: sin_out_mps SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_out SparseCPU, SparseCUDA: sin_sparse_out @@ -5506,6 +5520,7 @@ dispatch: CPU: softmax_cpu_out CUDA: softmax_cuda_out + PrivateUse1: softmax_zoom_out MPS: softmax_mps_out - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor @@ -5518,6 +5533,7 @@ dispatch: CPU: softmax_backward_cpu_out CUDA: softmax_backward_cuda_out + PrivateUse1: softmax_backward_zoom_out MPS: softmax_backward_mps_out - func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] @@ -6817,7 +6833,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: zero_ + CPU, CUDA, PrivateUse1: zero_ MPS: zero_mps_ Meta: zero_meta_ SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ @@ -6831,7 +6847,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sub_out + CPU, CUDA, PrivateUse1: sub_out MPS: sub_out_mps SparseCPU, SparseCUDA: sub_out_sparse tags: pointwise @@ -7943,6 +7959,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + PrivateUse1: masked_fill__zoom QuantizedCPU: masked_fill__quantized_cpu QuantizedCUDA: masked_fill__quantized_cuda MPS: masked_fill__mps @@ -7962,6 +7979,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + PrivateUse1: masked_fill__zoom QuantizedCPU: masked_fill__quantized_cpu QuantizedCUDA: masked_fill__quantized_cuda MPS: masked_fill__mps @@ -7993,12 +8011,14 @@ - func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor dispatch: CUDA: masked_softmax_cuda + PrivateUse1: masked_softmax_zoom CPU: masked_softmax_cpu autogen: _masked_softmax.out - func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor dispatch: CUDA: masked_softmax_backward_cuda + PrivateUse1: masked_softmax_backward_zoom CPU: masked_softmax_backward_cpu autogen: _masked_softmax_backward.out @@ -8044,6 +8064,7 @@ dispatch: CPU: index_add_cpu_out CUDA: index_add_cuda_out + PrivateUse1: index_add_zoom_out MPS: index_add_mps_out - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) @@ -8065,6 +8086,7 @@ dispatch: CPU: index_reduce_cpu_out CUDA: index_reduce_cuda_out + PrivateUse1: index_reduce_zoom_out - func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) structured_delegate: index_reduce.out @@ -8725,6 +8747,7 @@ dispatch: CPU: triu_cpu CUDA: triu_cuda + PrivateUse1: triu_zoom MPS: triu_mps_out - func: triu(Tensor self, int diagonal=0) -> Tensor @@ -8736,6 +8759,7 @@ dispatch: CPU: tril_cpu CUDA: tril_cuda + PrivateUse1: tril_zoom MPS: tril_mps_out - func: tril(Tensor self, int diagonal=0) -> Tensor @@ -8759,6 +8783,7 @@ dispatch: CPU: trace_cpu CUDA: trace_cuda + PrivateUse1: trace_zoom MPS: trace_mps autogen: trace.out @@ -8874,7 +8899,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Scalar_out + CPU, CUDA, PrivateUse1: ge_Scalar_out MPS: ge_scalar_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -8893,7 +8918,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Tensor_out + CPU, CUDA, PrivateUse1: ge_Tensor_out MPS: ge_tensor_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -8938,7 +8963,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Scalar_out + CPU, CUDA, PrivateUse1: le_Scalar_out MPS: le_scalar_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -8956,7 +8981,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Tensor_out + CPU, CUDA, PrivateUse1: le_Tensor_out MPS: le_tensor_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -9001,7 +9026,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Scalar_out + CPU, CUDA, PrivateUse1: gt_Scalar_out MPS: gt_scalar_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9020,7 +9045,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Tensor_out + CPU, CUDA, PrivateUse1: gt_Tensor_out MPS: gt_tensor_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9065,7 +9090,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Scalar_out + CPU, CUDA, PrivateUse1: lt_Scalar_out MPS: lt_scalar_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9083,7 +9108,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Tensor_out + CPU, CUDA, PrivateUse1: lt_Tensor_out MPS: lt_tensor_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9141,6 +9166,7 @@ dispatch: CPU, QuantizedCPU: index_select_out_cpu_ CUDA, QuantizedCUDA: index_select_out_cuda + PrivateUse1: index_select_out_zoom MPS: index_select_out_mps - func: index_select(Tensor self, int dim, Tensor index) -> Tensor @@ -9150,6 +9176,7 @@ QuantizedCPU: index_select_quantized_cpu_ CUDA: index_select_cuda QuantizedCUDA: index_select_quantized_cuda + PrivateUse1: index_select_zoom SparseCPU: index_select_sparse_cpu SparseCUDA: index_select_sparse_cuda MPS: index_select_mps @@ -9574,7 +9601,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sign_out + CPU, CUDA, PrivateUse1: sign_out MPS: sign_out_mps SparseCPU, SparseCUDA: sign_sparse_out SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_out @@ -9594,6 +9621,7 @@ dispatch: CPU: signbit_out CUDA: signbit_out + PrivateUse1: signbit_out MPS: signbit_out_mps SparseCPU, SparseCUDA: signbit_sparse_out SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr_out @@ -10009,7 +10037,7 @@ - func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) structured: True dispatch: - CPU, CUDA: sort_stable_out + CPU, CUDA, PrivateUse1: sort_stable_out MPS: sort_stable_out_mps - func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) @@ -10059,6 +10087,7 @@ dispatch: CPU: topk_out_cpu CUDA: topk_out_cuda + PrivateUse1: topk_out_zoom MPS: topk_out_mps - func: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) @@ -10077,7 +10106,7 @@ device_check: NoCheck structured: True dispatch: - CPU, CUDA: all_all_out + CPU, CUDA, PrivateUse1: all_all_out MPS: all_all_out_mps - func: any(Tensor self) -> Tensor @@ -10092,7 +10121,7 @@ device_check: NoCheck structured: True dispatch: - CPU, CUDA: any_all_out + CPU, CUDA, PrivateUse1: any_all_out MPS: any_all_out_mps - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu b/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu new file mode 100644 index 00000000000000..7ad48ce8c7cb1e --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu @@ -0,0 +1,83 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +void div_floor_kernel_zoom(TensorIteratorBase& iter) { + // See NOTE: [Floor Division in Python] + const auto dtype = iter.common_dtype(); + if (dtype == kByte) { + // In the special case of unsigned integer division, floor division is + // equivalent to truncation division (since the signs of the divisor and + // dividend are always the same) + return div_trunc_kernel_zoom(iter); + } else if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::div_floor_integer(a, b); + }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_zoom", [&]() { + using accscalar_t = at::acc_type; + auto b = iter.scalar_value(2); + if (C10_UNLIKELY(b == 0)) { + return div_true_kernel_zoom(iter); + } + + auto inv_b = accscalar_t(1.0) / b; + iter.remove_operand(2); + gpu_kernel(iter, [b, inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + auto mod = std::fmod(a, b); + auto div = (a - mod) * inv_b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::hip::compat::copysign(scalar_t(0), a * inv_b); + } + return floordiv; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::div_floor_floating(a, b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_floor_stub, &binary_internal::div_floor_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu b/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu new file mode 100644 index 00000000000000..09b92154633f61 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel"; +void div_true_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto div_string = jiterator_stringify( + template T div_kernel(T a, T b) { return a / b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, div_string); +#else + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, DivFunctor()); +#endif + return; + } + if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_zoom", [&]() { + using opmath_t = at::opmath_type; + auto inv_b = opmath_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel( + iter, + BUnaryFunctor>( + MulFunctor(), inv_b)); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_zoom", [&]() { + DivFunctor f; + gpu_kernel_with_scalars(iter, f); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu b/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu new file mode 100644 index 00000000000000..bc1f9a851ae327 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu @@ -0,0 +1,53 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +void div_trunc_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_zoom", [&]() { + gpu_kernel_with_scalars( + iter, + [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_zoom", [&]() { + using accscalar_t = at::acc_type; + auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::trunc(a * inv_b); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return std::trunc(a / b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryInternal.h b/aten/src/ATen/native/zoom/BinaryInternal.h new file mode 100644 index 00000000000000..a42408c5207fa1 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryInternal.h @@ -0,0 +1,48 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_zoom(TensorIteratorBase& iter); +void div_trunc_kernel_zoom(TensorIteratorBase& iter); +} // namespace binary_internal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/zoom/BinaryMulKernel.cu b/aten/src/ATen/native/zoom/BinaryMulKernel.cu new file mode 100644 index 00000000000000..dd42ba4d24880d --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryMulKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +CONSTEXPR_EXCEPT_WIN_CUDA char mul_name[] = "mul_kernel"; +void mul_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto mul_string = jiterator_stringify( + template T mul_kernel(T a, T b) { return a * b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, mul_string); +#else + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_zoom", [&]() { + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(mul_stub, &mul_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Bmm.cpp b/aten/src/ATen/native/zoom/Bmm.cpp new file mode 100644 index 00000000000000..f95e530655919f --- /dev/null +++ b/aten/src/ATen/native/zoom/Bmm.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + + +namespace at::native { + // Forward decl, defined in HIPbmm.cu + template + void batched_matmul(const T* A, const T* B, T* C, int M, int N, int K, int batch_size); + + const Tensor& bmm_out_hip_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2) { + // handle pathological cases + if (result.numel() == 0) { + return result; + } else if (batch1.size(2) == 0) { + return result.zero_(); + } + + c10::MaybeOwned result_ = c10::MaybeOwned::borrowed(result); + IntArrayRef result_strides = result.strides(); + IntArrayRef result_sizes = result.sizes(); + + int m = result_sizes[1]; + int n = result_sizes[2]; + int k = batch1.sizes()[2]; + int num_batches = result_->sizes()[0]; + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "bmm_hip", [&] { + const scalar_t* batch1_ptr = batch1.const_data_ptr(); + const scalar_t* batch2_ptr = batch2.const_data_ptr(); + scalar_t* result_ptr = result_->mutable_data_ptr(); + + batched_matmul(batch1_ptr, batch2_ptr, result_ptr, m, n, k, num_batches); + }); + if (!result.is_same(*result_)) { + result.copy_(*result_); + } + return result; + + } + + TORCH_IMPL_FUNC(bmm_out_zoom)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) + { + NoNamesGuard guard; + bmm_out_hip_impl(result, result, batch1, batch2); + } + + Tensor& mm_out_hip_impl(Tensor& result, const Tensor& mat1, const Tensor& mat2) { + // Make sure to keep addmm_hip below in sync with this code; it + // preflights a check to try to avoid actually needing to call + // expand(). + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ) + + TensorArg targs[]{{result, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}}; + checkAllSameGPU(__func__, targs); + + IntArrayRef mat1_sizes = mat1.sizes(); + IntArrayRef mat2_sizes = mat2.sizes(); + at::ScalarType scalar_type = mat1.scalar_type(); + TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + + // resize result tensor + at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); + IntArrayRef result_sizes = result.sizes(); + if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { + return result; + } + + if (mat1.numel() == 0) { + // By definition, values in self should be ignored. nans and infs + // should not propagate + return result.zero_(); + } + + int m = mat1_sizes[0]; + int n = mat1_sizes[1]; + int k = mat2_sizes[1]; + + // TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result.is_conj()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "mm_zoom", + [&] { + const scalar_t* mat1_ptr = mat1.const_data_ptr(); + const scalar_t* mat2_ptr = mat2.const_data_ptr(); + scalar_t* result_ptr = result.mutable_data_ptr(); + batched_matmul(mat1_ptr, mat2_ptr, result_ptr, m, n, k, 1); + }); + + return result; + } + + TORCH_IMPL_FUNC(mm_out_zoom)(const Tensor& self, const Tensor& mat2, const Tensor& result) + { + mm_out_hip_impl(const_cast(result), self, mat2); + } + +} // at::native + + diff --git a/aten/src/ATen/native/zoom/CompareKernels.cu b/aten/src/ATen/native/zoom/CompareKernels.cu new file mode 100644 index 00000000000000..7975d449d19592 --- /dev/null +++ b/aten/src/ATen/native/zoom/CompareKernels.cu @@ -0,0 +1,103 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { namespace { + +enum class OpType {GE, GT, LE, LT}; + +template +struct CompareFunctor{ + constexpr CompareFunctor(OpType op): op_(op) {}; + OpType op_; + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + if (op_ == OpType::GE) { + return a >= b; + } else if (op_ == OpType::GT) { + return a > b; + } else if (op_ == OpType::LE) { + return a <= b; + } else { //LT + return a < b; + } + } +}; + +// Reflects the comparison operator, so reflect(op)(a, b) == op(b, a) +OpType reflect(OpType x) { + switch (x) { + case OpType::GE: return OpType::LE; + case OpType::GT: return OpType::LT; + case OpType::LE: return OpType::GE; + case OpType::LT: return OpType::GT; + } + TORCH_INTERNAL_ASSERT(false, "Invalid OpType"); +} + +} // namespace (anonymous) + +template +void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) { + CompareFunctor f(op); + gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool { + return f(lhs, rhs); + }); +} + +template +void compare_kernel_impl(TensorIteratorBase &iter, OpType op) { + // If either input is a cpu scalar, perform the equivalent comparison + // where the scalar is on the right hand side. This saves us from + // generating two otherwise identical kernels with mirrored + // arguments. + if (iter.is_cpu_scalar(1)) { + const scalar_t lhs = iter.scalar_value(1); + iter.remove_operand(1); + const DeviceGuard device_guard(iter.device(1)); + compare_scalar_kernel(iter, reflect(op), lhs); + } else if (iter.is_cpu_scalar(2)) { + const scalar_t rhs = iter.scalar_value(2); + iter.remove_operand(2); + compare_scalar_kernel(iter, op, rhs); + } else { + CompareFunctor f(op); + gpu_kernel(iter, f); + } +} + +C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_zoom", [&]() { + compare_kernel_impl(iter, op); + }); +} + + +void ge_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GE); +} + +void gt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GT); +} + +void le_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LE); +} + +void lt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LT); +} + +REGISTER_PRIVATEUSE1_DISPATCH(ge_stub, &ge_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(gt_stub, >_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(le_stub, &le_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(lt_stub, <_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Copy.cu b/aten/src/ATen/native/zoom/Copy.cu index 57436f844beedc..f1ad63e7cd7e63 100644 --- a/aten/src/ATen/native/zoom/Copy.cu +++ b/aten/src/ATen/native/zoom/Copy.cu @@ -27,33 +27,8 @@ namespace at::native { // forward decl, defined below void direct_copy_kernel_zoom(TensorIteratorBase &iter); -// NB: Ignores the negative bit on tensors -CONSTEXPR_EXCEPT_WIN_CUDA char neg_name[] = "neg_kernel"; -void neg_kernel_zoom(TensorIteratorBase& iter) { - auto dtype = iter.dtype(); - if (at::isComplexType(dtype)) { - static const auto neg_string = jiterator_stringify( - template - T neg_kernel(T a) { - return -a; - } - ); // neg_string - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_zoom", [&]() { - jitted_gpu_kernel< - /*name=*/ neg_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, neg_string); - }); - - } else { - AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, dtype, "neg_zoom", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return -a; - }); - }); - } -} +// forward decl, defined in UnarySignKernels.cu +void neg_kernel_zoom(TensorIteratorBase& iter); // NB: Ignores the negative bit on tensors CONSTEXPR_EXCEPT_WIN_CUDA char conj_name[] = "conj_kernel"; diff --git a/aten/src/ATen/native/zoom/CumminmaxKernel.cu b/aten/src/ATen/native/zoom/CumminmaxKernel.cu new file mode 100644 index 00000000000000..5c3e3a6aa211f4 --- /dev/null +++ b/aten/src/ATen/native/zoom/CumminmaxKernel.cu @@ -0,0 +1,29 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +#include +#include + +namespace at::native { + +void launch_cummax_zoom_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, + self.scalar_type(), "cummax_zoom", [&]() { + scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits::infinity()) : std::numeric_limits::lowest(); + scan_dim_with_indices(self, values, indices, dim, init, std::greater_equal()); + }); +} + +void launch_cummin_zoom_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, + self.scalar_type(), "cummin_zoom", [&]() { + scalar_t init = self.is_floating_point() ? std::numeric_limits::infinity() : std::numeric_limits::max(); + scan_dim_with_indices(self, values, indices, dim, init, std::less_equal()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CumprodKernel.cu b/aten/src/ATen/native/zoom/CumprodKernel.cu new file mode 100644 index 00000000000000..eaa48e306d4799 --- /dev/null +++ b/aten/src/ATen/native/zoom/CumprodKernel.cu @@ -0,0 +1,23 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +namespace at::native { + +void launch_cumprod_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "cumprod_zoom", [&]() { + scalar_t init = 1; + scan_dim( + self, + result, + dim, + init, + std::multiplies()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CumsumKernel.cu b/aten/src/ATen/native/zoom/CumsumKernel.cu new file mode 100644 index 00000000000000..41808fb8fae8ae --- /dev/null +++ b/aten/src/ATen/native/zoom/CumsumKernel.cu @@ -0,0 +1,25 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +namespace at::native { + +void launch_cumsum_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "cumsum_zoom", + [&]() { + scalar_t init = 0; + scan_dim( + self, + result, + dim, + init, + std::plus()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DeviceSqrt.cuh b/aten/src/ATen/native/zoom/DeviceSqrt.cuh new file mode 100644 index 00000000000000..d5833a9882fd82 --- /dev/null +++ b/aten/src/ATen/native/zoom/DeviceSqrt.cuh @@ -0,0 +1,18 @@ +#pragma once + +namespace at { namespace native { +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} +}} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/HIPbmm.cu b/aten/src/ATen/native/zoom/HIPbmm.cu new file mode 100644 index 00000000000000..84f5eb2aaf6201 --- /dev/null +++ b/aten/src/ATen/native/zoom/HIPbmm.cu @@ -0,0 +1,126 @@ +#include +#include +#include +#include + +namespace at::native { + + // Helper function to convert hip_bfloat16 to float + __device__ float bfloat16_to_float(hip_bfloat16 a) { + union { + uint32_t int32; + float float32; + } u = {uint32_t(a.data) << 16}; + return u.float32; + } + + // Helper function to convert float to hip_bfloat16 + __device__ hip_bfloat16 float_to_bfloat16(float a) { + union { + float float32; + uint32_t int32; + } u = {a}; + hip_bfloat16 b; + b.data = uint16_t(u.int32 >> 16); + return b; + } + + template + __device__ float convert_to_float(T a) { + return a; + } + + template <> + __device__ float convert_to_float(hip_bfloat16 a) { + return bfloat16_to_float(a); + } + + template <> + __device__ float convert_to_float<__half>( __half a) { + return __half2float(a); + } + + template + __device__ T convert_from_float(float a) { + return static_cast(a); + } + + template <> + __device__ hip_bfloat16 convert_from_float(float a) { + return float_to_bfloat16(a); + } + + template <> + __device__ __half convert_from_float<__half>(float a) { + return __float2half(a); + } + + + template + __global__ void batched_matmul_kernel(const T* A, const T* B, T* C, + int M, int N, int K, int batch_size) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + int batch = blockIdx.z; + + if (row < M && col < N) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + sum += convert_to_float(A[batch * M * K + row * K + k]) * + convert_to_float(B[batch * K * N + k * N + col]); + } + C[batch * M * N + row * N + col] = convert_from_float(sum); + } + } + + template + void batched_matmul(const T* A, const T* B, T* C, + int M, int N, int K, int batch_size) { + dim3 threadsPerBlock(16, 16); + dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + (M + threadsPerBlock.y - 1) / threadsPerBlock.y, + batch_size); + + hipLaunchKernelGGL(batched_matmul_kernel, numBlocks, threadsPerBlock, 0, 0, + A, B, C, M, N, K, batch_size); + } + + // Specialization for at::Half + template <> + void batched_matmul(const at::Half* A, const at::Half* B, at::Half* C, + int M, int N, int K, int batch_size) { + dim3 threadsPerBlock(16, 16); + dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + (M + threadsPerBlock.y - 1) / threadsPerBlock.y, + batch_size); + + hipLaunchKernelGGL(batched_matmul_kernel<__half>, numBlocks, threadsPerBlock, 0, 0, + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast<__half*>(C), + M, N, K, batch_size); + } + + // Specialization for at::BFloat16 + template <> + void batched_matmul(const at::BFloat16* A, const at::BFloat16* B, at::BFloat16* C, + int M, int N, int K, int batch_size) { + dim3 threadsPerBlock(16, 16); + dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + (M + threadsPerBlock.y - 1) / threadsPerBlock.y, + batch_size); + + hipLaunchKernelGGL(batched_matmul_kernel, numBlocks, threadsPerBlock, 0, 0, + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(C), + M, N, K, batch_size); + } + + // Explicit instantiations for supported types + template void batched_matmul(const float*, const float*, float*, int, int, int, int); + template void batched_matmul(const double*, const double*, double*, int, int, int, int); + template void batched_matmul(const half*, const half*, half*, int, int, int, int); + template void batched_matmul(const hip_bfloat16*, const hip_bfloat16*, hip_bfloat16*, int, int, int, int); + +} // at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Indexing.cu b/aten/src/ATen/native/zoom/Indexing.cu new file mode 100644 index 00000000000000..6cd4d946ea9cda --- /dev/null +++ b/aten/src/ATen/native/zoom/Indexing.cu @@ -0,0 +1,1798 @@ +#include +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include + +#include + +namespace { +template +__global__ void indexing_backward_kernel( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { +//numel is total number of flattened indices, not expanded to dimensions that are not indexed. +//stride is the cumulative size of the not-indexed last dimensions +//stride_before is the stride of the dimension immediately preceding first indexed dimension +//if indexing starts from the 0th dimension, stride_before does not matter because blockIdx.z will be 0 in this case +//outer_dim is number of elements in the first unindexed dimensions + using opmath_t = at::opmath_type; + + // Each warp is responsible for an input into the LookupTable. + // If the preceding input has the same destination index as this input, then the warp + // exits immediately. The warp also processes subsequent inputs with the + // same value. + // + // Input Warp + // 1 + // 1 ( exits without doing any work) + // 5 + // 8 + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + if (idx < numel + && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){ + do { + int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + // if not accumulate, we only keep the last duplicate index so skip those before it + if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { + idx++; + continue; + } + const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before; + const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride; + const opmath_t scale = (opmath_t)1.0; + + opmath_t gradient[SZ]; + opmath_t weight[SZ]; + + while (start_feature < stride) { + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int64_t feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); + if (accumulate) { + weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); + } + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + if (accumulate) { + weight[ii] += gradient[ii] * scale; + } else { + weight[ii] = gradient[ii] * scale; + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int64_t feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); + } + } + start_feature += gridDim.y * blockDim.x * SZ; + } + + idx++; + } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); + } + } +} + +template +__global__ void indexing_backward_kernel_stride_1( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if ((idx < numel) && + (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1])) + { + // Determine the number of duplicates in advance + int64_t num_duplicates = 1; + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + // Continue computing weights + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + int64_t grad_row = 0; + const opmath_t scale = (opmath_t)1.0; + + if (!accumulate) { + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row] = + static_cast(static_cast(grad_output[grad_row]) * scale); + } else { + opmath_t gradient = (opmath_t)0.0; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + int64_t num_warp_passes = num_duplicates / C10_WARP_SIZE; + for (int64_t i = 0; i < num_warp_passes; ++i) { + grad_row = ((int64_t) indices[idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + WARP_SYNC(); + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + gradient += WARP_SHFL_DOWN(gradient, offset); + } + + if (laneIdx == 0) { + for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); + } + } + } + } +} + +template +__global__ void indexing_backward_kernel_small_stride( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + int64_t tidx = threadIdx.x; + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if ((idx < numel) && + (tidx < stride) && + (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1])) + { + // Determine the number of duplicates in advance + int64_t num_duplicates = 1; + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + // Continue computing weights + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + int64_t grad_row = 0; + const opmath_t scale = (opmath_t)1.0; + + if (!accumulate) { + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row + tidx] = + static_cast(static_cast(grad_output[grad_row + tidx]) * scale); + } else { + opmath_t gradient = (opmath_t)0.0; + for (int64_t i = 0; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row + tidx]) * scale; + } + + grad_weight[weight_row + tidx] = static_cast(static_cast(grad_weight[weight_row + tidx]) + gradient); + } + } + } +} + +template +__global__ void indexing_backward_kernel_quantized( + const int64_t* sorted_indices, const int64_t* indices, const float* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, + float inv_scale, int zero_point, int64_t qmin, int64_t qmax) { + + // This implementation is adopted from indexing_backward_kernel above. + using opmath_t = at::opmath_type; + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + if (idx < numel + && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){ + do { + int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + // we only keep the last duplicate index so skip those before it + if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { + idx++; + continue; + } + const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before; + const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride; + const opmath_t scale = (opmath_t)1.0; + + opmath_t gradient[SZ]; + opmath_t weight[SZ]; + + while (start_feature < stride) { + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int64_t feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + weight[ii] = gradient[ii] * scale; + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int64_t feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + // we do quantization here + int64_t qvalue = static_cast(zero_point + nearbyintf(weight[ii]* inv_scale)); + qvalue = min(max(qvalue, qmin), qmax); + grad_weight[weight_row + feature_dim] = static_cast(qvalue); + } + } + start_feature += gridDim.y * blockDim.x * SZ; + } + + idx++; + } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); + } + } +} + + +} + + +namespace at::native { + +namespace { + +class ReduceMultiply { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMul(self_data_start + index, *src_data); + } +}; +static ReduceMultiply reduce_multiply; + +class ReduceAdd { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + fastAtomicAdd(self_data_start, index, numel, *src_data, true); + } +}; +static ReduceAdd reduce_add; + +class ReduceMinimum { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMin(self_data_start + index, *src_data); + } +}; +static ReduceMinimum reduce_minimum; + +class ReduceMaximum { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMax(self_data_start + index, *src_data); + } +}; +static ReduceMaximum reduce_maximum; + +} + +static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size, bool check_range=true) { +//we don't need to check range in backward - if there were out of bounds indices forward should already have errored out + if (index.numel() != 0 && check_range) { + at::_assert_async(index.max() < dim_size); + at::_assert_async(index.min() >= -dim_size); + } + return index.remainder(dim_size); +} + +static std::vector computeLinearStride(const Tensor & tensor) { + // computes the stride as if tensor were contiguous + auto sizes = tensor.sizes(); + std::vector stride(tensor.dim()); + if (stride.empty()) { + return stride; + } + stride[tensor.dim() - 1] = 1; + std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies()); + return stride; +} + +static std::tuple +computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) { + auto strides = computeLinearStride(src); + const auto& device = src.options().device(); + + // Compute the linear index by multiplying the indexing tensors by the + // stride and summing them. All the indexing tensors have the same shape at + // this point. We also compute the number of dimensions before and after that + // are not being index. + Tensor linearIndex; + int64_t nElemBefore = 1, nElemAfter = 1, strideBefore =0; + for (const auto i: c10::irange(src.dim())) { + if (indices[i].defined()) { + // Cast index to the longType matching src's device + // This allows us to support ie indexing a cuda tensor with a cpu tensor + Tensor index = (wrapIndexOnce(indices[i], i, src.size(i), check_range) * strides[i]).to(device); + if (linearIndex.defined()) { + linearIndex += index; + } else { + linearIndex = index; + if (i>0) { + strideBefore = src.stride(i-1); // stride after undefined dimensions + } + } + } else if (linearIndex.defined()) { + nElemAfter *= src.size(i); + } else { + nElemBefore *= src.size(i); + } + } + + return std::make_tuple(std::move(linearIndex), nElemBefore, strideBefore, nElemAfter); +} + + +static std::tuple> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { + checkIndexTensorTypes(orig, /*allow_int*/true); + // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors + auto indices = expandTensors(self, orig); + for (auto & i : indices) { + if (i.defined() && i.dtype() == at::kInt) { + i = i.to(at::kLong); + } + } + // next broadcast all index tensors together + indices = expand_outplace(indices); + // add missing null Tensors so that it matches self.dim() + while (indices.size() < (size_t)self.dim()) { + indices.emplace_back(); + } + // if the non-null indices are not all adjacent, transpose self and indices + // together so that they're adjacent at the front + std::vector inversePerm; + if (!hasContiguousSubspace(indices)) { + std::tie(self, indices, inversePerm) = transposeToFrontAndInvPerm(self, indices); + } + auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = computeLinearIndex(self, indices, check_range); + return std::make_tuple(linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm); +} + + +void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices); + +namespace { + +int64_t largestIndex(const Tensor &self) { + int64_t result = 0; + for (const auto i: c10::irange(self.dim())) { + result += (self.sizes()[i] - 1) * self.strides()[i]; + } + return result; +} + +void index_put_with_sort_kernel(Tensor & self, const c10::List>& indices, const Tensor & value, bool accumulate, bool unsafe) { + TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(), + " cannot be broadcast to indexing result of shape ", self.sizes()); + if (indices.size() > (size_t)self.dim()) { + TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); + } + bool self_contiguous = self.is_contiguous(); + auto self_ = self_contiguous ? self : self.contiguous(); + Tensor linearIndex, src, expandedValue = value; + int64_t nElemBefore, strideBefore, sliceSize; + std::vector inversePerm; + std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe); + int64_t num_indices = linearIndex.numel(); + + if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) { + auto expanded_size = at::DimVector(expandedValue.sizes()); + auto size1 = expandedValue.sizes(); + auto size2 = linearIndex.sizes(); + if (are_expandable(size1, size2)) { + expanded_size = infer_size_dimvector(size1, size2); + } + if (nElemBefore > 1) { + expanded_size.insert(expanded_size.begin(), nElemBefore); + } + if (sliceSize > 1) { + expanded_size.insert(expanded_size.end(), sliceSize); + } + expandedValue = expandedValue.expand(expanded_size); + } + expandedValue = expandedValue.contiguous(); + + if (num_indices > 0 && sliceSize > 0) { + const bool permuted = !src.is_contiguous(); + auto src_ = permuted ? src.contiguous() : src; + linearIndex = linearIndex.reshape(-1); + auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + linearIndex.divide_(sliceSize, "trunc"); + + if (num_indices < 50000) { + index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices); + } else + { + // Sort the inputs into sorted with the corresponding indices + auto range = at::arange(num_indices, linearIndex.options()); + // linearIndex can not be negative, and we take advantage of this + // fact to sort on less bits for better performance. + int64_t nbits = zoom::hipcub::get_num_bits(largestIndex(self_) / sliceSize); + zoom::hipcub::radix_sort_pairs( + linearIndex.const_data_ptr(), sorted_indices.mutable_data_ptr(), + range.const_data_ptr(), orig_indices.mutable_data_ptr(), + num_indices, false, 0, nbits); + } + + TORCH_INTERNAL_ASSERT( + linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(), + "number of flattened indices did not match number of elements in the value tensor: ", + linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel()); + const int UNROLL = 4; + const int indices_per_block = 4; + const int warp_size = at::zoom::warp_size(); + dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block), + std::min(at::zoom::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))), + ::min(std::max(1,nElemBefore), at::zoom::getCurrentDeviceProperties()->maxGridSize[2])); + dim3 block(warp_size, indices_per_block); + + + if (sliceSize == 1) { + // This implementation is faster with high amounts of duplicates but could overflow + // if FP16 / BF16 is used + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, + expandedValue.scalar_type(), "indexing_backward_kernel_stride_1", [&] { + hipLaunchKernelGGL(( indexing_backward_kernel_stride_1), dim3(grid), dim3(block), 0, stream, + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (sliceSize <= warp_size) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, + expandedValue.scalar_type(), "indexing_backward_kernel_small_stride", [&] { + hipLaunchKernelGGL(( indexing_backward_kernel_small_stride), dim3(grid), dim3(block), 0, stream, + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, + expandedValue.scalar_type(), "indexing_backward", [&] { + hipLaunchKernelGGL(( indexing_backward_kernel), dim3(grid), dim3(block), 0, stream, + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } + } + + if (permuted) { + self.copy_(src_.permute(inversePerm)); + } else if (!self_contiguous) { + self.copy_(self_); + } + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel); + +void index_put_with_sort_quantized(Tensor & self, const c10::List>& indices, const Tensor & value, double scale, int zero_point, bool unsafe) { + if (indices.size() > (size_t)self.dim()) { + TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); + } + bool self_contiguous = self.is_contiguous(); + auto self_ = self_contiguous ? self : self.contiguous(); + Tensor linearIndex, src, expandedValue = value; + int64_t nElemBefore, strideBefore, sliceSize; + std::vector inversePerm; + std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe); + int64_t num_indices = linearIndex.numel(); + + if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) { + auto expanded_size = at::DimVector(expandedValue.sizes()); + auto size1 = expandedValue.sizes(); + auto size2 = linearIndex.sizes(); + if (are_expandable(size1, size2)) { + expanded_size = infer_size_dimvector(size1, size2); + } + if (nElemBefore > 1) { + expanded_size.insert(expanded_size.begin(), nElemBefore); + } + expandedValue = expandedValue.expand(expanded_size); + } + expandedValue = expandedValue.contiguous(); + + if (num_indices > 0 && sliceSize > 0) { + const bool permuted = !src.is_contiguous(); + auto src_ = permuted ? src.contiguous() : src; + linearIndex = linearIndex.reshape(-1); + auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + linearIndex.divide_(sliceSize, "trunc"); + + // cub on CUDA <= 11.2 have a bug that for small sizes + // cub's sort can be much slower than thrust's merge sort + // this bug is fixed in CUDA 11.3 +#if (defined(TORCH_HIP_VERSION) && TORCH_HIP_VERSION < 11030) + if (num_indices < 50000) { + index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices); + } else +#endif + { + // Sort the inputs into sorted with the corresponding indices + auto range = at::arange(num_indices, linearIndex.options()); + // linearIndex can not be negative, and we take advantage of this + // fact to sort on less bits for better performance. + int64_t nbits = zoom::hipcub::get_num_bits(largestIndex(self_) / sliceSize); + zoom::hipcub::radix_sort_pairs( + linearIndex.const_data_ptr(), sorted_indices.mutable_data_ptr(), + range.const_data_ptr(), orig_indices.mutable_data_ptr(), + num_indices, false, 0, nbits); + } + + TORCH_INTERNAL_ASSERT( + linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(), + "number of flattened indices did not match number of elements in the value tensor: ", + linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel()); + const int UNROLL = 4; + const int indices_per_block = 4; + const int warp_size = at::zoom::warp_size(); + dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block), + std::min(at::zoom::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))), + ::min(std::max(1,nElemBefore), at::zoom::getCurrentDeviceProperties()->maxGridSize[2])); + dim3 block(warp_size, indices_per_block); + + AT_DISPATCH_QINT_TYPES( + src.scalar_type(), "indexing_backward_quantized", [&] { + constexpr int64_t qmin = std::numeric_limits::min(); + constexpr int64_t qmax = std::numeric_limits::max(); + float inv_scale = 1.0f / static_cast(scale); + + hipLaunchKernelGGL(( indexing_backward_kernel_quantized), dim3(grid), dim3(block), 0, stream, + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + inv_scale, + zero_point, + qmin, + qmax); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + + if (permuted) { + self.copy_(src_.permute(inversePerm)); + } else if (!self_contiguous) { + self.copy_(self_); + } + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized); +} //anonymous + + +// Check tensor dimensions for index operations, and return the slice size. +static ptrdiff_t getSliceSize(const Tensor & dst, + int dim, + const Tensor & index, + const Tensor & src) +{ + const auto dstDims = dst.dim(); + const auto srcDims = src.dim(); + + TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar"); + + ptrdiff_t dstSliceSize = 1; + TORCH_CHECK(dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds"); + for (const auto d: c10::irange(dstDims)) { + if (d != dim) { + dstSliceSize *= dst.size(d); + } + } + + TORCH_CHECK(dim < srcDims, "Indexing dim ", dim, " is out of bounds"); + TORCH_CHECK(index.numel() == src.size(dim), + "length of src.size[dim] is not equal to length of indices"); + + ptrdiff_t srcSliceSize = 1; + bool mismatch = false; + + if (dstDims != srcDims) mismatch = true; + + for (const auto d: c10::irange(srcDims)) { + if (d != dim) { + srcSliceSize *= src.size(d); + if (!mismatch && dst.size(d) != src.size(d)) mismatch = true; + } + } + + TORCH_CHECK(dstSliceSize == srcSliceSize, + "Source/destination tensor have different slice sizes (%ld vs %ld)", + dstSliceSize, srcSliceSize); + + if (mismatch) { + TORCH_WARN_ONCE( + "Warning: source/destination slices have same size but different " + "shape for an index operation. This behavior is deprecated.\n"); + } + + return dstSliceSize; +} + +// We prefer this kernel to avoid reloading index points if the number +// of indices is a small number. +// This kernel in fact works for all choices of problem size, but if +// the number of indices chosen is large, then the +// indexFuncLargeIndex kernel is a better choice to increase +// parallelism. +template +__global__ void indexFuncSmallIndex(zoom::detail::TensorInfo dst, + zoom::detail::TensorInfo src, + zoom::detail::TensorInfo indices, + int dstAddDim, + int srcAddDim, + IndexType innerSize, + int64_t dstAddDimSize, + int64_t dstNumel, + const func_t& op, + T alpha) { + // In order to avoid reloading the index that we are copying, load + // it once to handle all of the points that are being selected, so + // it can be reused as much as possible. This kernel is chosen when + // this is a good choice (small number of chosen indices), since + // re-accessing indices in addition to src elements can be slow. + for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) { + // Lua indices begin at 1 + IndexType dstIndex = + indices.data[zoom::detail::IndexToOffset::get(srcIndex, indices)]; + ZOOM_KERNEL_ASSERT(dstIndex < dstAddDimSize); + + // We stride over the output ignoring the indexed dimension + // (innerSize), whose offset calculation is handled differently + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < innerSize; + linearIndex += gridDim.x * blockDim.x) { + IndexType dstOffset = + zoom::detail::IndexToOffset::get(linearIndex, dst); + dstOffset += dstIndex * dst.strides[dstAddDim]; + + IndexType srcOffset = + zoom::detail::IndexToOffset::get(linearIndex, src); + srcOffset += srcIndex * src.strides[srcAddDim]; + + T val = src.data[srcOffset] * alpha; + op(dst.data, dstOffset, dstNumel, &val); + } + + } +} + +// We prefer this kernel to balance parallelism across index points, +// if there are a large number of indices. +// This kernel in fact works for all choices of problem size, but if +// the number of indices chosen is small, then the +// indexFuncSmallIndex kernel is a better choice to reduce memory +// accesses. +template +__global__ void indexFuncLargeIndex(zoom::detail::TensorInfo dst, + zoom::detail::TensorInfo src, + zoom::detail::TensorInfo indices, + int dstAddDim, + int srcAddDim, + IndexType totalSize, + IndexType innerSize, + int64_t dstAddDimSize, + int64_t dstNumel, + const func_t& op, + T alpha) { + // We stride over the output including the indexed dimension + // (totalSize), and calculate the destination index point based on that + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalSize; + linearIndex += gridDim.x * blockDim.x) { + IndexType srcIndex, elementInSlice; + if (IndexIsMajor) { + srcIndex = linearIndex / innerSize; + elementInSlice = linearIndex % innerSize; + } + else { + elementInSlice = linearIndex / innerSize; + srcIndex = linearIndex % innerSize; + } + + // Lua indices begin at 1 + IndexType dstIndex = + indices.data[zoom::detail::IndexToOffset::get(srcIndex, indices)]; + ZOOM_KERNEL_ASSERT(dstIndex < dstAddDimSize); + + IndexType dstOffset = + zoom::detail::IndexToOffset::get(elementInSlice, dst); + dstOffset += dstIndex * dst.strides[dstAddDim]; + + IndexType srcOffset = + zoom::detail::IndexToOffset::get(elementInSlice, src); + srcOffset += srcIndex * src.strides[srcAddDim]; + + T val = src.data[srcOffset] * alpha; + op(dst.data, dstOffset, dstNumel, &val); + } +} + +// Compare the stride between adjacent slices (sliceStride) with strides in the +// other dimensions (i.e., strides *inside* each slice). +// +// - Returns true if some dimension inside the slice has lower stride than +// sliceStride. The simplest example is a 2-D contiguous tensor with sliceDim +// == 0 (that is, each slice is a row). +// +// In this case, we choose the CUDA kernel that processes the data in +// "index-major order". For example, if thread count equals slice size, then +// all threads process slice #0 in lockstep, and then slice #1, and so on. +// +// - Otherwise (i.e., sliceStride has the lowest value), this function returns +// false. The simplest example is a 2-D contiguous tensor with sliceDim == 1 +// (each slice is a column). +// +// In this case, we choose the CUDA kernel that processes the data in +// "elementInSlice-major order". For example, each thread can process element +// #0 of every slice, and then element #1 of every slice, and so on. +template +bool indexShouldBeMajor(zoom::detail::TensorInfo &info, + int sliceDim) +{ + // The stride between adjacent slices (e.g., between element #0 of slice #100 + // and element #0 of slice #101). + unsigned int sliceStride = info.strides[sliceDim]; + + for (const auto i: c10::irange(info.dims)) { + if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) { + return true; + } + } + + return false; +} + +void index_add_zoom_impl(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) { + if (!result.is_same(self)) { + result.copy_(self); + } + + // Scalars are treated as 1-d tensor + const Tensor self_ = (result.dim() == 0) ? result.view(1) : result; + const Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + + TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); + TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" ); + TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); + + if (globalContext().deterministicAlgorithms()){ + torch::List> indices; + indices.reserve(dim + 1); + for (const auto i: c10::irange(dim)) { + indices.emplace_back(); + } + indices.emplace_back(index.to(at::kLong)); + result.index_put_(indices, source * alpha, true); + return; + } + + // The `source` is partitioned into two parts: + // -the size of each slice we are indexing, which is the + // total size of the tensor ignoring dimension `dim`; + // -the number of index we are choosing, which is the total size + // of the tensor `index`. + const ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_); + const ptrdiff_t sourceTotalSize = source.numel(); + const int64_t selfAddDimSize = self_.size(dim); + const ptrdiff_t numIndex = index.numel(); + const int64_t selfNumel = self_.numel(); + + if (sliceSize == 0) { + return; + } + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + const bool indContig = index.is_contiguous(); + + const int mpc = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + hipLaunchKernelGGL(( indexFuncSmallIndex) \ + , dim3(smallIndexGrid), dim3(smallIndexBlock), 0, stream, \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize, \ + selfNumel, reduce_add, alpha_value); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ + SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ + hipLaunchKernelGGL(( indexFuncLargeIndex) \ + , dim3(largeIndexGrid), dim3(largeIndexBlock), 0, stream, \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfAddDimSize, selfNumel, reduce_add, alpha_value); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + const dim3 smallIndexGrid(::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + const dim3 smallIndexBlock(::min(sliceSize, (ptrdiff_t)128)); + + const dim3 largeIndexGrid(::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + const dim3 largeIndexBlock(::min(sourceTotalSize, (ptrdiff_t)128)); + + if (zoom::detail::canUse32BitIndexMath(result) && + zoom::detail::canUse32BitIndexMath(source) && + zoom::detail::canUse32BitIndexMath(index)) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::ComplexHalf, result.scalar_type(), "index_add", [&] { + zoom::detail::TensorInfo selfInfo = + zoom::detail::getTensorInfo(self_); + const int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + const auto alpha_value = alpha.to(); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_zoom_", [&] () { + auto sourceInfo = + zoom::detail::getTensorInfo(source_); + const int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + auto indexInfo = + zoom::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + } else { + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + } + } else { + const bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); + + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + } + } + }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { + zoom::detail::TensorInfo selfInfo = + zoom::detail::getTensorInfo(self_); + const int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + const auto alpha_value = alpha.to(); + + zoom::detail::TensorInfo sourceInfo = + zoom::detail::getTensorInfo(source_); + const int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_zoom_", [&] () { + zoom::detail::TensorInfo indexInfo = + zoom::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); + }); + } + +#undef SMALL_INDEX +#undef LARGE_INDEX +} + +template +void index_reduce_func_zoom_impl( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const func_t& reduce_func, + const Tensor& result) { + globalContext().alertNotDeterministic("index_reduce_zoom"); + + if (!result.is_same(self)) result.copy_(self); + + // Scalars are treated as 1-d tensor + Tensor self_ = (result.dim() == 0) ? result.view(1) : result; + Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + + TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); + TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" ); + TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); + + if (!include_self) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + self.scalar_type(), "index_reduce_func_zoom_exclude_input_init", [&] { + scalar_t init_val; + switch (reduce) { + case ReductionType::PROD: + init_val = (scalar_t)1; + break; + case ReductionType::MAX: + init_val = std::numeric_limits::has_infinity ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + break; + case ReductionType::MIN: + init_val = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + break; + default: + init_val = (scalar_t)0; + break; + } + // index_fill_ requires index to be a LongTensor + self_.index_fill_(dim, index.to(at::ScalarType::Long), init_val); + }); + } + + // The `source` is partitioned into two parts: + // -the size of each slice we are indexing, which is the + // total size of the tensor ignoring dimension `dim`; + // -the number of index we are choosing, which is the total size + // of the tensor `index`. + ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_); + ptrdiff_t sourceTotalSize = source.numel(); + int64_t selfReduceDimSize = self_.size(dim); + ptrdiff_t numIndex = index.numel(); + int64_t selfNumel = self_.numel(); + + if (sliceSize == 0) { + return; + } + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + bool indContig = index.is_contiguous(); + + int mpc = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + hipLaunchKernelGGL(( indexFuncSmallIndex) \ + , dim3(smallIndexGrid), dim3(smallIndexBlock), 0, stream, \ + selfInfo, sourceInfo, indexInfo, \ + selfReduceDim, sourceReduceDim, sliceSize, selfReduceDimSize, \ + selfNumel, reduce_func, alpha_value); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ + SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ + hipLaunchKernelGGL(( indexFuncLargeIndex) \ + , dim3(largeIndexGrid), dim3(largeIndexBlock), 0, stream, \ + selfInfo, sourceInfo, indexInfo, \ + selfReduceDim, sourceReduceDim, sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfReduceDimSize, selfNumel, reduce_func, alpha_value); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + dim3 smallIndexGrid(::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + dim3 smallIndexBlock(::min(sliceSize, (ptrdiff_t)128)); + + dim3 largeIndexGrid(::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + dim3 largeIndexBlock(::min(sourceTotalSize, (ptrdiff_t)128)); + + if (zoom::detail::canUse32BitIndexMath(result) && + zoom::detail::canUse32BitIndexMath(source) && + zoom::detail::canUse32BitIndexMath(index)) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_reduce", [&] { + zoom::detail::TensorInfo selfInfo = + zoom::detail::getTensorInfo(self_); + int selfReduceDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfReduceDim); + auto alpha_value = (scalar_t) 1; + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_zoom", [&] () { + auto sourceInfo = + zoom::detail::getTensorInfo(source_); + int sourceReduceDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceReduceDim); + + auto indexInfo = + zoom::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + } else { + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + } + } else { + bool indexIsMajor = indexShouldBeMajor(selfInfo, selfReduceDim); + + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + } + } + }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce", [&] { + zoom::detail::TensorInfo selfInfo = + zoom::detail::getTensorInfo(self_); + int selfReduceDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfReduceDim); + auto alpha_value = (scalar_t) 1; + + zoom::detail::TensorInfo sourceInfo = + zoom::detail::getTensorInfo(source_); + int sourceReduceDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceReduceDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_zoom", [&] () { + zoom::detail::TensorInfo indexInfo = + zoom::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); + }); + } + +#undef SMALL_INDEX +#undef LARGE_INDEX +} + +TORCH_IMPL_FUNC(index_add_zoom_out) +(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) { + index_add_zoom_impl(self, dim, index, source, alpha, result); +} + +TORCH_IMPL_FUNC(index_reduce_zoom_out) +(const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const c10::string_view reduce, + bool include_self, + const Tensor& result) { + TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time."); + + if (reduce == "prod") { + index_reduce_func_zoom_impl(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result); + } else if (reduce == "mean") { + index_reduce_func_zoom_impl(self, dim, index, source, include_self, ReductionType::MEAN, reduce_add, result); + auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); + counts.index_add_(dim, index, at::ones_like(source)); + counts.masked_fill_(counts == 0, 1); + if (result.is_floating_point() || result.is_complex()) { + result.div_(counts); + } else { + result.div_(counts, "floor"); + } + } else if (reduce == "amax") { + index_reduce_func_zoom_impl(self, dim, index, source, include_self, ReductionType::MAX, reduce_maximum, result); + } else if (reduce == "amin") { + index_reduce_func_zoom_impl(self, dim, index, source, include_self, ReductionType::MIN, reduce_minimum, result); + } else { + TORCH_CHECK(false, "reduce argument must be either prod, mean, amax or amin, got ", reduce, "."); + } +} + +namespace { +// We prefer this kernel to avoid reloading index points if the number +// of indices is a small number. +// This kernel in fact works for all choices of problem size, but if +// the number of indices chosen is large, then the +// indexSelectLargeIndex kernel is a better choice to increase +// parallelism. +template +__global__ void indexSelectSmallIndex(zoom::detail::TensorInfo dst, + zoom::detail::TensorInfo src, + zoom::detail::TensorInfo indices, + int dstSelectDim, + int srcSelectDim, + IndexType innerSize, + int64_t srcSelectDimSize) { + // In order to avoid reloading the index that we are copying, load + // it once to handle all of the points that are being selected, so + // it can be reused as much as possible. This kernel is chosen when + // this is a good choice (small number of chosen indices), since + // re-accessing indices in addition to src elements can be slow. + for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) { + IndexType srcIndex = + indices.data[zoom::detail::IndexToOffset::get(dstIndex, indices)]; + ZOOM_KERNEL_ASSERT(srcIndex < srcSelectDimSize); + + // We stride over the output ignoring the indexed dimension + // (innerSize), whose offset calculation is handled differently + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < innerSize; + linearIndex += gridDim.x * blockDim.x) { + IndexType dstOffset = + zoom::detail::IndexToOffset::get(linearIndex, dst); + dstOffset += dstIndex * dst.strides[dstSelectDim]; + + IndexType srcOffset = + zoom::detail::IndexToOffset::get(linearIndex, src); + srcOffset += srcIndex * src.strides[srcSelectDim]; + + dst.data[dstOffset] = src.data[srcOffset]; + } + } +} + +// We prefer this kernel to balance parallelism across index points, +// if there are a large number of indices. +// This kernel in fact works for all choices of problem size, but if +// the number of indices chosen is small, then the +// indexSelectSmallIndex kernel is a better choice to reduce memory +// accesses. +template +__global__ void indexSelectLargeIndex(zoom::detail::TensorInfo dst, + zoom::detail::TensorInfo src, + zoom::detail::TensorInfo indices, + int dstSelectDim, + int srcSelectDim, + IndexType totalSize, + IndexType innerSize, + int64_t srcSelectDimSize) { + // We stride over the output including the indexed dimension + // (totalSize), and calculate the destination index point based on that + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalSize; + linearIndex += gridDim.x * blockDim.x) { + IndexType dstIndex, elementInSlice; + if (IndexIsMajor) { + dstIndex = linearIndex / innerSize; + elementInSlice = linearIndex % innerSize; + } + else { + elementInSlice = linearIndex / innerSize; + dstIndex = linearIndex % innerSize; + } + + IndexType srcIndex = + indices.data[zoom::detail::IndexToOffset::get(dstIndex, indices)]; + ZOOM_KERNEL_ASSERT(srcIndex < srcSelectDimSize); + + IndexType dstOffset = + zoom::detail::IndexToOffset::get(elementInSlice, dst); + dstOffset += dstIndex * dst.strides[dstSelectDim]; + + IndexType srcOffset = + zoom::detail::IndexToOffset::get(elementInSlice, src); + srcOffset += srcIndex * src.strides[srcSelectDim]; + + dst.data[dstOffset] = src.data[srcOffset]; + } +} + +namespace { + +// When using a 0-dim scalar tensor, we need the legacy (THC) semantics of +// TensorInfo: Pretend that the scalar tensor is in fact a one-element vector. +template +zoom::detail::TensorInfo +tensorInfoLegacyIfScalar(zoom::detail::TensorInfo ti) { + if (ti.dims == 0) { + ti.dims = 1; + ti.sizes[0] = 1; + ti.strides[0] = 1; + } + return ti; +} + +} + +template +void index_select_out_zoom_impl( + Tensor& out, + const Tensor& self, + long dim, + const Tensor& index) { + ptrdiff_t numIndices = index.numel(); + int selfDims = self.dim() == 0 ? 1 : self.dim(); + + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + TORCH_CHECK( + index.dim() <= 1, "Index is supposed to be an empty tensor or a vector"); + TORCH_CHECK( + !(self.dim() == 0 && numIndices != 1), "index_select(): Index to scalar can have only 1 value, got ", numIndices, " value(s)"); + TORCH_CHECK(dim < selfDims, "Indexing dim is out of bounds"); + + std::vector newSize = self.sizes().vec(); + if (self.dim() > 0) { + newSize[dim] = numIndices; + } + + if (self.is_quantized()){ + out = at::empty_quantized(newSize, out); + } else { + at::native::resize_output(out, newSize); + } + + ptrdiff_t outTotalSize = out.numel(); + if (outTotalSize == 0) { + return; + } + + bool indContig = index.is_contiguous(); + + // The `self` is partitioned into two parts: + // -the size of each slice we are indexing, which is the + // total size of the tensor ignoring dimension `dim`; + // -the number of indices we are choosing, which is the total size + // of the tensor `indices`. + int64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim); + ptrdiff_t sliceSize = outTotalSize / numIndices; + + int mpc = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + hipLaunchKernelGGL(( indexSelectSmallIndex) \ + , dim3(smallIndexGrid), dim3(smallIndexBlock), 0, stream, \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(sliceSize), \ + selfSelectDimSize); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ + DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ + hipLaunchKernelGGL(( indexSelectLargeIndex) \ + , dim3(largeIndexGrid), dim3(largeIndexBlock), 0, stream, \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(outTotalSize), \ + static_cast((IDX_IS_MAJOR) ? sliceSize : numIndices), \ + selfSelectDimSize); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + dim3 smallIndexGrid(::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + dim3 smallIndexBlock(::min(sliceSize, (ptrdiff_t)128)); + + dim3 largeIndexGrid(::min(ceil_div(outTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + dim3 largeIndexBlock(::min(outTotalSize, (ptrdiff_t)128)); + if (zoom::detail::canUse32BitIndexMath(out) && + zoom::detail::canUse32BitIndexMath(self) && + zoom::detail::canUse32BitIndexMath(index)) { + auto outInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(out)); + int outSelectDim = outInfo.collapseDims(dim); + outInfo.reduceDim(outSelectDim); + + auto selfInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(self)); + int selfSelectDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfSelectDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_zoom_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // indices to choose + if (numIndices <= 16) { + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + } else { + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + } + } else { + bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim); + + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + } + } + }); + } else { + auto outInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(out)); + int outSelectDim = outInfo.collapseDims(dim); + outInfo.reduceDim(outSelectDim); + + auto selfInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(self)); + int selfSelectDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfSelectDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_zoom_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(zoom::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); + + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); + } +#undef SMALL_INDEX +#undef LARGE_INDEX +} +} // anonymous namespace + +Tensor& index_select_out_zoom( + const Tensor& self, + int64_t dim, + const Tensor& index, + Tensor& out) { + static constexpr string_view DIM_WARNING = + "Tensor too large or too many (> 25) dimensions"; + TORCH_CHECK( + at::zoom::check_device({out, self, index}), + "Input, output and indices must be on the current device"); + at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, self); + at::assert_no_overlap(out, index); + + dim = at::maybe_wrap_dim(dim, self); + TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); + TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); + if (self.is_quantized()){ + TORCH_CHECK( + self.qscheme() == kPerTensorAffine, + "Only per_tensor quantized quantized tensors are supported by index_select.") + AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_zoom", [&] { + index_select_out_zoom_impl(out, self, dim, index); + }); + } else { + AT_DISPATCH_V2( + out.scalar_type(), + "index_select_zoom", + AT_WRAP([&] { index_select_out_zoom_impl(out, self, dim, index); }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16 + ); + } + + return out; +} + +Tensor index_select_zoom(const Tensor& self, int64_t dim, const Tensor& index) { + Tensor out = at::empty({0}, self.options()); + at::native::index_select_out_zoom(self, dim, index, out); + return out; +} + +Tensor index_select_quantized_zoom(const Tensor& self, int64_t dim, const Tensor& index) { + TORCH_CHECK( + self.qscheme() == kPerTensorAffine, + "Only per_tensor quantized quantized tensors are supported by index_select.") + Tensor out = at::empty_quantized({0}, self); + at::native::index_select_out_zoom(self, dim, index, out); + return out; +} + +namespace { + +void masked_fill_kernel(TensorIterator& iter, const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + kBool, kHalf, kBFloat16, kComplexHalf, iter.common_dtype(), "masked_fill_", [&]() { + const auto value_ = value.to(); + gpu_kernel( + iter, [value_] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t { + if (mask) { + return value_; + } + return self; + }); + }); +} + +template +void zoom_masked_fill_kernel_quantized(TensorIterator& iter, scalar_t quantized_val) { + gpu_kernel( + iter, [quantized_val] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t { + if (mask) { + return quantized_val; + } + return self; + }); +} + +void masked_fill_kernel_quantized(TensorIterator& iter, const Scalar& value, double scale, int zero_point) { + TORCH_CHECK(iter.input_dtype(1) == at::ScalarType::Bool, "masked_fill only supports boolean masks, ", + "but got dtype ", iter.input_dtype(1)); + AT_DISPATCH_QINT_TYPES( + iter.common_dtype(), "masked_fill_", [&]() { + float float_val = value.to(); + const auto quantized_val = quantize_val(scale, zero_point, float_val); + + zoom_masked_fill_kernel_quantized(iter, quantized_val); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized); + +} // anonymous namespace + +Tensor & masked_fill__zoom(Tensor & self, const Tensor & mask, const Scalar & value) { + TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ", + mask.device(), " and self on ", self.device()); + TORCH_CHECK(mask.scalar_type() == kBool, + "masked_fill only supports boolean masks, but got dtype ", mask.scalar_type()); + auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); + if (at::has_internal_overlap(self) == MemOverlap::Yes) { + TORCH_WARN( + "Use of masked_fill_ on expanded tensors is deprecated. " + "Please clone() the tensor before performing this operation. " + "This also applies to advanced indexing e.g. tensor[mask] = scalar"); + } + at::assert_no_partial_overlap(self, mask); + + c10::MaybeOwned b_mask = expand_inplace(self, mask, "masked_fill_"); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self) + .add_const_input(self) + .add_const_input(*b_mask) + .build(); + + masked_fill_kernel(iter, value); + namedinference::propagate_names_if_nonempty(self, maybe_outnames); + return self; +} + + +Tensor & masked_fill__zoom(Tensor & self, const Tensor & mask, const Tensor & value) { + TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor " + "with ", value.dim(), " dimension(s)."); + // We hit this function if either of the input tensor lives on CUDA. + // It is ok, if `value` is `CPU` tensor but we should not allow `self` or + // `mask` to be CPU tensor. Check for `self` and `mask` being on same device + // exists in `masked_fill__zoom` (Scalar version). + TORCH_CHECK(!self.device().is_cpu(), "masked_fill_: Expected inputs to be on same device") + return masked_fill__zoom(self, mask, value.item()); +} + +namespace { + +// ForwardIt: only legacy random access iterator is supported. +template +static __host__ __device__ __forceinline__ +ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) { + ForwardIt it; + typename std::iterator_traits::difference_type count, step; + // NOTE: std::distance(first, last) compiles but produces wrong results here, + // so only legacy random access iterators are safe in this code. + count = last - first; + + while (count > 0) { + it = first; + step = count / 2; + // avoiding std::advance(it, step), + // although it does work unlike std::distance + it += step; + if (is_lower ? *it < value : value >= *it) { + first = ++it; + count -= step + 1; + } + else { + count = step; + } + } + return first; +} + +} + +Tensor index_select_sparse_zoom(const Tensor& self, int64_t dim, const Tensor& index) { + const auto ndim = self.dim(); + TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor."); + TORCH_CHECK_INDEX( + index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided, + "index_select() argument index must be 1-D strided (non-sparse) long-tensor."); + dim = maybe_wrap_dim(dim, ndim); + const auto size = self.size(dim); + const auto sparse_dim = self.sparse_dim(); + const auto dense_dim = self.dense_dim(); + const auto indices = self._indices(); + const auto values = self._values(); + const auto nnz = values.size(0); + const auto index_len = index.size(0); + auto res_sizes = self.sizes().vec(); + res_sizes[dim] = index_len; + + // If indexing into sparse dimensions + if (dim < sparse_dim) { + const auto make_output = [ + dim, sparse_dim, dense_dim, res_sizes, &self, &indices, &values + ]( + const Tensor& selected_dim_indices, + const Tensor& res_dim_indices + ) -> Tensor { + auto res_indices = indices.index_select(1, selected_dim_indices); + res_indices[dim] = res_dim_indices; + const auto res_values = values.index_select(0, selected_dim_indices); + + return at::_sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options()); + }; + + // short-circuit if index is empty + if (!index_len) { + return make_output(index, index); + } + + const auto nneg_index = [&index, size]() -> Tensor { + auto nneg_index = at::empty_like(index, at::MemoryFormat::Contiguous); + + auto iter = TensorIteratorConfig() + .add_output(nneg_index) + .add_input(index) + .build(); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_sparse_zoom", [&]() { + gpu_kernel(iter, [size] GPU_LAMBDA (index_t idx) -> index_t { + ZOOM_KERNEL_ASSERT(idx >= -size && idx < size + && "index_select(): index out of bounds"); + return idx < 0 ? idx + size : idx; + }); + }); + return nneg_index; + }(); + + const auto dim_indices = indices[dim].contiguous(); + const auto idx_nneg_index = at::arange(index_len, nneg_index.options()); + const auto idx_dim_indices = at::arange(nnz, dim_indices.options()); + + Tensor sorted_dim_indices, argsort_dim_indices; + std::tie(sorted_dim_indices, argsort_dim_indices) = [&]() -> std::tuple { + if (dim == 0 && self.is_coalesced()) { + return std::make_tuple(dim_indices, idx_dim_indices); + } + else { + return dim_indices.sort(); + } + }(); + + Tensor intrsc_counts_nneg_index; + Tensor intrsc_first_match_nneg_index; + std::tie(intrsc_counts_nneg_index, intrsc_first_match_nneg_index) = [&]() -> std::tuple { + auto intrsc_counts_nneg_index = at::zeros_like(nneg_index); + auto intrsc_first_match_nneg_index = at::zeros_like(nneg_index); + + auto iter = TensorIteratorConfig() + .add_output(intrsc_first_match_nneg_index) + .add_input(nneg_index) + .add_input(idx_nneg_index) + .build(); + + AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_zoom", [&]() { + index_t* ptr_intrsc_counts_nneg_index = intrsc_counts_nneg_index.mutable_data_ptr(); + const index_t* ptr_sorted_dim_indices = sorted_dim_indices.const_data_ptr(); + gpu_kernel( + iter, + [ptr_intrsc_counts_nneg_index, ptr_sorted_dim_indices, nnz] GPU_LAMBDA ( + index_t idx_val, index_t idx_idx + ) -> index_t { + auto* lb = find_bound( + ptr_sorted_dim_indices, + ptr_sorted_dim_indices + nnz, + idx_val + ); + auto* ub = find_bound( + ptr_sorted_dim_indices, + ptr_sorted_dim_indices + nnz, + idx_val + ); + const auto idx_count = ub - lb; + ptr_intrsc_counts_nneg_index[idx_idx] = idx_count; + + return lb - ptr_sorted_dim_indices; + } + ); + }); + + return std::make_tuple(intrsc_counts_nneg_index, intrsc_first_match_nneg_index); + }(); + + // Unavoidable sync since the shape of the result is not known in advance + auto res_len = intrsc_counts_nneg_index.sum().item(); + // Short-circuit if empty intersection + if (!res_len) { + auto empty_idx = at::empty({0}, nneg_index.options()); + return make_output(empty_idx, empty_idx); + } + + Tensor selected_dim_indices, res_dim_indices; + std::tie(selected_dim_indices, res_dim_indices) = [&]() -> std::tuple { + auto res_dim_indices = at::empty({res_len}, nneg_index.options()); + auto selected_dim_indices = at::empty_like(res_dim_indices); + auto selected_dim_indices_offsets = intrsc_counts_nneg_index.cumsum(0) + .sub_(intrsc_counts_nneg_index); + + // Need to have output as TensorIterator does not allow having void lambdas. + auto dummy_output = at::empty({1}, dim_indices.options()).expand(IntArrayRef({index_len})); + auto iter = TensorIteratorConfig() + .add_output(dummy_output) + // All iterations map to a single element in dummy_output by design, + // hence removed output memory overlap check. + .set_check_mem_overlap(false) + .add_input(idx_nneg_index) + .add_input(intrsc_counts_nneg_index) + .add_input(selected_dim_indices_offsets) + .add_input(intrsc_first_match_nneg_index) + .build(); + + AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_zoom", [&]() { + index_t* ptr_res_dim_indices = res_dim_indices.mutable_data_ptr(); + index_t* ptr_selected_dim_indices = selected_dim_indices.mutable_data_ptr(); + const index_t* ptr_argsort_dim_indices = argsort_dim_indices.const_data_ptr(); + gpu_kernel( + iter, + [ptr_res_dim_indices, ptr_selected_dim_indices, ptr_argsort_dim_indices] GPU_LAMBDA ( + index_t idx_idx, index_t count, index_t offset, index_t first_match + ) -> index_t { + index_t* __restrict__ ptr_res_dim_indices_out = ptr_res_dim_indices + offset; + const index_t* __restrict__ ptr_argsort_dim_indices_in = ptr_argsort_dim_indices + first_match; + index_t* __restrict__ ptr_selected_dim_indices_out = ptr_selected_dim_indices + offset; + for (index_t i = 0; i < count; ++i) { + *ptr_res_dim_indices_out++ = idx_idx; + *ptr_selected_dim_indices_out++ = *ptr_argsort_dim_indices_in++; + } + + // A dummy return scalar for a dummy output + return static_cast(1); + } + ); + }); + + return std::make_tuple(selected_dim_indices, res_dim_indices); + }(); + + return make_output(selected_dim_indices, res_dim_indices); + } + // If indexing into dense dimensions + else { + // It is sufficient to just perform `index_select` on values + // if `dim` refers to dense dimensions. + const auto res_values = values.index_select(dim - sparse_dim + 1, index); + + return _sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, res_sizes, indices, res_values, self.options()); + } +} + +} // at::native diff --git a/aten/src/ATen/native/zoom/KernelUtils.cuh b/aten/src/ATen/native/zoom/KernelUtils.cuh new file mode 100644 index 00000000000000..99c66efe21ffc4 --- /dev/null +++ b/aten/src/ATen/native/zoom/KernelUtils.cuh @@ -0,0 +1,97 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include + +namespace at { +namespace native { + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +// for channels-last +__device__ __forceinline__ size_t +idx_cl( + const size_t n, const size_t h, const size_t w, const size_t c, + const size_t height, const size_t width, const size_t channel +) { + return ((n * height + h) * width + w) * channel + c; +} + +// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization +// that speed up half-precision atomics. The situation with half +// precision atomics is that we have a slow __half atomic, and +// a fast vectored __half2 atomic (this can be worth up to a 6x +// speedup, see https://github.com/pytorch/pytorch/pull/21879). +// We can convert a __half atomic into a __half2 atomic by simply +// pairing the __half with a zero entry on the left/right depending +// on alignment... but only if this wouldn't cause an out of bounds +// access! Thus, you must specify tensor and numel so we can check +// if you would be out-of-bounds and use a plain __half atomic if +// you would be. +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +} + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); + +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value && !std::is_same::value >::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn(tensor + index, value); +} + +template +__device__ __forceinline__ void fastAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value, + bool fast_atomics) { + if (fast_atomics) { + fastSpecializedAtomicAdd(tensor, index, numel, value); + } else { + gpuAtomicAddNoReturn(tensor + index, value); + } +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/LegacyThrustHelpers.cu b/aten/src/ATen/native/zoom/LegacyThrustHelpers.cu new file mode 100644 index 00000000000000..6379b68b9479b4 --- /dev/null +++ b/aten/src/ATen/native/zoom/LegacyThrustHelpers.cu @@ -0,0 +1,113 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) { + sorted_indices.copy_(linearIndex); + const hipStream_t stream = c10::zoom::getCurrentZoomStream(); + at::zoom::ThrustAllocator allocator; + auto policy = thrust::hip::par(allocator).on(stream); + + using device_ptr = thrust::device_ptr; + + // Fill sortedOrigIndices with sequential indices + const auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.mutable_data_ptr()); + thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); + + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + // Sort; a stable sort is not required + // NB - not passing comparator causes thrust to use radix sort, and it hurts perf A LOT, at least for medium (few K) sized indices + auto sorted_data = device_ptr(sorted_indices.mutable_data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp()); +} + +#if !CUB_SUPPORTS_SCAN_BY_KEY() + +template +void embedding_dense_backward_zoom_scan(Tensor &sorted_indices, Tensor &count) { + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + at::zoom::ThrustAllocator allocator; + auto policy = thrust::hip::par(allocator).on(stream); + + auto num_indices = count.numel(); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = thrust::device_ptr(sorted_indices.const_data_ptr()); + auto count_data = thrust::device_ptr(count.mutable_data_ptr()); + thrust::inclusive_scan_by_key( + policy, + sorted_data, + sorted_data + num_indices, + thrust::make_constant_iterator(1), + count_data + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::equal_to(), + thrust::maximum() + ); +} + +template +void embedding_dense_backward_zoom_scan(Tensor &sorted_indices, Tensor &count); +template +void embedding_dense_backward_zoom_scan(Tensor &sorted_indices, Tensor &count); + +#endif + +template +int64_t embedding_backward_zoom_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { + auto stream = c10::zoom::getCurrentZoomStream(); + at::zoom::ThrustAllocator allocator; + auto policy = thrust::hip::par(allocator).on(stream); + const ptrdiff_t numel = sorted_indices.numel(); + auto sorted_indices_dev = thrust::device_ptr(sorted_indices.const_data_ptr()); + auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto dummy_dev = thrust::device_ptr(dummy.mutable_data_ptr()); + auto ends = thrust::unique_by_key_copy( + policy, + sorted_indices_dev, + sorted_indices_dev + numel, + thrust::make_counting_iterator(0), + dummy_dev, + thrust::device_ptr(segment_offsets.mutable_data_ptr())); + return thrust::get<0>(ends) - dummy_dev; +} + +template +int64_t embedding_backward_zoom_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); +template +int64_t embedding_backward_zoom_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/LogcumsumexpKernel.cu b/aten/src/ATen/native/zoom/LogcumsumexpKernel.cu new file mode 100644 index 00000000000000..13f8c9af5af6a4 --- /dev/null +++ b/aten/src/ATen/native/zoom/LogcumsumexpKernel.cu @@ -0,0 +1,124 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +#include +#include + +#include +#include + +namespace at::native { + +// custom min and max to be used in logcumsumexp for complex arguments +template +__host__ __device__ c10::complex _logcumsumexp_minmax(const c10::complex& x, const c10::complex& y) { + scalar_t xr = std::real(x); + scalar_t yr = std::real(y); + if (::isnan(yr) || (::isnan(std::imag(y)))) { + return y; + } else if (::isnan(xr) || (::isnan(std::imag(x)))) { + return x; + } else if (min) { // min + return (xr < yr) ? x : y; + } else { // max + return (xr >= yr) ? x : y; + } +} + +template +__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) { + // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + // Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM + auto isnan_x = at::_isnan(x); + auto isnan_y = at::_isnan(y); + scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y)); + scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y)); + if (min != max || ::isfinite(min)) { + // nan will be propagated here + return ::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite cases + return x; + } +} + +template +__host__ __device__ c10::complex _fast_build_exp(const c10::complex& x) { + // complex exponential function, but implemented manually to get fast compilation time + // this function only handles the case where the x is finite (not inf nor nan) + auto xreal = std::real(x); + auto ximag = std::imag(x); + auto exp_x_abs = std::exp(xreal); + auto exp_x_real = exp_x_abs * std::cos(ximag); + auto exp_x_imag = exp_x_abs * std::sin(ximag); + return {exp_x_real, exp_x_imag}; +} + +template +__host__ __device__ c10::complex _fast_build_exp_inf(const c10::complex& x) { + // complex exponential function, but implemented manually to get fast compilation time + // this function only handles the case where the real part of x is infinite + auto ximag = std::imag(x); + auto exp_x_abs = std::numeric_limits::infinity(); + auto sin = std::sin(ximag); + auto cos = std::cos(ximag); + // special case if the angle is exactly the multiple of pi/2 + auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos; + auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin; + return {exp_x_real, exp_x_imag}; +} + +template +__host__ __device__ c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) { + c10::complex min = _logcumsumexp_minmax(x, y); + c10::complex max = _logcumsumexp_minmax(x, y); + scalar_t min_real = std::real(min); + scalar_t max_real = std::real(max); + + if (::isnan(min_real) || ::isnan(std::imag(min))) { + // handling the "infectious" NaNs + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } + else if ((!::isfinite(min_real)) && (min_real == max_real)) { + if (min_real < 0) { + // handle the -inf case, the imaginary part here does not really matter as the exp(value) + // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined. + // It does not matter if we're taking the exp of this value + return min; + } else { + // handle the +inf case, we don't need the special precision for log1p for small values + // and to avoid producing nan in case of real(max) == real(min) == +inf + auto exp_min = _fast_build_exp_inf(min); + auto exp_max = _fast_build_exp_inf(max); + return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log + } + } else { + auto minmax = min - max; + auto exp_minmax = _fast_build_exp(minmax); + return ::log1p(exp_minmax) + max; + } +} + +void launch_logcumsumexp_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) { +// Compile time for CUDA-11.4 is 3x slower than with CUDA-11.6+, specifically for complex numbers +#if defined(FBCODE_CAFFE2) || defined(OVRSOURCE) +#define _LCME_DISPATCH AT_DISPATCH_FLOATING_TYPES_AND2 +#else +#define _LCME_DISPATCH AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 +#endif + _LCME_DISPATCH(ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "logcumsumexp_zoom", + [&]() { + using opmath_t = at::opmath_type; + scalar_t init = -std::numeric_limits::infinity(); + auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x_, const scalar_t y_) -> scalar_t { + const opmath_t x{x_}, y{y_}; + return _log_add_exp_helper(x, y); + }; + scan_dim(self, result, dim, init, log_add_exp); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Math.cuh b/aten/src/ATen/native/zoom/Math.cuh new file mode 100644 index 00000000000000..c7085693ee2240 --- /dev/null +++ b/aten/src/ATen/native/zoom/Math.cuh @@ -0,0 +1,3026 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace native { +// See note [Jiterator] +// TODO: elaborate in this comment on the structure of math.cuh + +const auto ndtri_string = jiterator_stringify( + /* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ + template + T polevl(const T x, const T A[], const int len) { + // NOTE: This `polevl` is different from other `polevl` + // implementation (in PyTorch) which expect the `len` to be + // `len(A) - 1` instead of `len(A)`. + T result = 0; + for (int i = 0; i < len; ++i) { + result = result * x + A[i]; + } + return result; + } + + /* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ + template + T ndtri(T y0) { + + constexpr T zero = 0; + constexpr T one = 1; + + // Handles special cases + if (y0 == zero) { + return NEG_INFINITY; + } + if (y0 == one) { + return POS_INFINITY; + } + if (y0 < zero || y0 > one) { + return NAN; + } + + bool code = true; + T y = y0; + // Note: the constant 0.135... is equal to exp(-2) + if (y > one - T{0.13533528323661269189}) { + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, int{5}) / polevl(y2, Q0, int{9})); + return x * s2pi; + } + + T x = sqrt(T{-2.} * log(y)); + const T x0 = x - (log(x) / x); + + const T z = one / x; + T x1; + + /* y > exp(-32) = 1.2664165549e-14 */ + if (x < T{8.0}) { + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + x1 = z * polevl(z, P1, int{9}) / polevl(z, Q1, int{9}); + } else { + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + x1 = z * polevl(z, P2, int{9}) / polevl(z, Q2, int{9}); + } + + x = x0 - x1; + return (!code) ? x : -x; + } +); // ndtri_string + +const auto log_ndtr_string = jiterator_stringify( + template + T log_ndtr(T x) { + constexpr T SQRT1_2{0.707106781186547524400844362104849039}; // 1/sqrt(2) + T t = x * SQRT1_2; + if (x < T{-1.0}) { + return log(erfcx(-t) / 2) - t * t; + } else { + return log1p(-erfc(t) / 2); + } + } +); // log_ndtr_string + +const auto gcd_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } +); // gcd_string + +const auto lcm_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } + + template + T lcm(const T a, const T b) { + T g = gcd(a, b); + return (g == T{0}) ? T{0} : abs(a / g * b); + } +); // lcm_string + +/* + * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma +const auto digamma_string = jiterator_stringify( + template + T digamma(T x) { + static const double PI_f64 = 3.14159265358979323846; + + // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard + if (x == 0) { + return copysign(POS_INFINITY, -x); + } + + T result = 0; + if (x < 0) { + // Short-circuits if x is a negative integer and returns NaN + // per the C++ standard + const bool x_is_integer = (x == trunc(x)); + if (x_is_integer) { + return NAN; + } + + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = modf(static_cast(x), &q); + result = - PI_f64 / tan(PI_f64 * r); + x = 1 - x; + } + + while (x < T{10}) { + result -= T{1} / x; + x += T{1}; + } + + if (x == T{10}) { + return result + T{2.25175258906672110764}; + } + + T y = 0; + if (x < T{1.0e17}) { + const T A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + + T z = T{1} / (x * x); + + T polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return log(x) - (T{0.5} / x) - y + result; + } +); // digamma_string + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +const auto zeta_string = jiterator_stringify( + template + T zeta(T x, T q) { + const T MACHEP{1.11022302462515654042E-16}; + constexpr T zero{0}; + constexpr T half{0.5}; + constexpr T one{1}; + static const T A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + T a, b, k, s, t, w; + + // Short-circuits x -> +infty + if (x == one) { + return POS_INFINITY; + } + + // Short-circuits x < 1 -> NaN + if (x < one) { + return NAN; + } + + // Short-circuits negative q integers map to +infty, + // negative q non-integers map to NaN + if (q <= zero) { + if (q == floor(q)) { + return POS_INFINITY; + } + if (x != floor(x)) { + return NAN; + } + } + + s = pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= T{9.0})) { + i += 1; + a += one; + b = pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return s; + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = fabs(t / s); + + if (t < MACHEP) { + return s; + } + + k += one; + a *= x + k; + b /= w; + k += one; + } + + return s; + } +); // zeta_string + +const auto trigamma_string = jiterator_stringify( + template + T trigamma(T x) { + const T PI{3.14159265358979323846}; + T sign = 1; + T result = 0; + + if (x < T{0.5}) { + sign = -1; + T sin_pi_x = sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + + for (int i = 0; i < 6; ++i) { + result += T{1} / (x * x); + x += 1; + } + + const T one{1}; + const T ixx = one / (x*x); + result += (one + one / (T{2}*x) + ixx * (one/T{6} - ixx * (one/T{30} - ixx * (one/T{42})))) / x; + return sign * result; +} +); // trigamma_string + +const auto lgamma_string = jiterator_stringify( + template + T lgamma_kernel(T a) { + return lgamma(a); + } +); // lgamma_string + +const auto polygamma_string = zeta_string + jiterator_stringify( + template + T polygamma(T x, int n) { + // already blocked if n <= 1 + const auto one = T{1}; + return ((n % 2) ? one : -one) * exp(lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); + } +); // polygamma_string + +const auto exp2_string = jiterator_stringify( + template + T exp2_impl(T a) { + return exp2(a); + } + + namespace std { template class complex; } + template + std::complex exp2_impl(std::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + const auto ln_2 = static_cast(0.693147180559945309417232121458176); + return exp(ln_2 * x); + } + + template + T exp2_kernel(T a) { + return exp2_impl(a); + } +); // exp2_string + +const auto erfc_string = jiterator_stringify( + template + T erfc_kernel(T a) { + return erfc(a); + } +); // erfc_string + +const auto erfinv_string = jiterator_stringify( + template + T erfinv_kernel(T a) { + return erfinv(a); + } +); // erfinv_string + +const auto entr_string = jiterator_stringify( + template + T entr(T a) { + if (a != a) { + return a; + } + + if (a > 0) { + return -a * log(a); + } + + if (a == 0) { + return 0; + } + + return NEG_INFINITY; + } +); // entr_string + +// NOTE: `kaiser_window_string` depends on `i0_string` +// for its implementation. +const auto i0_string = jiterator_stringify( + template + T chbevl(T x, const T array[], const int len) { + + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i0(T _x) { + T x = fabs(_x); + + if (x <= T{8.0}) { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T A[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return exp(x) * chbevl(y, A, int{30}); + } + + // Handles x > 8 case + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + const T B[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return (exp(x) * chbevl(T{32.0} / x - T{2.0}, B, int{25})) / sqrt(x); + } +); // i0_string + +const auto i1_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i1(T _x) { + const T x = fabs(_x); + + if (x <= T{8.0}) { + // Chebyshev coefficients for exp(-x) i1(x) in the internal [0, 8] + // lim(x->0){ exp(-x) i1(x) / x } = 1/2 + static const T coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const T y = x / T{2.0} - T{2.0}; + const T out = exp(x) * x * chbevl(y, coefficients, int{29}); + return (_x < T{0.0}) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval [8, infinity] + // lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi) + static const T coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + const T out = (exp(x) * chbevl(T{32.} / x - T{2.}, coefficients, int{25})) / sqrt(x); + return (_x < T{0.}) ? -out : out; + } +); // i1_string + +const auto i1e_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + // See double and float instantiations below + template + T i1e(T _x) { } + + // Double specialization (uses different coefficients than the float version) + template<> + double i1e(double _x) { + const double x = fabs(_x); + if (x <= double{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const double coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const double y = x / double{2.} - double{2.}; + const double out = chbevl(y, coefficients, int{29}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const double coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + const double out = chbevl(double{32.} / x - double{2.}, coefficients, int{25}) / sqrt(x); + return (_x < double{0.}) ? -out : out; + } + + // Float specialization (uses different coefficients than the double version) + template<> + float i1e(float _x) { + const float x = fabsf(_x); + if (x <= float{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const float coefficients[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + const float y = x / float{2.} - float{2.}; + const float out = chbevl(y, coefficients, int{17}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const float coefficients[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + const float out = chbevl(float{32.} / x - float{2.}, coefficients, int{7}) / sqrt(x); + return (_x < float{0.}) ? -out : out; + } +); // i1e_string + +const auto kaiser_window_string = i0_string + jiterator_stringify( + template + T kaiser_window(T a, T inv_alpha, T beta, T inv_i0_beta) { + T x = a * inv_alpha - T{1}; + T y = max(T{0}, T{1} - x * x); + return i0(beta * sqrt(y)) * inv_i0_beta; + } +); // kaiser_window_string + +const auto sinc_string = jiterator_stringify( + template + T sinc(T a) { + if (a == T(0)) { + return T(1); + } else { + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; + } + } +); // sinc_string + +const auto erfcx_string = jiterator_stringify( + /* The next function is taken from http://ab-initio.mit.edu/Faddeev */ + + /* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + /* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived therefrom] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. + */ + + /* Given y100 = 100 * y, where y = 4 / (4 + x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. + */ + + // TODO: review if this is computing in double when given a float input + template + T erfcx_y100(T y100) { + switch (static_cast(y100)) { + case 0: { + T t = 2*y100 - 1; + return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 1: { + T t = 2*y100 - 3; + return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 2: { + T t = 2*y100 - 5; + return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 3: { + T t = 2*y100 - 7; + return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; + } + case 4: { + T t = 2*y100 - 9; + return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 5: { + T t = 2*y100 - 11; + return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 6: { + T t = 2*y100 - 13; + return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 7: { + T t = 2*y100 - 15; + return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; + } + case 8: { + T t = 2*y100 - 17; + return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 9: { + T t = 2*y100 - 19; + return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 10: { + T t = 2*y100 - 21; + return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 11: { + T t = 2*y100 - 23; + return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 12: { + T t = 2*y100 - 25; + return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 13: { + T t = 2*y100 - 27; + return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 14: { + T t = 2*y100 - 29; + return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 15: { + T t = 2*y100 - 31; + return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 16: { + T t = 2*y100 - 33; + return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 17: { + T t = 2*y100 - 35; + return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 18: { + T t = 2*y100 - 37; + return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; + } + case 19: { + T t = 2*y100 - 39; + return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 20: { + T t = 2*y100 - 41; + return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 21: { + T t = 2*y100 - 43; + return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 22: { + T t = 2*y100 - 45; + return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 23: { + T t = 2*y100 - 47; + return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 24: { + T t = 2*y100 - 49; + return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 25: { + T t = 2*y100 - 51; + return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 26: { + T t = 2*y100 - 53; + return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 27: { + T t = 2*y100 - 55; + return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 28: { + T t = 2*y100 - 57; + return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 29: { + T t = 2*y100 - 59; + return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 30: { + T t = 2*y100 - 61; + return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 31: { + T t = 2*y100 - 63; + return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 32: { + T t = 2*y100 - 65; + return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 33: { + T t = 2*y100 - 67; + return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 34: { + T t = 2*y100 - 69; + return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; + } + case 35: { + T t = 2*y100 - 71; + return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 36: { + T t = 2*y100 - 73; + return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 37: { + T t = 2*y100 - 75; + return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 38: { + T t = 2*y100 - 77; + return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 39: { + T t = 2*y100 - 79; + return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 40: { + T t = 2*y100 - 81; + return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 41: { + T t = 2*y100 - 83; + return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 42: { + T t = 2*y100 - 85; + return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 43: { + T t = 2*y100 - 87; + return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 44: { + T t = 2*y100 - 89; + return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 45: { + T t = 2*y100 - 91; + return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 46: { + T t = 2*y100 - 93; + return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 47: { + T t = 2*y100 - 95; + return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 48: { + T t = 2*y100 - 97; + return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 49: { + T t = 2*y100 - 99; + return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 50: { + T t = 2*y100 - 101; + return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; + } + case 51: { + T t = 2*y100 - 103; + return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 52: { + T t = 2*y100 - 105; + return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 53: { + T t = 2*y100 - 107; + return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; + } + case 54: { + T t = 2*y100 - 109; + return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 55: { + T t = 2*y100 - 111; + return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 56: { + T t = 2*y100 - 113; + return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 57: { + T t = 2*y100 - 115; + return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 58: { + T t = 2*y100 - 117; + return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 59: { + T t = 2*y100 - 119; + return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 60: { + T t = 2*y100 - 121; + return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 61: { + T t = 2*y100 - 123; + return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 62: { + T t = 2*y100 - 125; + return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 63: { + T t = 2*y100 - 127; + return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 64: { + T t = 2*y100 - 129; + return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 65: { + T t = 2*y100 - 131; + return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 66: { + T t = 2*y100 - 133; + return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 67: { + T t = 2*y100 - 135; + return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 68: { + T t = 2*y100 - 137; + return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 69: { + T t = 2*y100 - 139; + return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 70: { + T t = 2*y100 - 141; + return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 71: { + T t = 2*y100 - 143; + return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 72: { + T t = 2*y100 - 145; + return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 73: { + T t = 2*y100 - 147; + return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; + } + case 74: { + T t = 2*y100 - 149; + return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 75: { + T t = 2*y100 - 151; + return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 76: { + T t = 2*y100 - 153; + return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 77: { + T t = 2*y100 - 155; + return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 78: { + T t = 2*y100 - 157; + return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 79: { + T t = 2*y100 - 159; + return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 80: { + T t = 2*y100 - 161; + return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 81: { + T t = 2*y100 - 163; + return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 82: { + T t = 2*y100 - 165; + return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 83: { + T t = 2*y100 - 167; + return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 84: { + T t = 2*y100 - 169; + return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 85: { + T t = 2*y100 - 171; + return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 86: { + T t = 2*y100 - 173; + return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 87: { + T t = 2*y100 - 175; + return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 88: { + T t = 2*y100 - 177; + return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 89: { + T t = 2*y100 - 179; + return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 90: { + T t = 2*y100 - 181; + return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 91: { + T t = 2*y100 - 183; + return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 92: { + T t = 2*y100 - 185; + return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 93: { + T t = 2*y100 - 187; + return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 94: { + T t = 2*y100 - 189; + return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 95: { + T t = 2*y100 - 191; + return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 96: { + T t = 2*y100 - 193; + return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 97: { + T t = 2*y100 - 195; + return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; + } + case 98: { + T t = 2*y100 - 197; + return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 99: { + T t = 2*y100 - 199; + return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; + } + } + + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.; + } + + template + T erfcx(T x) { + // Short-circuits on NaN (returning NaN) + if (x != x) { + return x; + } + + if (x >= 0) { + if (x > T{50}) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + + if (x > T{5e7}) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi * ((x*x) * (x*x+T{4.5}) + T{2}) / (x * ((x*x) * (x*x+T{5}) + T{3.75})); + } + + // x >= 0 x <= 50 + return erfcx_y100(T{400} / (T{4} + x)); + } + + // x < 0 + if (x < T{-26.7}) { + return POS_INFINITY; + } else if (x < T{-6.1}) { + return T{2} * exp(x * x); + } + + // x < 0 and x >= -6.1 + return T{2} * exp(x * x) - erfcx_y100(T{400} / (T{4} - x)); + } +); // erfcx_string + +const auto airy_ai_string = jiterator_stringify( + template + T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (isinf(x)) { + return NAN; + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * sqrt(-x) / T(3.0) + T(0.25) * T(3.14159265358979323846); + + return T(5.64189583547756286948e-01) / sqrt(sqrt(-x)) * (sin(t) * (T(1.0) + z * z * afn / afd) - cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * sqrt(sqrt(x)) * exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > T(1.11022302462515654042e-16)) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; + } // T airy_ai(T x) +); // airy_ai_string + +const auto bessel_j0_string = jiterator_stringify( + template + T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j0_forward(T x) +); // bessel_j0_string + +const auto bessel_y0_string = bessel_j0_string + jiterator_stringify( + template + T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x < T(0.0)) { + NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y0_forward(T x) +); // bessel_y0_string + +const auto bessel_j1_string = jiterator_stringify( + template + T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j1_forward(T x) +); // bessel_j1_string + +const auto bessel_y1_string = bessel_j1_string + jiterator_stringify( + template + T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x <= T(0.0)) { + return NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y1_forward(T x) +); // bessel_y1_string + +const auto chebyshev_polynomial_t_string = jiterator_stringify( + template + T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x) < T(1.0))) { + return cos(n * acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); + } // chebyshev_polynomial_t_forward(T x, T n) +); // chebyshev_polynomial_t_string + +const auto chebyshev_polynomial_u_string = jiterator_stringify( + template + T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x)) != T(0.0)) { + return sin((n + 1) * acos(x)) / sin(acos(x)); + } + + return (n + 1) * cos((n + 1) * acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); + } // chebyshev_polynomial_u_forward(T x, T n) +); // chebyshev_polynomial_u_string + +const auto chebyshev_polynomial_v_string = jiterator_stringify( + template + T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x) / T(2.0)) != T(1.0)) { + return cos((n + T(0.5)) * acos(x)) / cos(acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); + } // chebyshev_polynomial_v_forward(T x, T n) +); // chebyshev_polynomial_v_string + +const auto chebyshev_polynomial_w_string = jiterator_stringify( + template + T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (cos(acos(x) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x)) / sin(acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); + } // chebyshev_polynomial_w_forward(T x, T n) +); // chebyshev_polynomial_w_string + +const auto hermite_polynomial_h_string = jiterator_stringify( + template + T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_h_forward(T x, int64_t n) + + template + T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); + } // hermite_polynomial_h_forward(T x, T n) +); // hermite_polynomial_h_string + +const auto hermite_polynomial_he_string = jiterator_stringify( + template + T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_he_forward(T x, int64_t n) + + template + T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); + } // hermite_polynomial_he_forward(T x, T n) +); // hermite_polynomial_he_string + +const auto laguerre_polynomial_l_string = jiterator_stringify( + template + T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // laguerre_polynomial_l_forward(T x, int64_t n) + + template + T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); + } // laguerre_polynomial_l_forward(T x, T n) +); // laguerre_polynomial_l_string + +const auto legendre_polynomial_p_string = jiterator_stringify( + template + T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // legendre_polynomial_p_forward(T x, int64_t n) + + template + T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); + } // legendre_polynomial_p_forward(T x, T n) +); // legendre_polynomial_p_string + +const auto modified_bessel_i0_string = jiterator_stringify( + template + T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return exp(abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i0_forward(T x) +); // modified_bessel_i0_string + +const auto modified_bessel_i1_string = jiterator_stringify( + template + T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * abs(x) * exp(abs(x))); + } + + return T(0.5) * (a - p) * abs(x) * exp(abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x))); + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i1_forward(T x) +); // modified_bessel_i1_string + +const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k0_forward(T x) +); // modified_bessel_k0_string + +const auto scaled_modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - log(T(0.5) * x) * modified_bessel_i0_forward(x)) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / sqrt(x); + } // T scaled_modified_bessel_k0_forward(T x) +); // scaled_modified_bessel_k0_string + +const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k1_forward(T x) +); // modified_bessel_k1_string + +const auto scaled_modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / sqrt(x)); + } // T scaled_modified_bessel_k1_forward(T x) +); // scaled_modified_bessel_k1_string + +const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + return cos(n * acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_t_forward(T x, T n) +); // shifted_chebyshev_polynomial_t_string + +const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0))) != T(0.0)) { + return sin((n + 1) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0))); + } + + return (n + 1) * cos((n + 1) * acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_u_forward(T x, T n) +); // shifted_chebyshev_polynomial_u_string + +const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return cos(((n) + T(0.5)) * acos(x + x - T(1.0))) / cos(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_v_forward(T x, T n) +); // shifted_chebyshev_polynomial_v_string + +const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (abs(x + x - T(1.0)) < T(1.0))) { + if (cos(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_w_forward(T x, T n) +); // shifted_chebyshev_polynomial_w_string + +const auto spherical_bessel_j0_string = jiterator_stringify( + template + T spherical_bessel_j0_forward(T x) { + if (isinf(x)) { + return T(0.0); + } + + if (abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return sin(x) / x; + } // T spherical_bessel_j0_forward(T x) +); // spherical_bessel_j0_string + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/PersistentSoftmax.cuh b/aten/src/ATen/native/zoom/PersistentSoftmax.cuh new file mode 100644 index 00000000000000..64919d846c41eb --- /dev/null +++ b/aten/src/ATen/native/zoom/PersistentSoftmax.cuh @@ -0,0 +1,402 @@ +#include +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace { + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. +// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. +// The template arguments have the following meaning: +// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. +// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. +// A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp. +// This is important because it means only __shfl_ instructions are required for reductions. +// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. +// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch. +// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs. +// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. +// is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed. +// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. +// This allows SoftMax to be fused with a cast immediately following the SoftMax. +// The mask should have the same shape as input, with a boolean indicate if the value is masked. +// The head_chunk_size is only used for transformer mask softmax, equals to H * D * D. +// For instance: +// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor. +// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. +// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. + +template +__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + int idx_offset = first_batch * stride + local_idx; + + src += idx_offset; + dst += idx_offset; + + if (is_transformer_mask) { + mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx; + } else { + mask += idx_offset; + } + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + elements[i][it] = src[i*element_count+it*WARP_SIZE]; + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + bool is_meaningful_max = false; + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (is_masked) { + int idx = it*WARP_SIZE; + if ((idx + local_idx) < batch_element_count) { + if (!is_transformer_mask) { + idx += i*element_count; + } + if (!mask[idx]) { + max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + is_meaningful_max = true; + } + } + } else { + max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it]; + } + } + if (is_masked) { + if (!is_meaningful_max) { + max_value[i] = -std::numeric_limits::infinity(); + } + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked) { + if (is_log_softmax) { + sum[i] += ::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = ::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + int idx = it*WARP_SIZE; + bool valid = (idx + local_idx) < batch_element_count; + if (!is_transformer_mask) { + idx += i*element_count; + } + if (valid) { + if (!mask[idx]) { + if (is_log_softmax) { + sum[i] += ::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = ::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + if (!is_log_softmax) { + // Masked values are treated as -infinity, and ::exp(-infinity) is 0. + elements[i][it] = 0; + } + } + } else { + if (!is_log_softmax) { + elements[i][it] = 0.; + } + } + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + if (is_log_softmax) sum[i] = ::log(sum[i]); + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_log_softmax) { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i]; + } else if (sum[i] == 0) { + dst[i*element_count+it*WARP_SIZE] = std::numeric_limits::quiet_NaN(); + } else { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i]; + } + } else { + break; + } + } + } +} + +template +__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x % WARP_SIZE; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + if (is_masked) { + mask += thread_offset; + } + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]; + output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; + } else { + grad_reg[i][it] = acc_t(0); + output_reg[i][it] = acc_t(0); + } + } + } + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) { + sum[i] += grad_reg[i][it]; + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_masked && mask[i*element_count+it*WARP_SIZE]) { + gradInput[i*element_count+it*WARP_SIZE] = 0; + } + // compute gradients + else if (is_log_softmax) { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - ::exp(output_reg[i][it]) * sum[i]); + } else { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); + } + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = at::zoom::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + hipLaunchKernelGGL(( softmax_warp_forward) \ + , dim3(blocks), dim3(threads), 0, c10::zoom::getCurrentZoomStream(), dst, \ + src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024 + default: + break; + } + } +} + +template +void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = at::zoom::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + hipLaunchKernelGGL(( softmax_warp_backward) \ + , dim3(blocks), dim3(threads), 0, c10::zoom::getCurrentZoomStream(), \ + grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements, mask); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 + default: + break; + } + } +} diff --git a/aten/src/ATen/native/zoom/Reduce.cuh b/aten/src/ATen/native/zoom/Reduce.cuh new file mode 100644 index 00000000000000..c22e4bd53f020d --- /dev/null +++ b/aten/src/ATen/native/zoom/Reduce.cuh @@ -0,0 +1,1354 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { namespace native { + +using at::detail::Array; + +static inline int64_t div_up(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +// returns floor(log2(n)) +static inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +// returns reduced fraction numerator & denominator +C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; +} + +//template for changing MAX_NUM_THREADS based on op dtype +template +struct mnt_wrapper { + static constexpr int MAX_NUM_THREADS = 512; +}; + +template <> +struct mnt_wrapper >{ + static constexpr int MAX_NUM_THREADS = 256; +}; + +constexpr int max_reduce_threads(c10::ScalarType type) { + return type == kComplexDouble ? 256 : 512; +} + +struct ReduceConfig { + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + + ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs) + : element_size_bytes(element_size_bytes) + , num_inputs(num_inputs) + , num_outputs(num_outputs) {} + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + template + void set_block_dimension(int64_t dim0, int64_t dim1) { + const int max_num_threads = mnt_wrapper::MAX_NUM_THREADS / output_vec_size; + int dim0_pow2 = dim0 < max_num_threads ? static_cast(last_pow2(dim0)) : max_num_threads; + int dim1_pow2 = dim1 < max_num_threads ? static_cast(last_pow2(dim1)) : max_num_threads; + block_width = std::min(dim0_pow2, int(at::zoom::warp_size())); + block_height = std::min(dim1_pow2, int(max_num_threads / block_width)); + block_width = std::min(dim0_pow2, int(max_num_threads / block_height)); + num_threads = block_width * block_height; + } + + int split_input(int parallelism) { + int step = step_input; + step_input *= parallelism; + return step; + } + + int split_output(int parallelism) { + int step = step_output; + step_output *= parallelism; + return step; + } + + dim3 block() const { + return dim3(block_width, block_height); + } + + dim3 grid() const { + return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output); + } + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + int shared_memory_size() const { + if (!should_block_y_reduce() && + (!should_block_x_reduce() || + block_width <= at::zoom::warp_size())) { + return 0; + } + return element_size_bytes * num_threads * output_vec_size; + } + + int64_t global_memory_size() const { + if (!should_global_reduce()) { + return 0; + } + auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output; + if (!should_block_x_reduce()) { + size *= block().x * output_vec_size; + } + return size; + } + + int semaphore_size() const { + if (!should_global_reduce()) { + return 0; + } + return sizeof(int) * grid().x; + } + + int values_per_thread() const { + return div_up(num_inputs, step_input); + } +}; + +std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void reduce_kernel(R reduction) { + reduction.template run(); +} + +template +static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int num_output_dims = iter.ndim() - num_reduce_dims; + int input_index = iter.ntensors() - 1; + int output_index = 0; + std::array strides = { + iter.strides(output_index).data() + num_reduce_dims, + iter.strides(input_index).data() + num_reduce_dims, + }; + auto shape = iter.shape().data() + num_reduce_dims; + return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data()); +} + +template +static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int input_index = iter.ntensors() - 1; + std::array strides = { + iter.strides(input_index).data(), + }; + return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data()); +} + +template +struct func_wrapper_t { + using arg_t = typename binary_function_traits::arg1_t; + using scalar_t = typename binary_function_traits::arg2_t; + + func_t combine; + static inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + func_wrapper_t(const func_t& op) : combine(op) { + } + + // wrap a normal reduction that ignores the index + __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const { + return combine(acc, val); + } +}; + +template +func_wrapper_t func_wrapper(const func_t& op) { + return func_wrapper_t { op }; +} + +template +struct ReduceJitOp { +//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations +//Maybe we can find a way to unify ReduceOp and ReduceJitOp + using InputCalculator = OffsetCalculator<1, uint32_t>; + using OutputCalculator = OffsetCalculator<2, uint32_t>; + //TODO for now arg_t is always opmath_t of the input, later we'll need to change it + using arg_t = at::opmath_type; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor, + //not just wrapper + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceJitOp( + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } +}; + +template +struct ReduceOp { + using traits = function_traits; + using arg_t = typename std::decay::type>::type; + + using InputCalculator = OffsetCalculator<1, index_t>; + using OutputCalculator = OffsetCalculator<2, index_t>; + + static constexpr bool can_accumulate_in_output = + std::is_convertible::value + && std::is_convertible::value; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + ops_t ops; + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceOp( + ops_t ops, + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ops(ops), + ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } + + template + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + index_t output_idx = config.output_idx(); + index_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = at::detail::Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + value = thread_reduce(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + + using out_ptr_vec_t = at::detail::Array; + using offset_vec_t = at::detail::Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE at::detail::Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + ZOOM_KERNEL_ASSERT(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](index_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + index_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(at::native::memory::aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = at::native::memory::aligned_vector; + + index_t idx = config.input_idx(); + const index_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + while (idx * input_vec_size + input_vec_size - 1 < end) { + const auto values_vec = memory::load_vector(data, idx); + #pragma unroll + for (index_t i = 0; i < input_vec_size; i++) { + value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + index_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + const auto value = c10::load(data + idx); + value_list[0] = ops.reduce(value_list[0], value, idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = ops.combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE at::detail::Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + index_t idx = config.input_idx(); + const index_t end = config.num_inputs; + const index_t stride = config.step_input; + + using arg_vec_t = at::detail::Array; + using load_t = at::native::memory::aligned_vector; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + const auto offset = calc(idx + i * stride) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + } + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + const auto offset = calc(idx) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + idx += stride; + } + idx = idx_; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + + template + C10_DEVICE at::detail::Array block_x_reduce(at::detail::Array value, char* shared_memory) const { + using args_vec_t = at::detail::Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = ops.warp_shfl_down(value[i], offset); + value[i] = ops.combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE at::detail::Array block_y_reduce(at::detail::Array value, char* shared_memory) const { + using args_vec_t = at::detail::Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE at::detail::Array accumulate_in_output( + at::detail::Array out, + at::detail::Array value, + typename std::enable_if::type* = nullptr + ) const { + at::detail::Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = ops.combine(*(out[i]), value[i]); + } + return ret; + } + + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if::type* = nullptr + ) const { + ZOOM_KERNEL_ASSERT(!final_output); + return (out_scalar_t)value; + } + + // This function should never be called -- + // it's the version of `accumulate_in_output` + // when accumulation in the output is not possible. + template + C10_DEVICE at::detail::Array accumulate_in_output( + at::detail::Array, + at::detail::Array, + typename std::enable_if::type* = nullptr + ) const { + ZOOM_KERNEL_ASSERT(false); + return arg_t {}; + } + + // This function should never be called -- + // it's the version of `get_accumulated_output` + // when accumulation in the output is not possible. + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if::type* = nullptr + ) const { + ZOOM_KERNEL_ASSERT(false); + return *out; + } + + template + C10_DEVICE void set_results(const T x, const index_t base_offset) const { + ZOOM_KERNEL_ASSERT(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + + //Currently implemented for max of two outputs + template + C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { + if (noutputs >= 1) { + auto res0 = (T1*)((char*)dst[0] + base_offset); + *res0 = x.first; + } + if (noutputs >= 2) { + // base offset is computed assuming element size being sizeof(T1), so we need to make a + // correction to obtain the correct base offset + auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); + *res1 = x.second; + } + } + + template + C10_DEVICE void set_results_to_output(at::detail::Array value, at::detail::Array base_offset) const { + ZOOM_KERNEL_ASSERT(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(ops.project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE at::detail::Array global_reduce(at::detail::Array value, at::detail::Array *acc, char* shared_memory) const { + using arg_vec_t = at::detail::Array; + using out_ptr_vec_t = at::detail::Array; + using offset_vec_t = at::detail::Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + index_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + index_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + value = ident; + if (config.should_block_x_reduce()) { + index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + index_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } else { + index_t input_offset = threadIdx.y; + index_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +template +static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + auto stream = c10::zoom::getCurrentZoomStream(); + int shared_memory = config.shared_memory_size(); + + switch(config.output_vec_size) { + case 4: + reduce_kernel<<>>(reduction); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + case 2: + reduce_kernel<<>>(reduction); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + default: + reduce_kernel<<>>(reduction); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +inline void launch_jitted_reduce_kernel( + std::mutex &jiterator_mutex, + std::array &fn_cache, + const at::zoom::jit::KernelDescriptor &desc, + int vt0, const ReduceConfig& config, void *reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + int shared_memory = config.shared_memory_size(); + at::zoom::jit::hiprtcFunction* fn_ptr; + switch(config.output_vec_size) { + case 4: + fn_ptr = &fn_cache[0]; + break; + case 2: + fn_ptr = &fn_cache[1]; + break; + default: + fn_ptr = &fn_cache[2]; + } + if (!fn_ptr->function) { + int max_threads_codegen = + max_reduce_threads(desc.f_inputs_type) / config.output_vec_size; + auto code = at::zoom::jit::generate_reduction_code( + desc, vt0, true, false, config.output_vec_size, max_threads_codegen); + + *fn_ptr = at::zoom::jit::jit_pwise_function(code, "reduction_" + desc.name); + } + constexpr int kernel_args = 1; + void* args[kernel_args]; + args[0] = reduction; + at::zoom::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory); +} + + +class AccumulationBuffer { + public: + AccumulationBuffer() {} + + AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) { + out_ptr_ = (char*)out_ptr; + if (out_t_size >= acc_t_size) { + // reusing output buffer for accumulation. + acc_ptr_ = (char*)out_ptr; + numerator_ = 1; + denominator_ = 1; + } else { + auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); + buffer_ = allocator.allocate(size); + acc_ptr_ = (char*)buffer_.get(); + numerator_ = acc_t_size; + denominator_ = out_t_size; + reduce_fraction(numerator_, denominator_); + } + } + + char* get_acc_slice(char* out_ptr) { + if (acc_ptr_ == nullptr) { + return nullptr; + } + return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_); + } + + private: + char* acc_ptr_ = nullptr; + char* out_ptr_ = nullptr; + size_t numerator_; + size_t denominator_; + at::DataPtr buffer_; +}; + +template +int get_output_vec_size(const TensorIterator &iter) { + int vec_size = 4; + auto update_vec_size = [&vec_size](uint64_t n) { + while(n % vec_size != 0) { + vec_size /= 2; + } + }; + + uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t); + update_vec_size(base_address); + + const int output_index = iter.num_reduce_dims(); + update_vec_size(iter.shape()[output_index]); + + int j = 0; + for(auto i : iter.strides(iter.noutputs())) { + if (j != output_index) { + update_vec_size(i / sizeof(scalar_t)); + } + j++; + } + return vec_size; +} + +template +ReduceConfig setReduceConfig(const TensorIterator& iter){ + // Start by assuming that each thread handles a single output and all + // the inputs for that output. + int64_t num_outputs = iter.num_output_elements(); + int64_t inputs_per_output = iter.numel() / num_outputs; + int input_index = iter.ntensors() - 1; + + auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output); + + int64_t dim0; + int64_t dim1; + int64_t fastest_moving_stride; + bool reduction_on_fastest_striding_dimension; + + if (iter.ndim() > 0) { + // Adjust block size to map block width to fastest changing dimension of input + // tensor. This grants the best possible memory accessing pattern, given that + // for non-contiguous tensor with space in between, we cannot have perfect + // memory coalescing. + reduction_on_fastest_striding_dimension = + (iter.num_reduce_dims() == iter.ndim()) || + (iter.strides(/*arg=*/input_index)[0] < + iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]); + // Notice that dim0 & dim1 does NOT guarantee any launch configuration here! + // dim0 & dim1 are more like the upper bound of the block dimension. The + // actual launch config and reduction scheme is determined by setting values + // to `config.input_mult` and `config.output_mult`. + // We try to max out dim1 so that we have enough threads per CTA to deliver + // performance for larger problem size. + if (reduction_on_fastest_striding_dimension) { + // Map block.x to the fastest reducing dimension. It implies: + // 1. block_x_reduce is required. + // 2. block.y now max out to num_outputs. + dim0 = inputs_per_output; + dim1 = num_outputs; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[0]; + } else { + // Map block.x to the fastest non reducing dimension. It implies: + // 1. block_x_reduce is turned off. + // 2. block.y now max out to inputs_per_output. + dim0 = num_outputs; + dim1 = inputs_per_output; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]; + } + } else { + reduction_on_fastest_striding_dimension = true; + fastest_moving_stride = sizeof(scalar_t); + dim0 = 1; + dim1 = 1; + } + + // We do vectorization to gain better memory access, there are two cases which we call + // "vectorize along input" and "vectorize along output". Note that the "input/output" + // here does not mean we are vectorizing load/store instructions. We always only vectorize + // load instructions. + // + // Case 1: "vectorize along input" + // This case happens when we are reducing along fastest moving dimesion. In such case, threads + // with the same threadIdx.y works on the same reduction cooperatively and will produce results + // for the same output. In such case, values in each loaded vector always correspond to the same output. + // + // Case 2: "vectorize along output" + // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case, + // threads with different threadIdx.x are independent and will produce results for different outputs. + // In such case, values in each loaded vector always correspond to different outputs. + if (fastest_moving_stride == sizeof(scalar_t)) { + if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) { + // Case 1: "vectorize along input" + // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, + // we should avoid vectorization. + config.vectorize_input = true; + dim0 /= config.input_vec_size; + } else if (!reduction_on_fastest_striding_dimension) { + // Case 2: "vectorize along output" + config.output_vec_size = get_output_vec_size(iter); + dim0 /= config.output_vec_size; + } + } + + // Adjust block_width and block_height + config.set_block_dimension(dim0, dim1); + + int block_width = config.block_width; + int block_height = config.block_height; + + if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) { + // Split the input across lanes if the input is contiguous in the reduced + // dimension. This will require reduction between threads using warp + // shuffle instructions and shared memory (if block_width > warpSize). + config.input_mult[0] = config.split_input(block_width); + } else { + // Otherwise split the output across lanes in a warp. + config.output_mult[0] = config.split_output(block_width); + } + + constexpr int min_values_per_thread = 16; + constexpr int max_values_per_thread = 256; + + if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { + // Divide the input across warps in a thread-block, if that leaves at least + // 16 elements to be summed by each thread. This will require inter-warp + // reduction using shared memory. + config.input_mult[1] = config.split_input(block_height); + } else { + // Otherwise, each warp handles a separate output. + config.output_mult[1] = config.split_output(block_height); + } + + const int blocks_per_sm = at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads; + const int num_mp = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + const int target_grid_size = num_mp * blocks_per_sm; + int grid = config.grid().x; + if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) { + // Divide the input across thread-blocks if the amount of work per-thread + // is large enough and the size of the output is small enough. This will + // require a reduction using global memory. + // If we decide to split input across blocks, as long as we can get enough + // number of blocks (`target_grid_size`) to balance SM, we should still + // make the number of values per thread large for best performance. + int ctas_per_output1 = div_up(target_grid_size, grid); + int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread); + int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread); + // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have + // a large number of values to deal with. But we don't want values_per_thread to be larger than + // max_values_per_thread + config.ctas_per_output = std::max(std::min(ctas_per_output1, ctas_per_output2), ctas_per_output3); + if (config.ctas_per_output > 1) { + config.input_mult[2] = config.split_input(config.ctas_per_output); + } + } + return config; +}; + +template +inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + using traits = function_traits; + using arg_t = typename traits::template arg<0>::type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same::value && + std::is_same::value) || + (std::is_same, scalar_t>::value && + std::is_same, out_scalar_t>::value); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same::value && + std::is_same::value); + static constexpr bool can_accumulate_in_output = + std::is_convertible::value && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t), + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(arg_t))); + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + gpu_reduce_kernel(sub_iter, ops, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = c10::zoom::getCurrentZoomStream(); + C10_ZOOM_CHECK(hipMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceOp( + ops, + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + launch_reduce_kernel::MAX_NUM_THREADS>(config, reduce); +} + +//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function +//try unifying with gpu_reduce_kernel +template +inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + //TODO - this will be different for more complicated reductions, but for now reductions using + //func_wrapper all have arg_t = opmath + using arg_t = at::opmath_type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same::value && + std::is_same::value) || + (std::is_same, scalar_t>::value && + std::is_same, out_scalar_t>::value); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same::value && + std::is_same::value); + static constexpr bool can_accumulate_in_output = + std::is_convertible::value && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(out_scalar_t))); //TODO + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + jitted_gpu_reduce_kernel(sub_iter, func, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + //TODO - for now we support a single input, we may be able to relax this constraint + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = c10::zoom::getCurrentZoomStream(); + C10_ZOOM_CHECK(hipMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceJitOp( + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + constexpr int nInputs = 1; + constexpr int nOutputs = 1; + static auto desc = at::zoom::jit::make_kernel_descriptor< + out_scalar_t, scalar_t>(name, func, nInputs, nOutputs); + + static std::mutex jiterator_mutex; + static std::vector> fn_cache(c10::zoom::device_count()); + auto &cache = fn_cache[iter.device().index()]; + + launch_jitted_reduce_kernel( + jiterator_mutex, cache, desc, vt0, config, &reduce); +} + +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ReduceLogicKernel.cu b/aten/src/ATen/native/zoom/ReduceLogicKernel.cu new file mode 100644 index 00000000000000..fb6bb731781358 --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceLogicKernel.cu @@ -0,0 +1,38 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void and_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "and_zoom", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) && static_cast(b)); + }), + true); + }); +} + +void or_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "or_zoom", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) || static_cast(b)); + }), + false); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(and_stub, &and_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(or_stub, &or_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ScanKernels.cpp b/aten/src/ATen/native/zoom/ScanKernels.cpp new file mode 100644 index 00000000000000..3bd21f18615d6a --- /dev/null +++ b/aten/src/ATen/native/zoom/ScanKernels.cpp @@ -0,0 +1,115 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +static c10::MaybeOwned contiguous_out_arg(const Tensor &tensor) { + if (tensor.is_contiguous()) { + return c10::MaybeOwned::borrowed(tensor); + } + return c10::MaybeOwned::owned(at::empty(tensor.sizes(), tensor.options())); +} + +void cummax_helper_zoom(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { + TensorArg output_arg{ values, "output", 1 }; + TensorArg indices_arg{ indices, "indices", 2 }; + TensorArg input_arg{ self, "input", 3 }; + checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); + + auto values_ = contiguous_out_arg(values); + auto indices_ = contiguous_out_arg(indices); + launch_cummax_zoom_kernel(self, *values_, *indices_, dim); + if (!values.is_same(*values_)) { + values.copy_(*values_); + } + if (!indices.is_same(*indices_)) { + indices.copy_(*indices_); + } +} + +void cummin_helper_zoom(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { + TensorArg output_arg{ values, "output", 1 }; + TensorArg indices_arg{ indices, "indices", 2 }; + TensorArg input_arg{ self, "input", 3 }; + checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); + + auto values_ = contiguous_out_arg(values); + auto indices_ = contiguous_out_arg(indices); + launch_cummin_zoom_kernel(self, *values_, *indices_, dim); + if (!values.is_same(*values_)) { + values.copy_(*values_); + } + if (!indices.is_same(*indices_)) { + indices.copy_(*indices_); + } +} + +Tensor& _logcumsumexp_out_zoom(const Tensor& self, int64_t dim, Tensor& result) { + const auto wrap_dim = maybe_wrap_dim(dim, self.dim()); + result.resize_(self.sizes()); + if (self.dim() == 0) { + result.fill_(self); + return result; + } + if (self.numel() == 0) { + result.zero_(); + return result; + } + + TensorArg output_arg{ result, "output", 1 }; + TensorArg input_arg{ self, "input", 2 }; + checkAllSameGPU(__func__, {output_arg, input_arg}); + + auto result_ = contiguous_out_arg(result); + launch_logcumsumexp_zoom_kernel(*result_, self, wrap_dim); + if (!result.is_same(*result_)) { + result.copy_(*result_); + } + return result; +} + +Tensor _logcumsumexp_zoom(const Tensor& self, int64_t dim) { + Tensor result = at::empty_like(self, MemoryFormat::Contiguous); + return _logcumsumexp_out_zoom(self, dim, result); +} + +void cumsum_zoom_kernel(const Tensor& result, const Tensor& self, int64_t dim) { + if (self.is_floating_point() || self.is_complex()) { + // See Note [Writing Nondeterministic Operations] + // Issue reporting nondeterministic behavior: https://github.com/pytorch/pytorch/issues/75240 + globalContext().alertNotDeterministic("cumsum_zoom_kernel"); + } + auto result_ = contiguous_out_arg(result); + launch_cumsum_zoom_kernel(*result_, self, dim); + if (!result.is_same(*result_)) { + result.copy_(*result_); + } +} + +void cumprod_zoom_kernel(const Tensor& result, const Tensor& self, int64_t dim) { + auto result_ = contiguous_out_arg(result); + launch_cumprod_zoom_kernel(*result_, self, dim); + if (!result.is_same(*result_)) { + result.copy_(*result_); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(cumsum_stub, &cumsum_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(cumprod_stub, &cumprod_zoom_kernel); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ScanKernels.h b/aten/src/ATen/native/zoom/ScanKernels.h new file mode 100644 index 00000000000000..f9a6f86f2c6ebe --- /dev/null +++ b/aten/src/ATen/native/zoom/ScanKernels.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; + +namespace native { + +// NOTE: these functions require output tensors to be contiguous +void launch_cummax_zoom_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_cummin_zoom_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_logcumsumexp_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumsum_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumprod_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); + +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ScanUtils.cuh b/aten/src/ATen/native/zoom/ScanUtils.cuh new file mode 100644 index 00000000000000..2ff2970dce4dcd --- /dev/null +++ b/aten/src/ATen/native/zoom/ScanUtils.cuh @@ -0,0 +1,459 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +constexpr inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) { + integer log_num_threads_x = 0; + integer log_num_threads_y = 0; + while (((integer)1 << log_num_threads_x) < row_size) { + ++log_num_threads_x; + } + while (((integer)1 << log_num_threads_y) < num_rows) { + ++log_num_threads_y; + } + // we want to keep the ratio between the x-threads and y-threads about the same as + // the ratio between the row_size and num_rows, but the total number of threads in + // a block should be about 512 + integer diff = log_num_threads_x - log_num_threads_y; + // 9 is from log2(512) + log_num_threads_x = ((integer)9 + diff) / (integer)2; + // I found that in having larger log_num_threads_x can give significant speed up in some cases, + // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it + // similar to the previous implementation + // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block. + log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9); + return log_num_threads_x; +} + +template +__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) { + if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) { + rhs = lhs; + rhs_idx = lhs_idx; + } +} +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + int num_rows, int row_size, + const uint32_t num_threads, const uint32_t log_num_threads_x, + scalar_t init, BinaryFunction binary_op) { + // dynamic memory allocation for vbuf and ibuf + alignas(sizeof(double)) extern __shared__ char buf[]; + scalar_t* vbuf = reinterpret_cast(buf); // the size is num_threads * 2 + int64_t* ibuf = reinterpret_cast(vbuf + num_threads * 2); + const uint32_t num_threads_x = 1 << log_num_threads_x; + scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y; + int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y; + + for (int block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + const scalar_t *row_self = self_ + row * row_size; + scalar_t *row_values = values_ + row * row_size; + int64_t *row_indices = indices_ + row * row_size; + scalar_t block_total = init; + int64_t block_idx_final = 0; + const bool row_exists = row < num_rows; + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + int col1 = block_col + threadIdx.x; + int col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = c10::load(&row_self[col1]); + row_idx_buf[threadIdx.x] = col1; + } else { + row_buf[threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]); + row_idx_buf[num_threads_x + threadIdx.x] = col2; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t s = 1; s <= num_threads_x; s <<= 1) { + if (row_exists) { + uint32_t a = (threadIdx.x / s) * (2 * s) + s; + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size){ + row_values[col1] = row_buf[threadIdx.x]; + row_indices[col1] = row_idx_buf[threadIdx.x]; + } + if (col2 < row_size) { + row_values[col2] = row_buf[num_threads_x + threadIdx.x]; + row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; + } + } + block_total = row_buf[2 * num_threads_x - 1]; + block_idx_final = row_idx_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *self = self_ + orow * row_size * num_irows + irow; + scalar_t *values = values_ + orow * row_size * num_irows + irow; + int64_t *indices = indices_ + orow * row_size * num_irows + irow; + scalar_t out = init; + int64_t out_idx = 0; + + for (auto col = decltype(row_size){0}; col < row_size; ++col) { + const auto val = c10::load(self); + if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) { + out = val; + out_idx = col; + } + *values = out; + *indices = out_idx; + self += num_irows; + values += num_irows; + indices += num_irows; + } + } + } +} + +inline void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + +template +__host__ void scan_outer_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + int dim, scalar_t init, BinaryFunction binary_op) { + int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::zoom::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + tensor_kernel_scan_outer_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +__host__ void scan_innermost_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int row_size = self.size(ndim - 1); + int num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + dim3 grid(std::min(at::zoom::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y)))); + + const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t)); + tensor_kernel_scan_innermost_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_rows, row_size, num_threads, log_num_threads_x, init, binary_op); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) { + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous()); + if (dim == ndim - 1) { + scan_innermost_dim_with_indices(*self_, values, indices, init, binary_op); + } else { + scan_outer_dim_with_indices(*self_, values, indices, dim, init, binary_op); + } +} + +// TODO: The implementation of `tensor_kernel_scan_outer_dim` and +// `tensor_kernel_scan_innermost_dim` is similar to +// `tensor_kernel_scan_outer_dim_with_indices` +// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to +// remove the duplication. + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to scan; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) +{ + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *src = src_ + orow * row_size * num_irows + irow; + scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; + scalar_t acc = init; + + for (uint32_t col = 0; col < row_size; ++col) { + acc = binary_op(acc, c10::load(src)); + *tgt = acc; + + src += num_irows; + tgt += num_irows; + } + } + } +} + +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_, + const uint32_t num_rows, const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, BinaryFunction binary_op){ + const uint32_t num_threads_x = 1 << log_num_threads_x; + for (uint32_t block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + uint32_t row = block_row + threadIdx.y; + T block_total = init; + + const T *row_src = src_ + row * row_size; + T *row_tgt = tgt_ + row * row_size; + const bool row_exists = row < num_rows; + + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_src[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + row_buf[0] = binary_op(row_buf[0], block_total); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t m = 0; m <= log_num_threads_x; ++m) { + if (row_exists) { + uint32_t s = 1 << m; // s = 2 ^ m + uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + row_buf[ti] = binary_op(row_buf[ti], row_buf[si]); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; + } + block_total = row_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template < + typename T, + class BinaryFunction> +__global__ void tensor_kernel_scan_innermost_dim( + T* tgt_, + const T* src_, + const uint32_t num_rows, + const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, + BinaryFunction binary_op) { + alignas(sizeof(double)) extern __shared__ char sbuf[]; + T* sbuf2 = reinterpret_cast(sbuf); + const uint32_t num_threads_x = 1 << log_num_threads_x; + T* row_buf = reinterpret_cast(sbuf2 + num_threads_x * 2 * threadIdx.y); + + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); +} + + +template +__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result, + int dim, scalar_t init, BinaryFunction binary_op) { + const int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::zoom::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_outer_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_innermost_dim(const TensorBase& self, const TensorBase& result, + scalar_t init, BinaryFunction binary_op) { + int64_t ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int64_t row_size = self.size(ndim - 1); + int64_t num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + int64_t maxGridDim = at::zoom::getCurrentDeviceProperties()->maxGridSize[0]; + dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y}))); + + check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_innermost_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_rows, row_size, log_num_threads_x, init, binary_op); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim(const TensorBase& self, const TensorBase& result, + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(result.is_contiguous()); + + if (self.numel() == self.size(dim)) { + zoom::hipcub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } else if (dim == ndim - 1) { + scan_innermost_dim(*self_, result, init, binary_op); + } else { + scan_outer_dim(*self_, result, dim, init, binary_op); + } +} + +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Shape.cu b/aten/src/ATen/native/zoom/Shape.cu new file mode 100644 index 00000000000000..58ed638ee6bf87 --- /dev/null +++ b/aten/src/ATen/native/zoom/Shape.cu @@ -0,0 +1,521 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +constexpr int CAT_ARRAY_BATCH_SIZE = 128; +constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; +constexpr int ALIGNED_VEC_LOAD_BYTES = 16; + +namespace { + +inline bool is_aligned_vec4(const void* ptr) { + auto iptr = reinterpret_cast(ptr); + return !(iptr % alignof(int4)); +} + +inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) { + const int numSM = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + + // X dim of grid for cat array cooperates on a single tensor in the cat. + // Given half of the GPU, full utilization will always occur. + + // This will have cating two tensors fill the entire grid, but prevent + // many threads from needlessly load meta data if their sizes is small. + + grid = dim3( 2LL * numSM, (long long) nTensors ); + + return true; +} + +template +inline std::tuple getCatGridRocm(unsigned int max_elements_per_tensor, + ptrdiff_t nTensors) { + constexpr unsigned int threads_per_block = 256; + constexpr unsigned int elements_per_thread = 8; + constexpr unsigned int max_tb_per_sm = 32; + + unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread); + unsigned int thread_blocks = ceil_div(max_threads, threads_per_block); + + // Limit the number of thread blocks to prevent too many threads to load the metadata + // if they operate on very small tensors. + + const unsigned int num_sm = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + thread_blocks = std::min(num_sm * max_tb_per_sm, thread_blocks); + + dim3 block = dim3(threads_per_block); + dim3 grid = dim3(thread_blocks, (long long)nTensors); + + return std::make_tuple(grid, block); +} + +template +inline std::tuple getCatGridContig(unsigned int max_elements_per_tensor, + ptrdiff_t nTensors) { + constexpr unsigned int threads_per_block = 128; + constexpr unsigned int min_aligned_vec_per_thread = 1; + constexpr unsigned int max_tb_per_sm = 32; + + unsigned int elements_per_thread = ALIGNED_VEC_LOAD_BYTES / sizeof(T) * + min_aligned_vec_per_thread; + unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread); + unsigned int thread_blocks = ceil_div(max_threads, threads_per_block); + + // Limit the number of thread blocks to prevent too many threads to load the metadata + // if they operate on very small tensors. + + const unsigned int num_sm = at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + thread_blocks = std::min(num_sm * max_tb_per_sm, thread_blocks); + + dim3 block = dim3(threads_per_block); + dim3 grid = dim3(thread_blocks, (long long)nTensors); + + return std::make_tuple(grid, block); +} + +// Similar to any other IndexToOffset calculation for copying along a given +// dimension. +template +struct CatArrIndexToOffset { + static inline __device__ IndexType compute( + const IndexType tensorSize[Dims], + const IndexType tensorStride[Dims], + const IndexType dimSize, + const unsigned int concatDim, + IndexType linearIndex) { + // linearIndex is not really linear index, but instead the offset in + // input tensor. If the input tensor is contiguous, then this offset + // is the linear index, but if the input tensor is channels last, then + // it is the linear index of the permuted contiguous tensor + IndexType offset = 0; + + #pragma unroll + for (int i = Dims - 1; i >= 1; --i) { + IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i]; + IndexType nextDimIndex = linearIndex / curDimSize; + IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; + IndexType curDimOffset = curDimIndex * tensorStride[i]; + offset += curDimOffset; + linearIndex = nextDimIndex; + } + + return offset + linearIndex * tensorStride[0]; + } +}; + +template +struct TensorSizeStride { + IndexType tensorSize[MaxDims]; + IndexType tensorStride[MaxDims]; +}; + +/** + * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a + * grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to + * copy each element from each input tensor into the output. + * + * output: base pointer to the storage associated with the output tensor + * inputs: GPU-allocated array of input metadata for each input to concatenate + * in the kernel + * os: the size/stride vectors for the output tensor + * concatDim: dimension along which we are concatenating + * dimStride: the stride of the output tensor at the concatDim + * + * The most important assumption made is that the input tensors are contiguous. + */ + + +// pass meta data directly through kernel argument instead of pin memory +// In contiguous case, we will not need stride_size, setting it as 1 as placeholder +// to pass compile. +template +struct CatArrInputTensorMetadata { + const T* input[n]; + IndexType offset[n]; + IndexType dimSize[n]; + IndexType nElements[n]; + bool isContiguous[n]; + TensorSizeStride tensorStride[stride_size]; +}; + +template +__global__ void CatArrayBatchedCopy( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs.nElements[blockIdx.y]; + TensorSizeStride ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0]; + bool isContig = inputs.isContiguous[blockIdx.y]; + + if(tid >= nElements) return; + + const T* data = inputs.input[blockIdx.y]; + IndexType offset = inputs.offset[blockIdx.y]; + IndexType dimSize = inputs.dimSize[blockIdx.y]; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); + if (isContig) { + output[dataOffset + elementOffset] = data[tid]; + } else { + IndexType inElementOffset = CatArrIndexToOffset::compute( + ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[inElementOffset]; + } + tid += stride; + } +} + +template +__global__ void CatArrayBatchedCopy_contig( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs.nElements[blockIdx.y]; + + if(tid >= nElements) return; + + const T* data = inputs.input[blockIdx.y]; + IndexType offset = inputs.offset[blockIdx.y]; + IndexType dimSize = inputs.dimSize[blockIdx.y]; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + tid += stride; + } +} + +/* + Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads + to improve memory bandwidth throughput. +*/ + +template +__global__ void CatArrayBatchedCopy_aligned16_contig( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) { + + // This kernel tries to use 128 bit loads + constexpr int kILP = ALIGNED_VEC_LOAD_BYTES / sizeof(T); + IndexType inputOffset = (blockIdx.x * blockDim.x + threadIdx.x) * kILP; + IndexType inputStride = gridDim.x * blockDim.x * kILP; + + IndexType nElements = inputs.nElements[blockIdx.y]; + if (inputOffset >= nElements) { + return; + } + + const T* data = inputs.input[blockIdx.y]; + IndexType offset = inputs.offset[blockIdx.y]; + IndexType dimSize = inputs.dimSize[blockIdx.y]; + IndexType dataOffset = offset * dimStride; + + IndexType v_elementOffset[kILP]; + T reg_data[kILP]; + + while (inputOffset + kILP <= nElements) { + for (int i = 0; i < kILP; ++i) { + v_elementOffset[i] = CatArrIndexToOffset::compute(os.tensorSize, + os.tensorStride, dimSize, concatDim, inputOffset + i); + } + + using LT = at::native::memory::aligned_vector; + ((LT*)reg_data)[0] = const_cast((LT*)(data + inputOffset))[0]; + + #pragma unroll + for (int i = 0; i < kILP; ++i) { + output[dataOffset + v_elementOffset[i]] = reg_data[i]; + } + + inputOffset += inputStride; + } + + // Handle remaining tail in case nElements does not divide + // exactly to kILP + + while (inputOffset < nElements) { + v_elementOffset[0] = CatArrIndexToOffset::compute(os.tensorSize, + os.tensorStride, dimSize, concatDim, inputOffset); + output[dataOffset + v_elementOffset[0]] = data[inputOffset]; + inputOffset++; + } +} + +template +void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, int64_t dimension, + int nDims, c10::MemoryFormat memory_format) { + // First, let's set up our kernel parameters. We start with a raw pointer to + // the storage for the output Tensor. + scalar_t *data = (scalar_t *)(out.mutable_data_ptr()); + CatArrInputTensorMetadata catMetaData; + TensorSizeStride outputParam; + + // Next, let's initialize the size, stride arrays for the output Tensor. + if (memory_format == c10::MemoryFormat::Contiguous) { + for (int i = 0; i < nDims; ++i) { + outputParam.tensorSize[i] = out.size(i); + outputParam.tensorStride[i] = out.stride(i); + } + } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { + // permute the semantics of dims from NCHW to NHWC so that the input + // tensor is now contiguous + outputParam.tensorSize[0] = out.size(0); + outputParam.tensorStride[0] = out.stride(0); + for (int i = 1; i < nDims - 1; ++i) { + outputParam.tensorSize[i] = out.size(i + 1); + outputParam.tensorStride[i] = out.stride(i + 1); + } + outputParam.tensorSize[nDims - 1] = out.size(1); + outputParam.tensorStride[nDims - 1] = out.stride(1); + } else { + TORCH_CHECK(false, "unsupported memory format"); + } + + c10::zoom::ZoomStream stream = c10::zoom::getCurrentZoomStream(); + + // If all batches are contiguous we can call a specialized implementation + // which requires the input tensor addresses to be aligned to a + // 16 Byte boundary. + + bool isContig = true; + bool isAligned = true; + unsigned int max_elements_per_tensor = 0; + + // Now we loop + int batchCounter = 0; + int64_t offset = 0; + for (unsigned i = 0; i < inputs.size() ; i += batch_size) { + for (batchCounter = 0; + batchCounter < batch_size && + (i+batchCounter) < inputs.size(); + ++batchCounter) { + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].get().numel() > 0) { + dimSize = inputs[i+batchCounter].get().size(dimension); + } + + catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr()); + catMetaData.offset[batchCounter] = offset; + catMetaData.dimSize[batchCounter] = dimSize; + catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel(); + + // On ROCm, CatArrayBatchedCopy_contig is faster + isAligned = false; + + if (stride_size > 1) { + auto strides = inputs[i+batchCounter].get().strides(); + auto sizes = inputs[i+batchCounter].get().sizes(); + for(int j = 0; j < nDims; j++){ + catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j]; + catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; + } + catMetaData.isContiguous[batchCounter] = false; + isContig = false; + } else { + catMetaData.isContiguous[batchCounter] = true; + } + + // Update offset + offset += dimSize; + + // We need max elements per tensor to compute grid parameters + max_elements_per_tensor = std::max(max_elements_per_tensor, + catMetaData.nElements[batchCounter]); + } + + // Skip if the tensor is empty. Otherwise, the grid dim is invalid + if (max_elements_per_tensor == 0) + continue; + + dim3 applyBlock, catGrid; + + // always base grid size on max_elements_per_tensor + { + std::tuple launchParams = getCatGridRocm( + max_elements_per_tensor, batchCounter); + catGrid = std::get<0>(launchParams); + applyBlock = std::get<1>(launchParams); + } + + if (memory_format != c10::MemoryFormat::Contiguous) { + switch (dimension) { + case 0: + break; + case 1: + dimension = nDims - dimension; + break; + default: + dimension--; + } + } + // Template Declarations for dim = 1, 2, 3, 4 +#define HANDLE_CASE(DIMS) \ + if (isContig && isAligned && sizeof(scalar_t) >= 4 && sizeof(scalar_t) <= 8) {\ + CatArrayBatchedCopy_aligned16_contig<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ + } else if (isContig) {\ + CatArrayBatchedCopy_contig<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ + } else {\ + CatArrayBatchedCopy<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ + }\ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + switch (nDims) { + case 1: + HANDLE_CASE(1); + break; + case 2: + HANDLE_CASE(2); + break; + case 3: + HANDLE_CASE(3); + break; + case 4: + HANDLE_CASE(4); + break; + } +#undef HANDLE_CASE + } +} +// The kernels are templated on an opaque, self-aligned type of the correct +// size to avoid redundant kernels for different types of the same size. +template struct alignas(N) OpaqueType { char data[N]; }; + +} // namespace + +TORCH_IMPL_FUNC(cat_out_zoom) +(const ITensorListRef& tensors, + int64_t dim, + int64_t valid, + bool all_contiguous, + bool all_same_dtype, + bool all_same_sizes_and_stride, + MemoryFormat memory_format, + const Tensor& result) { + if (result.numel() == 0) { + return; + } + + auto materialized = tensors.materialize(); + + // We parallelize the copy if all 6 conditions pass: + // + // 1. There is more than one input tensor + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing + + const bool all32BitIndexable = std::all_of(materialized.begin(), materialized.end(), + [] (const Tensor& t) { + return at::zoom::detail::canUse32BitIndexMath(t); + }); + + int nDims = materialized[valid].get().dim(); + + // We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways + // For contiguous input, we don't need to pass stride meta data to cuda kernel through constant + // memory. Therefore, we could pass more inputs to cuda threads. + // For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation + // of constant memory. + + + + if (materialized.size() > 1 && + result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + at::zoom::detail::canUse32BitIndexMath(result) && + all_contiguous && + all32BitIndexable && + all_same_dtype) { + if (isBitsType(result.scalar_type())) { + AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_zoom", [&]() { + using dtype = OpaqueType; + parallel_cat(result, materialized, dim, nDims, memory_format); + }); + } else { + AT_DISPATCH_V2(result.scalar_type(), "cat_zoom", AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat(result, materialized, dim, nDims, memory_format); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + } + } else if (materialized.size() > 1 && + result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + at::zoom::detail::canUse32BitIndexMath(result) && + nDims <= CAT_ARRAY_MAX_INPUT_DIMS && + all32BitIndexable && + all_same_dtype && + memory_format == c10::MemoryFormat::Contiguous) { + if (isBitsType(result.scalar_type())) { + AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_zoom", [&]() { + using dtype = OpaqueType; + parallel_cat(result, materialized, dim, nDims, memory_format); + }); + } else { + AT_DISPATCH_V2(result.scalar_type(), "cat_zoom", AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat(result, materialized, dim, nDims, memory_format); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + } + } else { + int64_t offset = 0; + for (const Tensor& t : materialized) { + if (cat_should_skip_tensor(t)) continue; + int64_t dimSize = t.size(dim); + Tensor nt = at::narrow(result, dim, offset, dimSize); + copy_(nt, t); + offset += dimSize; + } + } +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/SoftMax.cu b/aten/src/ATen/native/zoom/SoftMax.cu new file mode 100644 index 00000000000000..2101dd42bb2d7a --- /dev/null +++ b/aten/src/ATen/native/zoom/SoftMax.cu @@ -0,0 +1,1272 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +namespace { + +constexpr int ALIGN_BYTES = 16; + +template +struct LogSoftMaxForwardEpilogue { + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) + : max_input(max_input), logsum(::log(sum)) {} + + __device__ __forceinline__ OutT operator()(T input) const { + return static_cast(input - max_input - logsum); +} + + const AccumT max_input; + const AccumT logsum; +}; + +template +struct LogSoftMaxBackwardEpilogue { + __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) + : sum(sum) {} + + __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { + return static_cast(gradOutput - ::exp(static_cast(output)) * sum); + } + + const AccumT sum; +}; + +template +struct SoftMaxForwardEpilogue { + __device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum) + : max_input(max_input) + , sum(sum) {} + + __device__ __forceinline__ OutT operator()(T input) const { + return static_cast(::exp(input - max_input) / sum); + } + + const AccumT max_input; + const AccumT sum; +}; + +template +struct SoftMaxBackwardEpilogue { + __device__ __forceinline__ SoftMaxBackwardEpilogue(AccumT sum) + : sum(sum) {} + + // XXX: gradOutput that we get here is really gradOutput * output + // Look for cmul in SoftMax_updateGradInput + __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { + return static_cast(gradOutput - output * sum); + } + + const AccumT sum; +}; + + + + +//////////////////////////////////////////////////////////////////////////////// +// Spatial kernel (fast with large inner_size and small dim_size) +//////////////////////////////////////////////////////////////////////////////// +// Let's assume that our input has been flattened to have only three dimension: +// outer x dim x inner +// The spatial algorithm tries to parallelize along all of them. +// Within a 2d block threadIdx.y parallelizes over dim slices, and threads that +// share it will speed up reductions over dim (along axis x). +// The 2d grid is used to parallelize inner dimension over y axis and outer over x. +inline dim3 SpatialSoftMax_getGridSize( + dim3 block, uint32_t max_active_blocks, + uint64_t outer_size, uint64_t inner_size) { + // First, tile as many blocks as we can over the y axis + uint32_t inner_blocks = (inner_size + block.y - 1) / block.y; + if (inner_blocks > max_active_blocks) + inner_blocks = max_active_blocks; + // Fill the x axis with as many blocks as we can fit (a little more is ok too) + uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; + if (outer_blocks > outer_size) + outer_blocks = outer_size; + return dim3(outer_blocks, inner_blocks); +} + +const int max_threads = 1024; + +inline dim3 SpatialSoftMax_getBlockSize( + uint64_t dim_size, uint64_t inner_size) { + uint32_t inner_threads = inner_size; + inner_threads = ::min(inner_threads, static_cast(max_threads)); + uint32_t dim_threads = 1; + if (inner_threads <= 64 && dim_size >= 64) { + while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size) + dim_threads *= 2; + dim_threads /= 2; + } + return dim3(dim_threads, inner_threads); +} + + +template +void SpatialSoftMax_getLaunchSizes( + Kernel k, + uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, + dim3& grid, dim3& block, uint32_t& smem_size) { + block = SpatialSoftMax_getBlockSize(dim_size, inner_size); + uint32_t block_threads = block.x * block.y; + smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t); + int max_active_blocks; +#if defined(TORCH_HIP_VERSION) && TORCH_HIP_VERSION < 305 + // HIP function signature is not compatible yet. + uint32_t max_blocks; + hipOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks, + k, block_threads, smem_size); + max_active_blocks = max_blocks; +#else + hipOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + k, block_threads, smem_size); +#endif + max_active_blocks *= at::zoom::getCurrentDeviceProperties()->multiProcessorCount; + grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, inner_size); +} + +inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { + uint64_t block_size = 1; + uint64_t max_block_size = ::min(dim_size / ILP, static_cast(max_threads)); + + // In the vectorized case we want to trade off allowing more of the buffers to be accessed + // in a vectorized way against wanting a larger block size to get better utilisation. + // In general with ILP you can have (ILP-1)/ILP of the buffer accessed vectorised, at the risk + // of having a very small block size. We choose to keep >= 1/2 of the buffer vectorised while + // allowing a larger block size. + if (ILP > 1) { + max_block_size /= 2; + } + + while (block_size < (max_block_size)) block_size *= 2; + // Launch at least a single warp - the kernel assumes that. + block_size = ::max(block_size, static_cast(at::zoom::warp_size())); + return dim3(block_size); +} + +inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) { + uint64_t block_size = 1; + uint64_t max_block_size = ::min(dim_size, static_cast(max_threads)); + + // We need a block size that is a multiple of C10_WARP_SIZE in order + // to perform block size reductions using warp shuffle instructions. + // Since max_threads is also a multiple of C10_WARPS_SIZE we do not + // risk creating a block size larger than the limit. + + if (max_block_size % C10_WARP_SIZE == 0) { + block_size = max_block_size; + } else { + block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE; + } + + return dim3(block_size); +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } + + __device__ __forceinline__ T combine(T a, T b) const { + return a + b; + } + + // Needed to allow warp level reduction as a first step in the + // thread block reduction + __device__ __forceinline__ T warp_shfl_down(T data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } + + __device__ __forceinline__ T combine(T a, T b) const { + return a < b ? b : a; + } + + // Needed to allow warp level reduction as a first step in the + // thread block reduction + __device__ __forceinline__ T warp_shfl_down(T data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +// Note that it's not a complete block-wide reduction. +// Only threads that share threadIdx.y reduce values. +template class ReduceOp> +__forceinline__ __device__ +T spatialBlockReduceX(T *shared, T val) { + ReduceOp r; + shared += threadIdx.y * blockDim.x; + + __syncthreads(); + + shared[threadIdx.x] = val; + + // NOTE: loop starts with __syncthreads() + int offset = blockDim.x / 2; + while (offset > 0) { + __syncthreads(); + if (threadIdx.x < offset) + shared[threadIdx.x] = r(shared[threadIdx.x], shared[threadIdx.x + offset]); + offset /= 2; + } + + __syncthreads(); + + return shared[0]; +} + +template class Epilogue> +__global__ void cunn_SpatialSoftMaxForward( + outscalar_t *output, const scalar_t *input, + index_t outer_size, index_t dim_size, index_t inner_size) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + const index_t outer_stride = inner_size * dim_size; + const index_t dim_stride = inner_size; + + for (index_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { + const index_t outer_offset = outer_index * outer_stride; + for (index_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) { + const index_t data_offset = outer_offset + inner_index; + //////////////////////////////////////////////////////////// + // These two blocks are really equivalent, but specializing on + // blockDim.x == 1 makes the kernel faster when it's unused. + // I didn't want to thread an extra template parameter, and nvcc + // seems to be smart enough to hoist the if outside of the loops. + //////////////////////////////////////////////////////////// + + if (blockDim.x > 1) { + accscalar_t max_input = at::numeric_limits::lowest(); + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) { + const accscalar_t value = static_cast(input[data_offset + d * dim_stride]); + max_input = Max()(max_input, value); + } + max_input = spatialBlockReduceX(sdata,max_input); + + accscalar_t sum = 0; + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) + sum += ::exp(static_cast(input[data_offset + d * dim_stride]) + - max_input); + sum = spatialBlockReduceX(sdata, sum); + + Epilogue epilogue(max_input, sum); + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) + output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); + } else { + accscalar_t max_input = at::numeric_limits::lowest(); + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) { + const accscalar_t value = static_cast(input[data_offset + d * dim_stride]); + max_input = Max()(max_input, value); + } + accscalar_t sum = 0; + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) + sum += ::exp(static_cast(input[data_offset + d * dim_stride]) + - max_input); + Epilogue epilogue(max_input, sum); + for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) + output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); + } + } + } +} + + + +template class Epilogue> +__global__ void cunn_SpatialSoftMaxBackward( + scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput, + uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + const uint32_t outer_stride = inner_size * dim_size; + const uint32_t dim_stride = inner_size; + + for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { + const uint32_t outer_offset = outer_index * outer_stride; + for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) { + const uint32_t data_offset = outer_offset + inner_index; + // See the comment in forward kernel + if (blockDim.x > 1) { + accscalar_t sum = 0; + for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) + sum += gradOutput[data_offset + d * dim_stride]; + sum = spatialBlockReduceX(sdata, sum); + + Epilogue epilogue(sum); + for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { + gradInput[data_offset + d * dim_stride] = + epilogue(gradOutput[data_offset + d * dim_stride], + output[data_offset + d * dim_stride]); + } + } else { + accscalar_t sum = 0; + for (uint32_t d = 0; d < dim_size; d++) + sum += gradOutput[data_offset + d * dim_stride]; + + Epilogue epilogue(sum); + for (uint32_t d = 0; d < dim_size; d++) { + gradInput[data_offset + d * dim_stride] = + epilogue(gradOutput[data_offset + d * dim_stride], + output[data_offset + d * dim_stride]); + } + } + } + } +} + + +//////////////////////////////////////////////////////////////////////////////// +// Regular kernel (fast when dim_size is large; requires inner_size == 1) +//////////////////////////////////////////////////////////////////////////////// + + +template +struct MaxFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { + return ::max(max, (AccumT)v); + } +}; + +template +struct AddFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + v; + } +}; + +template +struct SumExpFloat +{ + __device__ __forceinline__ SumExpFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + ::exp(v - max_k); + } + + const AccumT max_k; +}; + +template class Reduction, typename AccumT> +__device__ __forceinline__ AccumT +blockReduce(AccumT* smem, AccumT val, + const Reduction& r, + AccumT defaultVal) +{ + // To avoid RaW races from chaining blockReduce calls together, we need a sync here + __syncthreads(); + + smem[threadIdx.x] = val; + + __syncthreads(); + + AccumT warpVal = defaultVal; + + // First warp will perform per-warp reductions for the remaining warps + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { +#pragma unroll + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); + } + smem[lane] = warpVal; + } + } + + __syncthreads(); + + // First thread will perform a reduction of the above per-warp reductions + AccumT blockVal = defaultVal; + + if (threadIdx.x == 0) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { + blockVal = r(blockVal, smem[i]); + } + smem[0] = blockVal; + } + + // Sync and broadcast + __syncthreads(); + return smem[0]; +} + +// Performs a thread block reduction with a given functor but uses +// warp shuffles as the first step in the reduction +template class Reduction, typename T> +__device__ __forceinline__ +T blockReduceWarp(T* smem_cache, T value, const Reduction& op, T defaultVal) +{ + T result = zoom_utils::BlockReduce>(value, op, defaultVal, smem_cache); + if (threadIdx.x == 0) { + smem_cache[0] = result; + } + __syncthreads(); + return smem_cache[0]; +} + +template class Reduction, int ILP, typename T, typename AccumT, typename index_t=int> +__device__ __forceinline__ AccumT +ilpReduce(index_t shift, + const T* data, + index_t size, + const Reduction& r, + AccumT defaultVal) +{ + using LoadT = at::native::memory::aligned_vector; + AccumT threadVal = defaultVal; + index_t offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + data -= shift; + size += shift; + if(threadIdx.x >= shift){ + threadVal = r(threadVal, data[offset]); + } + size -= blockDim.x; + data += blockDim.x; + } + index_t last = size % (ILP * blockDim.x); + + T v[ILP]; + LoadT* value = reinterpret_cast(&v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *value = reinterpret_cast(data)[offset]; + + #pragma unroll + for (int j = 0; j < ILP; ++j) { + threadVal = r(threadVal, v[j]); + } + } + + offset = size - last + threadIdx.x; + // Epilogue + for (; offset < size; offset += blockDim.x) + threadVal = r(threadVal, data[offset]); + + return threadVal; +} + +/** + * This will apply the Epilogue with vectorized reads & writes when input & output have the same shift + */ +template class Epilogue> +__device__ __forceinline__ void +WriteFpropResultsVectorized( + int size, + const int shift, + const scalar_t *input, + outscalar_t *output, + Epilogue epilogue) { + using LoadT = at::native::memory::aligned_vector; + using StoreT = at::native::memory::aligned_vector; + + int offset = threadIdx.x; + + // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later + if (shift > 0) { + input -= shift; + output -= shift; + size += shift; + + if (threadIdx.x >= shift) { + output[offset] = epilogue(input[offset]); + } + size -= blockDim.x; + input += blockDim.x; + output += blockDim.x; + } + + const int last = size % (ILP * blockDim.x); + + scalar_t in_v[ILP]; + LoadT* in_value = reinterpret_cast(&in_v); + + outscalar_t out_v[ILP]; + const StoreT* out_value = reinterpret_cast(&out_v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *in_value = reinterpret_cast(input)[offset]; + + #pragma unroll + for (int j = 0; j < ILP; ++j) { + out_v[j] = epilogue(in_v[j]); + } + + reinterpret_cast(output)[offset] = *out_value; + } + + offset = size - last + threadIdx.x; + // handle the tail + for (; offset < size; offset += blockDim.x) { + output[offset] = epilogue(input[offset]); + } +} + +template class Epilogue, typename index_t = int32_t> +__device__ __forceinline__ void +WriteBpropResultsVectorized( + index_t size, + const index_t shift, + scalar_t *gradInput, + const outscalar_t *output, + const outscalar_t *gradOutput, + Epilogue epilogue) { + using gradInputT = at::native::memory::aligned_vector; + using outputT = at::native::memory::aligned_vector; + + index_t offset = threadIdx.x; + + // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later + if (shift > 0) { + gradInput -= shift; + output -= shift; + gradOutput -= shift; + size += shift; + + if (threadIdx.x >= shift) { + gradInput[offset] = epilogue(gradOutput[offset], output[offset]); + } + size -= blockDim.x; + gradInput += blockDim.x; + output += blockDim.x; + gradOutput += blockDim.x; + } + + const index_t last = size % (ILP * blockDim.x); + + scalar_t dX[ILP]; + gradInputT *dX_v = reinterpret_cast(&dX); + + outscalar_t Y[ILP]; + outputT *Y_v = reinterpret_cast(&Y); + + outscalar_t dY[ILP]; + outputT *dY_v = reinterpret_cast(&dY); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *Y_v = reinterpret_cast(output)[offset]; + *dY_v = reinterpret_cast(gradOutput)[offset]; + + #pragma unroll + for (int j = 0; j < ILP; ++j) { + dX[j] = epilogue(dY[j], Y[j]); + } + + reinterpret_cast(gradInput)[offset] = *dX_v; + } + + offset = size - last + threadIdx.x; + for (; offset < size; offset += blockDim.x) { + gradInput[offset] = epilogue(gradOutput[offset], output[offset]); + } +} + +/** + * This will apply the Epilogue with non-vectorized reads & writes for the general case + */ +template class Epilogue> +__device__ __forceinline__ void +WriteFpropResults( + int classes, + const scalar_t *input, + outscalar_t *output, + Epilogue epilogue) { + for (int offset = threadIdx.x; offset < classes; offset += blockDim.x) { + output[offset] = epilogue(input[offset]); + } +} + +template class Epilogue, typename index_t> +__device__ __forceinline__ void +WriteBpropResults( + int classes, + scalar_t *gradInput, + const outscalar_t *output, + const outscalar_t *gradOutput, + Epilogue epilogue) { + + index_t offset = threadIdx.x; + + index_t last = classes % (ILP * blockDim.x); + + for (; offset < classes - last; offset += blockDim.x * ILP) { + outscalar_t tmpOutput[ILP]; + outscalar_t tmpGradOutput[ILP]; + + #pragma unroll + for (int j = 0; j < ILP; ++j) { + tmpOutput[j] = output[offset + j * blockDim.x]; + tmpGradOutput[j] = gradOutput[offset + j * blockDim.x]; + } + + #pragma unroll + for (int j = 0; j < ILP; ++j) { + gradInput[offset + j * blockDim.x] = epilogue(tmpGradOutput[j], tmpOutput[j]); + } + } + + // Remainder - no ILP + for (; offset < classes; offset += blockDim.x) { + gradInput[offset] = epilogue(gradOutput[offset], output[offset]); + } +} + +template class Epilogue> +__global__ void +cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + + // forward pointers to batch[blockIdx.x] + // each block handles a sample in the mini-batch + input += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + + const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); + const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t); + + // find the max + accscalar_t threadMax = ilpReduce( + shift, input, classes, MaxFloat(), -at::numeric_limits::max()); + accscalar_t max_k = blockReduceWarp(sdata, threadMax, + Max(), -at::numeric_limits::max()); + + // reduce all values + accscalar_t threadExp = ilpReduce( + shift, input, classes, SumExpFloat(max_k), static_cast(0)); + accscalar_t sumAll = blockReduceWarp(sdata, threadExp, + Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + if (shift == output_shift) { + WriteFpropResultsVectorized(classes, shift, input, output, epilogue); + } else { + WriteFpropResults(classes, input, output, epilogue); + } +} + +template class Epilogue, typename index_t = int32_t> +__global__ void +cunn_SoftMaxForwardSmem(outscalar_t *output, const scalar_t *input, index_t classes) +{ + // Each thread block processes a sample in the batch + input += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + + accscalar_t threadMax = -at::numeric_limits::max(); + accscalar_t threadExp = static_cast(0); + + // The first smem segment is used to cache input values and the last + // segment is used for thread block reductions + extern __shared__ unsigned char smem[]; + auto smem_input_cache = reinterpret_cast(smem); + auto smem_reduction_cache = reinterpret_cast(smem + + classes * sizeof(scalar_t)); + + using LoadT = at::native::memory::aligned_vector; + const LoadT* const input_vec_ptr = reinterpret_cast(input); + LoadT* const smem_input_cache_vec_ptr = reinterpret_cast(smem_input_cache); + + // Download inputs to shared memory while doing the first step + // in max calculation + MaxFloat maxFunc; + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + smem_input_cache_vec_ptr[offset] = crnt_vec; + + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadMax = maxFunc(threadMax, crnt_vec.val[i]); + } + } + + accscalar_t max_k = blockReduceWarp(smem_reduction_cache, threadMax, + Max(), -at::numeric_limits::max()); + + // Reload input from shared memory to compute the sum. The previous + // reduce has performed a __syncthreads() so the smem contents are populated. + SumExpFloat sumExpFunc(max_k); + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = smem_input_cache_vec_ptr[offset]; + + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadExp = sumExpFunc(threadExp, crnt_vec.val[i]); + } + } + + accscalar_t sumAll = blockReduceWarp(smem_reduction_cache, threadExp, + Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + // Use vectorized stores to save the output + using StoreT = at::native::memory::aligned_vector; + StoreT* output_vec_ptr = reinterpret_cast(output); + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = smem_input_cache_vec_ptr[offset]; + StoreT out_vec; + + #pragma unroll + for (int i = 0; i < ILP; ++i) { + out_vec.val[i] = epilogue(crnt_vec.val[i]); + } + + output_vec_ptr[offset] = out_vec; + } +} + +C10_DEVICE bool inline is_32bit_representable(const int64_t value) { + return value < static_cast(std::numeric_limits::max()); +} + +template class Epilogue> +__global__ void +cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput, int64_t classes) +{ + using LoadT = at::native::memory::aligned_vector; + using StoreT = at::native::memory::aligned_vector; + + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + gradInput += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + gradOutput += static_cast(blockIdx.x) * classes; + + const int64_t shift = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); + const int64_t output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t); + const int64_t grad_output_shift = ((uint64_t)gradOutput) % ALIGN_BYTES / sizeof(outscalar_t); + + const bool can_use_32bit_indexing = is_32bit_representable(shift) && is_32bit_representable(output_shift) && is_32bit_representable(grad_output_shift) && is_32bit_representable(classes); + accscalar_t threadSum; + if (can_use_32bit_indexing) { + threadSum = ilpReduce( + static_cast(grad_output_shift), gradOutput, classes, AddFloat(), accscalar_t(0)); + } else { + threadSum = ilpReduce( + grad_output_shift, gradOutput, classes, AddFloat(), accscalar_t(0)); + } + accscalar_t sum_k = blockReduce( + sdata, threadSum, Add(), accscalar_t(0)); + + Epilogue epilogue(sum_k); + + if (shift == output_shift && shift == grad_output_shift) { + if (can_use_32bit_indexing) { + WriteBpropResultsVectorized(classes, static_cast(shift), gradInput, output, gradOutput, epilogue); + } else { + WriteBpropResultsVectorized(classes, shift, gradInput, output, gradOutput, epilogue); + } + } else { + if (can_use_32bit_indexing) { + WriteBpropResults(classes, gradInput, output, gradOutput, epilogue); + } else { + WriteBpropResults(classes, gradInput, output, gradOutput, epilogue); + } + } +} + +template class Epilogue, bool is_log_softmax> +Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){ + if (half_to_float) { + TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only"); + } + auto input = input_.contiguous(); + static_assert(std::is_same, float>::value, "accscalar_t for half should be float"); + if (input.dim() == 0) input = input.view(1); + int64_t dim = maybe_wrap_dim(dim_, input.dim()); + TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions"); + int64_t outer_size = 1; + int64_t dim_size = input.size(dim); + + if (input.numel() > 0) { + int64_t inner_size = 1; + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + for (int64_t i = 0; i < dim; ++i) + outer_size *= input.size(i); + for (int64_t i = dim + 1; i < input.dim(); ++i) + inner_size *= input.size(i); + // This kernel spawns a block per each element in the batch. + // XXX: it assumes that inner_size == 1 + + if (inner_size == 1) { + dim3 grid(outer_size); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { + using accscalar_t = acc_type; + if (!half_to_float) { + auto output_ptr = output.mutable_data_ptr(); + auto input_ptr = input.const_data_ptr(); + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + int64_t remaining = outer_size; + int64_t chunk_size = (1L << 30L) / dim_size; + while(remaining > 0) { + dispatch_softmax_forward( + output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr/* not masked */); + input_ptr += chunk_size * dim_size; + output_ptr += chunk_size * dim_size; + remaining -= chunk_size; + } + } else { + constexpr int ILP = sizeof(float4) / sizeof(scalar_t); + dim3 block = SoftMaxForward_getBlockSize(dim_size); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + auto max_elements_per_smem = (at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock - + smem_reduction_sz) / sizeof(scalar_t); + + bool can_use_smem = dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); + can_use_smem &= !(dim_size % ILP); + + if (can_use_smem) { + size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; + hipLaunchKernelGGL(( cunn_SoftMaxForwardSmem) + , dim3(grid), dim3(block), smem_sz, stream, output_ptr, input_ptr, dim_size); + } else { + hipLaunchKernelGGL(( cunn_SoftMaxForward) + , dim3(grid), dim3(block), smem_reduction_sz, stream, output_ptr, input_ptr, dim_size); + } + + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } else { + auto output_ptr = output.mutable_data_ptr(); + auto input_ptr = input.const_data_ptr(); + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + int64_t remaining = outer_size; + int64_t chunk_size = (1<<30) / dim_size; + while(remaining > 0) { + dispatch_softmax_forward( + output_ptr, input_ptr, dim_size, dim_size, std::min(remaining, chunk_size), nullptr/* not masked */); + input_ptr += chunk_size * dim_size; + output_ptr += chunk_size * dim_size; + remaining -= chunk_size; + } + } else { + constexpr int ILP = sizeof(float4) / sizeof(scalar_t); + dim3 block = SoftMaxForward_getBlockSize(dim_size); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + auto max_elements_per_smem = (at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock - + smem_reduction_sz) / sizeof(scalar_t); + + bool can_use_smem = dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); + can_use_smem &= !(dim_size % ILP); + + if (can_use_smem) { + size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; + hipLaunchKernelGGL(( cunn_SoftMaxForwardSmem) + , dim3(grid), dim3(block), smem_sz, stream, output_ptr, input_ptr, dim_size); + } else { + hipLaunchKernelGGL(( cunn_SoftMaxForward) + , dim3(grid), dim3(block), smem_reduction_sz, stream, output_ptr, input_ptr, dim_size); + } + + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } + }); + // This kernel runs in a 2D grid, where each application along y dimension has a fixed + // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. + // Reductions over dim are done in a single-threaded manner. + } else { + uint32_t smem_size; + dim3 grid, block; + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { + using accscalar_t = acc_type; + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long, + "host_softmax_launcher", [&] { + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + hipLaunchKernelGGL(( cunn_SpatialSoftMaxForward) + , dim3(grid), dim3(block), smem_size, stream, + output.mutable_data_ptr(), input.const_data_ptr(), outer_size, dim_size, inner_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + hipLaunchKernelGGL(( cunn_SpatialSoftMaxForward) + , dim3(grid), dim3(block), smem_size, stream, + output.mutable_data_ptr(), input.const_data_ptr(), outer_size, dim_size, inner_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + }); + } + } + return output; +} + +template class Epilogue, bool is_log_softmax> +void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float, const Tensor &gI){ + int64_t dim = maybe_wrap_dim(dim_, grad_.dim()); + if (grad_.numel() == 0) { + return; + } + auto grad = grad_.contiguous(); + static_assert(std::is_same, float>::value, "accscalar_t for half should be float"); + if (grad.dim() == 0) grad = grad.view(1); + TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); + auto output = output_.contiguous(); + if (output.dim() == 0) output = output.view(1); + int64_t outer_size = 1; + int64_t dim_size = output.size(dim); + int64_t inner_size = 1; + for (int64_t i = 0; i < dim; ++i) + outer_size *= output.size(i); + for (int64_t i = dim + 1; i < output.dim(); ++i) + inner_size *= output.size(i); +// See descriptions of kernels above. + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + if (inner_size == 1) { + dim3 grid(outer_size); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { + using accscalar_t = acc_type; + if (!half_to_float) { + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + auto gI_ptr = gI.mutable_data_ptr(); + auto grad_ptr = grad.const_data_ptr(); + auto output_ptr = output.const_data_ptr(); + int64_t remaining = outer_size; + int64_t chunk_size = (1<<30) / dim_size; + while(remaining > 0) { + dispatch_softmax_backward( + gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min(remaining, chunk_size)); + gI_ptr += chunk_size * dim_size; + grad_ptr += chunk_size * dim_size; + output_ptr += chunk_size * dim_size; + remaining -= chunk_size; + } + } else { + constexpr int ILP = sizeof(float4) / sizeof(scalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + hipLaunchKernelGGL(( cunn_SoftMaxBackward) + , dim3(grid), dim3(block), block.x * sizeof(accscalar_t), stream, + gI.mutable_data_ptr(), output.const_data_ptr(), grad.const_data_ptr(), dim_size + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } else { + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + auto gI_ptr = gI.mutable_data_ptr(); + auto grad_ptr = grad.const_data_ptr(); + auto output_ptr = output.const_data_ptr(); + int64_t remaining = outer_size; + int64_t chunk_size = (1<<30) / dim_size; + while(remaining > 0) { + dispatch_softmax_backward( + gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min(remaining, chunk_size)); + gI_ptr += chunk_size * dim_size; + grad_ptr += chunk_size * dim_size; + output_ptr += chunk_size * dim_size; + remaining -= chunk_size; + } + } else { + constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + hipLaunchKernelGGL(( cunn_SoftMaxBackward) + , dim3(grid), dim3(block), block.x * sizeof(accscalar_t), stream, + gI.mutable_data_ptr(), output.const_data_ptr(), grad.const_data_ptr(), dim_size + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } + }); + } else { + uint32_t smem_size; + dim3 grid, block; + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { + using accscalar_t = acc_type; + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + hipLaunchKernelGGL(( cunn_SpatialSoftMaxBackward) + , dim3(grid), dim3(block), smem_size, stream, + gI.mutable_data_ptr(), output.const_data_ptr(), grad.const_data_ptr(), + outer_size, dim_size, inner_size + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + hipLaunchKernelGGL(( cunn_SpatialSoftMaxBackward) + , dim3(grid), dim3(block), smem_size, stream, + gI.mutable_data_ptr(), output.const_data_ptr(), grad.const_data_ptr(), + outer_size, dim_size, inner_size + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + } +} +} + +TORCH_IMPL_FUNC(log_softmax_zoom_out) ( + const Tensor &input, + const int64_t dim, + const bool half_to_float, + const Tensor &output) { + host_softmax(input, dim, half_to_float, output); +} + +TORCH_IMPL_FUNC(log_softmax_backward_zoom_out) ( + const Tensor& grad, + const Tensor& output, + int64_t dim, + ScalarType input_dtype, + const Tensor& grad_input) { + bool half_to_float = grad.scalar_type() != input_dtype; + if (half_to_float) { + TORCH_CHECK( + (grad.scalar_type() == ScalarType::Float && + input_dtype == ScalarType::Half), + "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); + } + host_softmax_backward(grad, output, dim, half_to_float, grad_input); +} + +TORCH_IMPL_FUNC(softmax_zoom_out) ( + const Tensor &input, + const int64_t dim, + const bool half_to_float, + const Tensor &output) { + host_softmax(input, dim, half_to_float, output); +} + +TORCH_IMPL_FUNC(softmax_backward_zoom_out) +(const Tensor& grad, + const Tensor& output, + int64_t dim, + ScalarType input_dtype, + const Tensor& grad_input) { + bool half_to_float = grad.scalar_type() != input_dtype; + if (half_to_float) { + TORCH_CHECK( + (grad.scalar_type() == ScalarType::Float && + input_dtype == ScalarType::Half), + "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); + } + Tensor tmp = grad * output; + host_softmax_backward(tmp, output, dim, half_to_float, grad_input); +} + +Tensor masked_softmax_zoom(const Tensor& input_, const Tensor& mask_, const std::optional dim_, const c10::optional mask_type_) { + Tensor output = at::empty_like(input_, input_.options()); + TORCH_CHECK(mask_.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor"); + + TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined"); + int64_t mask_type = mask_type_.value(); + TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)"); + + // If input is [B, H, T, T] and mask is [B, T] + // we have special fast kernel + // mask_type == 1 => mask_ is a src_key_padding_mask + bool is_BxT_mask = (mask_type == 1) && (input_.dim() == 4 && mask_.dim() == 2 && input_.size(0) == mask_.size(0) && input_.size(2) == mask_.size(1) && input_.size(3) == mask_.size(1)); + + // If input is [B, H, T, T] and mask is [T, T] + // expand mask to [B, H, T, T] and treat it like regular mask + // TODO We should have special fast kernel for TxT mask as well + // mask_type == 0 => mask_ is a src_mask + bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1); + // If mask_type == 2, then mask_.sizes() must equal input_.sizes() + TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes()); + + auto input = input_.dim() == 0 ? input_.view(1) : input_; + auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_; + if (is_TxT_mask) { + mask = mask.expand(input.sizes()); + } + int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1; + + int softmax_elements = input.size(dim); + // Persistent softmax is only supported when all of the conditions are held: + // 1) softmax_elements <= 1024 + // 2) softmax_elements * input.element_size() <= 4096 + // 3) mask.is_contiguous() + // 4) dim == input.dim() - 1 + // Otherwise, we fallback to vanilla softmax (where we do not support transformer_mask since converting the mask is expensive) + if (softmax_elements > 1024 || softmax_elements * input.element_size() > 4096 || !mask.is_contiguous() || dim < input.dim()-1) { + if (is_BxT_mask) { + mask = mask.view({mask_.size(0), 1, 1, mask_.size(1)}).expand(input.sizes()); + } + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + input.scalar_type(), + "masked_softmax", + [&] { + output = at::softmax(input.masked_fill(mask, -std::numeric_limits::infinity()), dim); + }); + return output; + } + int batch_count = input.numel() / softmax_elements; + int chunk_size = input.numel() / input.size(0); + if (is_BxT_mask) { + // Only support when num_heads is even in transformer + TORCH_CHECK(input.size(1) % 2 == 0, "Only support when num_heads is even in transformer"); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + input.scalar_type(), + "masked_softmax", + [&] { + using accscalar_t = acc_type; + dispatch_softmax_forward( + output.mutable_data_ptr(), // dst + input.const_data_ptr(), // src + softmax_elements, + softmax_elements, + batch_count, + mask.const_data_ptr(), + chunk_size, + true // is_transformer_mask + ); + }); + + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + input.scalar_type(), + "masked_softmax", + [&] { + using accscalar_t = acc_type; + dispatch_softmax_forward( + output.mutable_data_ptr(), // dst + input.const_data_ptr(), // src + softmax_elements, + softmax_elements, + batch_count, + mask.const_data_ptr() + ); + }); + } + return output; +} + +Tensor masked_softmax_backward_zoom( + const Tensor& grad_, + const Tensor& output_, + const Tensor& mask_, + const std::optional dim_) { + Tensor grad_input = at::empty_like(grad_, grad_.options()); + if (grad_.numel() == 0) { + return grad_input; + } + + auto grad = grad_.contiguous(); + auto output = output_.contiguous(); + auto mask = mask_.contiguous(); + int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) : output.dim() - 1; + + grad = grad.dim() == 0 ? grad.view(1) : grad; + mask = mask.dim() == 0 ? mask.view(1) : mask; + output = output.dim() == 0 ? output.view(1) : output; + + TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); + TORCH_CHECK(grad.sizes() == mask.sizes(), "Mask shape should match grad shape"); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor"); + + int softmax_elements = output.size(dim); + int64_t batch_count = grad.numel() / softmax_elements; + + if (softmax_elements > 1024 || softmax_elements * grad.element_size() > 4096 || dim < grad.dim()-1) { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + grad_input.scalar_type(), + "masked_softmax_backward", + [&] { + grad_input = at::_softmax_backward_data( + grad, + output.masked_fill(mask, 0), + dim, + grad.scalar_type() + ); + }); + } else { + grad = grad * output; + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + grad_input.scalar_type(), + "masked_softmax_backward", + [&] { + using accscalar_t = acc_type; + dispatch_softmax_backward( + grad_input.mutable_data_ptr(), // gI_ptr + grad.const_data_ptr(), // grad_ptr + output.const_data_ptr(), // output_ptr + softmax_elements, // softmax_elements + softmax_elements, // softmax_elements_stride + batch_count, // batch_count + mask.const_data_ptr() /* not masked */ + ); + }); + } + return grad_input; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Sort.cpp b/aten/src/ATen/native/zoom/Sort.cpp new file mode 100644 index 00000000000000..5f34f230c0edf2 --- /dev/null +++ b/aten/src/ATen/native/zoom/Sort.cpp @@ -0,0 +1,128 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +#include + +namespace at::native { + +std::vector infer_dense_strides_dim_last(const Tensor & self, int64_t dim); + +void fillSliceWithIndex(const Tensor& t, int dim) { + if (t.numel()) { + auto sizes = DimVector(t.dim(), 1); + sizes[dim] = t.sizes()[dim]; + auto range = at::arange(t.sizes()[dim], t.options()); + auto rangeview = range.view(sizes); + t.copy_(rangeview); + } +} + +// We perform a segmented sort in cub with inputs that have +// more than 1024/2048 elements along the selected dimension. +// Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace). +void sort_zoom_kernel( + const TensorBase& self_base, + const TensorBase& values_base, + const TensorBase& indices_base, + int64_t dim, + bool descending, + bool stable) { + // this algorithm is always stable + + // Macro for converting `TensorBase` -> `Tensor` without + // reference count bumps. +#define TOTENSOR(BASE, VAR) \ + OptionalTensorRef opt_##BASE(BASE); \ + const Tensor& VAR = *opt_##BASE; + + // Converting TensorBase into Tensor. + // We will need Tensor's methods from this point onwards. + TOTENSOR(self_base, self); + TOTENSOR(values_base, values); + TOTENSOR(indices_base, indices); + + TORCH_CHECK(self.sizes()[dim] <= std::numeric_limits::max(), + "The dimension being sorted can not have more than INT_MAX elements."); + + const auto self_dtype = self.dtype(); + // FIXME: remove this check once cub sort supports bool + TORCH_CHECK(self_dtype != ScalarType::Bool, + "Sort currently does not support bool dtype on Zoom."); + TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble, + "Sort currently does not support complex dtypes on Zoom."); + + // use inplace algorithm for smaller input sizes without stable=True + if (should_use_small_sort(self, dim)) { + // from thc: sorted->values, indices->indices, input->self + fillSliceWithIndex(indices, dim); + + // We sort k/v pairs in-place; copy unsorted input to output + values.copy_(self); + + // Sort using our in-place k/v kernel that supports arbitrary + // layout + sortKeyValueInplace(values, indices, dim, descending, stable); + return; + } + + Tensor self_; + bool newself = false; + if (self.is_non_overlapping_and_dense() && self.stride(dim) == 1) { + self_ = self; + } else { + auto new_strides_unsort = infer_dense_strides_dim_last(self, dim); + self_ = at::empty_strided(self.sizes(), new_strides_unsort, self.options()); + self_.copy_(self); + newself = true; + } + + c10::MaybeOwned values_tmp, indices_tmp; + if (values.strides() == self_.strides() && (newself || get_overlap_status(self, values) == MemOverlapStatus::No)) { + values_tmp = c10::MaybeOwned::borrowed(values); + } else { + values_tmp = c10::MaybeOwned::owned( + at::empty_strided(self_.sizes(), self_.strides(), self_.options())); + } + + if (indices.strides() != self_.strides()) { + indices_tmp = c10::MaybeOwned::owned( + at::empty_strided(self_.sizes(), self_.strides(), self_.options().dtype(kLong))); + } else { + indices_tmp = c10::MaybeOwned::borrowed(indices); + } + + launch_stable_sort_kernel(self_, dim, descending, *values_tmp, *indices_tmp); + + if (!values_tmp->is_same(values)) { + values.copy_(*values_tmp); + } + if (!indices_tmp->is_same(indices)) { + indices.copy_(*indices_tmp); + } +} + +// TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH, +// since REGISTER_PRIVATEUSE1_DISPATCH won't work in this cpp file. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_PRIVATEUSE1_DISPATCH(sort_stub, &sort_zoom_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Sort.cu b/aten/src/ATen/native/zoom/Sort.cu new file mode 100644 index 00000000000000..466c705ced9b5c --- /dev/null +++ b/aten/src/ATen/native/zoom/Sort.cu @@ -0,0 +1,384 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::native { + +template +static int minimum_grid_for_occupancy(T kernel, int max_block_size) { + int minGridSize = 0; + int blockSize; + C10_ZOOM_CHECK(hipOccupancyMaxPotentialBlockSize( + &minGridSize, + &blockSize, + kernel, + /*dynamicSMemSize=*/0, + max_block_size)); + return minGridSize; +} + +template +constexpr bool has_nan() { + if constexpr (std::numeric_limits::is_specialized) { + return std::numeric_limits::has_quiet_NaN; + } else if constexpr ( + c10::is_complex::value || + std::is_same_v || + std::is_same_v) { + return true; + } +} + +// For very small unstable sorts (n <= 32), use bitonicSortKVInPlace +// which can sort multiple arrays within the same block of threads, +// improving occupancy. +struct SmallBitonicSort { + template + void sort( + at::zoom::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + constexpr int sort_size = 32; + constexpr int max_block_y = 16; + constexpr int items_per_thread = 2; + static_assert(sort_size % items_per_thread == 0, ""); + constexpr int block_x = sort_size / items_per_thread; + + TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size); + + // Scale batch size down if the grid would be too small + const auto min_grid = minimum_grid_for_occupancy( + bitonicSortKVInPlace< + A, -1, block_x, max_block_y, + K, V, LTOp, IndexType>, + block_x * max_block_y); + const auto max_batch = ::max(IndexType{1}, keySlices / min_grid); + const int block_y = ::min(IndexType(max_block_y), max_batch); + dim3 block(block_x, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), + "Too many slices to sort"); + const auto stream = c10::zoom::getCurrentZoomStream(); + + if (descending) { + hipLaunchKernelGGL(( bitonicSortKVInPlace) + , dim3(grid), dim3(block), 0, stream, + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + hipLaunchKernelGGL(( bitonicSortKVInPlace) + , dim3(grid), dim3(block), 0, stream, + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } +}; + +#if HAS_WARP_MERGE_SORT() + +// For small sorts (n <= 128) we use warpMergeSortKVInPlace which +// sorts one slice per warp and potentially multiple slices in the +// same block for improved occupancy with large batch sizes. +template +struct WarpMergeSort { + + template + void sort( + at::zoom::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + constexpr int max_block_y = 16; + const int block_x = at::zoom::warp_size(); + + TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size); + + // Scale batch size down if the grid would be too small + const auto min_grid = minimum_grid_for_occupancy( + warpMergeSortKVInPlace< + A, -1, sort_size, max_block_y, + K, V, LTOp, IndexType>, + block_x * max_block_y); + const auto max_batch = ::max(IndexType{1}, keySlices / min_grid); + const int block_y = ::min(IndexType(max_block_y), max_batch); + dim3 block(block_x, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), + "Too many slices to sort"); + const auto stream = c10::zoom::getCurrentZoomStream(); + + if (descending) { + const K invalid_key = at::numeric_limits::lower_bound(); + hipLaunchKernelGGL(( warpMergeSortKVInPlace) + , dim3(grid), dim3(block), 0, stream, + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp(), + invalid_key); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + const K invalid_key = []{ + // NAN is sorted after inf + if constexpr(has_nan()) { + return K(NAN); + } + return at::numeric_limits::upper_bound(); + }(); + hipLaunchKernelGGL(( warpMergeSortKVInPlace) + , dim3(grid), dim3(block), 0, stream, + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp(), + invalid_key); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } +}; + +#endif // !HAS_WARP_MERGE_SORT() + +// For medium sizes (128 < n <= 4096) use radixSortKVInplace. +struct MediumRadixSort { + + template + void sort( + at::zoom::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + +#define HANDLE_CASE(SIZE, ITEMS_PER_THREAD) \ + fixed_size_sort( \ + keyInfo, \ + keySlices, \ + keySliceSize, \ + keySliceStride, \ + valueInfo, \ + valueSliceStride, \ + descending) + + int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize); + TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 4096); + switch (ceilPowerOf2) { + case 4096: + HANDLE_CASE(4096, 32); + break; + case 2048: + HANDLE_CASE(2048, 32); + break; + case 1024: + case 512: + case 256: + HANDLE_CASE(1024, 32); + break; + case 128: + case 64: +#if !HAS_WARP_MERGE_SORT() + HANDLE_CASE(128, 4); + break; +#endif + case 32: + case 16: + case 8: + case 4: + case 2: +#if HAS_WARP_MERGE_SORT() + TORCH_INTERNAL_ASSERT( + false, "Expected size <= 128 to be handled by a different algorithm"); +#else + HANDLE_CASE(32, 2); +#endif + break; + case 1: + /* Nothing to do, data already sorted */ + break; + default: + TORCH_INTERNAL_ASSERT(false); + } +#undef HANDLE_CASE + + } + + template + void fixed_size_sort( + at::zoom::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + static_assert(sort_size % items_per_thread == 0, ""); + constexpr int block = sort_size / items_per_thread; + dim3 grid; + TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid), + "Too many slices to sort"); + + const auto stream = c10::zoom::getCurrentZoomStream(); + hipLaunchKernelGGL(( radixSortKVInPlace) + , dim3(grid), dim3(block), 0, stream, + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + descending); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +}; + +template +void sortCommon(Sorter sorter, const TensorBase &key, const TensorBase &value, + int dim, bool descending) { + TORCH_CHECK(key.sizes() == value.sizes(), + "Key tensor must have same size as value tensor"); + int dims = value.dim(); + TORCH_CHECK(dims <= MAX_DIMS, "value tensor has too many dimensions"); + // if key and value tensors have the same size, we do not need to check both + + ptrdiff_t inElements = key.numel(); + + if (inElements == 0) { + return; + } + + int64_t keySliceSize = key.size(dim); + ptrdiff_t keySlices = inElements / keySliceSize; + +#define HANDLE_SORT_CASE(TYPE, A) \ + sorter.template sort( \ + keyInfo, \ + (TYPE) keySlices, \ + (TYPE) keySliceSize, \ + (TYPE) keyInfo.strides[collapseKeyDim], \ + valueInfo, \ + (TYPE) valueInfo.strides[collapseValueDim], \ + descending) + + // The constructed key/value tensor info is used to select the slice + // we are sorting on a per-block basis + // The constructed key/value tensor info is used to select the slice + // we are sorting on a per-block basis + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, key.scalar_type(), "sortKeyValueInplace", [&] { + if (at::zoom::detail::canUse32BitIndexMath(key)) { + at::zoom::detail::TensorInfo keyInfo = + at::zoom::detail::getTensorInfo(key); + at::zoom::detail::TensorInfo valueInfo = + at::zoom::detail::getTensorInfo(value); + + auto strideKey = keyInfo.strides[dim]; + keyInfo.sizes[dim] = 1; + int collapseKeyDim = keyInfo.collapseDims(dim); + keyInfo.strides[collapseKeyDim] = strideKey; + auto strideValue = valueInfo.strides[dim]; + valueInfo.sizes[dim]=1; + int collapseValueDim = valueInfo.collapseDims(dim); + valueInfo.strides[collapseValueDim] = strideValue; + + if (keyInfo.isContiguous()) { + HANDLE_SORT_CASE(unsigned int, -2); + } else { + switch (keyInfo.dims) { + case 2: + HANDLE_SORT_CASE(unsigned int, 2); + break; + default: + HANDLE_SORT_CASE(unsigned int, -1); + break; + } + } + + } else { + at::zoom::detail::TensorInfo keyInfo = + at::zoom::detail::getTensorInfo(key); + at::zoom::detail::TensorInfo valueInfo = + at::zoom::detail::getTensorInfo(value); + + auto strideKey = keyInfo.strides[dim]; + keyInfo.sizes[dim] = 1; + int collapseKeyDim = keyInfo.collapseDims(dim); + keyInfo.strides[collapseKeyDim] = strideKey; + auto strideValue = valueInfo.strides[dim]; + valueInfo.sizes[dim]=1; + int collapseValueDim = valueInfo.collapseDims(dim); + valueInfo.strides[collapseValueDim] = strideValue; + + // int64_t case is rare, just instantiate the generic version + HANDLE_SORT_CASE(uint64_t, -1); + } + }); +#undef HANDLE_SORT_CASE +} + +void sortKeyValueInplace( + const TensorBase& key, + const TensorBase& value, + int dim, + bool descending, + bool stable) { + const auto sort_size = key.size(dim); + if (sort_size <= 1) { + return; // Already sorted + } else if (!stable && sort_size <= 32) { + // NOTE: Bitonic sort is unstable + sortCommon(SmallBitonicSort{}, key, value, dim, descending); +#if HAS_WARP_MERGE_SORT() + } else if (sort_size <= 128) { + sortCommon(WarpMergeSort<128>{}, key, value, dim, descending); +#endif + } else { + sortCommon(MediumRadixSort{}, key, value, dim, descending); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Sort.h b/aten/src/ATen/native/zoom/Sort.h new file mode 100644 index 00000000000000..77f33a5b8d7634 --- /dev/null +++ b/aten/src/ATen/native/zoom/Sort.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include +#include + +namespace at { +namespace native { + +inline bool should_use_small_sort(const TensorBase &self, int64_t dim) { + return self.size(dim) <= 4096; +} + +void sortKeyValueInplace( + const TensorBase &key, const TensorBase &value, int dim, + bool descending, bool stable=false); + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/SortImpl.cu b/aten/src/ATen/native/zoom/SortImpl.cu new file mode 100644 index 00000000000000..5d779d0fd15ce5 --- /dev/null +++ b/aten/src/ATen/native/zoom/SortImpl.cu @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +namespace at::native { + +std::vector infer_dense_strides_dim_last(const Tensor & self, int64_t dim) { + int64_t ndim = self.dim(); + // sort the strides in descending order according to its value, + // keeping dim the last. + std::vector strides = self.strides().vec(); + strides[dim] = -1; + std::vector original_dim(ndim); + for (int64_t i = 0; i < ndim; i++) { + original_dim[i] = i; + } + thrust::stable_sort_by_key( + thrust::host, strides.data(), strides.data() + ndim, original_dim.data(), + thrust::greater() + ); + // generate contiguous strides on permuted dims + std::vector new_strides(ndim); + std::vector new_strides_unsort(ndim); + int64_t cumprod = 1; + for (int64_t i = 0; i < ndim; i++) { + new_strides[ndim - 1 - i] = cumprod; + cumprod *= self.sizes()[original_dim[ndim - 1 - i]]; + } + // unsort new strides + for (int64_t i = 0; i < ndim; i++) { + new_strides_unsort[original_dim[i]] = new_strides[i]; + } + return new_strides_unsort; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/SortStable.cu b/aten/src/ATen/native/zoom/SortStable.cu new file mode 100644 index 00000000000000..62df3c4379e8af --- /dev/null +++ b/aten/src/ATen/native/zoom/SortStable.cu @@ -0,0 +1,286 @@ +// !!! This is a file automatically generated by hipify!!! +#include + +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::native { + +namespace { + +struct offset_t { + int stride; + int begin; + __device__ int operator[](int i) { + return stride * (begin + i); + } +}; +// Segmented sort by full sort algorithm:. +// Say we are sorting a (2, 3) tensor. We have in flattened form: +// values 0.4 1.2 5.3 6.2 1.3 2.3 +// indices 0 1 2 0 1 2 +// segment_id 0 0 0 1 1 1 + +// First we sort by values, globally: +// values 6.2 5.3 2.3 1.2 1.3 0.4 +// indices 0 2 2 1 1 0 +// segment_id 1 0 1 0 1 0 + +// Then we stable sort by segment id: +// values 5.3 1.2 0.4 6.2 2.3 1.3 +// indices 2 1 0 0 2 1 +// segment_id 0 0 0 1 1 1 + +// This method can only work if the slice we are sorting (`dim`) is +// innermost, and both values and indices are contiguous. We do this +// by re-arranging the input into this form as needed, which will +// unfortunately allocate memory if the request is not in this form. +// Vectorized sort is slower than iterated sort if the number of +// slices is small (since we're sorting twice, instead of invoking a +// smaller sort `numSlices` times), but the cub sort +// implementation here is a catch-all, so we're not looking for +// efficiency, but instead correctness. + +template +__global__ void sort_postprocess_kernel( + const scalar_t* in, + scalar_t* out, + int64_t* index, + const int2* i_s_ptr, + int nsegments, + int nsort) { + HIP_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; + int j = i % nsort; + + int offset = segment * nsort; + const scalar_t* in_ = in + offset; + scalar_t* out_ = out + offset; + int64_t* index_ = index + offset; + const int2* i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +C10_LAUNCH_BOUNDS_1(at::zoom::detail::HIP_NUM_THREADS) +__global__ void fill_index_and_segment_kernel( + int2* data, + int numel, + at::zoom::detail::IntDivider nsort_divider) { + HIP_KERNEL_LOOP(idx, numel) { + auto div_mod = nsort_divider.divmod(idx); + auto segment = static_cast(div_mod.div); + auto sort = static_cast(div_mod.mod); + data[idx] = int2{segment, sort}; + } +} + +C10_LAUNCH_BOUNDS_1(at::zoom::detail::HIP_NUM_THREADS) +__global__ void fill_reverse_indices_kernel( + int64_t* data, + int numel, + at::zoom::detail::IntDivider nsort_divider) { + HIP_KERNEL_LOOP(idx, numel) { + data[idx] = nsort_divider.mod(idx); + } +} + +template +inline void segmented_sort_large_segments( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + using namespace at::zoom::detail; + auto allocator = at::zoom::getZoomDeviceAllocator(); + auto stream = c10::zoom::getCurrentZoomStream(); + dim3 block = HIP_NUM_THREADS; + dim3 grid = GET_BLOCKS(nsort); + c10::DeviceArray indices(*allocator, nsort); + at::zoom::detail::IntDivider nsort_divider(nsort); + hipLaunchKernelGGL(( fill_reverse_indices_kernel), dim3(grid), dim3(block), 0, stream, + indices.get(), nsort, nsort_divider); + const int64_t* initial_indices = indices.get(); + + for (auto i : c10::irange(nsegments)) { + at::zoom::hipcub::radix_sort_pairs( + self_ptr, values_ptr, initial_indices, indices_ptr, nsort, descending); + indices_ptr += nsort; + self_ptr += nsort; + values_ptr += nsort; + } +} + +template +inline void segmented_sort_pairs_by_full_sort( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* const self_ptr, + scalar_t* const values_ptr, + int64_t* const indices_ptr) { + int64_t segment_bits = std::max( + 1L, static_cast(::ceil(std::log2(nsegments)))); + + const auto numel = nsort * nsegments; + auto zoom_allocator = at::zoom::getZoomDeviceAllocator(); + auto indices_and_segment = zoom_allocator->allocate(numel * sizeof(int2)); + auto i_s_ptr = static_cast(indices_and_segment.get()); + + using namespace at::zoom::detail; + dim3 block = HIP_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::zoom::getCurrentZoomStream(); + at::zoom::detail::IntDivider nsort_divider(nsort); + hipLaunchKernelGGL(( fill_index_and_segment_kernel), dim3(grid), dim3(block), 0, stream, + i_s_ptr, numel, nsort_divider); + + auto indices_and_segment2 = + zoom_allocator->allocate(nsegments * nsort * sizeof(int2)); + auto i_s_ptr2 = static_cast(indices_and_segment2.get()); + + at::zoom::hipcub::radix_sort_pairs( + self_ptr, nullptr, i_s_ptr, i_s_ptr2, n, descending); + + TORCH_INTERNAL_ASSERT(segment_bits <= 32); + + // sort on lower 32bits, i.e. segment index + at::zoom::hipcub::radix_sort_keys( + reinterpret_cast(i_s_ptr2), + reinterpret_cast(i_s_ptr), + n, + false, + 0, + segment_bits); + + hipLaunchKernelGGL(( sort_postprocess_kernel), + dim3((n + 511) / 512), + dim3(512), + 0, + c10::zoom::getCurrentZoomStream(), + self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); +} + +template +void segmented_sort_pairs( + int64_t nsegments, + int64_t nsort, + int64_t n, + bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + const auto numel = nsort * nsegments; + auto zoom_allocator = at::zoom::getZoomDeviceAllocator(); + auto reverse_indices = zoom_allocator->allocate(numel * sizeof(int64_t)); + int64_t* reverse_indices_ptr = static_cast(reverse_indices.get()); + + using namespace at::zoom::detail; + dim3 block = HIP_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::zoom::getCurrentZoomStream(); + at::zoom::detail::IntDivider nsort_divider(nsort); + hipLaunchKernelGGL(( fill_reverse_indices_kernel), dim3(grid), dim3(block), 0, stream, + reverse_indices_ptr, numel, nsort_divider); + + at::zoom::hipcub::segmented_sort_pairs( + self_ptr, + values_ptr, + reverse_indices_ptr, + indices_ptr, + n, + nsegments, + offset_t{(int)nsort, 0}, + offset_t{(int)nsort, 1}, + descending); +} + +} // namespace + +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices) { + const auto numel = self.numel(); + if (numel == 0) { + return; + } + + int64_t numel_or_intmax = + ::min(numel, static_cast(std::numeric_limits::max())); + int64_t nsort = self.size(dim); + int64_t nbatch = (numel_or_intmax / nsort) * nsort; + TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); + int64_t* indices_ptr = indices.mutable_data_ptr(); + + AT_DISPATCH_ALL_TYPES_AND3( + kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] { + const scalar_t* self_ptr = self.const_data_ptr(); + scalar_t* values_ptr = values.mutable_data_ptr(); + int64_t remaining = numel; + while (remaining > 0) { + int64_t n = ::min(remaining, nbatch); + int64_t nsegments = n / nsort; + + if (nsegments == 1 || + nsort >= 1000000) { // rough heuristics where even a single + // sort occupies GPU + segmented_sort_large_segments( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else if (nsegments < 128) { + segmented_sort_pairs_by_full_sort( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else { + segmented_sort_pairs( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } + + remaining -= n; + self_ptr += n; + values_ptr += n; + indices_ptr += n; + } + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/SortStable.h b/aten/src/ATen/native/zoom/SortStable.h new file mode 100644 index 00000000000000..039c4307c522c9 --- /dev/null +++ b/aten/src/ATen/native/zoom/SortStable.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/zoom/SortUtils.cuh b/aten/src/ATen/native/zoom/SortUtils.cuh new file mode 100644 index 00000000000000..95197f75cedba0 --- /dev/null +++ b/aten/src/ATen/native/zoom/SortUtils.cuh @@ -0,0 +1,333 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HAS_WARP_MERGE_SORT() (TORCH_HIP_VERSION >= 110600) + + +namespace at { namespace native { + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, + K& kB, V& vB, bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, + const Comparator& comp) { + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + __syncthreads(); + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + flag, comp); + } + } + + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + __syncthreads(); + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + false, comp); + } + + __syncthreads(); + +} + +// at::zoom::detail::TensorInfo version +// Sorts (key, value) pairs (in different tensors) in-place; i.e., +// modifies the input `keys` and `values` +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) +__global__ void +bitonicSortKVInPlace(at::zoom::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { + return; + } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, values); + + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; + } + } +} + +#if HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::zoom::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_zoom_detail::cub); + + ZOOM_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + ZOOM_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +#endif // HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::zoom::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::zoom::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::zoom::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_zoom_detail::cub); + + using key_t = typename at::zoom::hipcub::detail::hip_type::type; + using LoadKeys = hipcub::BlockLoad; + using LoadValues = hipcub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = hipcub::BlockStore; + using StoreValues = hipcub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +}} // at::native diff --git a/aten/src/ATen/native/zoom/Sorting.cpp b/aten/src/ATen/native/zoom/Sorting.cpp new file mode 100644 index 00000000000000..405184c65a32f3 --- /dev/null +++ b/aten/src/ATen/native/zoom/Sorting.cpp @@ -0,0 +1,208 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { +namespace { + +std::tuple kthvalue_out_impl_zoom( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim()); + int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim); + zero_numel_check_dims(self, dim, "kthvalue()"); + + TORCH_CHECK(k >= 1 && k <= slicesize, + "kthvalue(): selected number k out of range for dimension ", dim); + + at::assert_no_overlap(self, values); + + _reduction_with_indices_allocate_or_resize_output( + values, indices, self, dim, keepdim); + if (self.dim() == 0 && self.numel() == 1) { + values.copy_(self); + indices.zero_(); + return std::forward_as_tuple(values, indices); + } + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + // Based on required index size, run the algorithm with the + // appropriate index type + if (self.numel() != 0) { + launch_kthvalue_kernel(values, indices, self, dim, k); + } + + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } + return std::forward_as_tuple(values, indices); +} + +std::tuple median_with_indices_impl( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim, + bool ignore_nan) { + // See note [Writing Nondeterministic Operations] + // If there are duplicate elements of a median value, the procedure for choosing which + // of the duplicates to use for the indices output is nondeterministic. + at::globalContext().alertNotDeterministic("median Zoom with indices output"); + NoNamesGuard guard; + + dim = at::maybe_wrap_dim(dim, self.dim()); + Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0); + + checkDeviceType("median", {values, indices}, self.device().type()); + checkScalarType("median", {indices, "indices", 1}, kLong); + checkSameType("median", {values, "values", 0}, {self, "self", 2}); + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "median() cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + std::vector out_shape = self.sizes().vec(); + zero_numel_check_dims(self, dim, "median()"); + if (self.dim() > 0) { + assert(dim >= 0); + assert(dim < static_cast(out_shape.size())); + + if (keepdim) { + out_shape[dim] = 1; + } else { + out_shape.erase(out_shape.begin() + dim); + } + } + + values.resize_(out_shape); + indices.resize_(out_shape); + + // Only launch kernel for non-empty tensors + if (self.numel() > 0) { + // Ensure #dim is the same for all tensors required for reduction + Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); + Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); + + launch_median_kernel(vals, inds, in, dim, ignore_nan); + } + + guard.reset(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + + return std::forward_as_tuple(values, indices); +} + +Tensor median_impl(const Tensor& self, bool ignore_nan) { + NoNamesGuard guard; + + int64_t size = self.numel(); + // Return nan for empty tensors + if (size <= 0) { + return at::full({}, std::numeric_limits::quiet_NaN()).to(self.options()); + } + + // Sort input tensor to efficiently query for median element + Tensor sorted = std::get<0>(self.flatten().sort()); + + if (!ignore_nan) { + // For torch.median return either the middle element or nan (sorted as + // largest) if there are any + int64_t k = (size - 1) / 2; + return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]); + } else { + // For torch.nanmedian return the middle element among the non-nan values + int64_t k = ((size - 1) - sorted.isnan().sum().item()) / 2; + return sorted[k].clone(); // Clone so we aren't keeping `sorted` alive + } +} + +} // namespace (anonymous) + +std::tuple kthvalue_out_zoom( + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim, + Tensor& values, + Tensor& indices) { + // See note [Writing Nondeterministic Operations] + // If there are duplicate elements of the kth value, the procedure for choosing which + // of the duplicates to use for the indices output is nondeterministic. + at::globalContext().alertNotDeterministic("kthvalue Zoom"); + auto result = [&]() { + NoNamesGuard guard; + // `kthvalue_out_impl_zoom` expects contiguous in input `self`. + return kthvalue_out_impl_zoom(values, indices, self.contiguous(), k, dim, keepdim); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +// Mark: median + +std::tuple median_out_zoom( + const Tensor& self, + int64_t dim, + bool keepdim, + Tensor& values, + Tensor& indices) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/false); +} + +Tensor median_zoom(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/false); +} + +std::tuple nanmedian_out_zoom( + const Tensor& self, + int64_t dim, + bool keepdim, + Tensor& values, + Tensor& indices) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/true); +} + +Tensor nanmedian_zoom(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/true); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Sorting.cu b/aten/src/ATen/native/zoom/Sorting.cu new file mode 100644 index 00000000000000..e3a0a647fc8181 --- /dev/null +++ b/aten/src/ATen/native/zoom/Sorting.cu @@ -0,0 +1,282 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace at::native { + +namespace { + +// Finds the rank k element, and its index, of the values along dimension dim +template +__global__ void gatherKthValue( + zoom::detail::TensorInfo input, + index_t inputSliceSize, + index_t k, + index_t numInputSlices, + index_t inputWithinSliceStride, + zoom::detail::TensorInfo kthValue, + zoom::detail::TensorInfo indices) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of index_t + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + index_t sliceStartIndex = + zoom::detail::IndexToOffset::get(slice, input); + index_t kthValueSliceStartIndex = + zoom::detail::IndexToOffset::get(slice, kthValue); + index_t indicesSliceStartIndex = + zoom::detail::IndexToOffset::get(slice, indices); + + const scalar_t* inputSliceStart = &input.data[sliceStartIndex]; + scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + scalar_t kValue = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t>( + inputSliceStart, + k, + false, + inputSliceSize, + inputWithinSliceStride, + smem, + &kValue); + + // Find the index of the k-th highest element + index_t kValueIndex = 0; + bool foundKValue = false; + + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + bool isKValue = inRange && + ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue))); + if (isKValue) { + kValueIndex = i; + foundKValue = true; + break; + } + } + + if (foundKValue) { + kthValueSliceStart[0] = kValue; + indicesSliceStart[0] = kValueIndex; + } +} + +// CUDA kernel to find the median, and its index, of the values along dimension dim +template +__global__ void gatherMedian( + zoom::detail::TensorInfo values, + zoom::detail::TensorInfo indices, + zoom::detail::TensorInfo input, + index_t inputSliceSize, + index_t numInputSlices, + index_t inputWithinSliceStride, + bool ignore_nan) { + // Shared memory for the subroutine RadixSelect. Note that RadixSelect converts the + // floating point type to int with the same relative ordering. + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Finds the start offset for our slice + index_t valuesSliceStartIndex = + zoom::detail::IndexToOffset::get(slice, values); + index_t indicesSliceStartIndex = + zoom::detail::IndexToOffset::get(slice, indices); + index_t inputSliceStartIndex = + zoom::detail::IndexToOffset::get(slice, input); + + scalar_t* valuesSliceStart = &values.data[valuesSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + const scalar_t* inputSliceStart = &input.data[inputSliceStartIndex]; + + index_t nan_count = 0; + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + nan_count += at::_isnan(val) ? 1 : 0; + } + + // Counts number of nan values + // This code performs a parallel sum reduction (not the most efficient code) + __shared__ int64_t num_nan; + if (threadIdx.x == 0) { + num_nan = 0; + } + __syncthreads(); + if (nan_count > 0) { + gpuAtomicAddNoReturn(&num_nan, nan_count); + } + __syncthreads(); + + // For torch.median, if we found nan set k to last index so the computed value + // is nan, otherwise set k to the middle element of the non-nan values + index_t k = (!ignore_nan && num_nan > 0) ? inputSliceSize - 1 + : (inputSliceSize - num_nan - 1) / 2; + + // Find the median + scalar_t median = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t>( + inputSliceStart, + k + 1, + false, + inputSliceSize, + inputWithinSliceStride, + smem, + &median); + + valuesSliceStart[0] = median; + + // Find the index of the median value in the slice + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + if (val == median || (at::_isnan(val) && at::_isnan(median))) { + indicesSliceStart[0] = i; + break; + } + } +} + +struct KthValueLauncher { + int64_t k; + + KthValueLauncher(int64_t k) : k(k) {} + + template + inline void launch( + zoom::detail::TensorInfo values_info, + int collapse_values_dim, + zoom::detail::TensorInfo indices_info, + int collapse_indices_dim, + zoom::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + (void)collapse_indices_dim; // Suppress unused variable warning + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(::min( + round_up(slice_size, (int64_t)at::zoom::warp_size()), (int64_t)1024)); + auto stream = c10::zoom::getCurrentZoomStream(); + hipLaunchKernelGGL(( gatherKthValue), dim3(grid), dim3(block), 0, stream, + self_info, + slice_size, + k, + num_slices, + /* The actual dimension that the k-selection is running in */ + /* may have changed from collapseDims() */ + self_info.strides[collapse_self_dim], + values_info, + indices_info); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +}; + +struct MedianLauncher { + bool ignore_nan; + + MedianLauncher(bool ignore_nan) : ignore_nan(ignore_nan) {} + + template + inline void launch( + zoom::detail::TensorInfo values_info, + int collapse_values_dim, + zoom::detail::TensorInfo indices_info, + int collapse_indices_dim, + zoom::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + (void)collapse_values_dim; // Suppress unused variable warning + (void)collapse_indices_dim; // Suppress unused variable warning + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(::min( + round_up(slice_size, (int64_t)at::zoom::warp_size()), (int64_t)1024)); + auto stream = c10::zoom::getCurrentZoomStream(); + hipLaunchKernelGGL(( gatherMedian), dim3(grid), dim3(block), 0, stream, + values_info, + indices_info, + self_info, + slice_size, + num_slices, + self_info.strides[collapse_self_dim], + ignore_nan); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +}; + +} // namespace (anonymous) + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "kthvalue_zoom", [&] { + AT_DISPATCH_INDEX_TYPES( + zoom::detail::canUse32BitIndexMath(self) && + zoom::detail::canUse32BitIndexMath(values) && + zoom::detail::canUse32BitIndexMath(indices) ? ScalarType::Int : ScalarType::Long, + "kth_value_launcher", [&] { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + }); + }); +} + +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &self, int64_t dim, bool ignore_nan) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "median_out_impl", [&] { + if (zoom::detail::canUse32BitIndexMath(vals) && + zoom::detail::canUse32BitIndexMath(inds) && + zoom::detail::canUse32BitIndexMath(self)) { + run_launcher( + vals, inds, self, dim, MedianLauncher(ignore_nan)); + } else { + run_launcher( + vals, inds, self, dim, MedianLauncher(ignore_nan)); + } + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Sorting.h b/aten/src/ATen/native/zoom/Sorting.h new file mode 100644 index 00000000000000..bd10ffb1a02741 --- /dev/null +++ b/aten/src/ATen/native/zoom/Sorting.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/SortingCommon.cuh b/aten/src/ATen/native/zoom/SortingCommon.cuh new file mode 100644 index 00000000000000..902145fd4fbfba --- /dev/null +++ b/aten/src/ATen/native/zoom/SortingCommon.cuh @@ -0,0 +1,188 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// Is this questionable namespace pollution? +constexpr int MAX_BLOCK_SIZE = 256; + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +static bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +static uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +// WARNING: This function assumes input tensors are contiguous +template +void run_launcher( + const TensorBase &values, + const TensorBase &indices, + const TensorBase &self, + int64_t dim, + Launcher l) { + auto self_info = zoom::detail::getTensorInfo(self); + auto values_info = zoom::detail::getTensorInfo(values); + auto indices_info = zoom::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/SortingRadixSelect.cuh b/aten/src/ATen/native/zoom/SortingRadixSelect.cuh new file mode 100644 index 00000000000000..83f893b76f9d75 --- /dev/null +++ b/aten/src/ATen/native/zoom/SortingRadixSelect.cuh @@ -0,0 +1,410 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + static_assert(sizeof(short) == 2, ""); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { + RadixType x = __half_as_ushort(v); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::Half deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return __ushort_as_half(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::BFloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::BFloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + at::BFloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + const scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. + for (index_t i = threadIdx.x; i < sliceSize;) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = at::zoom::Bitfield::getBitfield( + val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); + counts[j] += __popcll(WARP_BALLOT(vote)); + } + i += blockDim.x; + } + + // Now, for each warp, sum values + if (at::zoom::getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + gpuAtomicAddNoReturn(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + const scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + round_up(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (found != static_cast(0)) { + // all threads return this value + return val; + } + } + + // should not get here + ZOOM_KERNEL_ASSERT(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + const scalar_t* data, + index_t k, + bool largest, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = at::zoom::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::zoom::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + at::zoom::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::zoom::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (largest) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/zoom/TensorTopK.cpp b/aten/src/ATen/native/zoom/TensorTopK.cpp new file mode 100644 index 00000000000000..bf0539c0b2db97 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorTopK.cpp @@ -0,0 +1,96 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { + +// TODO: remove this when CUDA <11.6 is no longer supported +void topk_out_with_sort( + const Tensor& self, + int64_t k, int64_t dim, bool largest, + const Tensor& values, + const Tensor& indices +) { + auto [sorted_values, sorted_indices] = at::privateuse1::sort(self, /* stable= */false, dim, largest); + values.copy_(sorted_values.narrow(dim, 0, k)); + indices.copy_(sorted_indices.narrow(dim, 0, k)); +} + +// TODO: remove this when CUDA <11.6 is no longer supported +bool disable_sort_for_topk(); +bool should_use_sort(const Tensor& self, int64_t dim) { + if (disable_sort_for_topk()) return false; + // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 + if (self.dim() == 0) return false; + if (self.dtype() == kBool) return false; // Bool is not support by topk + int64_t slice_size = self.size(dim); + if (slice_size == 0) return false; + int64_t num_slices = self.numel() / slice_size; + return num_slices <= 10 && slice_size >= 100000; +} + +TORCH_IMPL_FUNC(topk_out_zoom) + (const Tensor& self, + int64_t k, int64_t dim, bool largest, bool sorted, + const Tensor& values, + const Tensor& indices) { + TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3}; + checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg}); + + dim = at::maybe_wrap_dim(dim, self); + + if (should_use_sort(self, dim)) { + topk_out_with_sort(self, k, dim, largest, values, indices); + return; + } + + // If k is 0 the result is an empty tensor, so we don't need to launch a kernel. + if (k == 0) { + return; + } + + launch_gather_topk_kernel(self, k, dim, largest, values, indices); + + // Sort the results if the user wants them sorted, since our + // selection routine does not ensure sorting + if (sorted && values.numel() > 1) { + if (should_use_small_sort(values, dim)) { + // This avoids any memory allocations and performs all sorting + // work inplace along the slice + + sortKeyValueInplace(values, indices, dim, largest); + } else { + // Depend upon the backup sort that returns indices, which we + // can use in conjunction with gather to produce the original + // indices. + // This is not the most efficient implementation, especially since + // there are memory allocations performed here. If the user desires + // greater performance, they should torch.gather() the results + // themselves using the reported indices, providing previously + // allocated tensors to receive the results. + + Tensor sortedIndices = at::empty_like(indices); + Tensor sortedValues = at::empty_like(values); + at::privateuse1::sort_outf(values, /* stable= */ false, dim, largest, sortedValues, sortedIndices); + indices.copy_(indices.gather(dim, sortedIndices)); + values.copy_(sortedValues); + } + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/TensorTopK.cu b/aten/src/ATen/native/zoom/TensorTopK.cu new file mode 100644 index 00000000000000..c4a431b2a1dc98 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorTopK.cu @@ -0,0 +1,895 @@ +#include +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace at::native; + +namespace at::native { + +// TODO: remove this when CUDA <11.6 is no longer supported +bool disable_sort_for_topk() { + return CUB_SUPPORTS_SCAN_BY_KEY(); +} + +namespace sbtopk { // single_block_topk + +template +struct AddOp { + __device__ __forceinline__ T operator()(T const &lhs, T const &rhs) { + return (lhs + rhs); + } +}; + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void gatherTopK(at::zoom::detail::TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, // aka `k` + bool largest, + + IndexType numInputSlices, + IndexType inputWithinSliceStride, + + at::zoom::detail::TensorInfo topK, + IndexType topKWithinSliceStride, + + at::zoom::detail::TensorInfo indices, + IndexType indicesWithinSliceStride, + T* kthValues) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of IndexType + __shared__ int smem[64]; + IndexType slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + IndexType sliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice, input); + IndexType topKSliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice, topK); + IndexType indicesSliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice, indices); + + const T* inputSliceStart = &input.data[sliceStartIndex]; + T* topKSliceStart = &topK.data[topKSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + T topKValue; + if (WithKthValues){ + topKValue = kthValues[slice]; + } else { + topKValue = static_cast(0); + radixSelect::RadixType, IndexType>( + inputSliceStart, outputSliceSize, largest, + inputSliceSize, inputWithinSliceStride, + smem, &topKValue); + } + const auto topKConverted = at::native::TopKTypeConfig::convert(topKValue); + + // Every value that is strictly less/greater than `pattern` + // (depending on sort dir) in sorted int format is in the top-K. + // The top-K value itself might not be unique. + // + // Since there are a variable number of elements that we see that + // are within the top-k, we don't know at what index to write out + // the resulting values. + // In order to get this, we perform an exclusive prefix sum of + // `hasTopK`. This will return the resulting index into which we + // need to write the result, if a thread has a result. + + // All threads need to participate in the loop and the prefix sum, + // but not necessarily in the load; hence loop bounds being rounded + // up to a multiple of the block dim. + IndexType numIterations = round_up(inputSliceSize, (IndexType) blockDim.x); + IndexType writeIndexStart = 0; + + for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + T v = + inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast(0); + const auto convertedV = at::native::TopKTypeConfig::convert(v); + bool hasTopK; + if (largest) { + hasTopK = inRange && (convertedV > topKConverted); + } else { + hasTopK = inRange && (convertedV < topKConverted); + } + + int index; + int carry; + at::zoom::exclusiveBinaryPrefixScan( + smem, hasTopK, &index, &carry, AddOp()); + + if (hasTopK) { + int writeIndex = writeIndexStart + index; + ZOOM_KERNEL_ASSERT(writeIndex < outputSliceSize); + + IndexType topKOffset = writeIndex * topKWithinSliceStride; + IndexType indexOffset = writeIndex * indicesWithinSliceStride; + + topKSliceStart[topKOffset] = v; + indicesSliceStart[indexOffset] = i; + } + + writeIndexStart += carry; + } + + // We need to fill in the rest with actual == top-K values. + // The number that we need is outputSliceSize - + // writeIndexStart. There might be more than that number available, + // in which case we have to choose the first seen set. We do this + // via a prefix sum to calculate indices for writing results. + ZOOM_KERNEL_ASSERT(outputSliceSize >= writeIndexStart); + IndexType topKRemaining = (outputSliceSize - writeIndexStart); + + for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + T v = + inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast(0); + const auto convertedV = at::native::TopKTypeConfig::convert(v); + bool hasTopK = inRange && (convertedV == topKConverted); + + int index; + int carry; + at::zoom::exclusiveBinaryPrefixScan( + smem, hasTopK, &index, &carry, AddOp()); + + if (hasTopK && index < topKRemaining) { + int writeIndex = writeIndexStart + index; + ZOOM_KERNEL_ASSERT(writeIndex < outputSliceSize); + + IndexType topKOffset = writeIndex * topKWithinSliceStride; + IndexType indexOffset = writeIndex * indicesWithinSliceStride; + + topKSliceStart[topKOffset] = v; + indicesSliceStart[indexOffset] = i; + } + + if (carry >= topKRemaining) { + break; + } + + topKRemaining -= carry; + writeIndexStart += carry; + } + +}; + +template +void launch( + at::zoom::detail::TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, // aka `k` + bool largest, + + IndexType numInputSlices, + IndexType inputWithinSliceStride, + + at::zoom::detail::TensorInfo topK, + IndexType topKWithinSliceStride, + + at::zoom::detail::TensorInfo indices, + IndexType indicesWithinSliceStride) { + + dim3 grid; + TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); + int warp_size = at::zoom::warp_size(); + dim3 block(::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); + hipLaunchKernelGGL(( gatherTopK), dim3(grid), dim3(block), 0, c10::zoom::getCurrentZoomStream(), + input, + inputSliceSize, + outputSliceSize, + largest, + numInputSlices, + inputWithinSliceStride, + topK, + topKWithinSliceStride, + indices, + indicesWithinSliceStride, + nullptr); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} +} // namespace sbtopk + +namespace mbtopk { // multi_block_topk + +// Assumptions: +// The number of elements can be larger than UINT32_MAX, but +// the number of total blocks can not be larger than UINT32_MAX. +// So we can not have more than UINT32_MAX slices. The actual limit +// for number of slices could be a few fold smaller than UINT32_MAX, +// because we could be using multiple blocks per slice. +// Further more, the size of each input slice is also assumped to be +// smaller than UINT32_MAX + +constexpr int BLOCK_THREADS = 256; + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 8; +constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_DIGITS - 1); +static_assert(RADIX_DIGITS <= BLOCK_THREADS, "radixFindKthValues kernel requires RADIX_DIGITS <= BLOCK_THREADS"); +constexpr int MIN_ITEMS_PER_THREAD = 4; +constexpr int MAX_ITEMS_PER_THREAD = 64; + +template +__global__ void fill(T* x, T value, IndexType size) { + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + for (IndexType i = idx; i < size; i += gridDim.x * blockDim.x) { + x[i] = value; + } +} + +// find the kth smallest value, +// for largest topk, k_to_find = slice_size - k + 1 +template +C10_LAUNCH_BOUNDS_1(BLOCK_THREADS) +__global__ void radixFindKthValues( + at::zoom::detail::TensorInfo input, + uint32_t slice_size, + uint32_t* ks_to_find, // size: num_slices + + uint32_t num_slices, + IndexType withinSliceStride, + + int current_bit, + int items_per_thread, + uint32_t blocks_per_slice, + Bitwise desiredMask, + + // outputs + uint32_t* semaphores, // size: num_slices + Bitwise* desires, // size: num_slices + short* counts, // size: num_slices * blocks_per_slice * radix_digits + T* kthValues // size: num_slices, only write when current_bit reaches 0 + ) { + + int items_per_block = items_per_thread * BLOCK_THREADS; + int tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + uint32_t slice_idx = block_idx / blocks_per_slice; + uint32_t blk_idx_in_slice = block_idx % blocks_per_slice; + if (slice_idx >= num_slices) { + return; + } + + Bitwise desired = desires[slice_idx]; + uint32_t k_to_find = ks_to_find[slice_idx]; + IndexType slice_start_index = at::zoom::detail::IndexToOffset::get(slice_idx, input); + const T* data = &input.data[slice_start_index]; + + typedef hipcub::BlockScan BlockScan; + static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits::max(), + "blockwise counter too large"); + union __align__(16) TempStorage { + uint32_t digit_counters[RADIX_DIGITS]; + uint32_t digit_count_cumsum[RADIX_DIGITS]; // only used if this it the last block for this slice + typename BlockScan::TempStorage scan_storage; + }; + __shared__ TempStorage temp_storage; + + // fill digit_counters with zeros + if (tidx < RADIX_DIGITS) { + temp_storage.digit_counters[tidx] = 0; + } + __syncthreads(); + + items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice) + ? items_per_thread + : at::ceil_div((int64_t)(slice_size - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS); + + // collect digit counts and store in shared memory + for (int i = 0; i < items_per_thread; ++i) { + // Find the start offset for this slice + IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx; + if (idx < slice_size) { + idx *= withinSliceStride; + Bitwise val = TopKTypeConfig::convert(doLdg(&data[idx])); + bool has_val = ((val & desiredMask) == (desired & desiredMask)); + Bitwise digit = at::zoom::Bitfield::getBitfield(val, current_bit, RADIX_BITS); + if (has_val) { + atomicAdd(&temp_storage.digit_counters[digit], 1); + } + } + } + + __syncthreads(); + + // load digit counter to register, one digit per thread + static_assert(RADIX_DIGITS <= BLOCK_THREADS, "this kernel requires RADIX_DIGITS <= BLOCK_THREADS"); + uint32_t digit_count = 0; + if (tidx < RADIX_DIGITS) { + digit_count = temp_storage.digit_counters[tidx]; + } + + // We always write out counts regardless if blocks_per_slice == 1 because + // it will be used to compute offsets for `gatherTopK`. + if (tidx < RADIX_DIGITS) { + counts[block_idx * RADIX_DIGITS + tidx] = digit_count; + } + // if blocks_per_slice == 1, there is no need to do cross-block reduction + // in this case we use counts saved at registers directly + if (blocks_per_slice > 1) { + __threadfence(); // make sure writes are globally visible + __syncthreads(); // make sure all writes are finished before update semaphores + } + + // the last block of each slice accumulates counters from multiple blocks and updates desired and ks_to_find + __shared__ bool s_is_last_block_done; + + if (tidx == 0) { + if (blocks_per_slice == 1) { + s_is_last_block_done = true; + } else { + uint32_t blocks_finished_old = atomicAdd(&semaphores[slice_idx], 1); + s_is_last_block_done = (blocks_finished_old == blocks_per_slice - 1); + } + } + + __syncthreads(); + + if (!s_is_last_block_done) + return; + + // accumulates counters from multiple blocks + if (tidx < RADIX_DIGITS && blocks_per_slice > 1) { + digit_count = 0; + for (int blk = 0; blk < blocks_per_slice; ++blk) { + digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + tidx]; + } + } + + // compute the block-wide inclusive prefix sum + uint32_t digit_count_cumsum; + BlockScan(temp_storage.scan_storage).InclusiveSum(digit_count, digit_count_cumsum); + __syncthreads(); + // every thread also need the perfix_sum of it's left value for comparison, so save a copy in shared mem + if (tidx < RADIX_DIGITS) { + temp_storage.digit_count_cumsum[tidx] = digit_count_cumsum; + } + __syncthreads(); + + if (tidx < RADIX_DIGITS) { + uint32_t digit_count_cumsum_left = (tidx == 0) ? 0 : temp_storage.digit_count_cumsum[tidx - 1]; + + // if not the last pass: update desired and ks_to_find + // if last pass: write out the kth value + if (digit_count_cumsum_left < k_to_find && k_to_find <= digit_count_cumsum) { + desired = at::zoom::Bitfield::setBitfield(desired, tidx, current_bit, RADIX_BITS); + desires[slice_idx] = desired; + if (current_bit > 0) { + ks_to_find[slice_idx] = k_to_find - digit_count_cumsum_left; + } else { + kthValues[slice_idx] = TopKTypeConfig::deconvert(desired); + } + } + } + + // reset semaphores for the next pass + if (tidx == 0) { + semaphores[slice_idx] = 0; + } +} + +#if CUB_SUPPORTS_SCAN_BY_KEY() +// Assumption: k can not be larger than UINT32_MAX +template +C10_LAUNCH_BOUNDS_1(RADIX_DIGITS) // one thread per digit +__global__ void computeBlockwiseWithinKCounts( + Bitwise* desires, // size: num_slices + short* counts, // size: num_slices * blocks_per_slice * radix_digits + uint32_t blocks_per_slice, + int current_bit, + bool largest, + // outputs: + uint32_t* withinKCounts, // size: num_slices * blocks_per_slice == num_blocks + uint32_t num_blocks +) { + // This kernel should be launched with the same number of blocks as the `radixFindKthValues` kernel. + int tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + uint32_t slice_idx = block_idx / blocks_per_slice; + + // The grid is computed from `getGridFromTiles`, when there are lots of + // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z + // when this is the case, the number of blocks that we are launching can be + // more than the number of blocks we need. So we need to check the range of + // `block_idx`. + if (block_idx >= num_blocks) { + return; + } + + Bitwise desired = doLdg(desires + slice_idx); + Bitwise desired_digit = at::zoom::Bitfield::getBitfield(desired, current_bit, RADIX_BITS); + + // if largest, then only threads that has tidx > desired_digit are active + // if !largest, then only threads that has tidx < desired_digit are active + // each active thread will read the count for its corresponding, and + // do warp reduction followed by shared memory reduction to get the total count + // non-active thread should not load, and non-active warp should not do reduction. + bool warp_is_active, thread_is_active; + int warp = tidx / C10_WARP_SIZE; + if (largest) { + int end_of_warp = warp * C10_WARP_SIZE + C10_WARP_SIZE - 1; + warp_is_active = end_of_warp > desired_digit; + thread_is_active = tidx > desired_digit; + } else { + int start_of_warp = warp * C10_WARP_SIZE; + warp_is_active = start_of_warp < desired_digit; + thread_is_active = tidx < desired_digit; + } + uint32_t count = 0; + if (warp_is_active) { + if (thread_is_active) { + count = doLdg(counts + block_idx * RADIX_DIGITS + tidx); + } + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + count += WARP_SHFL_DOWN(count, offset); + } + } + + constexpr int num_warps = RADIX_DIGITS / C10_WARP_SIZE; + __shared__ uint32_t warp_counts[num_warps]; + if (tidx % C10_WARP_SIZE == 0) { + warp_counts[warp] = count; + } + __syncthreads(); + static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE, + "Assuming only 1 warp is needed for final reduction"); + if (warp != 0) { + return; + } + count = 0; + if (tidx < num_warps) { + count = warp_counts[tidx]; + } + for (int offset = num_warps / 2; offset > 0; offset /= 2) { + count += WARP_SHFL_DOWN(count, offset); + } + if (tidx == 0) { + withinKCounts[block_idx] += count; + } +} + +// Assumption: slice_size can not be larger than UINT32_MAX +template +__global__ void computeBlockwiseKthCounts( + Bitwise* desires, // size: num_slices + short* counts, // size: num_slices * blocks_per_slice * radix_digits + uint32_t num_blocks, // the number of blocks used by `radixFindKthValues` kernel + uint32_t blocks_per_slice, + // outputs: + uint32_t* kthCounts // size: num_slices * blocks_per_slice == num_blocks +) { + HIP_KERNEL_LOOP_TYPE(idx, num_blocks, uint32_t) { + uint32_t slice_idx = idx / blocks_per_slice; + Bitwise desired = doLdg(desires + slice_idx); + Bitwise desired_digit = at::zoom::Bitfield::getBitfield(desired, 0, RADIX_BITS); + kthCounts[idx] = doLdg(counts + idx * RADIX_DIGITS + desired_digit); + } +} + +template +C10_LAUNCH_BOUNDS_1(BLOCK_THREADS) +__global__ void gatherTopK(at::zoom::detail::TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, // aka `k` + bool largest, + + uint32_t numInputSlices, + IndexType inputWithinSliceStride, + + at::zoom::detail::TensorInfo topK, + IndexType topKWithinSliceStride, + + at::zoom::detail::TensorInfo indices, + IndexType indicesWithinSliceStride, + + uint32_t items_per_thread, + uint32_t blocks_per_slice, + + T *kthValues, + uint32_t* withinKCounts, + uint32_t* kthCounts, + uint32_t num_blocks) { + + uint32_t items_per_block = items_per_thread * BLOCK_THREADS; + uint32_t tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + + // The grid is computed from `getGridFromTiles`, when there are lots of + // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z + // when this is the case, the number of blocks that we are launching can be + // more than the number of blocks we need. So we need to check the range of + // `block_idx`. + if (block_idx >= num_blocks) { + return; + } + + uint32_t slice_idx = block_idx / blocks_per_slice; + uint32_t blk_idx_in_slice = block_idx % blocks_per_slice; + + items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice) + ? items_per_thread + : at::ceil_div((int64_t)(inputSliceSize - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS); + + // Find the start offset for our slice + IndexType sliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice_idx, input); + IndexType topKSliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice_idx, topK); + IndexType indicesSliceStartIndex = + at::zoom::detail::IndexToOffset::get(slice_idx, indices); + + const T* inputSliceStart = &input.data[sliceStartIndex]; + T* topKSliceStart = &topK.data[topKSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + T kthValue = kthValues[slice_idx]; + const auto kthValueConverted = at::native::TopKTypeConfig::convert(kthValue); + + // Find the start index in output tensor of this block + uint32_t startWithinK = 0; + if (blk_idx_in_slice > 0) { + startWithinK = withinKCounts[block_idx - 1]; + } + uint32_t startKth = withinKCounts[slice_idx * blocks_per_slice + blocks_per_slice - 1]; + if (blk_idx_in_slice > 0) { + startKth += kthCounts[block_idx - 1]; + } + + // Read input, select topk out and write + typedef hipcub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + for (int i = 0; i < items_per_thread; ++i) { + // Find the start offset for this slice + IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx; + T val; + int withinK = 0; + int kth = 0; + if (idx < inputSliceSize) { + val = doLdg(inputSliceStart + idx * inputWithinSliceStride); + const auto valConverted = at::native::TopKTypeConfig::convert(val); + withinK = (largest ? valConverted > kthValueConverted : valConverted < kthValueConverted); + kth = (valConverted == kthValueConverted); + } + + uint32_t withinKIndex; + uint32_t numWithinK; + BlockScan(temp_storage).ExclusiveSum(withinK, withinKIndex, numWithinK); + __syncthreads(); + if (withinK) { + uint32_t offset = withinKIndex + startWithinK; + topKSliceStart[offset * topKWithinSliceStride] = val; + indicesSliceStart[offset * indicesWithinSliceStride] = idx; + } + startWithinK += numWithinK; + + if (startKth < outputSliceSize) { + uint32_t kthIndex; + uint32_t numKth; + BlockScan(temp_storage).ExclusiveSum(kth, kthIndex, numKth); + __syncthreads(); + if (kth) { + uint32_t offset = kthIndex + startKth; + if (offset < outputSliceSize) { + topKSliceStart[offset * topKWithinSliceStride] = val; + indicesSliceStart[offset * indicesWithinSliceStride] = idx; + } + } + startKth += numKth; + } + } +} +#endif + +int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) { + // occupancy of this kernel is limited by registers per threads + constexpr int REGS_PER_THREAD = 40; // from nsight launch statistics + constexpr int REGS_PER_BLOCK = REGS_PER_THREAD * BLOCK_THREADS; + hipDeviceProp_t* prop = at::zoom::getCurrentDeviceProperties(); + int mpc = prop->multiProcessorCount; + int regs_per_mp = prop->regsPerBlock; + int max_blocks_per_mp = 32; + int blocks_per_mp = ::min(regs_per_mp / REGS_PER_BLOCK, max_blocks_per_mp); + int64_t items_per_thread = at::ceil_div((int64_t)(slice_size * num_slices), (int64_t)(mpc * blocks_per_mp * BLOCK_THREADS)); + items_per_thread = ::max(MIN_ITEMS_PER_THREAD, ::min((int)items_per_thread, MAX_ITEMS_PER_THREAD)); // clamp to (4, 64) + return items_per_thread; +} + +class BlockIdxToKey { + uint32_t blocks_per_slice; +public: + BlockIdxToKey(uint32_t blocks_per_slice): blocks_per_slice(blocks_per_slice) {} + __device__ __forceinline__ uint32_t operator()(uint32_t blk) const { + return blk / blocks_per_slice; + } +}; + +template +void launch( + at::zoom::detail::TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, // aka `k` + bool largest, + + uint32_t numInputSlices, + IndexType inputWithinSliceStride, + + at::zoom::detail::TensorInfo topK, + IndexType topKWithinSliceStride, + + at::zoom::detail::TensorInfo indices, + IndexType indicesWithinSliceStride) { + auto stream = c10::zoom::getCurrentZoomStream(); + + // configure items_per_thread based on device architecture and input size + int items_per_thread = get_items_per_thread(numInputSlices, inputSliceSize); + int items_per_block = items_per_thread * BLOCK_THREADS; + + using Bitwise = typename TopKTypeConfig::RadixType; + uint32_t blocks_per_slice = at::ceil_div((int64_t)inputSliceSize, (int64_t)items_per_block); + uint32_t num_blocks = numInputSlices * blocks_per_slice; + + // temporary storage + auto& allocator = *c10::zoom::ZoomCachingAllocator::get(); + + auto kthValues_buffer = allocator.allocate(numInputSlices * sizeof(T)); + T* kthValues = reinterpret_cast(kthValues_buffer.get()); + + TORCH_CHECK(blocks_per_slice <= std::numeric_limits::max(), "blocks_per_slice larger than uint32 maximum is not supported"); + auto semaphores_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t)); + uint32_t* semaphores = reinterpret_cast(semaphores_buffer.get()); + C10_ZOOM_CHECK(hipMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream)); + + auto ks_to_find_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t)); + uint32_t* ks_to_find = reinterpret_cast(ks_to_find_buffer.get()); + uint32_t k_to_find = largest ? inputSliceSize - outputSliceSize + 1: outputSliceSize; + hipLaunchKernelGGL(( fill), dim3(::min(((int64_t)numInputSlices + 511) / 512, (int64_t)1073741824)), dim3(512), 0, stream, + ks_to_find, k_to_find, numInputSlices); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + auto desired_buffer = allocator.allocate(numInputSlices * sizeof(Bitwise)); + Bitwise* desired = reinterpret_cast(desired_buffer.get()); + + auto counts_buffer = allocator.allocate(num_blocks * RADIX_DIGITS * sizeof(short)); + short* counts = reinterpret_cast(counts_buffer.get()); + static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits::max(), + "blockwise counter too large"); + +#if CUB_SUPPORTS_SCAN_BY_KEY() + auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); + uint32_t* withinKCounts = reinterpret_cast(withinKCounts_buffer.get()); + C10_ZOOM_CHECK(hipMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream)); + + auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t)); + uint32_t* kthCounts = reinterpret_cast(kthCounts_buffer.get()); +#endif + + Bitwise desiredMask = 0; + dim3 grid; + TORCH_INTERNAL_ASSERT(getGridFromTiles(num_blocks, grid), "Too many slices for topk"); + dim3 block(BLOCK_THREADS); + + // iterate radix bits for multiple passes + for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; current_bit -= RADIX_BITS) { + hipLaunchKernelGGL(( radixFindKthValues), dim3(grid), dim3(block), 0, stream, + input, + inputSliceSize, + ks_to_find, + numInputSlices, + inputWithinSliceStride, + current_bit, + items_per_thread, + blocks_per_slice, + desiredMask, + semaphores, + desired, + counts, + kthValues); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +#if CUB_SUPPORTS_SCAN_BY_KEY() + hipLaunchKernelGGL(( computeBlockwiseWithinKCounts), dim3(grid), dim3(RADIX_DIGITS), 0, stream, + desired, counts, blocks_per_slice, current_bit, largest, withinKCounts, num_blocks); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +#endif + desiredMask = at::zoom::Bitfield::setBitfield(desiredMask, RADIX_MASK, current_bit, RADIX_BITS); + } + +#if CUB_SUPPORTS_SCAN_BY_KEY() + hipLaunchKernelGGL(( computeBlockwiseKthCounts), dim3(::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824)), dim3(256), 0, stream, + desired, counts, num_blocks, blocks_per_slice, kthCounts); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + // Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block + using counting_iter_t = hipcub::CountingInputIterator; + using slice_idx_iter_t = hipcub::TransformInputIterator; + slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice)); + at::zoom::hipcub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks); + at::zoom::hipcub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks); + // copy topk values to output tensor + hipLaunchKernelGGL(( gatherTopK), dim3(grid), dim3(block), 0, stream, + input, inputSliceSize, outputSliceSize, largest, numInputSlices, inputWithinSliceStride, + topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread, + blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +#else + // Find topk values based on kth values + { + dim3 grid; + TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk"); + int warp_size = at::zoom::warp_size(); + dim3 block(::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024)); + hipLaunchKernelGGL(( sbtopk::gatherTopK), dim3(grid), dim3(block), 0, stream, + input, + inputSliceSize, + outputSliceSize, + largest, + numInputSlices, + inputWithinSliceStride, + topK, + topKWithinSliceStride, + indices, + indicesWithinSliceStride, + kthValues); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +#endif +} + +} // namespace mbtopk + +bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { + if (num_slices > std::numeric_limits::max() || + slice_size > std::numeric_limits::max()) return false; +#if CUB_SUPPORTS_SCAN_BY_KEY() + // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267 + return (num_slices <= 20 && slice_size >= 20000) || + (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || + (num_slices > 40 && num_slices <= 80 && slice_size >= 8000) || + (num_slices > 80 && num_slices < 200 && slice_size >= 5000) || + (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || + (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || + (num_slices > 4000 && slice_size >= 400); +#else + // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081 + return (num_slices <= 400 && slice_size >= 5000) || + (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || + (num_slices >= 4000 && slice_size >= 300); +#endif +} + +void launch_gather_topk_kernel( + const TensorBase& self, int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices) { + int numDims = self.dim(); + numDims = numDims == 0 ? 1 : numDims; + TORCH_CHECK(numDims <= MAX_DIMS, "input tensor has too many dimensions"); + int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim); + + auto input = self.contiguous(); + // static_cast is required to ensure that the correct type (INDEX_T) + // is provided to the kernel for the arguments. +#define RUN_K(INDEX_T, DIM, LAUNCH_FUNCTION_NAME) \ + LAUNCH_FUNCTION_NAME( \ + inputInfo, \ + static_cast(sliceSize), \ + static_cast(k), \ + largest, \ + static_cast(numInputSlices), \ + /* The actual dimension that the k-selection is running in */ \ + /* may have changed from collapseDims() */ \ + static_cast(inputInfo.strides[collapseInputDim]), \ + topKInfo, \ + static_cast(topKInfo.strides[collapseTopKDim]), \ + indicesInfo, \ + static_cast(indicesInfo.strides[collapseIndicesDim])); + +#define RUN_MB(INDEX_T, DIM) \ + if (should_use_multiblock(numInputSlices, sliceSize)) { \ + RUN_K(INDEX_T, DIM, mbtopk::launch); \ + } else { \ + RUN_K(INDEX_T, DIM, sbtopk::launch); \ + } + +#define RUN_DIM(INDEX_T) \ + if (allDims == 1) { \ + RUN_MB(INDEX_T, 1); \ + } else if (allDims == 2) { \ + RUN_MB(INDEX_T, 2); \ + } else if (allDims == 3) { \ + RUN_MB(INDEX_T, 3); \ + } else { \ + RUN_MB(INDEX_T, -1); \ + } + +#define RUN_T(INDEX_T) \ + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "topk_out_zoom", [&] { \ + at::zoom::detail::TensorInfo inputInfo = \ + at::zoom::detail::getTensorInfo(input); \ + at::zoom::detail::TensorInfo topKInfo = \ + at::zoom::detail::getTensorInfo(values); \ + at::zoom::detail::TensorInfo indicesInfo = \ + at::zoom::detail::getTensorInfo(indices); \ + /* tensorInfoLegacyIfScalar*/ \ + if (!input.dim()) { \ + inputInfo.dims = 1; \ + inputInfo.sizes[0] = 1; \ + inputInfo.strides[0] = 1; \ + topKInfo.dims = 1; \ + topKInfo.sizes[0] = 1; \ + topKInfo.strides[0] = 1; \ + indicesInfo.dims = 1; \ + indicesInfo.sizes[0] = 1; \ + indicesInfo.strides[0] = 1; \ + } \ + /* We use these structures solely to find the offset to */ \ + /* each slice we are operating on */ \ + inputInfo.sizes[dim] = 1; \ + topKInfo.sizes[dim] = 1; \ + indicesInfo.sizes[dim] = 1; \ + /* stash the stride of dim because it can be accidentally collapsed */ \ + auto strideTopK = topKInfo.strides[dim]; \ + auto strideIndices = indicesInfo.strides[dim]; \ + /* Collapse all other dims */ \ + int collapseInputDim = inputInfo.collapseDims(dim); \ + int collapseTopKDim = topKInfo.collapseDims(dim); \ + int collapseIndicesDim = indicesInfo.collapseDims(dim); \ + /* restore stride in case it was collapsed */ \ + topKInfo.strides[collapseTopKDim] = strideTopK; \ + indicesInfo.strides[collapseIndicesDim] = strideIndices; \ + int64_t numInputSlices = 1; \ + for (int i = 0; i < inputInfo.dims; ++i) { \ + numInputSlices *= inputInfo.sizes[i]; \ + } \ + \ + /* This is used as a template parameter to calculate indices. */ \ + /* We only specialize it if all collapsed dim sizes are the */ \ + /* same; otherwise, we use -1 which is the specialization */ \ + /* parameter for arbitrary dimensions */ \ + int allDims = inputInfo.dims; \ + if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \ + allDims = -1; \ + } \ + \ + RUN_DIM(INDEX_T); \ + }); + + // the below is safe with 0-dimensional tensors because it is based on + // TensorInfo which implicitly expands to 1-dimensional. + if (input.numel() > 0) { + // Based on required index size, run the algorithm with the + // appropriate index type + if (at::zoom::detail::canUse32BitIndexMath(input) && + at::zoom::detail::canUse32BitIndexMath(values) && + at::zoom::detail::canUse32BitIndexMath(indices)) { + RUN_T(uint32_t); + } else { + RUN_T(uint64_t); + } + } +#undef RUN_T +#undef RUN_DIM +#undef RUN_K +} + +} // at::native diff --git a/aten/src/ATen/native/zoom/TensorTopK.h b/aten/src/ATen/native/zoom/TensorTopK.h new file mode 100644 index 00000000000000..9eebf2cd6040c4 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorTopK.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { +void launch_gather_topk_kernel( + const TensorBase& self, + int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices); +}} diff --git a/aten/src/ATen/native/zoom/TriangularOps.cu b/aten/src/ATen/native/zoom/TriangularOps.cu new file mode 100644 index 00000000000000..b1bd67b8f501c5 --- /dev/null +++ b/aten/src/ATen/native/zoom/TriangularOps.cu @@ -0,0 +1,165 @@ +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +#include + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +namespace at::native { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +constexpr static int block_size = 128; + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void triu_tril_kernel( + zoom::detail::TensorInfo result_info, + const zoom::detail::TensorInfo self_info, + const int64_t k, + const int64_t N_padded, + const IndexType last_dim_padded) { + int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread; + if (linear_idx >= N_padded) { + return; + } + + auto dims = self_info.dims; + + // Compute column index amd row index + IndexType col = linear_idx % last_dim_padded; + linear_idx /= last_dim_padded; + IndexType row = linear_idx % self_info.sizes[dims - 2]; + + if constexpr (inplace) { + bool mask_all_true = upper ? (col - row >= k) : (col + elements_per_thread - row <= k); + if (mask_all_true) + return; + } + + // Compute offset + IndexType self_offset = 0, result_offset = 0; + self_offset += self_info.strides[dims - 1] * col; + result_offset += result_info.strides[dims - 1] * col; + linear_idx /= self_info.sizes[dims - 2]; + self_offset += self_info.strides[dims - 2] * row; + result_offset += result_info.strides[dims - 2] * row; + + // Compute remaining offsets + IndexType running_index; + #pragma unroll + for (IndexType i = dims - 3; i >= 0; --i) { + running_index = linear_idx % self_info.sizes[i]; + linear_idx /= self_info.sizes[i]; + self_offset += running_index * self_info.strides[i]; + result_offset += running_index * result_info.strides[i]; + } + + if constexpr (inplace) { + #pragma unroll + for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++) { + bool mask = upper ? (col + i - row >= k) : (col + i - row <= k); + if (!mask) + result_info.data[result_offset + i * result_info.strides[dims - 1]] = scalar_t(0); + } + } else { + scalar_t frag[elements_per_thread] = {}; + bool has_mask = (upper && col + elements_per_thread - row >= k) || (!upper && col - row <= k); + if (has_mask) { + #pragma unroll + for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++) + frag[i] = self_info.data[self_offset + i * self_info.strides[dims - 1]]; + + #pragma unroll + for (int i = 0; i < elements_per_thread; i++) { + bool mask = upper ? (col + i - row >= k) : (col + i - row <= k); + frag[i] = mask ? frag[i] : scalar_t(0); + } + } + + #pragma unroll + for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++) + result_info.data[result_offset + i * result_info.strides[dims - 1]] = frag[i]; + } +} + +template +void triu_tril_zoom_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Bool, + self.scalar_type(), "triu_tril_zoom_template", [&] { + constexpr int elements_per_thread = sizeof(scalar_t) < 8 ? 8 / sizeof(scalar_t) : 1; + auto sizes = self.sizes(); + int64_t last_dim_padded = round_up(sizes.back(), elements_per_thread); + int64_t N_padded = c10::multiply_integers(sizes.begin(), sizes.end() - 1) * last_dim_padded; + dim3 dim_block = block_size; + dim3 dim_grid((N_padded / elements_per_thread + dim_block.x - 1) / dim_block.x); + if (zoom::detail::canUse32BitIndexMath(result) && zoom::detail::canUse32BitIndexMath(self)) { + auto result_info = zoom::detail::getTensorInfo(result); + auto self_info = zoom::detail::getTensorInfo(self); + BOOL_SWITCH(self.is_same(result), inplace, [&] { + hipLaunchKernelGGL(( triu_tril_kernel) + , dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), + result_info, self_info, k, N_padded, last_dim_padded); + }); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + auto result_info = zoom::detail::getTensorInfo(result); + auto self_info = zoom::detail::getTensorInfo(self); + BOOL_SWITCH(self.is_same(result), inplace, [&] { + hipLaunchKernelGGL(( triu_tril_kernel) + , dim3(dim_grid), dim3(dim_block), 0, c10::zoom::getCurrentZoomStream(), + result_info, self_info, k, N_padded, last_dim_padded); + }); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); +} + +TORCH_IMPL_FUNC(tril_zoom)(const Tensor& self, int64_t k, const Tensor &result) { + if (self.numel() != 0) { + triu_tril_zoom_template(result, self, k, "tril"); + } +} + +TORCH_IMPL_FUNC(triu_zoom)(const Tensor& self, int64_t k, const Tensor &result) { + if (self.numel() != 0) { + triu_tril_zoom_template(result, self, k, "triu"); + } +} + +Tensor trace_zoom(const Tensor& self) { + TORCH_CHECK(self.dim() == 2, "expected a matrix"); + return self.diagonal().sum(); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/UnaryGeometricCosKernel.cu b/aten/src/ATen/native/zoom/UnaryGeometricCosKernel.cu new file mode 100644 index 00000000000000..d76b21a6bcd095 --- /dev/null +++ b/aten/src/ATen/native/zoom/UnaryGeometricCosKernel.cu @@ -0,0 +1,58 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +#if AT_USE_JITERATOR() +CONSTEXPR_EXCEPT_WIN_CUDA char cos_name[] = "cos_impl"; +#endif // AT_USE_JITERATOR() + +void cos_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR() + static const auto cos_string = jiterator_stringify( + template T cos_impl(T a) { return std::cos(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + jitted_gpu_kernel< + /*name=*/cos_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, cos_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::cos(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "cos_zoom", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(cos_stub, &cos_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/UnaryGeometricSinKernel.cu b/aten/src/ATen/native/zoom/UnaryGeometricSinKernel.cu new file mode 100644 index 00000000000000..d7417fb6477a87 --- /dev/null +++ b/aten/src/ATen/native/zoom/UnaryGeometricSinKernel.cu @@ -0,0 +1,58 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +#if AT_USE_JITERATOR() +CONSTEXPR_EXCEPT_WIN_CUDA char sin_name[] = "sin_impl"; +#endif + +void sin_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR() + static const auto sin_string = jiterator_stringify( + template T sin_impl(T a) { return std::sin(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + jitted_gpu_kernel< + /*name=*/sin_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sin_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sin(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "sin_zoom", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sin(a); }); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(sin_stub, &sin_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/UnarySignKernels.cu b/aten/src/ATen/native/zoom/UnarySignKernels.cu new file mode 100644 index 00000000000000..57362dfcd6007d --- /dev/null +++ b/aten/src/ATen/native/zoom/UnarySignKernels.cu @@ -0,0 +1,121 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +void logical_not_kernel_zoom(TensorIteratorBase& iter) { + // error check -- this is just ensuring we don't dispatch on types that aren't in ALL_TYPES_AND_COMPLEX_AND3(...) + // so we don't have to maintain a separate list or to do double dispatch. + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(0), "logical_not_zoom", [&]() {}); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(1), "logical_not_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return !a; }); + }); +} + +// NB: Ignores the negative bit on tensors +CONSTEXPR_EXCEPT_WIN_CUDA char neg_name[] = "neg_kernel"; +void neg_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + if (at::isComplexType(dtype)) { + static const auto neg_string = jiterator_stringify( + template + T neg_kernel(T a) { + return -a; + } + ); // neg_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ neg_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, neg_string); + }); + + } else { + AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, dtype, "neg_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return -a; + }); + }); + } +} + +void sign_kernel_zoom(TensorIteratorBase& iter){ + if (iter.dtype() == ScalarType::Bool) { + gpu_kernel(iter, []GPU_LAMBDA(bool a){ + return a; + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sign_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return c10::signum(a); + }); + }); + } +} + +void signbit_kernel_zoom(TensorIteratorBase& iter){ + // NOTE: signbit does not always support integral arguments. + if (at::isIntegralType(iter.input_dtype(), /*includeBool=*/false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.input_dtype(), "signbit_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return is_negative(a); }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, ScalarType::Half, iter.input_dtype(), "signbit_zoom", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return signbit(opmath_t{a}); }); + }); + } +} + +template +C10_HOST_DEVICE static inline c10::complex sgn_wrapper(c10::complex z) { + if (z == c10::complex(0, 0)) { + return c10::complex(0, 0); + } else { + return z / std::abs(z); + } +} + +CONSTEXPR_EXCEPT_WIN_CUDA char sgn_name[] = "sgn_kernel"; +void sgn_kernel_zoom(TensorIteratorBase& iter){ + auto dtype = iter.dtype(); + static const auto sgn_string = jiterator_stringify( + template + T sgn_kernel(T z) { + const T zero = T(0); + if (z == zero) { + return zero; + } else { + return z / std::abs(z); + } + } + ); // sgn_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sgn_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ sgn_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, sgn_string); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(logical_not_stub, &logical_not_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(neg_stub, &neg_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(sign_stub, &sign_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(signbit_stub, &signbit_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(sgn_stub, &sgn_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/block_reduce.cuh b/aten/src/ATen/native/zoom/block_reduce.cuh new file mode 100644 index 00000000000000..16f9b2ba6b492a --- /dev/null +++ b/aten/src/ATen/native/zoom/block_reduce.cuh @@ -0,0 +1,143 @@ +#pragma once + +#include + +#include +#include + +namespace at { +namespace native { +namespace zoom_utils { + +constexpr int kHIPBlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kHIPBlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; + +// Sums `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +// Picks the maximum `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + +struct Block1D { + static __forceinline__ __device__ int Tid() { return threadIdx.x; } + + static __forceinline__ __device__ int Warps() { + return blockDim.x / C10_WARP_SIZE; + } +}; + +struct Block2D { + static __forceinline__ __device__ int Tid() { + return threadIdx.x + threadIdx.y * blockDim.x; + } + + static __forceinline__ __device__ int Warps() { + return blockDim.x * blockDim.y / C10_WARP_SIZE; + } +}; + +// Sums `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceSum(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + +template +__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ __device__ T +BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduce(val, op); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : identity_element; + if (wid == 0) { + val = WarpReduce(val, op); + } + return val; +} + +} // namespace zoom_utils +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 999f7489a8ff66..51dc66bfc6af88 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -349,7 +349,7 @@ def compute_ufunc_zoom(g: NativeFunctionsGroup) -> str: {dtype_cases_str} ); }} -REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +REGISTER_PRIVATEUSE1_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); {sig.defn()} {{ {stub_sig.direct_call(sig.arguments())}; }} From b33031d27b8ccd1e5d271f68c5abaf2735469755 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Mon, 13 Jan 2025 01:04:45 +0000 Subject: [PATCH 05/23] fix matmul kernel --- aten/src/ATen/native/zoom/Bmm.cpp | 7 +++--- aten/src/ATen/native/zoom/HIPbmm.cu | 38 +++++++++++++++++------------ 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/zoom/Bmm.cpp b/aten/src/ATen/native/zoom/Bmm.cpp index f95e530655919f..53e87a7eb3913e 100644 --- a/aten/src/ATen/native/zoom/Bmm.cpp +++ b/aten/src/ATen/native/zoom/Bmm.cpp @@ -28,14 +28,15 @@ namespace at::native { } else if (batch1.size(2) == 0) { return result.zero_(); } + TORCH_CHECK(batch1.sizes()[2] == batch2.sizes()[1], "batch1 dim 2 must match batch2 dim 1"); c10::MaybeOwned result_ = c10::MaybeOwned::borrowed(result); IntArrayRef result_strides = result.strides(); IntArrayRef result_sizes = result.sizes(); - int m = result_sizes[1]; - int n = result_sizes[2]; - int k = batch1.sizes()[2]; + int m = batch1.sizes()[1]; + int n = batch1.sizes()[2]; + int k = batch2.sizes()[2]; int num_batches = result_->sizes()[0]; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "bmm_hip", [&] { diff --git a/aten/src/ATen/native/zoom/HIPbmm.cu b/aten/src/ATen/native/zoom/HIPbmm.cu index 84f5eb2aaf6201..a77a31efaf1af6 100644 --- a/aten/src/ATen/native/zoom/HIPbmm.cu +++ b/aten/src/ATen/native/zoom/HIPbmm.cu @@ -2,9 +2,14 @@ #include #include #include +#include namespace at::native { + int num_threads() { + return 32; + } + // Helper function to convert hip_bfloat16 to float __device__ float bfloat16_to_float(hip_bfloat16 a) { union { @@ -63,64 +68,65 @@ namespace at::native { int col = blockIdx.x * blockDim.x + threadIdx.x; int batch = blockIdx.z; - if (row < M && col < N) { + if (row < M && col < K && batch < batch_size) { float sum = 0.0f; - for (int k = 0; k < K; ++k) { - sum += convert_to_float(A[batch * M * K + row * K + k]) * - convert_to_float(B[batch * K * N + k * N + col]); + for (int n = 0; n < N; ++n) { + sum += convert_to_float(A[batch * M * N + row * N + n]) * + convert_to_float(B[batch * N * K + n * K + col]); } - C[batch * M * N + row * N + col] = convert_from_float(sum); + C[batch * M * K + row * K + col] = convert_from_float(sum); } } template void batched_matmul(const T* A, const T* B, T* C, int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(16, 16); - dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + dim3 threadsPerBlock(num_threads(), num_threads()); + dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, (M + threadsPerBlock.y - 1) / threadsPerBlock.y, batch_size); - hipLaunchKernelGGL(batched_matmul_kernel, numBlocks, threadsPerBlock, 0, 0, + hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, A, B, C, M, N, K, batch_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); } // Specialization for at::Half template <> void batched_matmul(const at::Half* A, const at::Half* B, at::Half* C, int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(16, 16); - dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + dim3 threadsPerBlock(num_threads(), num_threads()); + dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, (M + threadsPerBlock.y - 1) / threadsPerBlock.y, batch_size); - hipLaunchKernelGGL(batched_matmul_kernel<__half>, numBlocks, threadsPerBlock, 0, 0, + hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel<__half>), numBlocks, threadsPerBlock, 0, 0, reinterpret_cast(A), reinterpret_cast(B), reinterpret_cast<__half*>(C), M, N, K, batch_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); } // Specialization for at::BFloat16 template <> void batched_matmul(const at::BFloat16* A, const at::BFloat16* B, at::BFloat16* C, int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(16, 16); - dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, + dim3 threadsPerBlock(num_threads(), num_threads()); + dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, (M + threadsPerBlock.y - 1) / threadsPerBlock.y, batch_size); - hipLaunchKernelGGL(batched_matmul_kernel, numBlocks, threadsPerBlock, 0, 0, + hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, reinterpret_cast(A), reinterpret_cast(B), reinterpret_cast(C), M, N, K, batch_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); } // Explicit instantiations for supported types template void batched_matmul(const float*, const float*, float*, int, int, int, int); template void batched_matmul(const double*, const double*, double*, int, int, int, int); - template void batched_matmul(const half*, const half*, half*, int, int, int, int); - template void batched_matmul(const hip_bfloat16*, const hip_bfloat16*, hip_bfloat16*, int, int, int, int); } // at::native \ No newline at end of file From 1aa7e9241d8c4b7190da5e0ce48bfe6d4d1e732d Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Thu, 16 Jan 2025 22:53:31 +0000 Subject: [PATCH 06/23] some ops + llama script running --- aten/src/ATen/native/native_functions.yaml | 85 +-- .../native/zoom/BinaryBitwiseOpsKernels.cu | 78 ++ .../zoom/DistributionExponentialKernel.cu | 16 + .../ATen/native/zoom/DistributionTemplates.h | 671 ++++++++++++++++++ aten/src/ATen/native/zoom/IndexKernel.cu | 463 ++++++++++++ aten/src/ATen/native/zoom/IndexKernel.h | 16 + aten/src/ATen/native/zoom/LaunchUtils.h | 18 + .../src/ATen/native/zoom/MultinomialKernel.cu | 462 ++++++++++++ .../ATen/native/zoom/ReduceAMinMaxKernel.cu | 45 ++ .../ATen/native/zoom/ReduceArgMaxKernel.cu | 46 ++ .../ATen/native/zoom/ReduceArgMinKernel.cu | 46 ++ .../src/ATen/native/zoom/ReduceLogicKernel.cu | 2 +- .../ATen/native/zoom/ReduceMaxValuesKernel.cu | 61 ++ .../ATen/native/zoom/ReduceMinValuesKernel.cu | 58 ++ aten/src/ATen/native/zoom/ReduceNormKernel.cu | 51 ++ aten/src/ATen/native/zoom/ReduceOps.cpp | 102 +++ aten/src/ATen/native/zoom/ReduceOps.h | 20 + .../ATen/native/zoom/ReduceSumProdKernel.cu | 215 ++++++ .../ATen/native/zoom/ScatterGatherKernel.cu | 573 +++++++++++++++ aten/src/ATen/native/zoom/UnaryOpsKernel.cu | 286 ++++++++ 20 files changed, 3271 insertions(+), 43 deletions(-) create mode 100644 aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu create mode 100644 aten/src/ATen/native/zoom/DistributionExponentialKernel.cu create mode 100644 aten/src/ATen/native/zoom/DistributionTemplates.h create mode 100644 aten/src/ATen/native/zoom/IndexKernel.cu create mode 100644 aten/src/ATen/native/zoom/IndexKernel.h create mode 100644 aten/src/ATen/native/zoom/LaunchUtils.h create mode 100644 aten/src/ATen/native/zoom/MultinomialKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceAMinMaxKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceArgMaxKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceArgMinKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceMaxValuesKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceMinValuesKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceNormKernel.cu create mode 100644 aten/src/ATen/native/zoom/ReduceOps.cpp create mode 100644 aten/src/ATen/native/zoom/ReduceOps.h create mode 100644 aten/src/ATen/native/zoom/ReduceSumProdKernel.cu create mode 100644 aten/src/ATen/native/zoom/ScatterGatherKernel.cu create mode 100644 aten/src/ATen/native/zoom/UnaryOpsKernel.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a5876201f7e9c6..1664a6642b4cc4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -816,7 +816,7 @@ - func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: argmax_out + CPU, CUDA, PrivateUse1: argmax_out MPS: argmax_out_mps - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor @@ -828,7 +828,7 @@ - func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: argmin_out + CPU, CUDA, PrivateUse1: argmin_out MPS: argmin_out_mps - func: acosh(Tensor self) -> Tensor @@ -1194,7 +1194,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: bitwise_not_out + CPU, CUDA, PrivateUse1: bitwise_not_out MPS: bitwise_not_out_mps tags: pointwise @@ -2564,7 +2564,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: exp_out + CPU, CUDA, PrivateUse1: exp_out MPS: exp_out_mps tags: pointwise @@ -2609,7 +2609,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: expm1_out + CPU, CUDA, PrivateUse1: expm1_out MPS: expm1_out_mps SparseCPU, SparseCUDA: expm1_sparse_out SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out @@ -3061,7 +3061,7 @@ precomputed: - indices -> DimVector sizes, DimVector strides dispatch: - CPU, CUDA, MPS: index_out + CPU, CUDA, MPS, PrivateUse1: index_out # Used by inductor to signal indexing without bounds checks # Note that we don't support boolean indexing, to avoid dynamic output shapes @@ -3076,7 +3076,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: index_copy_out + CPU, CUDA, PrivateUse1: index_copy_out - func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) variants: method @@ -3298,7 +3298,7 @@ - func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nan_to_num_out + CPU, CUDA, PrivateUse1: nan_to_num_out MPS: nan_to_num_out_mps SparseCPU, SparseCUDA: nan_to_num_sparse_out tags: pointwise @@ -3797,7 +3797,7 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: aminmax_out + CPU, CUDA, PrivateUse1: aminmax_out MPS: aminmax_out_mps - func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor @@ -3822,7 +3822,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: max_out + CPU, CUDA, PrivateUse1: max_out MPS: max_out_mps - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -4013,7 +4013,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: min_out + CPU, CUDA, PrivateUse1: min_out MPS: min_out_mps - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -5132,7 +5132,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: rsqrt_out + CPU, CUDA, PrivateUse1: rsqrt_out MPS: rsqrt_out_mps tags: pointwise @@ -5764,7 +5764,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: sum_out + CPU, CUDA, PrivateUse1: sum_out MPS: sum_out_mps - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) @@ -5778,12 +5778,12 @@ - func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method dispatch: - CPU, CUDA: nansum + CPU, CUDA, PrivateUse1: nansum MPS: nansum_mps - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nansum_out + CPU, CUDA, PrivateUse1: nansum_out MPS: nansum_out_mps - func: sum_to_size(Tensor self, SymInt[] size) -> Tensor @@ -5816,7 +5816,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sqrt_out + CPU, CUDA, PrivateUse1: sqrt_out MPS: sqrt_out_mps SparseCPU, SparseCUDA: sqrt_sparse_out SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_out @@ -5911,7 +5911,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: prod + CPU, CUDA, PrivateUse1: prod MPS: prod_mps autogen: prod.out tags: core @@ -5926,7 +5926,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: prod_out + CPU, CUDA, PrivateUse1: prod_out MPS: prod_out_mps - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor @@ -6107,7 +6107,7 @@ - func: flip(Tensor self, int[] dims) -> Tensor variants: function, method dispatch: - CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip + CPU, QuantizedCPU, CUDA, QuantizedCUDA, PrivateUse1: flip MPS: flip_mps autogen: flip.out tags: core @@ -6770,7 +6770,7 @@ - func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) dispatch: - CPU, CUDA: frexp_out + CPU, CUDA, PrivateUse1: frexp_out tags: pointwise # Deprecated (v.1.12) @@ -8048,7 +8048,7 @@ - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: - CPU, CUDA: put_ + CPU, CUDA, PrivateUse1: put_ autogen: put.out - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor @@ -8102,6 +8102,7 @@ dispatch: CPU: index_fill_ CUDA: index_fill_ + PrivateUse1: index_fill_ MPS: index_fill_mps_ autogen: index_fill.int_Scalar_out @@ -8115,7 +8116,7 @@ device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: index_fill_ + CPU, CUDA, PrivateUse1: index_fill_ MPS: index_fill_mps_ autogen: index_fill.int_Tensor_out @@ -8154,7 +8155,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_src_out + CPU, CUDA, PrivateUse1: scatter_src_out MPS: scatter_src_out_mps - func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor @@ -8170,7 +8171,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_value_out + CPU, CUDA, PrivateUse1: scatter_value_out MPS: scatter_value_out_mps - func: scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor @@ -8185,7 +8186,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_reduce_out + CPU, CUDA, PrivateUse1: scatter_reduce_out MPS: scatter_reduce_out_mps - func: scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor @@ -8200,7 +8201,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_value_reduce_out + CPU, CUDA, PrivateUse1: scatter_value_reduce_out MPS: scatter_value_reduce_out_mps - func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor @@ -8222,7 +8223,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_add + CPU, CUDA, PrivateUse1: scatter_add MPS: scatter_add_mps_out - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor @@ -8241,7 +8242,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_reduce_two + CPU, CUDA, PrivateUse1: scatter_reduce_two - func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: eq.Scalar_out @@ -8259,7 +8260,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_and_out + CPU, CUDA, PrivateUse1: bitwise_and_out MPS: bitwise_and_out_mps tags: pointwise @@ -8326,7 +8327,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_or_out + CPU, CUDA, PrivateUse1: bitwise_or_out MPS: bitwise_or_out_mps tags: pointwise @@ -8393,7 +8394,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_xor_out + CPU, CUDA, PrivateUse1: bitwise_xor_out MPS: bitwise_xor_out_mps tags: pointwise @@ -8718,7 +8719,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: exponential_ + CPU, CUDA, PrivateUse1: exponential_ MPS: exponential_mps_ autogen: exponential, exponential.out @@ -9150,12 +9151,12 @@ - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: take_out + CPU, CUDA, PrivateUse1: take_out - func: take(Tensor self, Tensor index) -> Tensor variants: method, function dispatch: - CPU, CUDA: take + CPU, CUDA, PrivateUse1: take - func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) @@ -9248,7 +9249,7 @@ - func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: gather_out + CPU, CUDA, PrivateUse1: gather_out MPS: gather_out_mps - func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor @@ -9464,13 +9465,13 @@ - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) tags: nondeterministic_seeded dispatch: - CPU, CUDA: multinomial_out + CPU, CUDA, PrivateUse1: multinomial_out MPS: multinomial_out_mps - func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor variants: method, function dispatch: - CPU, CUDA: multinomial + CPU, CUDA, PrivateUse1: multinomial MPS: multinomial_mps tags: nondeterministic_seeded @@ -9905,14 +9906,14 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: min + CPU, CUDA, PrivateUse1: min MPS: min_mps QuantizedCPU: min_quantized_cpu - func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: min_unary_out + CPU, CUDA, PrivateUse1: min_unary_out QuantizedCPU: min_quantized_unary_out - func: fmin(Tensor self, Tensor other) -> Tensor @@ -9933,7 +9934,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: max + CPU, CUDA, PrivateUse1: max MPS: max_mps QuantizedCPU: max_quantized_cpu @@ -9980,7 +9981,7 @@ - func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: max_unary_out + CPU, CUDA, PrivateUse1: max_unary_out QuantizedCPU: max_quantized_unary_out - func: minimum(Tensor self, Tensor other) -> Tensor @@ -14092,7 +14093,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA: linalg_vector_norm_out + CPU, CUDA, PrivateUse1: linalg_vector_norm_out MPS: linalg_vector_norm_out_mps - func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor diff --git a/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu new file mode 100644 index 00000000000000..fbd3657a48b6fd --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu @@ -0,0 +1,78 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct BitwiseAndFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a & b; + } +}; + +template<> +struct BitwiseAndFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a && b; + } +}; + +void bitwise_and_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_zoom", [&]() { + BitwiseAndFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +template +struct BitwiseOrFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a | b; + } +}; + +template<> +struct BitwiseOrFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a || b; + } +}; + +void bitwise_or_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_zoom", [&]() { + BitwiseOrFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +template +struct BitwiseXorFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a ^ b; + } +}; + +template<> +struct BitwiseXorFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a != b; + } +}; + +void bitwise_xor_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_zoom", [&]() { + BitwiseXorFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_or_stub, &bitwise_or_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_zoom); + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu b/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu new file mode 100644 index 00000000000000..2dd9cece286995 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu @@ -0,0 +1,16 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + + +namespace at::native { + +void exponential_kernel(TensorIteratorBase& iter, double lambda, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::exponential_kernel(iter, lambda, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(exponential_stub, &exponential_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionTemplates.h b/aten/src/ATen/native/zoom/DistributionTemplates.h new file mode 100644 index 00000000000000..24981a26aa817b --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionTemplates.h @@ -0,0 +1,671 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +// launch bounds used for kernels utilizing TensorIterator +const uint32_t block_size_bound = 256; +const uint32_t grid_size_bound = 4; +// number of randoms given by distributions like hiprand_uniform4, hiprand_uniform2_double +// used in calculating philox offset. +const uint32_t hiprand4_engine_calls = 4; + +// utility function that calculates proper philox_offset +// for distributions utilizing TensorIterator. For distributions using +// TensorIterator, we are using a grid-stride loop with each +// thread yielding one element per thread. For the edge of the grid-stride +// loop, if the tensor size is large, the unroll loop will kick in and the float4 +// from hiprand4 will start getting utilized (for common tensor sizes, we end up +// using rand.x from each thread). Hence, the philox_offset is +// (number of elements per thread * number of engine calls), which makes +// sure that philox offset increment is not less than the number of randoms used +// in each thread. +std::tuple calc_execution_policy(int64_t total_elements) { + const uint64_t numel = static_cast(total_elements); + const uint32_t block_size = block_size_bound; + const uint32_t unroll = hiprand4_engine_calls; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + uint32_t blocks_per_sm = at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min( + static_cast(at::zoom::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm, + grid.x); + //number of times random will be generated per thread, to offset philox counter in thc random state + uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1) + * hiprand4_engine_calls; + return std::make_tuple(counter_offset, grid, dim_block); +} + +// grid stride loop kernel for distributions +template +C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) +__global__ void distribution_elementwise_grid_stride_kernel(int numel, + PhiloxHIPState philox_args, + const dist_t dist_func, + const transform_t transform_func) { + auto seeds = at::zoom::philox::unpack(philox_args); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; + for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + auto rand = dist_func(&state); + #pragma unroll + for (int ii = 0; ii < unroll_factor; ii++) { + int li = linear_index + blockDim.x * gridDim.x * ii; + if (li < numel) { + transform_func(li, static_cast((&rand.x)[ii])); + } + } + __syncthreads(); + } +} + +/** + * distribution_nullary_kernel is analogous to gpu_kernel in + * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses + * TensorIterator to launch a kernel. However, the differences are + * - it launches a grid-stride loop based kernel. The kernel is not + * generic like elementwise_kernel in Loops.cuh and is specialized + * for the distribution kernels here. + * - For big size tensors, we can launch multiple kernels recursively + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox + * offset calculation is done in this function. + * + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh + * to have grid-stride loop kernel and then use that to launch our distribution + * kernels? Note that we need a grid-stride loop kernel because, we found by testing + * that it achieves peak effective bandwidth. + */ +template +void distribution_nullary_kernel(at::TensorIteratorBase& iter, + RNG gen, + const dist_t& dist_func, + const transform_t transform_func) { + static_assert(unroll_factor >= 1, "unroll_factor must be >= 1."); + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto execution_policy = calc_execution_policy(numel); + auto counter_offset = std::get<0>(execution_policy); + auto grid = std::get<1>(execution_policy); + auto block = std::get<2>(execution_policy); + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(counter_offset); + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_nullary_kernel(sub_iter, + gen, dist_func, transform_func); + } + return; + } + + char* out_data = (char*)iter.data_ptr(0); + + auto stream = c10::zoom::getCurrentZoomStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); + int stride0 = strides[0]; + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); + } + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + auto offset_calc = make_offset_calculator<1>(iter); + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); + } + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +// Binary kernel +template +__global__ void distribution_binary_elementwise_kernel( + int numel, + func_t f, + PhiloxHIPState philox_args, + typename function_traits::result_type *output_data, + const typename function_traits::template arg<1>::type *input_data_1, + const typename function_traits::template arg<2>::type *input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) { + auto seeds = at::zoom::philox::unpack(philox_args); + + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + + input_t_1 inputs_1[thread_work_size()]; + input_t_2 inputs_2[thread_work_size()]; + + int base_index = block_work_size() * blockIdx.x; + int remaining = std::min(numel - base_index, block_work_size()); + + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // load data into registers + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = inp_calc.get(input_idx); + inputs_1[i] = input_data_1[offsets[0]]; + inputs_2[i] = input_data_2[offsets[1]]; + + thread_idx += num_threads(); + } + + // compute and store + thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = out_calc.get(input_idx); + output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]); + thread_idx += num_threads(); + } +} + +template +void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxHIPState philox_args, const func_t &f) { + static_assert(std::is_same::template arg<0>::type, hiprandStatePhilox4_32_10_t&>::value, "the first argument of functor must be hiprandStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + output_t *output_data = static_cast(iter.data_ptr(0)); + const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1)); + const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2)); + + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + auto stream = c10::zoom::getCurrentZoomStream(); + + if (iter.is_contiguous()) { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace +}} // namespace at::native + + +namespace at { +namespace native { +namespace templates { +namespace zoom { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_zoom", AT_WRAP([&] { + if (( + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { + return hiprand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_zoom", [&] { + if (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int_full_range(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_zoom handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_zoom", [&] { + if (std::is_same::value || std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [] __device__ (uint32_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { + return hiprand4(state); + }, + random_func); + } + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, RNG gen) { + random_kernel(iter, gen); + } +}; + +// ==================================================================================================================== + +template +void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_uniform2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_uniform4(state); }, + transform); + } +} + +template +void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_normal2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_normal4(state); }, + transform); + } +} + +// ==================================================== Normal ======================================================== + +template +void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_zoom", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda to multiply std and add mean + auto normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::normal(rand, mean, std)); + }; + normal_and_transform(iter, gen, normal_func); + }); +} + +template +struct NormalKernel { + void operator()(const TensorBase &self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================== + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_zoom", [&] { + auto from = static_cast(from_); + auto to = static_cast(to_); + using opmath_t = at::opmath_type; + auto range = static_cast(to-from); + // define lambda to reverse bounds, multiply 'range' and add 'from_' + auto uniform_func = [range, from, to] __device__ (opmath_t rand) { + // Compute output value before reversing the bounds + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947 + auto value = static_cast(rand * range + from); + // reverse the bounds of hiprand4 from (0, 1] to [0, 1) + // Note that this method is from legacy THCTensorRandom and is likely to give + // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and + // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s. + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + auto reverse_bound_value = value == to ? from : value; + return reverse_bound_value; + }; + uniform_and_transform(iter, gen, uniform_func); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_zoom", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda for log_normal transformation + auto log_normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::log_normal(transformation::normal(rand, mean, std))); + }; + normal_and_transform(iter, gen, log_normal_func); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_zoom", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for geometric transformation + auto geometric_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::geometric(rand, p)); + }; + uniform_and_transform(iter, gen, geometric_func); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_zoom", [&] { + using accscalar_t = at::acc_type; + auto lambda = static_cast(lambda_); + // define lambda for exponential transformation + auto exponential_func = [lambda] __device__ (accscalar_t rand) { + return static_cast(transformation::exponential(rand, lambda)); + }; + uniform_and_transform(iter, gen, exponential_func); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_zoom", [&] { + using accscalar_t = at::acc_type; + auto median = static_cast(median_); + auto sigma = static_cast(sigma_); + // define lambda for cauchy transformation + auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) { + return static_cast(transformation::cauchy(rand, median, sigma)); + }; + uniform_and_transform(iter, gen, cauchy_func); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ==================================================== Bernoulli ===================================================== + +template +void bernoulli_tensor_zoom_kernel( + const TensorBase &ret, const at::TensorBase &p, + PhiloxHIPState philox_args) { + auto functor = [philox_args] __device__( + int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, + const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::zoom::philox::unpack(philox_args); + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // See Note [Register spilling in curand call for CUDA < 10] + float4 rand = hiprand_uniform4(&state); + switch (n) { + case 4: { + ZOOM_KERNEL_ASSERT(0 <= p4 && p4 <= 1); + v4 = static_cast(rand.w <= p4); + // fallthrough + } + case 3: { + ZOOM_KERNEL_ASSERT(0 <= p3 && p3 <= 1); + v3 = static_cast(rand.z <= p3); + // fallthrough + } + case 2: { + ZOOM_KERNEL_ASSERT(0 <= p2 && p2 <= 1); + v2 = static_cast(rand.y <= p2); + // fallthrough + } + case 1: { + ZOOM_KERNEL_ASSERT(0 <= p1 && p1 <= 1); + v1 = static_cast(rand.x <= p1); + } + } + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::zoom::Zoom_tensor_apply2(ret, p, functor); +} + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) { + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(10); + } + TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type()); + // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else + const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat; + auto p_zoom = p_.to(TensorOptions().device(self.device()).dtype(p_type)); + auto p = expand_inplace(self, p_zoom); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_zoom_self_", [&] { + if (std::is_same::value) { + return bernoulli_tensor_zoom_kernel(self, *p, rng_engine_inputs); + } else { + return bernoulli_tensor_zoom_kernel(self, *p, rng_engine_inputs); + } + }); +} + +template +void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_zoom_", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for bernoulli transformation + auto bernoulli_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::bernoulli(rand, p)); + }; + uniform_and_transform(iter, gen, bernoulli_func); + }); +} + +template +struct BernoulliKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + bernoulli_kernel(iter, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/IndexKernel.cu b/aten/src/ATen/native/zoom/IndexKernel.cu new file mode 100644 index 00000000000000..3df2d1bb120407 --- /dev/null +++ b/aten/src/ATen/native/zoom/IndexKernel.cu @@ -0,0 +1,463 @@ +#include +// #define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +static constexpr int launch_bound2 = 4; + +static constexpr int launch_size_nd = 128; + +template +C10_LAUNCH_BOUNDS_2(nt, launch_bound2) +__global__ void index_elementwise_kernel(const int64_t N, const func_t f) { + const auto tid = threadIdx.x; + const auto nv = nt * vt; + auto idx = nv * blockIdx.x + tid; + #pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_kernel(const int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + const dim3 block(nt); + const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + const auto stream = c10::zoom::getCurrentZoomStream(); + hipLaunchKernelGGL(( index_elementwise_kernel), dim3(grid), dim3(block), 0, stream, N, f); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f) { + const auto num_indices = index_size.size(); + AT_ASSERT(num_indices == index_stride.size()); + AT_ASSERT(static_cast(num_indices) == iter.ntensors() - 2); + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_index_kernel(sub_iter, index_size, index_stride, f); + } + return; + } + + auto sizes = at::detail::Array(0); + auto strides = at::detail::Array(0); + auto index_ptrs = at::detail::Array(nullptr); + for (unsigned i = 0; i < num_indices; i++) { + sizes[i] = index_size[i]; + strides[i] = index_stride[i]; + index_ptrs[i] = (char*)iter.data_ptr(i + 2); + } + + char* const out_ptr = static_cast(iter.data_ptr(0)); + char* const in_ptr = static_cast(iter.data_ptr(1)); + + auto offset_calc = make_offset_calculator<3>(iter); + launch_kernel(iter.numel(), [=]__device__(int idx) { + const auto offsets = offset_calc.get(idx); + char* const out_data = out_ptr + offsets[0]; + const char* const in_data = in_ptr + offsets[1]; + + int64_t offset = 0; + #pragma unroll + for (int i = 0; i < num_indices; i++) { + int64_t index = *reinterpret_cast(index_ptrs[i] + offsets[2]); + ZOOM_KERNEL_ASSERT(-sizes[i] <= index && index < sizes[i] && "index out of bounds"); + if (index < 0) { + index += sizes[i]; + } + offset += index * strides[i]; + } + + f(out_data, in_data, offset); + }); +} + +// The kernels are templated on an opaque, self-aligned type of the correct +// size to avoid redundant kernels for different types of the same size. +template struct alignas(N) OpaqueType { char data[N]; }; + +template +void index_fill_kernel_impl( + TensorIterator& iter, + const int64_t dim, + const int64_t self_dim_size, + const int64_t self_dim_stride, + const scalar_t fill_val) { + if (0 == iter.numel()) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + index_fill_kernel_impl(sub_iter, dim, self_dim_size, self_dim_stride, fill_val); + } + return; + } + + char* const __restrict__ self_ptr = reinterpret_cast(iter.data_ptr(0)); + char* const __restrict__ idx_ptr = reinterpret_cast(iter.data_ptr(1)); + + const auto offset_calc = make_offset_calculator<2>(iter); + + const auto loop = [=]C10_DEVICE(int i) { + const auto offsets = offset_calc.get(i); + + auto* __restrict__ self_data = reinterpret_cast(self_ptr + offsets[0]); + auto idx = *reinterpret_cast(idx_ptr + offsets[1]); + ZOOM_KERNEL_ASSERT(idx >= -self_dim_size && idx < self_dim_size && "index out of bounds"); + if (idx < 0) { + idx += self_dim_size; + } + + self_data[idx * self_dim_stride] = fill_val; + }; + launch_kernel(iter.numel(), loop); +} + +template +void index_copy_kernel_impl( + TensorIterator& iter, + const int64_t dim, + const int64_t self_dim_size, + const int64_t self_dim_stride) { + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + index_copy_kernel_impl(sub_iter, dim, self_dim_size, self_dim_stride); + } + return; + } + + char* const __restrict__ self_ptr = reinterpret_cast(iter.data_ptr(0)); + char* const __restrict__ idx_ptr = reinterpret_cast(iter.data_ptr(1)); + char* const __restrict__ source_ptr = reinterpret_cast(iter.data_ptr(2)); + + const auto offset_calc = make_offset_calculator<3>(iter); + + const auto loop = [=]C10_DEVICE(int i) { + const auto offsets = offset_calc.get(i); + + auto* const __restrict__ self_data = reinterpret_cast(self_ptr + offsets[0]); + auto idx = *reinterpret_cast(idx_ptr + offsets[1]); + const auto* const __restrict__ source_data = reinterpret_cast(source_ptr + offsets[2]); + ZOOM_KERNEL_ASSERT(idx >= 0 && idx < self_dim_size && "index_copy_(): index out of bounds"); + + self_data[idx * self_dim_stride] = *source_data; + }; + launch_kernel(iter.numel(), loop); +} + +template +void index_kernel_impl(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) { + gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) { + *reinterpret_cast(out_data) = *reinterpret_cast(in_data + offset); + }); +} + +template +void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride) { + gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) { + *reinterpret_cast(out_data + offset) = *reinterpret_cast(in_data); + }); +} + +static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_zoom", [&] { + using dtype = OpaqueType; + index_kernel_impl(iter, index_size, index_stride); + }); +} + +static void index_fill_kernel( + TensorIterator& iter, + const int64_t dim, + const int64_t self_dim_size, + const int64_t self_dim_stride, + const Scalar& source) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf, + iter.dtype(), "index_fill_zoom", [&] { + using dtype = OpaqueType; + const auto fill_val = source.to(); + const auto fill_val_opaque = *reinterpret_cast(&fill_val); + index_fill_kernel_impl(iter, dim, self_dim_size, self_dim_stride, fill_val_opaque); + }); +} + +static void index_copy_kernel( + TensorIterator& iter, + const int64_t dim, + const int64_t self_dim_size, + const int64_t self_dim_stride) { + // See note [Writing Nondeterministic Operations] + // Nondeterministic when index contains duplicate entries + // this kernel will not be called when torch.use_deterministic_algorithms(True) + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf, + iter.dtype(), "index_copy_zoom", [&] { + using dtype = OpaqueType; + index_copy_kernel_impl(iter, dim, self_dim_size, self_dim_stride); + }); +} + +static void index_put_kernel(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate) { + TORCH_CHECK(!accumulate, "index_put does not support accumulate=true"); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_put", [&] { + using dtype = OpaqueType; + index_put_kernel_impl(iter, index_size, index_stride); + }); +} + +void index_put_kernel_quantized_zoom(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate, const double scale, const int zero_point) { + TORCH_CHECK(!accumulate, "index_put does not support accumulate=true"); + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "index_put", [&] { + constexpr int64_t qmin = std::numeric_limits::min(); + constexpr int64_t qmax = std::numeric_limits::max(); + const float inv_scale = 1.0f / static_cast(scale); + + gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) { + int64_t qvalue = static_cast(zero_point + nearbyintf(*(float*)in_data * inv_scale)); + qvalue = std::clamp(qvalue, qmin, qmax); + *(scalar_t*)(out_data + offset) = static_cast(qvalue); + }); + }); +} + +template +void zoom_take_put_kernel( + TensorIterator& iter, + const TensorBase& indexed, + const func_t& f) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + zoom_take_put_kernel(sub_iter, indexed, f); + } + return; + } + + const auto numel = indexed.numel(); + const bool is_contiguous = indexed.is_contiguous(); + + char* const __restrict__ iterated_ptr = reinterpret_cast(iter.data_ptr(0)); + char* const __restrict__ idx_ptr = reinterpret_cast(iter.data_ptr(1)); + + const auto offset_calc = make_offset_calculator<2>(iter); + using uindex_t = std::make_unsigned_t; + + // OffsetCalculator needs the sizes and strides reveresed + const auto indexed_sizes = std::vector(indexed.sizes().rbegin(), indexed.sizes().rend()); + const auto indexed_strides = std::vector(indexed.strides().rbegin(), indexed.strides().rend()); + const auto* indexed_strides_data = indexed_strides.data(); + const auto offset_indexed = OffsetCalculator<1, uindex_t>(indexed.dim(), + indexed_sizes.data(), + &indexed_strides_data); + + const auto loop = [=]C10_DEVICE(int i) { + const auto offsets = offset_calc.get(i); + + auto& iterated = *reinterpret_cast(iterated_ptr + offsets[0]); + const auto idx = *reinterpret_cast(idx_ptr + offsets[1]); + ZOOM_KERNEL_ASSERT(idx < numel && idx >= -numel && "zoom_take_put_kernel() index out of bounds"); + index_t offset = static_cast(idx); + if (offset < 0) { + offset += numel; + } + if (!is_contiguous) { + offset = offset_indexed.get(offset)[0]; + } + + f(iterated, offset); + }; + launch_kernel(iter.numel(), loop); +} + +void put_kernel(TensorIterator& iter, const TensorBase& output, const bool accumulate) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "put_zoom", [&] { + // Cannot use `OpaqueType`, as we need the actual type for `fastSpecializedgpuAtomicAdd` + AT_DISPATCH_INDEX_TYPES(zoom::detail::canUse32BitIndexMath(output) ? ScalarType::Int : ScalarType::Long, + "put_zoom_index", [&] { + auto* __restrict__ indexed_ptr = output.template data_ptr(); + if (accumulate) { + index_t numel = output.numel(); + zoom_take_put_kernel(iter, output, + [numel, indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + fastSpecializedAtomicAdd(indexed_ptr, offset, numel, iterated); + }); + } + else { + zoom_take_put_kernel(iter, output, + [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + indexed_ptr[offset] = iterated; + }); + } + }); + }); +} + +void take_kernel( + TensorIterator& iter, + const TensorBase& input) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "take_zoom", [&] { + // Cannot use `OpaqueType`, as Tensor::data_ptr> is not implemented + AT_DISPATCH_INDEX_TYPES(zoom::detail::canUse32BitIndexMath(input) ? ScalarType::Int : ScalarType::Long, + "take_zoom_index", [&] { + const auto* __restrict__ indexed_ptr = input.template const_data_ptr(); + zoom_take_put_kernel(iter, input, + [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + iterated = indexed_ptr[offset]; + }); + }); + }); +} + +namespace { + +__global__ void masked_scatter_size_check( + const int64_t* const mask_exclusive_sum, + const bool* const mask, + const int64_t srcSize) { + // Convert exclusive sum to inclusive sum + const auto totalElements = *mask_exclusive_sum + *mask; + ZOOM_KERNEL_ASSERT(totalElements <= srcSize); +} + +} // anonymous namespace + +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source) { + const auto srcSize = source.numel(); + const auto mask_cont = mask.contiguous(); + const auto mask_numel = mask.numel(); + + // Use a prefix sum to determine the output locations of the masked elements + auto maskPrefixSum_data = maskPrefixSum.mutable_data_ptr(); + auto mask_data = mask_cont.const_data_ptr(); + + at::zoom::hipcub::mask_exclusive_sum( + mask_data, maskPrefixSum_data, mask_numel); + + // Asynchronously check that the number of `1` elements present in the mask + // must be <= the number of elements available in `src`. + hipLaunchKernelGGL(( masked_scatter_size_check), dim3(1), dim3(1), 0, c10::zoom::getCurrentZoomStream(), + &maskPrefixSum_data[mask_numel - 1], &mask_data[mask_numel - 1], srcSize); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + // We are getting elements from `src` based on an offset from + // `maskPrefixSum`, so that should be made contiguous too + auto source_contig = source.contiguous(); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self) + .add_input(self) + .add_const_input(mask_cont) + .add_input(maskPrefixSum) + .build(); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + ScalarType::Bool, + ScalarType::BFloat16, + ScalarType::Half, + self.scalar_type(), + "masked_scatter_", + [&]() { + auto source_ptr = source_contig.const_data_ptr(); + gpu_kernel( + iter, [=] GPU_LAMBDA(const scalar_t a, const bool mask, const int64_t maskPrefixSum) -> scalar_t { + if (mask) { + return source_ptr[maskPrefixSum]; + } + return a; + }); + C10_ZOOM_CHECK(hipGetLastError()); + }); +} + +template +void flip_kernel_impl(TensorIterator& iter) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + flip_kernel_impl(sub_iter); + } + return; + } + + char* const __restrict__ out_ptr = reinterpret_cast(iter.data_ptr(0)); + const char* const __restrict__ in_ptr = reinterpret_cast(iter.data_ptr(1)); + + const auto offset_calc = make_offset_calculator<2, /*signed_strides=*/true>(iter); + + const auto loop = [=]C10_DEVICE(const int i) { + const auto offsets = offset_calc.get(i); + // offsets can be negative here, but it's fine + scalar_t* const __restrict__ out_data = reinterpret_cast(out_ptr + offsets[0]); + const scalar_t* const __restrict__ in_data = reinterpret_cast(in_ptr + offsets[1]); + *out_data = *in_data; + }; + launch_kernel(iter.numel(), loop); +} + +void flip_kernel(TensorIterator& iter, const bool quantized) { + if (quantized) { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_zoom", + [&] { + using dtype = OpaqueType; + flip_kernel_impl(iter); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + iter.dtype(), "flip_zoom", + [&] { + using dtype = OpaqueType; + flip_kernel_impl(iter); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(index_stub, &index_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(index_fill_stub, &index_fill_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(index_copy_stub, &index_copy_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(index_put_stub, &index_put_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(put_stub, &put_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(take_stub, &take_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(flip_stub, &flip_kernel); + +REGISTER_PRIVATEUSE1_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_zoom); + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/IndexKernel.h b/aten/src/ATen/native/zoom/IndexKernel.h new file mode 100644 index 00000000000000..edd9190deb0dba --- /dev/null +++ b/aten/src/ATen/native/zoom/IndexKernel.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { +namespace native { +/// @param maskPrefixSum[in,out] +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source); +}} diff --git a/aten/src/ATen/native/zoom/LaunchUtils.h b/aten/src/ATen/native/zoom/LaunchUtils.h new file mode 100644 index 00000000000000..4d2f35a56a5837 --- /dev/null +++ b/aten/src/ATen/native/zoom/LaunchUtils.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +namespace native { + +// returns 2**floor(log2(n)) +static int lastPow2(unsigned int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/MultinomialKernel.cu b/aten/src/ATen/native/zoom/MultinomialKernel.cu new file mode 100644 index 00000000000000..ca9709637cf030 --- /dev/null +++ b/aten/src/ATen/native/zoom/MultinomialKernel.cu @@ -0,0 +1,462 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include +#include +#include + +namespace at::native { + +namespace { + +template < + typename T, + typename = std::enable_if_t< + std::is_floating_point_v || std::is_convertible_v>> +inline __device__ bool _isinf(T x) { + if constexpr (std::is_floating_point_v) { + return ::isinf(x); + } else { + return ::isinf(static_cast(x)); + } +} + +#define MAX_NUM_BLOCKS 200 + +// Normalizes the L1 norm of every row to 1; used by multinomial +template +C10_LAUNCH_BOUNDS_1(zoom::detail::HIP_NUM_THREADS) +__global__ void renormRowsL1(scalar_t* dist, long rows, long cols) { + extern __shared__ unsigned char my_smem[]; + scalar_t *smem = reinterpret_cast(my_smem); + scalar_t zero = static_cast(0); + scalar_t val; + for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { + scalar_t sum = static_cast(0); + for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) { + val = dist[row * cols + col]; + ZOOM_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling + sum = sum + val; + } + + sum = zoom_utils::BlockReduceSum(sum, smem); + if (threadIdx.x == 0) { + ZOOM_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling + smem[0] = sum; + } + __syncthreads(); + + sum = smem[0]; + if (sum > zero) { + for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) { + dist[row * cols + col] = dist[row * cols + col] / sum; + } + } + } +} + +void renormRows(Tensor& t) { + TORCH_CHECK(t.dim() == 2); + int64_t rows = t.size(0); + int64_t cols = t.size(1); + + auto props = at::zoom::getCurrentDeviceProperties(); + TORCH_CHECK(props != nullptr); + int numSM = props->multiProcessorCount; + const int64_t maxThreads = std::min( + props->maxThreadsPerBlock, zoom_utils::kHIPBlockReduceMaxThreads); + + int warp_size = at::zoom::warp_size(); + dim3 grid(rows < numSM * 4 ? rows : numSM * 4); + dim3 block(std::min(maxThreads, warp_size * ceil_div(cols, int64_t{warp_size}))); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, t.scalar_type(), "renormRows_zoom", [&] { + renormRowsL1 + <<>>(t.mutable_data_ptr(), + rows, cols); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +template +__device__ int binarySearchForMultinomial(const scalar_t* cumdist, + const scalar_t* dist, + int size, + scalar_t val) { + int start = 0; + int end = size; + // cumdist[size - 1] = 0 => all zero prob dist + ZOOM_KERNEL_ASSERT(cumdist[size - 1] > static_cast(0)); + + while (end - start > 0) { + int mid = start + (end - start) / 2; + + scalar_t midVal = cumdist[mid]; + if (midVal < val) { + start = mid + 1; + } else { + end = mid; + } + } + + if (start == size) { + // No probability mass or precision problems; just return the + // first non-zero element by setting start to size-1 here, + // the code below will move it to the last non-zero probability + // this actually can happen when the random number is 1 + // (github pytorch issue #4858). + start = size - 1; + } + + while(start >= 1 && dist[start] == 0) start--; + + return start; +} + +template +__global__ void +sampleMultinomialWithReplacement(PhiloxHIPState philox_args, + int totalSamples, + int64_t* dest, + int64_t distributions, + int categories, + const scalar_t* normDistPrefixSum, + const scalar_t* normDist) { + // At the moment, each warp computes one sample value in the binary + // search due to divergence. It seems possible to compute multiple + // values and limit divergence though later on. + + auto seeds = at::zoom::philox::unpack(philox_args); + + // global index formula for 2D grid of 1D blocks + int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x; + + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + // The block determines the distribution for which we generate a point + for (int64_t curDist = blockIdx.y; + curDist < distributions; + curDist += gridDim.y) { + for (int sample = blockIdx.x*blockDim.x + threadIdx.x; + sample < totalSamples; sample += blockDim.x*gridDim.x) { + + //we are losing 3 out of 4 generated numbers but it's ok + //this kernel is not very efficient anyway + auto rand = hiprand_uniform4(&state); + scalar_t r = static_cast(rand.x); + + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial( + normDistPrefixSum + curDist * categories, + normDist + curDist * categories, + categories, + r); + + dest[curDist * totalSamples + sample] = choice; + + } + } +} + +template +C10_LAUNCH_BOUNDS_1(zoom::detail::HIP_NUM_THREADS) +__global__ void sampleMultinomialOnce( + int64_t* dest, + int64_t distributions, + int categories, + const scalar_t* sampled, + const scalar_t* dist, + int stride_dist, // dist->stride(0) + int stride_categories // dist->stride(1) +) { + extern __shared__ unsigned char my_smem[]; + __shared__ bool found; + __shared__ unsigned foundPos; + + accscalar_t *smem = reinterpret_cast(my_smem); + + accscalar_t accZero = static_cast(0); + scalar_t zero = static_cast(0); + + for (int64_t curDist = blockIdx.x; + curDist < distributions; curDist += gridDim.x) { + // Each block handles one distribution + // First pass, find the total sum of the distribution + accscalar_t sum = accZero; + scalar_t val; + for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { + val = dist[curDist * stride_dist + cat * stride_categories]; + ZOOM_KERNEL_ASSERT(!at::_isnan(val)); + ZOOM_KERNEL_ASSERT(!_isinf(val)); + ZOOM_KERNEL_ASSERT(!(val < zero)); + sum = sum + static_cast(val); + } + + // threadIdx.x == 0 has the sum value from this + sum = zoom_utils::BlockReduceSum(sum, smem); + + // Broadcast sum and sample value + if (threadIdx.x == 0) { + // Make sure the sum of our distribution didn't overflow + ZOOM_KERNEL_ASSERT(!_isinf(val)); + ZOOM_KERNEL_ASSERT(sum > accZero); + + foundPos = 0; + smem[0] = sum; + smem[1] = sampled[curDist]; + } + __syncthreads(); + + sum = smem[0]; + scalar_t sample = static_cast(smem[1]); + __syncthreads(); + + if (sum == accZero) { + // Choose the first element + if (threadIdx.x == 0) { + dest[curDist] = 0; + } + + continue; + } + + int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; + accscalar_t prevHighProb = accZero; + found = false; + + for (int chunk = 0; chunk < chunks && !found; ++chunk) { + // All threads in bounds load a value + int cat = chunk * blockDim.x + threadIdx.x; + + accscalar_t dist_val = cat < categories ? + static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : + accZero; + + smem[threadIdx.x] = dist_val; + __syncthreads(); + + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < blockDim.x; offset *= 2) { + accscalar_t val = accZero; + + if (threadIdx.x >= offset) { + val = smem[threadIdx.x - offset] + smem[threadIdx.x]; + } + + __syncthreads(); + if (threadIdx.x >= offset) { + smem[threadIdx.x] = val; + } + __syncthreads(); + } + + // Each thread will check to see if the sample falls in its + // bucket + scalar_t curBucket = + static_cast(smem[threadIdx.x] + prevHighProb); + scalar_t prevBucket = static_cast( + threadIdx.x == 0 ? prevHighProb + : smem[threadIdx.x - 1] + prevHighProb); + bool inBucket = + (cat < categories) && + (!(sample >= curBucket) && + (sample >= prevBucket) && + (dist_val > zero)); + + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + atomicMax(&foundPos, cat); + found = true; + } + + // Store the previous scan's high value for future use + prevHighProb = prevHighProb + smem[blockDim.x - 1]; + + __syncthreads(); + } + + if (threadIdx.x == 0) { + if (found) { + dest[curDist] = foundPos; + } else { + // This should address a rare bug where we don't select a valid index. This likely occurs when + // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but + // and our uniform sample is greater than this value. In this case we likely have unitialized memory + // in dest[curDist]. So basically we will loop through the distribution and pick the largest index + // where the distribution is non-zero. This is obviously terribly inefficient, but due to the + // rarity in which this occurs, this should not be an issue. + for (int cat = categories - 1; cat >= 0; --cat) { + if (dist[curDist * stride_dist + cat * stride_categories] > zero) { + dest[curDist] = cat; + break; + } + } + } + } + } +} + +void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + std::optional generator) { + auto gen = get_generator_or_default(generator, zoom::detail::getDefaultZoomGenerator()); + + int inputSize = self.dim(); + int64_t numDist = + inputSize == 1 ? 1 : self.size(0); + int numCategories = + inputSize == 1 ? self.size(0) : self.size(1); + + // Restructure data for 2d + auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self; + + result.resize_({numDist, n_sample}); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self_v.scalar_type(), "multinomial_kernel_zoom", [&] { + using accscalar_t = at::acc_type; + auto props = at::zoom::getCurrentDeviceProperties(); + TORCH_CHECK(props != nullptr); + int numSM = props->multiProcessorCount; + int maxThreads = props->maxThreadsPerBlock; + int maxShared = props->sharedMemPerBlock; + + int warp_size = at::zoom::warp_size(); + int requiredWarps = at::ceil_div(numCategories, warp_size); + int requiredThreads = std::min(maxThreads, requiredWarps * warp_size); + int requiredShared = requiredThreads * sizeof(accscalar_t); + + if (n_sample == 1 && maxShared >= requiredShared) { + // Optimized allocation-free implementation + // To exploit greater parallelism for the sampling, generate the + // Uniform random samples in a separate kernel launch, into + // temporarily allocated memory. The device RNG is thread-limited + Tensor sampled = at::detail::empty_zoom({numDist, n_sample}, self_v.options()); + at::native::uniform_(sampled, 0.0, 1.0, generator); + + dim3 block(requiredThreads); + dim3 grid(std::min(static_cast(numDist), numSM * 4)); + + sampleMultinomialOnce + <<>>( + result.mutable_data_ptr(), + numDist, + numCategories, + sampled.const_data_ptr(), + self_v.const_data_ptr(), + self_v.stride(0), + self_v.stride(1) + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + // Generic, slow implementation with memory allocations + + // For sampling without replacement, we modify the distribution + // for subsequent samples in this space + Tensor origDist = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + origDist.copy_(self_v); + + Tensor normDist = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + Tensor prefixSum = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + // Renorm along rows + normDist.copy_(origDist); + renormRows(normDist); + + // Prefix sum along rows + at::privateuse1::cumsum_out(prefixSum, normDist, 1); + + PhiloxHIPState rng_engine_inputs; + + // Binary search is warp divergent (so effectively we're running + // with just a single thread), but for better utilization, + // we need each block to have at least 4 warps. + dim3 block(128); + + // Each block will generate a sample from one + // distribution concurrently. + int grid_y=std::min(numDist, at::zoom::getCurrentDeviceProperties()->maxGridSize[1]); + dim3 grid((n_sample-1)/block.x+1, grid_y); + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + + // each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use + // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]), + // offset is 4 times that. + auto offset = ((numDist-1)/grid.y+1)*4; + rng_engine_inputs = gen->philox_hip_state(offset); + } + // Sample with replacement + + sampleMultinomialWithReplacement + <<>>( + rng_engine_inputs, + n_sample, + result.mutable_data_ptr(), + numDist, numCategories, + prefixSum.const_data_ptr(), + normDist.const_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + + if (inputSize == 1) { + result.resize_({n_sample}); + } +} +} + +REGISTER_PRIVATEUSE1_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/zoom/ReduceAMinMaxKernel.cu new file mode 100644 index 00000000000000..e31cfc186de1ba --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceAMinMaxKernel.cu @@ -0,0 +1,45 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +template +void _min_max_values_kernel_zoom_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); +} + +void aminmax_allreduce_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_zoom", [&] { + _min_max_values_kernel_zoom_impl(iter); + }); +} + +void aminmax_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_zoom", [&]() { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceArgMaxKernel.cu b/aten/src/ATen/native/zoom/ReduceArgMaxKernel.cu new file mode 100644 index 00000000000000..b5f526ebff9ef2 --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceArgMaxKernel.cu @@ -0,0 +1,46 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +template +void argmax_kernel_zoom_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); +}; + +void argmax_kernel_zoom(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmax_kernel_zoom_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmax_kernel_zoom_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_zoom", [&]() { + argmax_kernel_zoom_impl(iter); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(argmax_stub, &argmax_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceArgMinKernel.cu b/aten/src/ATen/native/zoom/ReduceArgMinKernel.cu new file mode 100644 index 00000000000000..5007d0abeeca3f --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceArgMinKernel.cu @@ -0,0 +1,46 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +template +void argmin_kernel_zoom_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMinOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), 0)); +}; + +void argmin_kernel_zoom(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmin_kernel_zoom_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmin_kernel_zoom_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_zoom", [&]() { + argmin_kernel_zoom_impl(iter); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(argmin_stub, &argmin_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceLogicKernel.cu b/aten/src/ATen/native/zoom/ReduceLogicKernel.cu index fb6bb731781358..fafe22cc4b1fd3 100644 --- a/aten/src/ATen/native/zoom/ReduceLogicKernel.cu +++ b/aten/src/ATen/native/zoom/ReduceLogicKernel.cu @@ -35,4 +35,4 @@ void or_kernel_zoom(TensorIterator& iter) { REGISTER_PRIVATEUSE1_DISPATCH(and_stub, &and_kernel_zoom); REGISTER_PRIVATEUSE1_DISPATCH(or_stub, &or_kernel_zoom); -} // namespace at::native \ No newline at end of file +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/zoom/ReduceMaxValuesKernel.cu new file mode 100644 index 00000000000000..7da6d4a0e0855b --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceMaxValuesKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +template +struct MaxNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a > b) ? a : b; + } +}; + +template +void max_values_kernel_zoom_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + func_wrapper(MaxNanFunctor()), + at::numeric_limits::lower_bound()); +} + +void max_values_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.dtype(), "max_values_zoom", [&]() { + max_values_kernel_zoom_impl(iter); + }); +} + +void max_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "max_zoom", [&]() { + gpu_reduce_kernel( + iter, + MaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); + }); +} + +void max_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_zoom", [&] { + max_values_kernel_zoom_impl(iter); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(max_values_stub, &max_values_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceMinValuesKernel.cu b/aten/src/ATen/native/zoom/ReduceMinValuesKernel.cu new file mode 100644 index 00000000000000..e5acf1a000207c --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceMinValuesKernel.cu @@ -0,0 +1,58 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +namespace at::native { + +template +struct MinNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a < b) ? a : b; + } +}; + +template +void min_values_kernel_zoom_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper (MinNanFunctor()), + at::numeric_limits::upper_bound()); +} + +void min_values_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_zoom", [&]() { + min_values_kernel_zoom_impl(iter); + }); +} + +void min_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_zoom", [&]() { + gpu_reduce_kernel( + iter, + MinOps{}, + thrust::pair(at::numeric_limits::upper_bound(), 0)); + }); +} + +void min_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_zoom", [&] { + min_values_kernel_zoom_impl(iter); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(min_values_stub, &min_values_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceNormKernel.cu b/aten/src/ATen/native/zoom/ReduceNormKernel.cu new file mode 100644 index 00000000000000..71b1db37634943 --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceNormKernel.cu @@ -0,0 +1,51 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// This reduction accumulates results as the type `acc_t`. By default, when +// `scalar_t` is complex, `acc_t` is the downgraded real number type. +// Otherwise, `acc_t` and `scalar_t` are the same type. +template ::type, typename out_t=typename scalar_value_type::type> +void norm_kernel_zoom_impl(TensorIterator& iter, double p) { + if (p == static_cast(0)) { + gpu_reduce_kernel(iter, NormZeroOps(), 0); + } else if (p == static_cast(1)) { + gpu_reduce_kernel(iter, NormOneOps(), 0); + } else if (p == static_cast(2)) { + gpu_reduce_kernel(iter, NormTwoOps(), 0); + } else if (p == static_cast(INFINITY)) { + gpu_reduce_kernel(iter, AbsMaxOps(), 0); + } else if (p == static_cast(-INFINITY)) { + gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::infinity()); + } else { + gpu_reduce_kernel(iter, NormOps{acc_t(p)}, 0); + } +} + +void norm_launch_kernel(TensorIterator& iter, double ord) { + if (iter.dtype(0) == kHalf) { + return norm_kernel_zoom_impl(iter, ord); + } else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) { + // type promotion that does cast and reduction in a single kernel + return norm_kernel_zoom_impl(iter, ord); + } + else if(iter.dtype(0) == kBFloat16) { + return norm_kernel_zoom_impl(iter, ord); + } else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) { + // type promotion that does cast and reduction in a single kernel + return norm_kernel_zoom_impl(iter, ord); + } + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_zoom", [&] { + norm_kernel_zoom_impl(iter, ord); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceOps.cpp b/aten/src/ATen/native/zoom/ReduceOps.cpp new file mode 100644 index 00000000000000..c57c4303ccea69 --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceOps.cpp @@ -0,0 +1,102 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { +namespace { + +void norm_kernel_zoom(TensorIterator& iter, const Scalar& val) { + double p; + if (val.isIntegral(false)) { + p = val.to(); + } else if (val.isFloatingPoint()) { + p = val.to(); + } else { + TORCH_CHECK(false, "norm_kernel_zoom_impl expects norm to be integer or float"); + } + if (iter.numel() == 0) { + iter.output().fill_((p < 0) ? INFINITY : 0); + return; + } + + norm_launch_kernel(iter, p); + + if (isComplexType(iter.output().scalar_type())) { + at::imag(iter.output()).zero_(); + } + +} + +void min_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) { + auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong); + min_launch_kernel(iter); +} + +void max_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) { + auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong); + max_launch_kernel(iter); +} + +void aminmax_kernel_impl( + const Tensor& self, int64_t dim, bool keepdim, Tensor& min_result, Tensor& max_result) { + at::TensorIterator iter = make_reduction("aminmax_zoom", min_result, + max_result, self, dim, keepdim, self.scalar_type()); + if (iter.numel() != 0) { + aminmax_launch_kernel(iter); + } +} + +void min_all_kernel_impl(Tensor& result, const Tensor& input) { + auto dtype = input.scalar_type(); + auto iter = make_reduction("min_all", result, input, IntArrayRef{}, false, dtype); + min_all_launch_kernel(iter); +} + +void max_all_kernel_impl(Tensor& result, const Tensor& input) { + auto dtype = input.scalar_type(); + auto iter = make_reduction("max_all", result, input, IntArrayRef{}, false, dtype); + max_all_launch_kernel(iter); +} + +void aminmax_allreduce_kernel_impl(const Tensor& input, Tensor& min_result, Tensor& max_result) { + auto dtype = input.scalar_type(); + auto iter = make_reduction("aminmax_zoom", min_result, max_result, input, + IntArrayRef{}, false, dtype); + TORCH_CHECK(iter.numel() > 0, "min_max on a tensor with no elements is not defined."); + aminmax_allreduce_launch_kernel(iter); +} + +} // namespace (anonymous) + +REGISTER_PRIVATEUSE1_DISPATCH(min_stub, &min_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(max_stub, &max_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(min_all_stub, &min_all_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(max_all_stub, &max_all_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(aminmax_stub, &aminmax_kernel_impl); + +REGISTER_PRIVATEUSE1_DISPATCH(norm_stub, &norm_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceOps.h b/aten/src/ATen/native/zoom/ReduceOps.h new file mode 100644 index 00000000000000..a67a019ae49e2e --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceOps.h @@ -0,0 +1,20 @@ + +namespace at { +struct TensorIterator; +} + +namespace c10 { +class Scalar; +} + +namespace at { namespace native { + +void norm_launch_kernel(TensorIterator &iter, double val); +void min_launch_kernel(TensorIterator &iter); +void max_launch_kernel(TensorIterator &iter); +void aminmax_launch_kernel(TensorIterator &iter); +void min_all_launch_kernel(TensorIterator &iter); +void max_all_launch_kernel(TensorIterator &iter); +void aminmax_allreduce_launch_kernel(TensorIterator &iter); + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ReduceSumProdKernel.cu b/aten/src/ATen/native/zoom/ReduceSumProdKernel.cu new file mode 100644 index 00000000000000..815dfb35ac0252 --- /dev/null +++ b/aten/src/ATen/native/zoom/ReduceSumProdKernel.cu @@ -0,0 +1,215 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +// #include +#include + +namespace at::native { + +template +struct sum_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { + return a + b; + })); + } +}; + +// jiterated specialization for `complex` +CONSTEXPR_EXCEPT_WIN_CUDA char sum_name[] = "sum"; +template <> +struct sum_functor> { +// jiterator reduction fails on windows +// Ref: https://github.com/pytorch/pytorch/issues/77305 +#if AT_USE_JITERATOR() && !defined(_MSC_VER) + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + std::string func = jiterator_stringify( + arg_t combine(arg_t a, arg_t b) { + return a + b; + } + ); + jitted_gpu_reduce_kernel( + iter, func, 0.); + } +#else + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { + return a + b; + }), acc_t{0.}); + } +#endif +}; + +template +struct nansum_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, NanSumOps{}); + } +}; + +CONSTEXPR_EXCEPT_WIN_CUDA char nansum_name[] = "nansum"; +template +struct nansum_functor_complex { +#if AT_USE_JITERATOR() + void operator()(TensorIterator& iter) { + std::string func = jiterator_stringify( + arg_t combine(arg_t a, scalar_t b) { + return a + (std::isnan(b) ? arg_t{0.} : arg_t{b}); + } + ); + jitted_gpu_reduce_kernel( + iter, func, 0.); + } +#else + void operator()(TensorIterator& iter) { + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, NanSumOps{}); + } +#endif +}; + +CONSTEXPR_EXCEPT_WIN_CUDA char prod_name[] = "prod"; +template +struct prod_functor { + // jiterator reduction fails on windows + // Ref: https://github.com/pytorch/pytorch/issues/77305 + #if AT_USE_JITERATOR() && !defined(_MSC_VER) + void operator()(TensorIterator& iter) { + std::string func = jiterator_stringify( + arg_t combine(arg_t a, arg_t b) { + return a * b; + } + ); + jitted_gpu_reduce_kernel( + iter, func, 1.); + } + #else + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { + return a * b; + }), 1.); + } + #endif +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] +template <> +struct prod_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(bool a, bool b) -> bool { + return a && b; + }), 1); + } +}; + +// jiterated specialization for `complex` +template <> +struct prod_functor> { +// jiterator reduction fails on windows +// Ref: https://github.com/pytorch/pytorch/issues/77305 +#if AT_USE_JITERATOR() && !defined(_MSC_VER) + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + std::string func = + jiterator_stringify(arg_t combine(arg_t a, arg_t b) { return a * b; }); + jitted_gpu_reduce_kernel(iter, func, 1.); + } +#else + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, + func_wrapper( + [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a * b; }), + acc_t{1.}); + } +#endif +}; + +// The function `reduce_dispatch` below dispatches to the kernel based +// on the type of `iter`. It takes care of the common logic +// for handling Half-Precision floating types. +// Otherwise the functor `op` is called to dispatch to the kernel +// of relevant type. +// +// Note: Functor `op` should take care of all the types to be supported +// except for `at::Half` and `at::BFloat16`. +template < + template < + typename scalar_t, + typename acc_t = scalar_t, + typename out_t = scalar_t> + typename OpFunctor, + typename GeneralDispatcher> +static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) { + if (iter.dtype() == kHalf) { + return OpFunctor{}(iter); + } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { + // type promotion that does cast and reduction in a single kernel + return OpFunctor{}(iter); + } else if (iter.dtype() == kBFloat16) { + return OpFunctor{}(iter); + } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { + // type promotion that does cast and reduction in a single kernel + return OpFunctor{}(iter); + } + op(iter); +} + +static void sum_kernel_zoom(TensorIterator& iter){ + auto general_dispatcher = [](TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kBool, kComplexHalf, iter.dtype(), "sum_zoom", [&]() { + sum_functor{}(iter); + }); + }; + + reduce_dispatch(iter, general_dispatcher); +} + +static void nansum_kernel_zoom(TensorIterator& iter) { + auto general_dispatcher = [](TensorIterator& iter) { + auto dtype = iter.dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "nansum_zoom", [&]() { + nansum_functor_complex{}(iter); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nansum_zoom", [&]() { + nansum_functor{}(iter); + }); + } + }; + + reduce_dispatch(iter, general_dispatcher); +} + +static void prod_kernel_zoom(TensorIterator& iter) { + auto general_dispatcher = [](TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kComplexHalf, kBool, iter.dtype(), "prod_zoom", [&]() { + prod_functor{}(iter); + }); + }; + + reduce_dispatch(iter, general_dispatcher); +} + +REGISTER_PRIVATEUSE1_DISPATCH(sum_stub, &sum_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(nansum_stub, &nansum_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(prod_stub, &prod_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ScatterGatherKernel.cu b/aten/src/ATen/native/zoom/ScatterGatherKernel.cu new file mode 100644 index 00000000000000..4d0121d83e5204 --- /dev/null +++ b/aten/src/ATen/native/zoom/ScatterGatherKernel.cu @@ -0,0 +1,573 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace at::native { + +// Implement as functors since lambdas don't get optimized. +class ReduceMultiply { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMul(self_data_start + index, *src_data); + } +}; +static ReduceMultiply reduce_multiply; + +class ReduceAdd { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + fastAtomicAdd(self_data_start, index, numel, *src_data, true); + } +}; +static ReduceAdd reduce_add; + +class ReduceMean { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + fastAtomicAdd(self_data_start, index, numel, *src_data, true); + } +}; +static ReduceMean reduce_mean; + +class ReduceMinimum { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMin(self_data_start + index, *src_data); + } +}; +static ReduceMinimum reduce_minimum; + +class ReduceMaximum { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + gpuAtomicMax(self_data_start + index, *src_data); + } +}; +static ReduceMaximum reduce_maximum; + +class TensorAssign { +public: + template + constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const { + (void)numel; // suppress unused warning + *(self_data_start + index) = *src_data; + } +}; +static TensorAssign tensor_assign; + +// The kernels are implemented on an opaque, +// self-aligned type of the correct size, +// to avoid redundant kernels for different types +// of the same size. +template struct alignas(N) OpaqueType { char data[N]; }; + +// essentially rewritten related to legacy::launch_kernel parts +template +C10_LAUNCH_BOUNDS_2(nt, vt) +__global__ void _scatter_gather_elementwise_kernel(int N, func_t f) { + constexpr int nv = nt * vt; + int idx = nv * blockIdx.x + threadIdx.x; + + #pragma unroll + for (int i = 0; i < vt; ++i) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + + const dim3 block(nt); + const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + const auto stream = c10::zoom::getCurrentZoomStream(); + _scatter_gather_elementwise_kernel<<>>(N, f); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + + +template +struct _zoom_scatter_gather_internal_kernel { + template + void operator() ( + TensorIterator& iter, + int64_t index_size, + int64_t index_stride, + int64_t numel, // Do not use `const` qualifier here as it may cause issue in cuda 11.6.x. See #75434, #75545 + const func_t& f + ) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + _zoom_scatter_gather_internal_kernel()( + sub_iter, index_size, index_stride, numel, f + ); + } + return; + } + + char* self_ptr = (char*)iter.data_ptr(0); + char* src_ptr = (char*)iter.data_ptr(1); + char* index_ptr = (char*)iter.data_ptr(2); + + auto offset_calc = make_offset_calculator<3>(iter); + auto loop = [=]C10_DEVICE(int i) { + auto offsets = offset_calc.get(i); + + int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]); + ZOOM_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size + && "index out of bounds"); + + f( + (scalar_t*)(self_ptr + offsets[0]), + is_scatter_like ? idx_dim * index_stride : 0, + numel, + (scalar_t*)(src_ptr + offsets[1]) + (is_scatter_like ? 0 : idx_dim * index_stride) + ); + }; + + _launch_scatter_gather_kernel(iter.numel(), loop); + } +}; // struct _zoom_scatter_fill_internal_kernel + +template +struct zoom_scatter_gather_base_kernel { + void operator()( + const Tensor& self, int64_t dim, + const Tensor& index, const Tensor& src, + const std::string& method_name, + const ReduceAdd& f + ) { + at::assert_no_internal_overlap(self); + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + auto self_strides = ensure_nonempty_vec(self.strides().vec()); + auto src_strides = ensure_nonempty_vec(src.strides().vec()); + + // restride self and src such that + // self.shape = src.shape = index.shape + // + // restride stride[dim] such that + // if (is_scatter_like) self.stride[dim] = 0 + // else src.stride[dim] = 0 + auto self_restrided = is_scatter_like ? + restride_dim(self, dim, index_sizes) + : self.as_strided(index_sizes, self_strides); + auto src_restrided = is_scatter_like ? + src.as_strided(index_sizes, src_strides) + : restride_dim(src, dim, index_sizes); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self_restrided) + .add_const_input(src_restrided) + .add_const_input(index) + .build(); + + auto self_dim_stride = ensure_nonempty_stride(self, dim); + auto self_dim_size = ensure_nonempty_size(self, dim); + + auto src_dim_stride = ensure_nonempty_stride(src, dim); + auto src_dim_size = ensure_nonempty_size(src, dim); + + auto index_size = is_scatter_like ? self_dim_size : src_dim_size; + auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; + + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + iter.dtype(), + "zoom_scatter_gather_base_kernel_func", [&] { + using dtype = typename std::conditional, scalar_t>::type; + + _zoom_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + } + ); + } + + void operator()( + const Tensor& self, int64_t dim, + const Tensor& index, const Tensor& src, + const std::string& method_name, + const TensorAssign& f + ) { + at::assert_no_internal_overlap(self); + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + auto self_strides = ensure_nonempty_vec(self.strides().vec()); + auto src_strides = ensure_nonempty_vec(src.strides().vec()); + + // restride self and src such that + // self.shape = src.shape = index.shape + // + // restride stride[dim] such that + // if (is_scatter_like) self.stride[dim] = 0 + // else src.stride[dim] = 0 + auto self_restrided = is_scatter_like ? + restride_dim(self, dim, index_sizes) + : self.as_strided(index_sizes, self_strides); + auto src_restrided = is_scatter_like ? + src.as_strided(index_sizes, src_strides) + : restride_dim(src, dim, index_sizes); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self_restrided) + .add_const_input(src_restrided) + .add_const_input(index) + .build(); + + auto self_dim_stride = ensure_nonempty_stride(self, dim); + auto self_dim_size = ensure_nonempty_size(self, dim); + + auto src_dim_stride = ensure_nonempty_stride(src, dim); + auto src_dim_size = ensure_nonempty_size(src, dim); + + auto index_size = is_scatter_like ? self_dim_size : src_dim_size; + auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; + + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + iter.dtype(), + "zoom_scatter_gather_base_kernel_func", [&] { + using dtype = typename std::conditional, scalar_t>::type; + + _zoom_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + } + ); + } + + template + void operator()( + const Tensor& self, int64_t dim, + const Tensor& index, const Tensor& src, + const std::string& method_name, + const func_t& f + ) { + at::assert_no_internal_overlap(self); + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + auto self_strides = ensure_nonempty_vec(self.strides().vec()); + auto src_strides = ensure_nonempty_vec(src.strides().vec()); + + // restride self and src such that + // self.shape = src.shape = index.shape + // + // restride stride[dim] such that + // if (is_scatter_like) self.stride[dim] = 0 + // else src.stride[dim] = 0 + auto self_restrided = is_scatter_like ? + restride_dim(self, dim, index_sizes) + : self.as_strided(index_sizes, self_strides); + auto src_restrided = is_scatter_like ? + src.as_strided(index_sizes, src_strides) + : restride_dim(src, dim, index_sizes); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self_restrided) + .add_const_input(src_restrided) + .add_const_input(index) + .build(); + + auto self_dim_stride = ensure_nonempty_stride(self, dim); + auto self_dim_size = ensure_nonempty_size(self, dim); + + auto src_dim_stride = ensure_nonempty_stride(src, dim); + auto src_dim_size = ensure_nonempty_size(src, dim); + + auto index_size = is_scatter_like ? self_dim_size : src_dim_size; + auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; + + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.dtype(), + "zoom_scatter_gather_base_kernel_func", [&] { + using dtype = typename std::conditional, scalar_t>::type; + + _zoom_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + } + ); + } +}; // struct zoom_scatter_gather_base_kernel + +template +struct _zoom_scatter_fill_internal_kernel { + template + void operator()( + TensorIterator& iter, + scalar_t src_val, + int64_t index_size, + int64_t index_stride, + int64_t numel, // Do not use `const` qualifier here as it may cause issue in cuda 11.6.x. See #75434, #75545 + const func_t& f + ) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + _zoom_scatter_fill_internal_kernel()( + sub_iter, src_val, index_size, index_stride, numel, f + ); + } + return; + } + + char* self_ptr = (char*)iter.data_ptr(0); + char* index_ptr = (char*)iter.data_ptr(1); + + auto offset_calc = make_offset_calculator<2>(iter); + auto loop = [=]C10_DEVICE(int i) { + auto offsets = offset_calc.get(i); + + int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]); + ZOOM_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size + && "index out of bounds" + ); + + f( + (scalar_t*)(self_ptr + offsets[0]), + idx_dim * index_stride, + numel, + (scalar_t*)&src_val + ); + }; + + _launch_scatter_gather_kernel(iter.numel(), loop); + } +}; // struct _zoom_scatter_fill_internal_kernel + +template +struct zoom_scatter_fill_base_kernel { + template + void operator()( + const Tensor& self, int64_t dim, + const Tensor& index, Scalar src, + const std::string& method_name, + const func_t& f + ) { + at::assert_no_internal_overlap(self); + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + + // restride self such that + // self.shape = index.shape and + // self.stride[dim] = 0 + auto self_restrided = restride_dim(self, dim, index_sizes); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self_restrided) + .add_const_input(index) + .build(); + + auto index_size = ensure_nonempty_size(self, dim); + auto index_stride = ensure_nonempty_stride(self, dim); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + iter.dtype(), + "zoom_scatter_fill_base_kernel_func", [&] { + using dtype = typename std::conditional, scalar_t>::type; + + auto src_scalar_val = src.to(); + auto src_val = *(dtype*)&src_scalar_val; + + _zoom_scatter_fill_internal_kernel()( + iter, src_val, index_size, index_stride, self.numel(), f + ); + } + ); + } + + void operator()( + const Tensor& self, int64_t dim, + const Tensor& index, Scalar src, + const std::string& method_name, + const ReduceMultiply& f + ) { + at::assert_no_internal_overlap(self); + + auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); + + // restride self such that + // self.shape = index.shape and + // self.stride[dim] = 0 + auto self_restrided = restride_dim(self, dim, index_sizes); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self_restrided) + .add_const_input(index) + .build(); + + auto index_size = ensure_nonempty_size(self, dim); + auto index_stride = ensure_nonempty_stride(self, dim); + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.dtype(), + "zoom_scatter_fill_base_kernel_reduce_multiply", [&] { + using dtype = typename std::conditional, scalar_t>::type; + + auto src_scalar_val = src.to(); + auto src_val = *(dtype*)&src_scalar_val; + + _zoom_scatter_fill_internal_kernel()( + iter, src_val, index_size, index_stride, self.numel(), f + ); + } + ); + } +}; // struct zoom_scatter_fill_base_kernel + +void gather_zoom_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { + zoom_scatter_gather_base_kernel()( + result, dim, index, self, + "gather_out_zoom", tensor_assign); +} + +void scatter_zoom_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { + // When indices are not unique, the behavior is non-deterministic + globalContext().alertNotDeterministic("scatter_zoom_"); + zoom_scatter_gather_base_kernel<>()( + self, dim, index, src, + "scatter_zoom_", tensor_assign); +} + +void scatter_fill_zoom_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src) { + zoom_scatter_fill_base_kernel<>()( + self, dim, index, src, + "scatter_fill_zoom_", tensor_assign); +} + +void scatter_add_zoom_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("scatter_add_zoom_kernel"); + zoom_scatter_gather_base_kernel()( + self, dim, index, src, + "scatter_add_zoom_", reduce_add); +} + +void scatter_reduce_zoom_kernel(const Tensor& self, const int64_t dim, const Tensor& index, + const Tensor& src, const ReductionType& reduce) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd/AtomicMul usage + globalContext().alertNotDeterministic("scatter_reduce_zoom_kernel"); + switch (reduce) { + case ReductionType::SUM : + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_add_", reduce_add); + break; + case ReductionType::PROD : + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_multiply_", reduce_multiply); + break; + default : + break; + } +} + +void scatter_reduce_two_zoom_kernel(const Tensor& self, const int64_t dim, const Tensor& index, + const Tensor& src, const ReductionType& reduce) { + switch (reduce) { + case ReductionType::SUM : + globalContext().alertNotDeterministic("scatter_reduce_zoom_sum_"); + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_sum_", reduce_add); + break; + case ReductionType::PROD : + globalContext().alertNotDeterministic("scatter_reduce_zoom_prod_"); + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_prod_", reduce_multiply); + break; + case ReductionType::MAX : + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_amax_", reduce_maximum); + break; + case ReductionType::MIN : + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_amin_", reduce_minimum); + break; + case ReductionType::MEAN : + globalContext().alertNotDeterministic("scatter_reduce_zoom_mean_"); + zoom_scatter_gather_base_kernel()(self, dim, index, src, + "scatter_reduce_zoom_mean_", reduce_mean); + break; + } +} + +void scatter_scalar_reduce_zoom_kernel(const Tensor& self, const int64_t dim, const Tensor& index, + const Scalar& value, const ReductionType& reduce) { + switch (reduce) { + case ReductionType::SUM : + zoom_scatter_fill_base_kernel()(self, dim, index, value, + "scatter_fill_zoom_add_", reduce_add); + break; + case ReductionType::PROD : + zoom_scatter_fill_base_kernel()(self, dim, index, value, + "scatter_fill_zoom_multiply_", reduce_multiply); + break; + default : + break; + } +} + + +REGISTER_PRIVATEUSE1_DISPATCH(gather_stub, &gather_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_stub, &scatter_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_fill_stub, &scatter_fill_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_add_stub, &scatter_add_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_reduce_stub, &scatter_reduce_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_zoom_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_zoom_kernel); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/UnaryOpsKernel.cu b/aten/src/ATen/native/zoom/UnaryOpsKernel.cu new file mode 100644 index 00000000000000..49ed65a45004ce --- /dev/null +++ b/aten/src/ATen/native/zoom/UnaryOpsKernel.cu @@ -0,0 +1,286 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void bitwise_not_kernel_zoom(TensorIteratorBase& iter) { + if (iter.dtype() == ScalarType::Bool) { + gpu_kernel(iter, []GPU_LAMBDA(bool a) { + return !a; + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ~a; + }); + }); + } +} + +CONSTEXPR_EXCEPT_WIN_CUDA char exp_name[] = "exp_kernel"; +void exp_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { + #if AT_USE_JITERATOR() + static const auto exp_string = jiterator_stringify( + template + T exp_kernel(T x) { + return std::exp(x); + }); // exp_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/exp_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, exp_string); + }); + #else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return std::exp(static_cast(a)); + }); + }); + #endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "exp_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::exp(a); + }); + }); + } +} + +void expm1_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + ScalarType::BFloat16, ScalarType::Half, + iter.common_dtype(), "expm1_zoom", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::expm1(a); + }); + }); +} + +// We manually overload rsqrt because std::rsqrt does not work with complex types. +template +C10_HOST_DEVICE static inline scalar_t rsqrt_wrapper(scalar_t v) { + return ::rsqrt(v); +} + +template +C10_HOST_DEVICE static inline c10::complex rsqrt_wrapper(c10::complex v) { + const c10::complex one = c10::complex(1.0, 0); + // std::sqrt for c10::complex is overloaded in c10/util/complex_math.h + return one / ::sqrt(v); +} + +CONSTEXPR_EXCEPT_WIN_CUDA char rsqrt_name[] = "rsqrt_kernel"; +void rsqrt_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { + #if AT_USE_JITERATOR() + static const auto rsqrt_string = jiterator_stringify( + template + T rsqrt_kernel(T x) { + const T one = T{1}; + return one / std::sqrt(x); + }); // rsqrt_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/rsqrt_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, rsqrt_string); + }); + #else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return rsqrt_wrapper(static_cast(a)); + }); + }); + #endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::BFloat16, ScalarType::Half, + iter.common_dtype(), "rsqrt_zoom", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. + return rsqrt_wrapper(a); + }); + }); + } +} + +CONSTEXPR_EXCEPT_WIN_CUDA char sqrt_name[] = "sqrt_kernel"; +void sqrt_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { + #if AT_USE_JITERATOR() + static const auto sqrt_string = jiterator_stringify( + template + T sqrt_kernel(T x) { + return std::sqrt(x); + }); // sqrt_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/sqrt_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sqrt_string); + }); + #else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sqrt(static_cast(a)); + }); + }); + #endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, common_dtype, "sqrt_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::sqrt(a); + }); + }); + } +} + +void clamp_kernel_zoom(TensorIteratorBase& iter, const Scalar& min_value, const Scalar& max_value) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_zoom", [&]() { + auto lower = min_value.to(); + auto upper = max_value.to(); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(::max(v, lower), upper); + } + }); + }); +} + +void clamp_min_kernel_zoom(TensorIteratorBase& iter, const Scalar& min_value) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_min_zoom", [&]() { + auto lower = min_value.to(); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::max(v, lower); + } + }); + }); +} + +void clamp_max_kernel_zoom(TensorIteratorBase& iter, const Scalar& max_value) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_max_zoom", [&]() { + auto upper = max_value.to(); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(v, upper); + } + }); + }); +} + +template +C10_HOST_DEVICE static inline scalar_t _nan_to_num_replace(scalar_t a, scalar_t nan_replacement, scalar_t pos_inf_replacement, scalar_t neg_inf_replacement) { + return at::_isnan(a) + ? nan_replacement + : (a == std::numeric_limits::infinity() + ? pos_inf_replacement + : (a == -std::numeric_limits::infinity() + ? neg_inf_replacement + : a)); +} + +void nan_to_num_kernel_zoom( + TensorIteratorBase& iter, + std::optional nan, + std::optional pos_inf, + std::optional neg_inf) { + if (isComplexType(iter.dtype())) { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "nan_to_num", [&]() { + using value_t = scalar_t::value_type; + value_t nan_replacement = static_cast(nan.value_or(0.)); + value_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + value_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { + value_t res_real = _nan_to_num_replace( + a.real(), nan_replacement, pos_inf_replacement, neg_inf_replacement); + value_t res_imag = _nan_to_num_replace( + a.imag(), nan_replacement, pos_inf_replacement, neg_inf_replacement); + return scalar_t(res_real, res_imag); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "nan_to_num_zoom", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { + return _nan_to_num_replace( + a, nan_replacement, pos_inf_replacement, neg_inf_replacement); + }); + }); + } +} + +void frexp_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, + // The iter.dtype() here is the dtype of mantissa output. + // It's a floating point type and must be the same as the input's dtype. + iter.dtype(), + "frexp_zoom", [&]() { + gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t a) -> thrust::tuple { + int32_t exponent; + scalar_t mantissa = std::frexp(a, &exponent); + return {mantissa, exponent}; + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(exp_stub, &exp_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(expm1_stub, &expm1_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(rsqrt_stub, &rsqrt_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(sqrt_stub, &sqrt_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(frexp_stub, &frexp_kernel_zoom); + +} // namespace at::native From aaef6b9087992179551d3e3f6503b75b770cf196 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Fri, 17 Jan 2025 00:02:02 +0000 Subject: [PATCH 07/23] remove deps on hipblas, hipblaslt, hipsparse, hipsolver, hipfft, roctx, miopen --- cmake/Dependencies.cmake | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e29c89479f9dad..bc0d184cb8fd98 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1240,12 +1240,17 @@ if(USE_ROCM OR USE_ZOOM) # This is needed for library added by hip_add_library (same for hip_add_executable) hip_include_directories(${Caffe2_HIP_INCLUDE}) - set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS + if(USE_ZOOM) + set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS + ${PYTORCH_HIP_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB}) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS hip::hiprand) + else() + set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB}) - list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES}) - - list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS - roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES}) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS + roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver) + endif() # ---[ Kernel asserts # Kernel asserts is disabled for ROCm by default. From ac54e3e7aa427ff6737d9911f1962a9dbaebcd4a Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Sun, 26 Jan 2025 05:30:36 +0000 Subject: [PATCH 08/23] llama example working, bmm triton kernel --- .github/workflows/build_zoom_backend.yml | 124 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 36 ++-- aten/src/ATen/native/zoom/Bmm.cpp | 122 ------------ .../native/zoom/DistributionRandomKernel.cu | 27 +++ .../ATen/native/zoom/DistributionUniform.cu | 15 ++ aten/src/ATen/native/zoom/HIPbmm.cu | 132 ------------- aten/src/ATen/native/zoom/TensorCompare.cu | 133 +++++++++++++ test/test_ops.py | 3 +- torch/zoom/__init__.py | 2 +- torch/zoom/zoom_triton_mm.py | 182 ++++++++++++++++++ 10 files changed, 501 insertions(+), 275 deletions(-) create mode 100644 .github/workflows/build_zoom_backend.yml delete mode 100644 aten/src/ATen/native/zoom/Bmm.cpp create mode 100644 aten/src/ATen/native/zoom/DistributionRandomKernel.cu create mode 100644 aten/src/ATen/native/zoom/DistributionUniform.cu delete mode 100644 aten/src/ATen/native/zoom/HIPbmm.cu create mode 100644 aten/src/ATen/native/zoom/TensorCompare.cu create mode 100644 torch/zoom/zoom_triton_mm.py diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml new file mode 100644 index 00000000000000..aa7053cafe8379 --- /dev/null +++ b/.github/workflows/build_zoom_backend.yml @@ -0,0 +1,124 @@ +name: "Build PyTorch" + +on: + workflow_dispatch: + inputs: + force_debug_with_tmate: + type: boolean + description: 'Run the build with tmate session' + required: false + default: false + debug_with_tmate: + type: boolean + description: 'Run the build with a tmate session ONLY in case of failure' + required: false + default: false + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + build: + + strategy: + fail-fast: false + matrix: + include: + - name: "ubuntu-22.04" + runs-on: "mi300" + # container: "rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0" + # runs-on: "nod-ai-shared-cpubuilder-manylinux-x86_64" + + runs-on: ${{ matrix.runs-on }} + + name: ${{ matrix.name }} + + env: + CACHE_DIR: ${{ github.workspace }}/.container-cache + # either the PR number or `branch-N` where N always increments + CACHE_KEY: linux-build-test-cpp-asserts-manylinux-v2-${{ format('{0}-{1}', github.ref_name, github.run_number) }} + + defaults: + run: + shell: bash + + permissions: + id-token: write + contents: write + + container: + image: ${{ matrix.container }} + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + with: + submodules: true + + - name: Enable cache + uses: actions/cache/restore@v3 + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + restore-keys: linux-build-test-cpp- + + - name: "Build PyTorch" + id: build + run: | + + export CCACHE_DIR="${{ env.CACHE_DIR }}" + export CMAKE_C_COMPILER_LAUNCHER=ccache + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export CCACHE_SLOPPINESS=include_file_ctime,include_file_mtime,time_macros + + python -m venv venv + source venv/bin/activate + pip install -r requirements.txt + ./build.sh + + - name: "Audit" + id: audit + run: | + + sudo apt install patchelf + source venv/bin/activate + pip install auditwheel + auditwheel repair -w dist --plat manylinux_2_39_x86_64 dist/torch* + + - name: Save cache + uses: actions/cache/save@v3 + if: ${{ !cancelled() }} + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.name }}_artifact + path: dist + if-no-files-found: warn + + - name: Release current commit + uses: ncipollo/release-action@v1.12.0 + with: + artifacts: "dist/torch*.whl" + token: "${{ secrets.GITHUB_TOKEN }}" + tag: "latest" + name: "latest" + removeArtifacts: false + allowUpdates: true + replacesArtifacts: true + makeLatest: true + + - name: "Setup tmate session" + if: ${{ (failure() && inputs.debug_with_tmate) || inputs.force_debug_with_tmate }} + uses: mxschmitt/action-tmate@v3.18 + with: + limit-access-to-actor: true + install-dependencies: ${{ startsWith(matrix.runs-on, 'macos') || startsWith(matrix.runs-on, 'windows') }} \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1664a6642b4cc4..6271e79c453abf 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1352,7 +1352,6 @@ dispatch: CPU: bmm_out_cpu CUDA: bmm_out_cuda - PrivateUse1: bmm_out_zoom MPS: bmm_out_mps SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda @@ -1513,7 +1512,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_out + CPU, CUDA, PrivateUse1: clamp_out MPS: clamp_out_mps tags: pointwise @@ -1522,7 +1521,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_Tensor_out + CPU, CUDA, PrivateUse1: clamp_Tensor_out MPS: clamp_Tensor_out_mps tags: pointwise @@ -1553,7 +1552,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_out + CPU, CUDA, PrivateUse1: clamp_max_out MPS: clamp_max_out_mps tags: pointwise @@ -1562,7 +1561,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_Tensor_out + CPU, CUDA, PrivateUse1: clamp_max_Tensor_out MPS: clamp_max_Tensor_out_mps tags: pointwise @@ -1593,7 +1592,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_out + CPU, CUDA, PrivateUse1: clamp_min_out MPS: clamp_min_out_mps tags: pointwise @@ -1602,7 +1601,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_Tensor_out + CPU, CUDA, PrivateUse1: clamp_min_Tensor_out MPS: clamp_min_Tensor_out_mps tags: pointwise @@ -3168,7 +3167,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, MPS: isnan + CPU, CUDA, MPS, PrivateUse1: isnan SparseCPU, SparseCUDA: isnan_sparse SparseCsrCPU, SparseCsrCUDA: isnan_sparse_csr autogen: isnan.out @@ -4121,7 +4120,6 @@ dispatch: CPU: mm_out_cpu CUDA: mm_out_cuda - PrivateUse1: mm_out_zoom MPS: mm_out_mps SparseCPU, SparseCUDA: _sparse_mm_out SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out @@ -6463,13 +6461,13 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA, MPS: where + CPU, CUDA, MPS, PrivateUse1: where tags: [core, pointwise] - func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: where_self_out + CPU, CUDA, MPS, PrivateUse1: where_self_out - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor variants: function @@ -7874,7 +7872,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_ autogen: set.source_Storage, set.source_Storage_out tags: inplace_view @@ -7905,7 +7903,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_tensor_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_tensor_ autogen: set.source_Tensor, set.source_Tensor_out tags: inplace_view @@ -8663,7 +8661,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.from, random.from_out @@ -8673,7 +8671,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.to, random.to_out @@ -8683,7 +8681,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ MPS: random_mps_ Meta: random_meta_ autogen: random, random.out @@ -8693,7 +8691,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: uniform_ + CPU, CUDA, PrivateUse1: uniform_ MPS: uniform_mps_ Meta: uniform_meta_ autogen: uniform, uniform.out @@ -13077,7 +13075,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isposinf_out + CPU, CUDA, PrivateUse1: isposinf_out SparseCPU, SparseCUDA: isposinf_sparse_out SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr_out tags: pointwise @@ -13094,7 +13092,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isneginf_out + CPU, CUDA, PrivateUse1: isneginf_out SparseCPU, SparseCUDA: isneginf_sparse_out SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr_out tags: pointwise diff --git a/aten/src/ATen/native/zoom/Bmm.cpp b/aten/src/ATen/native/zoom/Bmm.cpp deleted file mode 100644 index 53e87a7eb3913e..00000000000000 --- a/aten/src/ATen/native/zoom/Bmm.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - - -namespace at::native { - // Forward decl, defined in HIPbmm.cu - template - void batched_matmul(const T* A, const T* B, T* C, int M, int N, int K, int batch_size); - - const Tensor& bmm_out_hip_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2) { - // handle pathological cases - if (result.numel() == 0) { - return result; - } else if (batch1.size(2) == 0) { - return result.zero_(); - } - TORCH_CHECK(batch1.sizes()[2] == batch2.sizes()[1], "batch1 dim 2 must match batch2 dim 1"); - - c10::MaybeOwned result_ = c10::MaybeOwned::borrowed(result); - IntArrayRef result_strides = result.strides(); - IntArrayRef result_sizes = result.sizes(); - - int m = batch1.sizes()[1]; - int n = batch1.sizes()[2]; - int k = batch2.sizes()[2]; - int num_batches = result_->sizes()[0]; - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "bmm_hip", [&] { - const scalar_t* batch1_ptr = batch1.const_data_ptr(); - const scalar_t* batch2_ptr = batch2.const_data_ptr(); - scalar_t* result_ptr = result_->mutable_data_ptr(); - - batched_matmul(batch1_ptr, batch2_ptr, result_ptr, m, n, k, num_batches); - }); - if (!result.is_same(*result_)) { - result.copy_(*result_); - } - return result; - - } - - TORCH_IMPL_FUNC(bmm_out_zoom)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) - { - NoNamesGuard guard; - bmm_out_hip_impl(result, result, batch1, batch2); - } - - Tensor& mm_out_hip_impl(Tensor& result, const Tensor& mat1, const Tensor& mat2) { - // Make sure to keep addmm_hip below in sync with this code; it - // preflights a check to try to avoid actually needing to call - // expand(). - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ) - - TensorArg targs[]{{result, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}}; - checkAllSameGPU(__func__, targs); - - IntArrayRef mat1_sizes = mat1.sizes(); - IntArrayRef mat2_sizes = mat2.sizes(); - at::ScalarType scalar_type = mat1.scalar_type(); - TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - // resize result tensor - at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); - IntArrayRef result_sizes = result.sizes(); - if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { - return result; - } - - if (mat1.numel() == 0) { - // By definition, values in self should be ignored. nans and infs - // should not propagate - return result.zero_(); - } - - int m = mat1_sizes[0]; - int n = mat1_sizes[1]; - int k = mat2_sizes[1]; - - // TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result.is_conj()); - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - scalar_type, - "mm_zoom", - [&] { - const scalar_t* mat1_ptr = mat1.const_data_ptr(); - const scalar_t* mat2_ptr = mat2.const_data_ptr(); - scalar_t* result_ptr = result.mutable_data_ptr(); - batched_matmul(mat1_ptr, mat2_ptr, result_ptr, m, n, k, 1); - }); - - return result; - } - - TORCH_IMPL_FUNC(mm_out_zoom)(const Tensor& self, const Tensor& mat2, const Tensor& result) - { - mm_out_hip_impl(const_cast(result), self, mat2); - } - -} // at::native - - diff --git a/aten/src/ATen/native/zoom/DistributionRandomKernel.cu b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu new file mode 100644 index 00000000000000..7e8aa20d652bae --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu @@ -0,0 +1,27 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_from_to_kernel(iter, range, base, gen); +} + +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_full_64_bits_range_kernel(iter, gen); +} + +void random_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_kernel(iter, gen); +} + +REGISTER_PRIVATEUSE1_DISPATCH(random_from_to_stub, &random_from_to_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_stub, &random_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionUniform.cu b/aten/src/ATen/native/zoom/DistributionUniform.cu new file mode 100644 index 00000000000000..25ed5e7b8b1148 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionUniform.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void uniform_kernel(TensorIteratorBase& iter, double from, double to, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + templates::zoom::uniform_kernel(iter, from, to, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(uniform_stub, &uniform_kernel); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/HIPbmm.cu b/aten/src/ATen/native/zoom/HIPbmm.cu deleted file mode 100644 index a77a31efaf1af6..00000000000000 --- a/aten/src/ATen/native/zoom/HIPbmm.cu +++ /dev/null @@ -1,132 +0,0 @@ -#include -#include -#include -#include -#include - -namespace at::native { - - int num_threads() { - return 32; - } - - // Helper function to convert hip_bfloat16 to float - __device__ float bfloat16_to_float(hip_bfloat16 a) { - union { - uint32_t int32; - float float32; - } u = {uint32_t(a.data) << 16}; - return u.float32; - } - - // Helper function to convert float to hip_bfloat16 - __device__ hip_bfloat16 float_to_bfloat16(float a) { - union { - float float32; - uint32_t int32; - } u = {a}; - hip_bfloat16 b; - b.data = uint16_t(u.int32 >> 16); - return b; - } - - template - __device__ float convert_to_float(T a) { - return a; - } - - template <> - __device__ float convert_to_float(hip_bfloat16 a) { - return bfloat16_to_float(a); - } - - template <> - __device__ float convert_to_float<__half>( __half a) { - return __half2float(a); - } - - template - __device__ T convert_from_float(float a) { - return static_cast(a); - } - - template <> - __device__ hip_bfloat16 convert_from_float(float a) { - return float_to_bfloat16(a); - } - - template <> - __device__ __half convert_from_float<__half>(float a) { - return __float2half(a); - } - - - template - __global__ void batched_matmul_kernel(const T* A, const T* B, T* C, - int M, int N, int K, int batch_size) { - int row = blockIdx.y * blockDim.y + threadIdx.y; - int col = blockIdx.x * blockDim.x + threadIdx.x; - int batch = blockIdx.z; - - if (row < M && col < K && batch < batch_size) { - float sum = 0.0f; - for (int n = 0; n < N; ++n) { - sum += convert_to_float(A[batch * M * N + row * N + n]) * - convert_to_float(B[batch * N * K + n * K + col]); - } - C[batch * M * K + row * K + col] = convert_from_float(sum); - } - } - - template - void batched_matmul(const T* A, const T* B, T* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, - A, B, C, M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Specialization for at::Half - template <> - void batched_matmul(const at::Half* A, const at::Half* B, at::Half* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel<__half>), numBlocks, threadsPerBlock, 0, 0, - reinterpret_cast(A), - reinterpret_cast(B), - reinterpret_cast<__half*>(C), - M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Specialization for at::BFloat16 - template <> - void batched_matmul(const at::BFloat16* A, const at::BFloat16* B, at::BFloat16* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, - reinterpret_cast(A), - reinterpret_cast(B), - reinterpret_cast(C), - M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Explicit instantiations for supported types - template void batched_matmul(const float*, const float*, float*, int, int, int, int); - template void batched_matmul(const double*, const double*, double*, int, int, int, int); - -} // at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorCompare.cu b/aten/src/ATen/native/zoom/TensorCompare.cu new file mode 100644 index 00000000000000..e92d058c9b7222 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorCompare.cu @@ -0,0 +1,133 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + + +namespace at::native { + +namespace { + +void where_kernel_impl(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_zoom", [&] { + gpu_kernel( + iter, + [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { + return cond_val ? self_val : other_val; + }); + }); +} + +void isposinf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits::infinity(); } + ); + }); +} + +void isneginf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits::infinity(); } + ); + }); +} + +void clamp_kernel_impl(TensorIteratorBase& iter) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_zoom", [&] { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (at::_isnan(v)) { + return v; + } if (at::_isnan(lower)) { + return lower; + } if (at::_isnan(upper)) { + return upper; + } else { + return ::min(::max(v, lower), upper); + } + }); + }); +} + +void inline launch_clamp_scalar(TensorIteratorBase& iter, Scalar lim0, Scalar lim1, at::native::detail::ClampLimits minmax){ + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_zoom", [&] { + using opmath_t = at::opmath_type; + auto lim0_val = lim0.to(); + auto lim1_val = lim1.to(); + + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(static_cast(v))) { + return v; + } else if (minmax==at::native::detail::ClampLimits::Min){ + return ::max(static_cast(v), lim0_val); + } else if (minmax==at::native::detail::ClampLimits::Max){ + return ::min(static_cast(v), lim0_val); + } else { + return ::min(::max(static_cast(v), lim0_val), lim1_val); + } + }); + }); +} + + +void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min, const Scalar& max) { + launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax); +} + +void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min) { + launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min); +} + +void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max) { + launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); +} + +} // anonymous namespace + + +REGISTER_PRIVATEUSE1_DISPATCH(where_kernel, &where_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isposinf_stub, &isposinf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isneginf_stub, &isneginf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_stub, &clamp_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); + +template +__global__ void _assert_async_zoom_kernel(const scalar_t* input) { + ZOOM_KERNEL_ASSERT(input[0] != 0); +} + +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} + +void _assert_async_zoom(const Tensor& self_tensor) { + const TensorBase &self = get_tensor_base(self_tensor); + auto n = self.numel(); + TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); + TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); + auto stream = c10::zoom::getCurrentZoomStream(); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_zoom", [&] { + _assert_async_zoom_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +// TODO (tmanlaibaatar) Ignore assert msg for now +void _assert_async_msg_zoom(const Tensor& self_tensor, c10::string_view assert_msg) { + _assert_async_zoom(self_tensor); +} + +} // namespace at::native \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6ed8..cd473ac92c4f4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -32,6 +32,7 @@ instantiate_device_type_tests, onlyCPU, onlyCUDA, + onlyCUDAAndZOOM, onlyNativeDeviceTypes, OpDTypes, ops, @@ -283,7 +284,7 @@ def test_numpy_ref(self, device, dtype, op): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA + @onlyCUDAAndZOOM @suppress_warnings @slowTest @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) diff --git a/torch/zoom/__init__.py b/torch/zoom/__init__.py index 7b5a757d08520c..debc3c917f96ae 100644 --- a/torch/zoom/__init__.py +++ b/torch/zoom/__init__.py @@ -44,7 +44,7 @@ def _maybe_exchange_device(device: int) -> int: return -1 raise RuntimeError("PyTorch was compiled without Zoom support") - +from .zoom_triton_mm import * _initialized = False _tls = threading.local() diff --git a/torch/zoom/zoom_triton_mm.py b/torch/zoom/zoom_triton_mm.py new file mode 100644 index 00000000000000..6967ed7f8c1a77 --- /dev/null +++ b/torch/zoom/zoom_triton_mm.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl +from torch.library import register_kernel +torch.utils.rename_privateuse1_backend('zoom') + +@triton.heuristics({ + 'BLOCK_SIZE_M': lambda args: 128, + 'BLOCK_SIZE_N': lambda args: 64, + 'BLOCK_SIZE_K': lambda args: 32, + 'GROUP_SIZE_M': lambda args: 32, + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def batched_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + B, + M, + N, + K, + stride_ab, + stride_am, + stride_ak, + stride_bb, + stride_bk, + stride_bn, + stride_cb, + stride_cm, + stride_cn, + a_scale_ptr, + b_scale_ptr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + APPLY_SCALE: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the batched matmul C = A x B. + A has shape (B, M, K), B has shape (B, K, N) and C has shape (B, M, N) + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_batch = num_pid_m * num_pid_n + batch_id = pid // num_pid_in_batch + pid_in_batch = pid % num_pid_in_batch + + if GROUP_SIZE_M == 1: + pid_m = pid_in_batch // num_pid_n + pid_n = pid_in_batch % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_in_batch % group_size_m) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m + + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + batch_id * stride_ab + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + batch_id * stride_bb + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if APPLY_SCALE: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if APPLY_SCALE: + accumulator = accumulator * a_scale * b_scale + + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + batch_id * stride_cb + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +# Wrapper for batched gemm kernel +def batched_matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""): + assert a.shape[2] == b.shape[1], "Incompatible matrix dimensions!!!" + assert a.shape[0] == b.shape[0], "Incompatible batch dimensions!!!" + assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!" + B, M, K = a.shape + _, K, N = b.shape + grid = lambda META: (B * triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + batched_matmul_kernel[grid]( + a, + b, + c, + B, + M, + N, + K, + a.stride(0), + a.stride(1), + a.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + c.stride(0), + c.stride(1), + c.stride(2), + a_scale, + b_scale, + APPLY_SCALE=scale_a8_b8, + ACTIVATION=activation, + ) + +# Activation function. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + +name_to_torch_types = { + 'int8': torch.int8, + 'int32': torch.int32, + 'fp16': torch.float16, + 'fp32': torch.float32, + 'bf16': torch.bfloat16, + 'fp8e5': torch.float8_e5m2fnuz, + 'fp8e4': torch.float8_e4m3fnuz, +} + +dtype_max = { + dtype: (torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)).max + for dtype in [ + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.int8, + ] +} + +def mm_out_zoom(self, mat2, out): + batched_matmul(self.unsqueeze(0), mat2.unsqueeze(0), out.unsqueeze(0), None, None, False) + +def bmm_out_zoom(self, mat2, out): + batched_matmul(self, mat2, out, None, None, False) + +@register_kernel("aten::mm.out", "zoom") +def mm_out(self, mat2, out): + mm_out_zoom(self, mat2, out) + +@register_kernel("aten::mm", "zoom") +def mm(self, mat2): + out = self.new_empty((self.size(0), mat2.size(1))) + mm_out_zoom(self, mat2, out) + return out + +@register_kernel("aten::bmm.out", "zoom") +def bmm_out(self, mat2, out): + bmm_out_zoom(self, mat2, out) + +@register_kernel("aten::bmm", "zoom") +def bmm(self, mat2): + out = self.new_empty((self.size(0), self.size(1), mat2.size(2))) + bmm_out_zoom(self, mat2, out) + return out + \ No newline at end of file From 74b12b7cc39027ad97365ca528c28b680ed51acb Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Mon, 27 Jan 2025 05:02:18 +0000 Subject: [PATCH 09/23] add build and llama3 demo instructions --- BuildingZoom.md | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 BuildingZoom.md diff --git a/BuildingZoom.md b/BuildingZoom.md new file mode 100644 index 00000000000000..66918e23162819 --- /dev/null +++ b/BuildingZoom.md @@ -0,0 +1,136 @@ +# Setup Python Env + +To start out, we just need to follow the normal procedure to build PyTorch from source. For convenience I've included these steps here: + +```bash +conda create -n nod-pytorch python==3.10 +conda activate nod-pytorch +conda install cmake ninja +pip install -r requirements.txt +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +python setup.py develop +``` + +# CMake Build + +Using the `USE_ZOOM` flag with CMake will enable building with HIP for ROCm without requiring any of the "HIPify" scripts in order to build. This will include HIP libraries and populate `torch.version.hip` appropriately. This flag is NOT yet entered into the `setup.py` script, so for now it needs to be added manually via `cmake` or `ccmake`. + +You'll need to set the `ROCM_PATH` and `HIP_ROOT_DIR` environment variables appropriately, by default on linux these should be `/opt/rocm/` and `/opt/rocm/hip` respectively. + +If you're running on Linux you can just use `build.sh` to build: +```bash +cd pytorch/ +source build.sh +``` + +Alternatively, if you want to manually setup your CMake build you can use the following commands: + +```bash +cd build/ +export PYTORCH_ROCM_ARCH=gfx90a +export ROCM_PATH=/opt/rocm +export HIP_ROOT_DIR=/opt/rocm/hip +cmake -DUSE_ZOOM=ON --build . --target install +``` + +# Running PyTorch with Zoom + +Programs using the zoom backend must be prefaced with this stub until we register a proper dispatch key in pytorch + +```python +import torch +import torch.zoom +torch.utils.rename_privateuse1_backend('zoom') +torch.utils.generate_methods_for_privateuse1_backend(unsupported_dtype=None) +``` + +# Installing Triton + +Since main Triton currently treats ROCm as if its masquerading as `torch.cuda`, we need a custom installation: + +```bash +git clone https://github.com/123epsilon/triton.git +cd triton/ +git checkout zoom +pip install pybind11 +pip install python/ +``` + +# Running LLama3 with Triton using LigerKernels and HuggingFace + +```bash +pip install liger-kernel +``` + +```python +# Run Llama 3 +import torch +from transformers import AutoTokenizer +from liger_kernel.transformers import AutoLigerKernelForCausalLM +from time import perf_counter as pf +torch.utils.rename_privateuse1_backend('zoom') + +# Set up the model and tokenizer +model_id = "meta-llama/Meta-Llama-3-8B" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoLigerKernelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="zoom" +) + +# Function to generate text +def generate_text(prompt, max_length=30): + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + outputs = model.generate(**inputs, max_new_tokens=max_length) + return tokenizer.decode(outputs[0], skip_special_tokens=True) + +# Example usage +prompt = "Hey, how are you doing today?" +s = pf() +response = generate_text(prompt) +e = pf() +print(f"Prompt: {prompt}") +print(f"Response: {response}") + +print(f"{e-s} seconds") +``` + +```python +# Or run the instruct-tuned variant +import torch +import transformers +from liger_kernel.transformers import apply_liger_kernel_to_llama +torch.utils.rename_privateuse1_backend('zoom') + +apply_liger_kernel_to_llama() +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +pipeline = transformers.pipeline( + "text-generation", + model=model_id, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="zoom", +) + +messages = [ + {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, + {"role": "user", "content": "Who are you?"}, +] + +terminators = [ + pipeline.tokenizer.eos_token_id, + pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +outputs = pipeline( + messages, + max_new_tokens=30, + eos_token_id=terminators, + do_sample=True, + temperature=0.6, + top_p=0.9, +) +print(outputs[0]["generated_text"][-1]) + +``` \ No newline at end of file From e7b9919f40145fe41e83da73a3f3dc1b5a17dc3e Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Tue, 28 Jan 2025 00:10:35 +0000 Subject: [PATCH 10/23] add range factories --- aten/src/ATen/native/native_functions.yaml | 4 + aten/src/ATen/native/zoom/RangeFactories.cu | 270 ++++++++++++++++++++ 2 files changed, 274 insertions(+) create mode 100644 aten/src/ATen/native/zoom/RangeFactories.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6271e79c453abf..5af124fc7703fc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -797,6 +797,7 @@ dispatch: CPU, Meta: arange_out CUDA: arange_cuda_out + PrivateUse1: arange_zoom_out MPS: arange_mps_out cpp_no_default_args: ['step'] @@ -3431,6 +3432,7 @@ dispatch: CPU, Meta: linspace_out CUDA: linspace_cuda_out + PrivateUse1: linspace_zoom_out MPS: linspace_out_mps - func: linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) @@ -3647,6 +3649,7 @@ dispatch: CPU, Meta: logspace_out CUDA: logspace_cuda_out + PrivateUse1: logspace_zoom_out - func: logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) category_override: factory @@ -4795,6 +4798,7 @@ dispatch: CPU, Meta: range_out CUDA: range_cuda_out + PrivateUse1: range_zoom_out MPS: range_mps_out cpp_no_default_args: ['step'] diff --git a/aten/src/ATen/native/zoom/RangeFactories.cu b/aten/src/ATen/native/zoom/RangeFactories.cu new file mode 100644 index 00000000000000..5f7417703ca601 --- /dev/null +++ b/aten/src/ATen/native/zoom/RangeFactories.cu @@ -0,0 +1,270 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +#define GPU_LAMBDA __device__ __host__ + +namespace { + + +constexpr int num_threads() { + return 128; +} + +constexpr int thread_work_size = 1; +constexpr int block_work_size = thread_work_size * num_threads(); + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void elementwise_kernel_with_index(index_t N, func_t f, typename function_traits::result_type *data) { + #pragma unroll + for (int i = 0; i < thread_work_size; i++) { + index_t idx = block_work_size * blockIdx.x + num_threads() * i + threadIdx.x; + if (idx < N) { + data[idx] = f(idx); + } + } +} + +template +void gpu_kernel_with_index(at::Tensor &output, func_t f) { + int64_t N = output.numel(); + if (N == 0) { + return; + } + int64_t grid = (N + block_work_size - 1) / block_work_size; + auto stream = c10::zoom::getCurrentZoomStream(); + using scalar_t = typename function_traits::result_type; + if (N <= std::numeric_limits::max()) { + elementwise_kernel_with_index<<>>(N, f, output.mutable_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + elementwise_kernel_with_index<<>>(N, f, output.mutable_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace + +namespace at::native { + +Tensor& linspace_zoom_out(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) { + TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); + + if (result.numel() != steps) { + result.resize_({steps}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + if (steps == 0) { + // skip + } else if (steps == 1) { + r.fill_(start); + } else if (isIntegralType(r.scalar_type(), 0)) { + AT_DISPATCH_INTEGRAL_TYPES(r.scalar_type(), "linspace_zoom", [&]() { + scalar_t scalar_start = start.to(); + scalar_t scalar_end = end.to(); + // Cast `end` and `start` to `float`, since range can be larger than scalar_t for integral types + float step = (static_cast(scalar_end) - static_cast(scalar_start)) / (steps - 1); + const int64_t halfway = steps / 2; + gpu_kernel_with_index(r, [scalar_start, scalar_end, steps, step, halfway]GPU_LAMBDA(int64_t ind) -> scalar_t { + if (ind < halfway) { + return scalar_start + (step * ind); + } + + return scalar_end - step * (steps - ind - 1); + }); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, r.scalar_type(), "linspace_zoom", [&]() { + scalar_t scalar_start = start.to(); + scalar_t scalar_end = end.to(); + scalar_t step = (scalar_end - scalar_start) / static_cast(steps - 1); + const int64_t halfway = steps / 2; + gpu_kernel_with_index(r, [scalar_start, scalar_end, steps, step, halfway]GPU_LAMBDA(int64_t ind) -> scalar_t { + if (ind < halfway) { + return scalar_start + (step * ind); + } + + return scalar_end - step * (steps - ind - 1); + }); + }); + } + + if (!is_contiguous) { + result.copy_(r); + } + + return result; +} + +Tensor& logspace_zoom_out(const Scalar& start, const Scalar& end, int64_t steps, double base, Tensor& result) { + TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); + + if (result.numel() != steps) { + result.resize_({steps}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + if (steps == 0) { + // skip + } else if (steps == 1) { + if (isComplexType(r.scalar_type())){ + r.fill_(std::pow(base, start.to>())); + } else { + r.fill_(std::pow(base, start.to())); + } + } else if (isIntegralType(r.scalar_type(), 0)) { + AT_DISPATCH_INTEGRAL_TYPES(r.scalar_type(), "logspace_zoom", [&]() { + float scalar_base = static_cast(base); // Use float to avoid promotion to double + scalar_t scalar_start = start.to(); + scalar_t scalar_end = end.to(); + float step = static_cast(scalar_end - scalar_start) / (steps - 1); + const int64_t halfway = steps / 2; + gpu_kernel_with_index(r, [scalar_start, scalar_end, scalar_base, steps, step, halfway]GPU_LAMBDA(int64_t ind) -> scalar_t { + if (ind < halfway) { + return std::pow(scalar_base, scalar_start + step * ind); + } + return std::pow(scalar_base, scalar_end - step * (steps - ind - 1)); + }); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, r.scalar_type(), "logspace_zoom", [&]() { + scalar_t scalar_base = static_cast(base); + scalar_t scalar_start = start.to(); + scalar_t scalar_end = end.to(); + scalar_t step = (scalar_end - scalar_start) / static_cast(steps - 1); + const int64_t halfway = steps / 2; + gpu_kernel_with_index(r, [scalar_start, scalar_end, scalar_base, steps, step, halfway]GPU_LAMBDA(int64_t ind) -> scalar_t { + if (ind < halfway) { + return std::pow(scalar_base, scalar_start + step * ind); + } + return std::pow(scalar_base, scalar_end - step * (steps - ind - 1)); + }); + }); + } + + if (!is_contiguous) { + result.copy_(r); + } + + return result; +} + +Tensor& range_zoom_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "range_zoom", [&]() { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + int64_t size = static_cast(((xend - xstart) / xstep) + 1); + + if (result.numel() != size) { + result.resize_({size}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { + accscalar_t inc = xstep * static_cast(ind); + accscalar_t val = xstart + inc; + return static_cast(val); + }); + + if(!is_contiguous) { + result.copy_(r); + } + + }); + + return result; +} + +Tensor& arange_zoom_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "arange_zoom", [&]() { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, + // but double on cpu for the same, + // and the effective output size starts differing on CPU vs GPU because of precision issues, which + // we dont want. + // the corner-case we do want to take into account is int64_t, which has higher precision than double + double size_d; + if constexpr (std::is_same_v) { + int64_t sgn = (xstep > 0) - (xstep < 0); + size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); + } else { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } + + TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + int64_t size = static_cast(size_d); + int64_t numel = result.numel(); + + if (numel != size) { + if(numel > 0){ + TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), + " is ", numel, " which does not match the computed number of elements ", size, + ". Note that this may occur as a result of rounding error. " + "The out tensor will be resized to a tensor of shape (", size, ",)."); + } + result.resize_({size}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { + accscalar_t inc = xstep * static_cast(ind); + accscalar_t val = xstart + inc; + return static_cast(val); + }); + + if(!is_contiguous) { + result.copy_(r); + } + }); + + return result; +} + +} // namespace at::native \ No newline at end of file From 1eae71dac3ce6460dece13e121dc36824e79dbcd Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Tue, 28 Jan 2025 00:42:24 +0000 Subject: [PATCH 11/23] adjust find_package calls for zoom in cmake --- cmake/public/LoadHIP.cmake | 45 ++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 107a6fbc15dac5..b7ab4c6d3d5aeb 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -151,23 +151,34 @@ if(HIP_FOUND) set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust) set(hipsolver_DIR ${ROCM_PATH}/lib/cmake/hipsolver) - - find_package_and_print_version(hip REQUIRED) - find_package_and_print_version(hsa-runtime64 REQUIRED) - find_package_and_print_version(amd_comgr REQUIRED) - find_package_and_print_version(rocrand REQUIRED) - find_package_and_print_version(hiprand REQUIRED) - find_package_and_print_version(rocblas REQUIRED) - find_package_and_print_version(hipblas REQUIRED) - find_package_and_print_version(hipblaslt REQUIRED) - find_package_and_print_version(miopen REQUIRED) - find_package_and_print_version(hipfft REQUIRED) - find_package_and_print_version(hipsparse REQUIRED) - find_package_and_print_version(rccl) - find_package_and_print_version(rocprim REQUIRED) - find_package_and_print_version(hipcub REQUIRED) - find_package_and_print_version(rocthrust REQUIRED) - find_package_and_print_version(hipsolver REQUIRED) + if(USE_ROCM) + find_package_and_print_version(hip REQUIRED) + find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(amd_comgr REQUIRED) + find_package_and_print_version(rocrand REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(rocblas REQUIRED) + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hipblaslt REQUIRED) + find_package_and_print_version(miopen REQUIRED) + find_package_and_print_version(hipfft REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + find_package_and_print_version(rccl) + find_package_and_print_version(rocprim REQUIRED) + find_package_and_print_version(hipcub REQUIRED) + find_package_and_print_version(rocthrust REQUIRED) + find_package_and_print_version(hipsolver REQUIRED) + else() # USE_ZOOM + find_package_and_print_version(hip REQUIRED) + find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(amd_comgr REQUIRED) + find_package_and_print_version(rocrand REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(miopen REQUIRED) + find_package_and_print_version(rocprim REQUIRED) + find_package_and_print_version(hipcub REQUIRED) + find_package_and_print_version(rocthrust REQUIRED) + endif() find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${ROCM_PATH}/lib) From 2ca34c835a82d7a06cc36112f7555f1bfbfcdf35 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Tue, 4 Feb 2025 02:31:49 +0000 Subject: [PATCH 12/23] add sudo to build whl --- .github/workflows/build_zoom_backend.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml index aa7053cafe8379..724b3f0e23e960 100644 --- a/.github/workflows/build_zoom_backend.yml +++ b/.github/workflows/build_zoom_backend.yml @@ -79,7 +79,7 @@ jobs: python -m venv venv source venv/bin/activate pip install -r requirements.txt - ./build.sh + sudo ./build.sh - name: "Audit" id: audit From 5d099e92698938775da93219124e553c256113b9 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Tue, 4 Feb 2025 22:27:14 +0000 Subject: [PATCH 13/23] chmod build script --- .github/workflows/build_zoom_backend.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml index 724b3f0e23e960..3afa5e97be74ee 100644 --- a/.github/workflows/build_zoom_backend.yml +++ b/.github/workflows/build_zoom_backend.yml @@ -79,7 +79,8 @@ jobs: python -m venv venv source venv/bin/activate pip install -r requirements.txt - sudo ./build.sh + chmod +x ./build.sh + ./build.sh - name: "Audit" id: audit From 51f643283286ba60ee0269734391d98959844689 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Wed, 19 Feb 2025 00:49:36 +0000 Subject: [PATCH 14/23] CI checkout recursive --- .github/workflows/build_zoom_backend.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml index 3afa5e97be74ee..5e8e27a7926805 100644 --- a/.github/workflows/build_zoom_backend.yml +++ b/.github/workflows/build_zoom_backend.yml @@ -58,7 +58,7 @@ jobs: - name: "Check out repository" uses: actions/checkout@v4.2.2 with: - submodules: true + submodules: recursive - name: Enable cache uses: actions/cache/restore@v3 From 6c373c534c5aa00c8ffb70bf9e77030ee68d5d1d Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Wed, 19 Feb 2025 02:28:26 +0000 Subject: [PATCH 15/23] clang-19 compat in intrusive_ptr --- c10/util/intrusive_ptr.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 035f22e3c1867b..8f50e91d8295cd 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -379,7 +379,7 @@ class intrusive_ptr final { intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { // NOLINTNEXTLINE(*assign*) - return operator= (std::move(rhs)); + return this->template operator= (std::move(rhs)); } template @@ -397,7 +397,7 @@ class intrusive_ptr final { // NOLINTNEXTLINE(bugprone-unhandled-self-assignment) intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept { // NOLINTNEXTLINE(*assign-operator, *assignment-signature) - return operator= (rhs); + return this->template operator= (rhs); } template @@ -769,7 +769,7 @@ class weak_intrusive_ptr final { weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept { // NOLINTNEXTLINE(*assign*) - return operator= (std::move(rhs)); + return this->template operator= (std::move(rhs)); } template @@ -788,7 +788,7 @@ class weak_intrusive_ptr final { return *this; } // NOLINTNEXTLINE(*assign*) - return operator= (rhs); + return this->template operator= (rhs); } weak_intrusive_ptr& operator=( From 12f62e26c1008bc5adb74ee470d8a2f785173cff Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Wed, 19 Feb 2025 02:47:30 +0000 Subject: [PATCH 16/23] add venv to audit build step --- .github/workflows/build_zoom_backend.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml index 5e8e27a7926805..8550d088aeb633 100644 --- a/.github/workflows/build_zoom_backend.yml +++ b/.github/workflows/build_zoom_backend.yml @@ -87,6 +87,7 @@ jobs: run: | sudo apt install patchelf + python -m venv venv source venv/bin/activate pip install auditwheel auditwheel repair -w dist --plat manylinux_2_39_x86_64 dist/torch* From a20c49366edd23bc8cb0c1ee1a63e68866f6ef67 Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Tue, 18 Mar 2025 21:03:20 +0000 Subject: [PATCH 17/23] add more kernels for autograd examples --- aten/src/ATen/native/Activation.cpp | 23 + aten/src/ATen/native/native_functions.yaml | 224 +++++- aten/src/ATen/native/zoom/Activation.cpp | 108 +++ aten/src/ATen/native/zoom/Activation.h | 20 + .../ATen/native/zoom/ActivationEluKernel.cu | 86 +++ .../ATen/native/zoom/ActivationGeluKernel.cu | 88 +++ .../ATen/native/zoom/ActivationGluKernel.cu | 141 ++++ .../native/zoom/ActivationHardshrinkKernel.cu | 39 + .../zoom/ActivationHardsigmoidKernel.cu | 74 ++ .../native/zoom/ActivationHardswishKernel.cu | 63 ++ .../native/zoom/ActivationHardtanhKernel.cu | 45 ++ .../native/zoom/ActivationLeakyReluKernel.cu | 62 ++ .../native/zoom/ActivationLogSigmoidKernel.cu | 64 ++ .../ATen/native/zoom/ActivationMishKernel.cu | 64 ++ .../ATen/native/zoom/ActivationPreluKernel.cu | 48 ++ .../ATen/native/zoom/ActivationSiluKernel.cu | 60 ++ .../native/zoom/ActivationSoftplusKernel.cu | 74 ++ .../native/zoom/ActivationSoftshrinkKernel.cu | 58 ++ .../native/zoom/ActivationThresholdKernel.cu | 52 ++ .../ATen/native/zoom/ForeachBinaryOpList.cu | 295 ++++++++ .../ATen/native/zoom/ForeachBinaryOpScalar.cu | 247 +++++++ .../native/zoom/ForeachBinaryOpScalarList.cu | 241 +++++++ .../zoom/ForeachBinaryOpScalarTensor.cu | 206 ++++++ aten/src/ATen/native/zoom/ForeachFunctors.cuh | 681 ++++++++++++++++++ .../native/zoom/ForeachMinMaxFunctors.cuh | 22 + .../ATen/native/zoom/ForeachPointwiseOp.cu | 272 +++++++ aten/src/ATen/native/zoom/ForeachReduceOp.cu | 352 +++++++++ aten/src/ATen/native/zoom/ForeachTernaryOp.cu | 159 ++++ aten/src/ATen/native/zoom/ForeachUnaryOp.cu | 408 +++++++++++ aten/src/ATen/native/zoom/Loss.cu | 627 ++++++++++++++++ .../src/ATen/native/zoom/MultiTensorApply.cuh | 379 ++++++++++ aten/src/ATen/native/zoom/NLLLoss2d.cu | 537 ++++++++++++++ aten/src/ATen/native/zoom/Pow.cuh | 58 ++ aten/src/ATen/native/zoom/RecordStream.cu | 17 + aten/src/ATen/native/zoom/RreluWithNoise.cu | 195 +++++ 35 files changed, 6057 insertions(+), 32 deletions(-) create mode 100644 aten/src/ATen/native/zoom/Activation.cpp create mode 100644 aten/src/ATen/native/zoom/Activation.h create mode 100644 aten/src/ATen/native/zoom/ActivationEluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationGeluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationGluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationHardswishKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationMishKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationPreluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationSiluKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu create mode 100644 aten/src/ATen/native/zoom/ActivationThresholdKernel.cu create mode 100644 aten/src/ATen/native/zoom/ForeachBinaryOpList.cu create mode 100644 aten/src/ATen/native/zoom/ForeachBinaryOpScalar.cu create mode 100644 aten/src/ATen/native/zoom/ForeachBinaryOpScalarList.cu create mode 100644 aten/src/ATen/native/zoom/ForeachBinaryOpScalarTensor.cu create mode 100644 aten/src/ATen/native/zoom/ForeachFunctors.cuh create mode 100644 aten/src/ATen/native/zoom/ForeachMinMaxFunctors.cuh create mode 100644 aten/src/ATen/native/zoom/ForeachPointwiseOp.cu create mode 100644 aten/src/ATen/native/zoom/ForeachReduceOp.cu create mode 100644 aten/src/ATen/native/zoom/ForeachTernaryOp.cu create mode 100644 aten/src/ATen/native/zoom/ForeachUnaryOp.cu create mode 100644 aten/src/ATen/native/zoom/Loss.cu create mode 100644 aten/src/ATen/native/zoom/MultiTensorApply.cuh create mode 100644 aten/src/ATen/native/zoom/NLLLoss2d.cu create mode 100644 aten/src/ATen/native/zoom/Pow.cuh create mode 100644 aten/src/ATen/native/zoom/RecordStream.cu create mode 100644 aten/src/ATen/native/zoom/RreluWithNoise.cu diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index a0141f974923e6..be525a961d9d6c 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -787,6 +787,18 @@ Tensor log_sigmoid_backward_cuda(const Tensor& grad_output, const Tensor& input, return iter.output(); } +Tensor log_sigmoid_backward_zoom(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) { + auto grad_input = at::empty_like(grad_output); + // NOTE: buffer is only used by CPU dispatch, we just ignore it here + auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(grad_output) + .build(); + log_sigmoid_backward_stub(kPrivateUse1, iter); + return iter.output(); +} + Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) { auto grad_input = at::empty_like(grad_output); auto iter = at::TensorIteratorConfig() @@ -810,6 +822,17 @@ Tensor& log_sigmoid_backward_cuda_out(const Tensor& grad_output, const Tensor& i return grad_input; } +Tensor& log_sigmoid_backward_zoom_out(const Tensor& grad_output, const Tensor& input, + const Tensor& buffer, Tensor& grad_input) { +auto iter = TensorIteratorConfig() +.add_output(grad_input) +.add_const_input(input) +.add_const_input(grad_output) +.build(); +log_sigmoid_backward_stub(kPrivateUse1, iter); +return grad_input; +} + Tensor& log_sigmoid_backward_cpu_out(const Tensor& grad_output, const Tensor& input, const Tensor& buffer, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5af124fc7703fc..fd33884a40b15a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1135,6 +1135,7 @@ dispatch: CPU: binary_cross_entropy_cpu CUDA: binary_cross_entropy_cuda + PrivateUse1: binary_cross_entropy_zoom MPS: binary_cross_entropy_mps - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) @@ -1144,6 +1145,7 @@ dispatch: CPU: binary_cross_entropy_out_cpu CUDA: binary_cross_entropy_out_cuda + PrivateUse1: binary_cross_entropy_out_zoom MPS: binary_cross_entropy_out_mps - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor @@ -1152,6 +1154,7 @@ dispatch: CPU: binary_cross_entropy_backward_cpu CUDA: binary_cross_entropy_backward_cuda + PrivateUse1: binary_cross_entropy_backward_zoom MPS: binary_cross_entropy_backward_mps - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -1160,6 +1163,7 @@ dispatch: CPU: binary_cross_entropy_backward_out_cpu CUDA: binary_cross_entropy_backward_out_cuda + PrivateUse1: binary_cross_entropy_backward_out_zoom MPS: binary_cross_entropy_backward_out_mps - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor @@ -4995,7 +4999,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu + CPU, CUDA, PrivateUse1: relu MPS: relu_mps MkldnnCPU: mkldnn_relu QuantizedCPU: relu_quantized_cpu @@ -5009,7 +5013,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu_ + CPU, CUDA, PrivateUse1: relu_ MPS: relu_mps_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ @@ -5032,14 +5036,14 @@ - func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor dispatch: - CPU, CUDA: _prelu_kernel + CPU, CUDA, PrivateUse1: _prelu_kernel QuantizedCPU: _prelu_kernel_quantized_cpu MkldnnCPU: mkldnn_prelu MPS: prelu_mps - func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) dispatch: - CPU, CUDA: _prelu_kernel_backward + CPU, CUDA, PrivateUse1: _prelu_kernel_backward MkldnnCPU: mkldnn_prelu_backward MPS: prelu_backward_mps @@ -5051,6 +5055,7 @@ dispatch: CPU: gelu_out_cpu CUDA: gelu_out_cuda + PrivateUse1: gelu_out_zoom MPS: gelu_out_mps - func: gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) @@ -5079,6 +5084,7 @@ dispatch: CPU: gelu_backward_out_cpu CUDA: gelu_backward_out_cuda + PrivateUse1: gelu_backward_out_zoom MPS: gelu_backward_out_mps - func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor @@ -5100,7 +5106,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: hardshrink_out + CPU, CUDA, PrivateUse1: hardshrink_out - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor structured_delegate: hardshrink.out @@ -5111,7 +5117,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hardshrink_backward_out + CPU, CUDA, PrivateUse1: hardshrink_backward_out - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor structured_delegate: hardshrink_backward.grad_input @@ -5204,7 +5210,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_out + CPU, CUDA, PrivateUse1: silu_out MPS: silu_out_mps tags: pointwise @@ -5213,7 +5219,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_backward_out + CPU, CUDA, PrivateUse1: silu_backward_out MPS: silu_backward_out_mps tags: pointwise @@ -5238,13 +5244,13 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: mish_out + CPU, CUDA, PrivateUse1: mish_out MPS: mish_out_mps - func: mish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: - CPU, CUDA: mish_backward + CPU, CUDA, PrivateUse1: mish_backward MPS: mish_backward_mps CompositeImplicitAutograd: math_mish_backward @@ -6040,14 +6046,14 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: threshold_out + CPU, CUDA, PrivateUse1: threshold_out MPS: threshold_out_mps - func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: threshold_backward_out + CPU, CUDA, PrivateUse1: threshold_backward_out MPS: threshold_backward_out_mps SparseCPU, SparseCUDA: threshold_backward_sparse_out SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed_out @@ -10367,6 +10373,7 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda + PrivateUse1: foreach_tensor_add_scalar_kernel_zoom - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10374,6 +10381,7 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_add_scalar_kernel_zoom_ autogen: _foreach_add.Scalar_out - func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10382,6 +10390,7 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda + PrivateUse1: foreach_tensor_add_list_kernel_zoom - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10389,6 +10398,7 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ + PrivateUse1: foreach_tensor_add_list_kernel_zoom_ autogen: _foreach_add.List_out - func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10397,6 +10407,7 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow CUDA: foreach_tensor_add_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_add_scalarlist_kernel_zoom - func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10404,6 +10415,7 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_add_scalarlist_kernel_zoom_ autogen: _foreach_add.ScalarList_out - func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] @@ -10412,6 +10424,7 @@ dispatch: CPU: foreach_tensor_add_tensor_kernel_slow CUDA: foreach_tensor_add_tensor_kernel_cuda + PrivateUse1: foreach_tensor_add_tensor_kernel_zoom - func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10419,6 +10432,7 @@ dispatch: CPU: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_add_tensor_kernel_zoom_ autogen: _foreach_add.Tensor_out - func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10427,6 +10441,7 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow CUDA: foreach_tensor_sub_scalar_kernel_cuda + PrivateUse1: foreach_tensor_sub_scalar_kernel_zoom - func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10434,6 +10449,7 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_scalar_kernel_zoom_ autogen: _foreach_sub.Scalar_out - func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10442,6 +10458,7 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow CUDA: foreach_tensor_sub_list_kernel_cuda + PrivateUse1: foreach_tensor_sub_list_kernel_zoom - func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10449,6 +10466,7 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_list_kernel_zoom_ autogen: _foreach_sub.List_out - func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10457,6 +10475,7 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow CUDA: foreach_tensor_sub_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_sub_scalarlist_kernel_zoom - func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10464,6 +10483,7 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_scalarlist_kernel_zoom_ autogen: _foreach_sub.ScalarList_out - func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10472,6 +10492,7 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow CUDA: foreach_tensor_mul_scalar_kernel_cuda + PrivateUse1: foreach_tensor_mul_scalar_kernel_zoom - func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10479,6 +10500,7 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_scalar_kernel_zoom_ autogen: _foreach_mul.Scalar_out - func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10487,6 +10509,7 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda + PrivateUse1: foreach_tensor_mul_list_kernel_zoom - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10494,6 +10517,7 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_list_kernel_zoom_ autogen: _foreach_mul.List_out - func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10502,6 +10526,7 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow CUDA: foreach_tensor_mul_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_mul_scalarlist_kernel_zoom - func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10509,6 +10534,7 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_scalarlist_kernel_zoom_ autogen: _foreach_mul.ScalarList_out - func: _foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] @@ -10517,6 +10543,7 @@ dispatch: CPU: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda + PrivateUse1: foreach_tensor_mul_tensor_kernel_zoom - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10524,6 +10551,7 @@ dispatch: CPU: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_tensor_kernel_zoom_ autogen: _foreach_mul.Tensor_out - func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10532,6 +10560,7 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow CUDA: foreach_tensor_div_scalar_kernel_cuda + PrivateUse1: foreach_tensor_div_scalar_kernel_zoom - func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10539,6 +10568,7 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_div_scalar_kernel_zoom_ autogen: _foreach_div.Scalar_out - func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10547,6 +10577,7 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow CUDA: foreach_tensor_div_list_kernel_cuda + PrivateUse1: foreach_tensor_div_list_kernel_zoom - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10554,6 +10585,7 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ + PrivateUse1: foreach_tensor_div_list_kernel_zoom_ autogen: _foreach_div.List_out - func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10562,6 +10594,7 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow CUDA: foreach_tensor_div_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_div_scalarlist_kernel_zoom - func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10569,6 +10602,7 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_div_scalarlist_kernel_zoom_ autogen: _foreach_div.ScalarList_out - func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] @@ -10577,6 +10611,7 @@ dispatch: CPU: foreach_tensor_div_tensor_kernel_slow CUDA: foreach_tensor_div_tensor_kernel_cuda + PrivateUse1: foreach_tensor_div_tensor_kernel_zoom - func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10584,6 +10619,7 @@ dispatch: CPU: foreach_tensor_div_tensor_kernel_slow_ CUDA: foreach_tensor_div_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_div_tensor_kernel_zoom_ autogen: _foreach_div.Tensor_out - func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10592,6 +10628,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom - func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10599,6 +10636,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom_ autogen: _foreach_clamp_max.Scalar_out - func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10607,6 +10645,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom - func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10614,6 +10653,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom_ autogen: _foreach_clamp_max.List_out - func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10622,6 +10662,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom - func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10629,6 +10670,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom_ autogen: _foreach_clamp_max.ScalarList_out - func: _foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10637,6 +10679,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom - func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10644,6 +10687,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom_ autogen: _foreach_clamp_min.Scalar_out - func: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10652,6 +10696,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom - func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10659,6 +10704,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom_ autogen: _foreach_clamp_min.List_out - func: _foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10667,6 +10713,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom - func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10674,6 +10721,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom_ autogen: _foreach_clamp_min.ScalarList_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10683,6 +10731,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom - func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10690,6 +10739,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom_ autogen: _foreach_maximum.Scalar_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10699,6 +10749,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom - func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10706,6 +10757,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom_ autogen: _foreach_maximum.List_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10715,6 +10767,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom - func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10722,6 +10775,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom_ autogen: _foreach_maximum.ScalarList_out - func: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10730,6 +10784,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom - func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10737,6 +10792,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom_ autogen: _foreach_minimum.Scalar_out - func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10745,6 +10801,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom - func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10752,6 +10809,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom_ autogen: _foreach_minimum.List_out - func: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10760,6 +10818,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom - func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10767,6 +10826,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom_ autogen: _foreach_minimum.ScalarList_out - func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] @@ -10775,6 +10835,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow CUDA: foreach_tensor_addcdiv_scalar_cuda + PrivateUse1: foreach_tensor_addcdiv_scalar_zoom - func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10782,6 +10843,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda + PrivateUse1: foreach_tensor_addcdiv_scalarlist_zoom - func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10789,6 +10851,7 @@ dispatch: CPU: foreach_tensor_addcdiv_tensor_slow CUDA: foreach_tensor_addcdiv_tensor_cuda + PrivateUse1: foreach_tensor_addcdiv_tensor_zoom - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10796,6 +10859,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ + PrivateUse1: foreach_tensor_addcdiv_scalar_zoom_ autogen: _foreach_addcdiv.Scalar_out - func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -10804,6 +10868,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ + PrivateUse1: foreach_tensor_addcdiv_scalarlist_zoom_ autogen: _foreach_addcdiv.ScalarList_out - func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () @@ -10812,6 +10877,7 @@ dispatch: CPU: foreach_tensor_addcdiv_tensor_slow_ CUDA: foreach_tensor_addcdiv_tensor_cuda_ + PrivateUse1: foreach_tensor_addcdiv_tensor_zoom_ autogen: _foreach_addcdiv.Tensor_out - func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] @@ -10820,6 +10886,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda + PrivateUse1: foreach_tensor_addcmul_scalar_zoom - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10827,6 +10894,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda + PrivateUse1: foreach_tensor_addcmul_scalarlist_zoom - func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10834,6 +10902,7 @@ dispatch: CPU: foreach_tensor_addcmul_tensor_slow CUDA: foreach_tensor_addcmul_tensor_cuda + PrivateUse1: foreach_tensor_addcmul_tensor_zoom - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10841,6 +10910,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ + PrivateUse1: foreach_tensor_addcmul_scalar_zoom_ autogen: _foreach_addcmul.Scalar_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -10849,6 +10919,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ + PrivateUse1: foreach_tensor_addcmul_scalarlist_zoom_ autogen: _foreach_addcmul.ScalarList_out - func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () @@ -10857,6 +10928,7 @@ dispatch: CPU: foreach_tensor_addcmul_tensor_slow_ CUDA: foreach_tensor_addcmul_tensor_cuda_ + PrivateUse1: foreach_tensor_addcmul_tensor_zoom_ autogen: _foreach_addcmul.Tensor_out - func: _foreach_abs(Tensor[] self) -> Tensor[] @@ -10865,6 +10937,7 @@ dispatch: CPU: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda + PrivateUse1: foreach_tensor_abs_zoom - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10872,6 +10945,7 @@ dispatch: CPU: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ + PrivateUse1: foreach_tensor_abs_zoom_ autogen: _foreach_abs.out - func: _foreach_acos(Tensor[] self) -> Tensor[] @@ -10880,6 +10954,7 @@ dispatch: CPU: foreach_tensor_acos_slow CUDA: foreach_tensor_acos_cuda + PrivateUse1: foreach_tensor_acos_zoom - func: _foreach_acos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10887,6 +10962,7 @@ dispatch: CPU: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ + PrivateUse1: foreach_tensor_acos_zoom_ autogen: _foreach_acos.out - func: _foreach_asin(Tensor[] self) -> Tensor[] @@ -10895,6 +10971,7 @@ dispatch: CPU: foreach_tensor_asin_slow CUDA: foreach_tensor_asin_cuda + PrivateUse1: foreach_tensor_asin_zoom - func: _foreach_asin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10902,6 +10979,7 @@ dispatch: CPU: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ + PrivateUse1: foreach_tensor_asin_zoom_ autogen: _foreach_asin.out - func: _foreach_atan(Tensor[] self) -> Tensor[] @@ -10910,6 +10988,7 @@ dispatch: CPU: foreach_tensor_atan_slow CUDA: foreach_tensor_atan_cuda + PrivateUse1: foreach_tensor_atan_zoom - func: _foreach_atan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10917,6 +10996,7 @@ dispatch: CPU: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ + PrivateUse1: foreach_tensor_atan_zoom_ autogen: _foreach_atan.out - func: _foreach_ceil(Tensor[] self) -> Tensor[] @@ -10925,6 +11005,7 @@ dispatch: CPU: foreach_tensor_ceil_slow CUDA: foreach_tensor_ceil_cuda + PrivateUse1: foreach_tensor_ceil_zoom - func: _foreach_ceil_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10932,6 +11013,7 @@ dispatch: CPU: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ + PrivateUse1: foreach_tensor_ceil_zoom_ autogen: _foreach_ceil.out - func: _foreach_cos(Tensor[] self) -> Tensor[] @@ -10940,6 +11022,7 @@ dispatch: CPU: foreach_tensor_cos_slow CUDA: foreach_tensor_cos_cuda + PrivateUse1: foreach_tensor_cos_zoom - func: _foreach_cos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10947,6 +11030,7 @@ dispatch: CPU: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ + PrivateUse1: foreach_tensor_cos_zoom_ autogen: _foreach_cos.out - func: _foreach_cosh(Tensor[] self) -> Tensor[] @@ -10955,6 +11039,7 @@ dispatch: CPU: foreach_tensor_cosh_slow CUDA: foreach_tensor_cosh_cuda + PrivateUse1: foreach_tensor_cosh_zoom - func: _foreach_cosh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10962,6 +11047,7 @@ dispatch: CPU: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ + PrivateUse1: foreach_tensor_cosh_zoom_ autogen: _foreach_cosh.out - func: _foreach_erf(Tensor[] self) -> Tensor[] @@ -10970,6 +11056,7 @@ dispatch: CPU: foreach_tensor_erf_slow CUDA: foreach_tensor_erf_cuda + PrivateUse1: foreach_tensor_erf_zoom - func: _foreach_erf_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10977,6 +11064,7 @@ dispatch: CPU: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ + PrivateUse1: foreach_tensor_erf_zoom_ autogen: _foreach_erf.out - func: _foreach_erfc(Tensor[] self) -> Tensor[] @@ -10985,6 +11073,7 @@ dispatch: CPU: foreach_tensor_erfc_slow CUDA: foreach_tensor_erfc_cuda + PrivateUse1: foreach_tensor_erfc_zoom - func: _foreach_erfc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10992,6 +11081,7 @@ dispatch: CPU: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ + PrivateUse1: foreach_tensor_erfc_zoom_ autogen: _foreach_erfc.out - func: _foreach_exp(Tensor[] self) -> Tensor[] @@ -11000,6 +11090,7 @@ dispatch: CPU: foreach_tensor_exp_slow CUDA: foreach_tensor_exp_cuda + PrivateUse1: foreach_tensor_exp_zoom - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11007,6 +11098,7 @@ dispatch: CPU: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ + PrivateUse1: foreach_tensor_exp_zoom_ autogen: _foreach_exp.out - func: _foreach_expm1(Tensor[] self) -> Tensor[] @@ -11015,6 +11107,7 @@ dispatch: CPU: foreach_tensor_expm1_slow CUDA: foreach_tensor_expm1_cuda + PrivateUse1: foreach_tensor_expm1_zoom - func: _foreach_expm1_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11022,6 +11115,7 @@ dispatch: CPU: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ + PrivateUse1: foreach_tensor_expm1_zoom_ autogen: _foreach_expm1.out - func: _foreach_floor(Tensor[] self) -> Tensor[] @@ -11030,6 +11124,7 @@ dispatch: CPU: foreach_tensor_floor_slow CUDA: foreach_tensor_floor_cuda + PrivateUse1: foreach_tensor_floor_zoom - func: _foreach_floor_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11037,6 +11132,7 @@ dispatch: CPU: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ + PrivateUse1: foreach_tensor_floor_zoom_ autogen: _foreach_floor.out - func: _foreach_frac(Tensor[] self) -> Tensor[] @@ -11045,6 +11141,7 @@ dispatch: CPU: foreach_tensor_frac_slow CUDA: foreach_tensor_frac_cuda + PrivateUse1: foreach_tensor_frac_zoom - func: _foreach_frac_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11052,6 +11149,7 @@ dispatch: CPU: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ + PrivateUse1: foreach_tensor_frac_zoom_ autogen: _foreach_frac.out - func: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] @@ -11060,6 +11158,7 @@ dispatch: CPU: foreach_tensor_ternary_lerp_slow CUDA: foreach_tensor_lerp_ternary_cuda + PrivateUse1: foreach_tensor_lerp_ternary_zoom autogen: _foreach_lerp.List_out - func: _foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () @@ -11068,6 +11167,7 @@ dispatch: CPU: foreach_tensor_ternary_lerp_slow_ CUDA: foreach_tensor_lerp_ternary_cuda_ + PrivateUse1: foreach_tensor_lerp_ternary_zoom_ autogen: _foreach_lerp.List_out - func: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] @@ -11076,6 +11176,7 @@ dispatch: CPU: foreach_tensor_lerp_list_kernel_slow CUDA: foreach_tensor_lerp_list_cuda + PrivateUse1: foreach_tensor_lerp_list_zoom autogen: _foreach_lerp.Scalar_out - func: _foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () @@ -11084,6 +11185,7 @@ dispatch: CPU: foreach_tensor_lerp_list_kernel_slow_ CUDA: foreach_tensor_lerp_list_cuda_ + PrivateUse1: foreach_tensor_lerp_list_zoom_ autogen: _foreach_lerp.Scalar_out - func: _foreach_lgamma(Tensor[] self) -> Tensor[] @@ -11092,6 +11194,7 @@ dispatch: CPU: foreach_tensor_lgamma_slow CUDA: foreach_tensor_lgamma_cuda + PrivateUse1: foreach_tensor_lgamma_zoom - func: _foreach_lgamma_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11099,6 +11202,7 @@ dispatch: CPU: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ + PrivateUse1: foreach_tensor_lgamma_zoom_ autogen: _foreach_lgamma.out - func: _foreach_log(Tensor[] self) -> Tensor[] @@ -11107,6 +11211,7 @@ dispatch: CPU: foreach_tensor_log_slow CUDA: foreach_tensor_log_cuda + PrivateUse1: foreach_tensor_log_zoom - func: _foreach_log_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11114,6 +11219,7 @@ dispatch: CPU: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ + PrivateUse1: foreach_tensor_log_zoom_ autogen: _foreach_log.out - func: _foreach_log10(Tensor[] self) -> Tensor[] @@ -11122,6 +11228,7 @@ dispatch: CPU: foreach_tensor_log10_slow CUDA: foreach_tensor_log10_cuda + PrivateUse1: foreach_tensor_log10_zoom - func: _foreach_log10_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11129,6 +11236,7 @@ dispatch: CPU: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ + PrivateUse1: foreach_tensor_log10_zoom_ autogen: _foreach_log10.out - func: _foreach_log1p(Tensor[] self) -> Tensor[] @@ -11137,6 +11245,7 @@ dispatch: CPU: foreach_tensor_log1p_slow CUDA: foreach_tensor_log1p_cuda + PrivateUse1: foreach_tensor_log1p_zoom - func: _foreach_log1p_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11144,6 +11253,7 @@ dispatch: CPU: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ + PrivateUse1: foreach_tensor_log1p_zoom_ autogen: _foreach_log1p.out - func: _foreach_log2(Tensor[] self) -> Tensor[] @@ -11152,6 +11262,7 @@ dispatch: CPU: foreach_tensor_log2_slow CUDA: foreach_tensor_log2_cuda + PrivateUse1: foreach_tensor_log2_zoom - func: _foreach_log2_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11159,6 +11270,7 @@ dispatch: CPU: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ + PrivateUse1: foreach_tensor_log2_zoom_ autogen: _foreach_log2.out - func: _foreach_neg(Tensor[] self) -> Tensor[] @@ -11167,6 +11279,7 @@ dispatch: CPU: foreach_tensor_neg_slow CUDA: foreach_tensor_neg_cuda + PrivateUse1: foreach_tensor_neg_zoom - func: _foreach_neg_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11174,6 +11287,7 @@ dispatch: CPU: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ + PrivateUse1: foreach_tensor_neg_zoom_ autogen: _foreach_neg.out - func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[] @@ -11182,6 +11296,7 @@ dispatch: CPU: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda + PrivateUse1: foreach_tensor_norm_zoom autogen: _foreach_norm.Scalar_out - func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] @@ -11190,6 +11305,7 @@ dispatch: CPU: foreach_tensor_pow_list_kernel_slow CUDA: foreach_tensor_pow_list_kernel_cuda + PrivateUse1: foreach_tensor_pow_list_kernel_zoom - func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11197,6 +11313,7 @@ dispatch: CPU: foreach_tensor_pow_scalar_kernel_slow CUDA: foreach_tensor_pow_scalar_kernel_cuda + PrivateUse1: foreach_tensor_pow_scalar_kernel_zoom - func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11204,6 +11321,7 @@ dispatch: CPU: foreach_tensor_pow_scalarlist_kernel_slow CUDA: foreach_tensor_pow_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_pow_scalarlist_kernel_zoom - func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11211,6 +11329,7 @@ dispatch: CPU: foreach_scalar_pow_list_kernel_slow CUDA: foreach_scalar_pow_list_kernel_cuda + PrivateUse1: foreach_scalar_pow_list_kernel_zoom - func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () device_check: NoCheck @@ -11218,6 +11337,7 @@ dispatch: CPU: foreach_tensor_pow_list_kernel_slow_ CUDA: foreach_tensor_pow_list_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_list_kernel_zoom_ autogen: _foreach_pow.List_out - func: _foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () @@ -11226,6 +11346,7 @@ dispatch: CPU: foreach_tensor_pow_scalar_kernel_slow_ CUDA: foreach_tensor_pow_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_scalar_kernel_zoom_ autogen: _foreach_pow.Scalar_out - func: _foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () @@ -11234,6 +11355,7 @@ dispatch: CPU: foreach_tensor_pow_scalarlist_kernel_slow_ CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_scalarlist_kernel_zoom_ autogen: _foreach_pow.ScalarList_out - func: _foreach_reciprocal(Tensor[] self) -> Tensor[] @@ -11242,6 +11364,7 @@ dispatch: CPU: foreach_tensor_reciprocal_slow CUDA: foreach_tensor_reciprocal_cuda + PrivateUse1: foreach_tensor_reciprocal_zoom - func: _foreach_reciprocal_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11249,6 +11372,7 @@ dispatch: CPU: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ + PrivateUse1: foreach_tensor_reciprocal_zoom_ autogen: _foreach_reciprocal.out - func: _foreach_round(Tensor[] self) -> Tensor[] @@ -11257,6 +11381,7 @@ dispatch: CPU: foreach_tensor_round_slow CUDA: foreach_tensor_round_cuda + PrivateUse1: foreach_tensor_round_zoom - func: _foreach_round_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11264,6 +11389,7 @@ dispatch: CPU: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ + PrivateUse1: foreach_tensor_round_zoom_ autogen: _foreach_round.out - func: _foreach_sigmoid(Tensor[] self) -> Tensor[] @@ -11272,6 +11398,7 @@ dispatch: CPU: foreach_tensor_sigmoid_slow CUDA: foreach_tensor_sigmoid_cuda + PrivateUse1: foreach_tensor_sigmoid_zoom - func: _foreach_sigmoid_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11279,6 +11406,7 @@ dispatch: CPU: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ + PrivateUse1: foreach_tensor_sigmoid_zoom_ autogen: _foreach_sigmoid.out - func: _foreach_sign(Tensor[] self) -> Tensor[] @@ -11287,6 +11415,7 @@ dispatch: CPU: foreach_tensor_sign_slow CUDA: foreach_tensor_sign_cuda + PrivateUse1: foreach_tensor_sign_zoom - func: _foreach_sign_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11294,6 +11423,7 @@ dispatch: CPU: foreach_tensor_sign_slow_ CUDA: foreach_tensor_sign_cuda_ + PrivateUse1: foreach_tensor_sign_zoom_ autogen: _foreach_sign.out - func: _foreach_sin(Tensor[] self) -> Tensor[] @@ -11302,6 +11432,7 @@ dispatch: CPU: foreach_tensor_sin_slow CUDA: foreach_tensor_sin_cuda + PrivateUse1: foreach_tensor_sin_zoom - func: _foreach_sin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11309,6 +11440,7 @@ dispatch: CPU: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ + PrivateUse1: foreach_tensor_sin_zoom_ autogen: _foreach_sin.out - func: _foreach_sinh(Tensor[] self) -> Tensor[] @@ -11317,6 +11449,7 @@ dispatch: CPU: foreach_tensor_sinh_slow CUDA: foreach_tensor_sinh_cuda + PrivateUse1: foreach_tensor_sinh_zoom - func: _foreach_sinh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11324,6 +11457,7 @@ dispatch: CPU: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ + PrivateUse1: foreach_tensor_sinh_zoom_ autogen: _foreach_sinh.out - func: _foreach_sqrt(Tensor[] self) -> Tensor[] @@ -11332,6 +11466,7 @@ dispatch: CPU: foreach_tensor_sqrt_slow CUDA: foreach_tensor_sqrt_cuda + PrivateUse1: foreach_tensor_sqrt_zoom - func: _foreach_sqrt_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11339,6 +11474,7 @@ dispatch: CPU: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ + PrivateUse1: foreach_tensor_sqrt_zoom_ autogen: _foreach_sqrt.out - func: _foreach_tan(Tensor[] self) -> Tensor[] @@ -11347,6 +11483,7 @@ dispatch: CPU: foreach_tensor_tan_slow CUDA: foreach_tensor_tan_cuda + PrivateUse1: foreach_tensor_tan_zoom - func: _foreach_tan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11354,6 +11491,7 @@ dispatch: CPU: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ + PrivateUse1: foreach_tensor_tan_zoom_ autogen: _foreach_tan.out - func: _foreach_tanh(Tensor[] self) -> Tensor[] @@ -11362,6 +11500,7 @@ dispatch: CPU: foreach_tensor_tanh_slow CUDA: foreach_tensor_tanh_cuda + PrivateUse1: foreach_tensor_tanh_zoom - func: _foreach_tanh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11369,6 +11508,7 @@ dispatch: CPU: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ + PrivateUse1: foreach_tensor_tanh_zoom_ autogen: _foreach_tanh.out - func: _foreach_trunc(Tensor[] self) -> Tensor[] @@ -11377,6 +11517,7 @@ dispatch: CPU: foreach_tensor_trunc_slow CUDA: foreach_tensor_trunc_cuda + PrivateUse1: foreach_tensor_trunc_zoom - func: _foreach_trunc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11384,6 +11525,7 @@ dispatch: CPU: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ + PrivateUse1: foreach_tensor_trunc_zoom_ autogen: _foreach_trunc.out - func: _foreach_zero_(Tensor(a!)[] self) -> () @@ -11392,6 +11534,7 @@ dispatch: CPU: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ + PrivateUse1: foreach_tensor_zero_zoom_ autogen: _foreach_zero, _foreach_zero.out - func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () @@ -11400,6 +11543,7 @@ dispatch: CPU: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ + PrivateUse1: foreach_tensor_copy_list_kernel_zoom_ autogen: _foreach_copy.out - func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out @@ -11573,6 +11717,7 @@ dispatch: CPU: nll_loss_forward_out_cpu CUDA: nll_loss_forward_out_cuda + PrivateUse1: nll_loss_forward_out_zoom MPS: nll_loss_forward_out_mps - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) @@ -11585,6 +11730,7 @@ dispatch: CPU: nll_loss_backward_out_cpu CUDA: nll_loss_backward_out_cuda + PrivateUse1: nll_loss_backward_out_zoom MPS: nll_loss_backward_out_mps - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor @@ -11604,6 +11750,7 @@ dispatch: CPU: nll_loss2d_forward_out_cpu CUDA: nll_loss2d_forward_out_cuda + PrivateUse1: nll_loss2d_forward_out_zoom MPS: nll_loss2d_forward_out_mps - func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) @@ -11611,6 +11758,7 @@ dispatch: CPU: nll_loss2d_forward_cpu CUDA: nll_loss2d_forward_cuda + PrivateUse1: nll_loss2d_forward_zoom MPS: nll_loss2d_forward_mps - func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -11618,6 +11766,7 @@ dispatch: CPU: nll_loss2d_backward_out_cpu CUDA: nll_loss2d_backward_out_cuda + PrivateUse1: nll_loss2d_backward_out_zoom MPS: nll_loss2d_backward_out_mps - func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor @@ -11625,6 +11774,7 @@ dispatch: CPU: nll_loss2d_backward_cpu CUDA: nll_loss2d_backward_cuda + PrivateUse1: nll_loss2d_backward_zoom MPS: nll_loss2d_backward_mps - func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) @@ -11702,7 +11852,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: elu_out + CPU, CUDA, PrivateUse1: elu_out MPS: elu_out_mps - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor @@ -11715,7 +11865,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward_out + CPU, CUDA, PrivateUse1: elu_backward_out MPS: elu_backward_out_mps - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor @@ -11745,6 +11895,7 @@ dispatch: CPU: glu_backward_cpu_out CUDA: glu_backward_cuda_out + PrivateUse1: glu_backward_zoom_out MPS: glu_backward_mps_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor @@ -11752,18 +11903,19 @@ dispatch: CPU: glu_backward_cpu CUDA: glu_backward_cuda + PrivateUse1: glu_backward_zoom MPS: glu_backward_mps - func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor python_module: nn dispatch: - CPU, CUDA: glu_jvp + CPU, CUDA, PrivateUse1: glu_jvp autogen: glu_jvp.out - func: glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor python_module: nn dispatch: - CPU, CUDA: glu_backward_jvp + CPU, CUDA, PrivateUse1: glu_backward_jvp autogen: glu_backward_jvp.out - func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -11772,7 +11924,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardsigmoid_out + CPU, CUDA, PrivateUse1: hardsigmoid_out MPS: hardsigmoid_out_mps QuantizedCPU: hardsigmoid_out_quantized_cpu @@ -11793,7 +11945,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: hardsigmoid_backward_out + CPU, CUDA, PrivateUse1: hardsigmoid_backward_out MPS: hardsigmoid_backward_out_mps - func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor @@ -11804,61 +11956,61 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh_out + CPU, CUDA, PrivateUse1, MPS: hardtanh_out QuantizedCPU: hardtanh_out_quantized_cpu - func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh + CPU, CUDA, PrivateUse1, MPS: hardtanh QuantizedCPU: hardtanh_quantized_cpu tags: core - func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU, CUDA: hardtanh_backward_out + CPU, CUDA, PrivateUse1: hardtanh_backward_out MPS: hardtanh_backward_out_mps - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor python_module: nn dispatch: - CPU, CUDA: hardtanh_backward + CPU, CUDA, PrivateUse1: hardtanh_backward MPS: hardtanh_backward_mps - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh_ + CPU, CUDA, PrivateUse1, MPS: hardtanh_ QuantizedCPU: hardtanh_quantized_cpu_ - func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_out + CPU, CUDA, PrivateUse1: hardswish_out MPS: hardswish_out_mps - func: hardswish(Tensor self) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish + CPU, CUDA, PrivateUse1: hardswish MPS: hardswish_mps - func: hardswish_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_ + CPU, CUDA, PrivateUse1: hardswish_ MPS: hardswish_mps_ - func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: - CPU, CUDA: hardswish_backward + CPU, CUDA, PrivateUse1: hardswish_backward MPS: hardswish_backward_mps autogen: hardswish_backward.out @@ -11868,7 +12020,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: leaky_relu_out + CPU, CUDA, PrivateUse1: leaky_relu_out MPS: leaky_relu_out_mps QuantizedCPU: leaky_relu_out_quantized_cpu @@ -11885,7 +12037,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: leaky_relu_backward_out + CPU, CUDA, PrivateUse1: leaky_relu_backward_out MPS: leaky_relu_backward_out_mps - func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor @@ -11913,6 +12065,7 @@ dispatch: CPU: log_sigmoid_forward_out_cpu CUDA: log_sigmoid_forward_out_cuda + PrivateUse1: log_sigmoid_forward_out_zoom MPS: log_sigmoid_forward_out_mps - func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) @@ -11921,6 +12074,7 @@ dispatch: CPU: log_sigmoid_forward_cpu CUDA: log_sigmoid_forward_cuda + PrivateUse1: log_sigmoid_forward_zoom MPS: log_sigmoid_forward_mps - func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -11928,6 +12082,7 @@ dispatch: CPU: log_sigmoid_backward_cpu_out CUDA: log_sigmoid_backward_cuda_out + PrivateUse1: log_sigmoid_backward_zoom_out MPS: log_sigmoid_backward_mps_out - func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor @@ -11935,6 +12090,7 @@ dispatch: CPU: log_sigmoid_backward_cpu CUDA: log_sigmoid_backward_cuda + PrivateUse1: log_sigmoid_backward_zoom MPS: log_sigmoid_backward_mps - func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) @@ -11943,12 +12099,14 @@ dispatch: CPU: rrelu_with_noise_out_cpu CUDA: rrelu_with_noise_out_cuda + PrivateUse1: rrelu_with_noise_out_zoom - func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: CPU: rrelu_with_noise_cpu CUDA: rrelu_with_noise_cuda + PrivateUse1: rrelu_with_noise_zoom tags: nondeterministic_seeded - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor @@ -11963,6 +12121,7 @@ dispatch: CPU: rrelu_with_noise_cpu_ CUDA: rrelu_with_noise_cuda_ + PrivateUse1: rrelu_with_noise_zoom_ - func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11970,7 +12129,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: softplus_out + CPU, CUDA, PrivateUse1: softplus_out MPS: softplus_out_mps - func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor @@ -11983,7 +12142,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: softplus_backward_out + CPU, CUDA, PrivateUse1: softplus_backward_out MPS: softplus_backward_out_mps - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor @@ -12009,7 +12168,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: softshrink_backward_out + CPU, CUDA, PrivateUse1: softshrink_backward_out MPS: softshrink_backward_out_mps - func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor @@ -13066,6 +13225,7 @@ variants: method dispatch: CUDA: record_stream_cuda + PrivateUse1: record_stream_zoom - func: isposinf(Tensor self) -> Tensor variants: function, method diff --git a/aten/src/ATen/native/zoom/Activation.cpp b/aten/src/ATen/native/zoom/Activation.cpp new file mode 100644 index 00000000000000..039585b1e71605 --- /dev/null +++ b/aten/src/ATen/native/zoom/Activation.cpp @@ -0,0 +1,108 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +// ----------------------------------- +// glu backward +// ----------------------------------- + +Tensor& glu_backward_zoom_out(const Tensor& grad_output, const Tensor& input, + int64_t dim, Tensor& grad_input) { + TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors"); + auto wrap_dim = maybe_wrap_dim(dim, input.dim()); + auto input_sizes = input.sizes(); + const int64_t nIn = input_sizes[wrap_dim]; + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", + wrap_dim, " is size ", nIn); + + resize_output(grad_input, input_sizes); + + DimVector iter_shape(input_sizes); + const auto dim_size = nIn / 2; + iter_shape[wrap_dim] = dim_size; + TORCH_CHECK(grad_output.sizes() == IntArrayRef{iter_shape}); + + const auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(grad_output) + .resize_outputs(false) + .declare_static_shape(iter_shape) + .build(); + + if (iter.numel() == 0) { + return grad_input; + } + + const auto I_stride = input.strides()[wrap_dim] * dim_size; + const auto gI_stride = grad_input.strides()[wrap_dim] * dim_size; + + if (iter.can_use_32bit_indexing()) { + launch_glu_backward_kernel(iter, gI_stride, I_stride); + } else { + for (const auto& sub_iter: iter.with_32bit_indexing()) { + launch_glu_backward_kernel(sub_iter, gI_stride, I_stride); + } + } + return grad_input; +} + +Tensor glu_backward_zoom(const Tensor& grad_output, const Tensor& input, int64_t dim) { + auto grad_input = at::empty({0}, input.options()); + return glu_backward_zoom_out(grad_output, input, dim, grad_input); +} + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +std::tuple log_sigmoid_forward_out_zoom(const Tensor& input, Tensor& result, Tensor& buffer) { + // NOTE: buffer is only used by CPU dispatch, we just ignore it here + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(input) + .build(); + launch_log_sigmoid_forward_kernel(iter); + return std::forward_as_tuple(result, buffer); +} + +std::tuple log_sigmoid_forward_zoom(const Tensor& input) { + auto result = at::empty_like(input); + auto buffer = at::empty({0}, input.options()); + log_sigmoid_forward_out_zoom(input, result, buffer); + return std::forward_as_tuple(result, buffer); +} + +TORCH_IMPL_FUNC(gelu_out_zoom) ( + const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/ +) { + GeluZoomKernelImpl(*this, get_gelutype_enum(approximate)); +} + +TORCH_IMPL_FUNC(gelu_backward_out_zoom) ( + const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/ +) { + GeluBackwardZoomKernelImpl(*this, get_gelutype_enum(approximate)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Activation.h b/aten/src/ATen/native/zoom/Activation.h new file mode 100644 index 00000000000000..309d316bd5fd7d --- /dev/null +++ b/aten/src/ATen/native/zoom/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { namespace native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluZoomKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardZoomKernelImpl(TensorIteratorBase& it, GeluType approximate); + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationEluKernel.cu b/aten/src/ATen/native/zoom/ActivationEluKernel.cu new file mode 100644 index 00000000000000..e3f296a2a0ed89 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationEluKernel.cu @@ -0,0 +1,86 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void elu_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > 0 ? aop * poscoef + : std::expm1(aop * negiptcoef) * negcoef; + }); + }); +} + +void elu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA( + scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + + if (is_result) { + return bop <= 0 ? aop * negiptcoef * (bop + negcoef) + : aop * poscoef; + } else { + return bop <= 0 + ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) + : aop * poscoef; + } + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(elu_stub, &elu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(elu_backward_stub, &elu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationGeluKernel.cu b/aten/src/ATen/native/zoom/ActivationGeluKernel.cu new file mode 100644 index 00000000000000..7da8acc5b7ab17 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationGeluKernel.cu @@ -0,0 +1,88 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void GeluZoomKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::hip::compat::tanh(inner)); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + }); + }); + } +} + +void GeluBackwardZoomKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::hip::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = opmath_t(0.5) * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = + opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = + c10::hip::compat::exp( + opmath_t(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + }); + }); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationGluKernel.cu b/aten/src/ATen/native/zoom/ActivationGluKernel.cu new file mode 100644 index 00000000000000..c98794cf016a03 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationGluKernel.cu @@ -0,0 +1,141 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// glu forward +// ----------------------------------- +void glu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_zoom", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t { + const opmath_t a = a_; + const opmath_t b = b_; + const opmath_t one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + return a * sigmoid; + }); + }); +} + +// ----------------------------------- +// glu forward ad +// ----------------------------------- +void glu_jvp_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_zoom", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, + [] GPU_LAMBDA( + scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_) + -> scalar_t { + const opmath_t res = res_; + const opmath_t b = b_; + const opmath_t da = da_; + const opmath_t db = db_; + const opmath_t one = opmath_t(1); + + const opmath_t sig_b = one / (one + std::exp(-b)); + return (da * sig_b + res * (db - sig_b * db)); + }); + }); +} + +// ----------------------------------- +// glu backward +// ----------------------------------- + +// Byte offsets don't require multiplication by sizeof(T), so are slightly +// cheaper. For fixed offsets, this removes all penalty from 64-bit indexing. +template +__device__ T* byte_offset(T* ptr, int64_t offset) { + using byte_ptr_t = typename std:: + conditional::value, const char*, char*>::type; + return reinterpret_cast(reinterpret_cast(ptr) + offset); +} + +template +__global__ void glu_backward_kernel( + int numel, + scalar_t* gI, + const scalar_t* I, + const scalar_t* gO, + OffsetCalc offset_calculator, + int64_t gI_byte_offset, + int64_t I_byte_offset) { + using opmath_t = at::opmath_type; + + const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= numel) { + return; + } + const auto offsets = offset_calculator.get(linear_index); + + // We explicitly iterate over the first half of the input tensor, and + // gI_byte_offset and I_byte_offset are the offsets to access the + // corresponding index in the second half of the tensor. + const opmath_t a = I[offsets[1]]; + const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); + const opmath_t gO_val = gO[offsets[2]]; + + const auto one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + + auto* gA = gI + offsets[0]; + *gA = sigmoid * gO_val; + + auto* gB = byte_offset(gA, gI_byte_offset); + *gB = (one - sigmoid) * sigmoid * gO_val * a; +} + +void launch_glu_backward_kernel( + const TensorIteratorBase& iter, + int64_t gI_stride, + int64_t I_stride) { + const auto N = iter.numel(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + N > 0 && N <= std::numeric_limits::max()); + const auto offset_calculator = make_element_offset_calculator<3>(iter); + constexpr int64_t block_size = 256; + const int64_t grid = (N + block_size - 1) / block_size; + const auto stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "glu_backward_zoom", [&] { + auto gI = static_cast(iter.data_ptr(0)); + auto I = static_cast(iter.data_ptr(1)); + auto gO = static_cast(iter.data_ptr(2)); + glu_backward_kernel<<>>( + N, + gI, + I, + gO, + offset_calculator, + gI_stride * sizeof(scalar_t), + I_stride * sizeof(scalar_t)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(glu_stub, &glu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu new file mode 100644 index 00000000000000..cb581dbc9d661a --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu @@ -0,0 +1,39 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardshrink_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardshrink_stub, &hardshrink_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu new file mode 100644 index 00000000000000..3af90e876b6e81 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu @@ -0,0 +1,74 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardsigmoid_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_zoom", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel( + iter, + [zero, one_sixth, three, six] GPU_LAMBDA( + scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth] GPU_LAMBDA( + scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + return (self_val > neg_three && self_val < three) + ? grad_val * one_sixth + : zero; + }); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu b/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu new file mode 100644 index 00000000000000..5b4704cbf85ab8 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardswish_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_zoom", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardswish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_zoom", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardswish_stub, &hardswish_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu b/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu new file mode 100644 index 00000000000000..ecd11f23e87fa3 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu @@ -0,0 +1,45 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardtanh_backward_kernel( + TensorIterator& iter, + const Scalar& min, + const Scalar& max) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.dtype(), "hardtanh_backward_zoom", [&]() { + using opmath_t = at::opmath_type; + auto min_val = min.to(); + auto max_val = max.to(); + gpu_kernel( + iter, + [min_val, max_val] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu new file mode 100644 index 00000000000000..94a9a8168c2b02 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu @@ -0,0 +1,62 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel(iter, [negval] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > opmath_t(0) ? aop : aop * negval; + }); + }); +} + +void leaky_relu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel( + iter, [negval] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return aop > opmath_t(0) ? bop : bop * negval; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu new file mode 100644 index 00000000000000..79bad5edc99db3 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_zoom", [&] { + using opmath_t = at::opmath_type; + + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t { + const opmath_t in = in_; + const auto min = std::min(opmath_t(0), in); + const auto z = std::exp(-std::abs(in)); + return min - std::log1p(z); + }); + }); +} + +namespace { +// ----------------------------------- +// log_sigmoid backward +// ----------------------------------- +void log_sigmoid_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_zoom", [&] { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t { + const opmath_t in = in_; + const opmath_t grad_out = grad_out_; + + auto in_negative = in < opmath_t(0); + auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); + auto sign = in_negative ? opmath_t(1) : -opmath_t(1); + const auto z = std::exp(-std::abs(in)); + return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationMishKernel.cu b/aten/src/ATen/native/zoom/ActivationMishKernel.cu new file mode 100644 index 00000000000000..75d69dd119185c --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationMishKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc * + c10::hip::compat::tanh( + c10::hip::compat::log1p(c10::hip::compat::exp(x_acc))); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::hip::compat::exp(-x_acc)); + const opmath_t t_acc = c10::hip::compat::tanh( + c10::hip::compat::log1p(c10::hip::compat::exp(x_acc))); + return dy_acc * + (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(mish_stub, &mish_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(mish_backward_stub, &mish_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationPreluKernel.cu b/aten/src/ATen/native/zoom/ActivationPreluKernel.cu new file mode 100644 index 00000000000000..512cc7224c5c85 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationPreluKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// prelu +// ----------------------------------- +void prelu_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_zoom", [&] { + gpu_kernel(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight) -> scalar_t { + return (input > 0) ? input : weight * input; + }); + }); +} + +void prelu_backward_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_backward_zoom", [&] { + gpu_kernel_multiple_outputs(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight, scalar_t grad) -> thrust::tuple { + auto mask = input > 0; + auto grad_input = mask ? grad : weight * grad; + auto grad_weight = mask ? scalar_t{0} : input * grad; + return {grad_input, grad_weight}; + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(prelu_stub, &prelu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(prelu_backward_stub, &prelu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSiluKernel.cu b/aten/src/ATen/native/zoom/ActivationSiluKernel.cu new file mode 100644 index 00000000000000..04f7d204a3a97b --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSiluKernel.cu @@ -0,0 +1,60 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void silu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc / (opmath_t(1) + ::exp(-x_acc)); + }); + }); +} + +void silu_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_backward_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::hip::compat::exp(-x_acc)); + return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(silu_stub, &silu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(silu_backward_stub, &silu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu b/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu new file mode 100644 index 00000000000000..ed3358d225af7f --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu @@ -0,0 +1,74 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel(iter, [beta, threshold] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return (aop * beta) > threshold + ? aop + : (::log1p(std::exp(aop * beta))) / beta; + }); + }); +} + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel( + iter, + [beta, threshold] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + opmath_t z = std::exp(bop * beta); + return (bop * beta) > threshold ? aop + : aop * z / (z + opmath_t(1.)); + }); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(softplus_stub, &softplus_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu new file mode 100644 index 00000000000000..69e27e22b477fb --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu @@ -0,0 +1,58 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softshrink_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); + }); + }); +} + +void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "shrink_backward_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel( + iter, + [lambd] GPU_LAMBDA( + scalar_t grad_val, scalar_t self_val) -> scalar_t { + return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) + : grad_val; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(softshrink_stub, &softshrink_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu b/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu new file mode 100644 index 00000000000000..0d6a1c7e15f80a --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu @@ -0,0 +1,52 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +template +void threshold_kernel_impl( + TensorIteratorBase& iter, + scalar_t threshold, + scalar_t value) { + gpu_kernel_with_scalars( + iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { + return x <= threshold ? value : other; + }); +} + +static void threshold_kernel_zoom( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "threshold_zoom", + [&] { + threshold_kernel_impl( + iter, threshold.to(), value.to()); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(threshold_stub, &threshold_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu b/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu new file mode 100644 index 00000000000000..02e2c4d4fe942c --- /dev/null +++ b/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu @@ -0,0 +1,295 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#endif + +namespace at::native { + +template class Op> +std::vector foreach_tensor_list_op( + TensorList tensors1, + TensorList tensors2, + const Scalar& alpha = 1) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors1.size()); + for (const auto& t : tensors1) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + using opmath_t = at::opmath_type; + multi_tensor_apply<3>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2>(), + Op(), + alpha.to()); + + return tensor_lists[2]; +} + +template class Op> +void foreach_tensor_list_op_( + TensorList tensors1, + TensorList tensors2, + const Scalar& alpha = 1) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + + using opmath_t = at::opmath_type; + multi_tensor_apply<2>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0>(), + Op(), + alpha.to()); + increment_version(tensors1); +} + +template