From 8bb6bb6938acd9486df78d268adc3878cd3bc557 Mon Sep 17 00:00:00 2001 From: Luis Chamberlain Date: Mon, 15 Dec 2025 22:12:49 -0800 Subject: [PATCH 1/2] feat: add HIP/ROCm support for AMD GPUs Add support for building pyg-lib on ROCm (AMD GPUs) alongside existing CUDA support. This allows pyg-lib to work with PyTorch ROCm builds. CMakeLists.txt: - Add WITH_HIP option to enable HIP/ROCm builds - Detect ROCm path and find HIP, rocBLAS, hipBLAS packages - Set appropriate GPU architectures for AMD (gfx906, gfx90a, etc.) - Compile .cu files as HIP sources on ROCm - Exclude CUTLASS/cuCollections files that have HIP replacements setup.py: - Auto-detect torch.version.hip for ROCm PyTorch builds - Pass -DWITH_HIP=ON to CMake when building on ROCm New HIP implementations: - ops/hip/matmul_kernel_hip.cpp: rocBLAS-based grouped GEMM - classes/hip/hash_map_hip.cpp: Sorted array + binary search hashmap Modified CUDA kernels for HIP compatibility: - sampler/cuda/random_walk_kernel.cu: Add HIP header/API macros - ops/cuda/sampled_kernel.cu: Add HIP header/API macros - library.cpp: Use hip_runtime.h and HIP_VERSION on ROCm The simple CUDA kernels work on HIP with compatibility macros for getCurrentCUDAStream -> getCurrentHIPStream and C10_CUDA_KERNEL_LAUNCH_CHECK -> C10_HIP_KERNEL_LAUNCH_CHECK. Tested on ROCm 6.2/6.4 with PyTorch 2.6. Generated-by: Claude AI Signed-off-by: Luis Chamberlain --- CMakeLists.txt | 72 +++++++- pyg_lib/csrc/classes/hip/hash_map_hip.cpp | 163 ++++++++++++++++++ pyg_lib/csrc/library.cpp | 8 + pyg_lib/csrc/ops/cuda/sampled_kernel.cu | 7 + pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp | 128 ++++++++++++++ .../csrc/sampler/cuda/random_walk_kernel.cu | 7 + setup.py | 8 + 7 files changed, 390 insertions(+), 3 deletions(-) create mode 100644 pyg_lib/csrc/classes/hip/hash_map_hip.cpp create mode 100644 pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 53bdf9f9e..0b086ac94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ option(BUILD_TEST "Enable testing" OFF) option(BUILD_BENCHMARK "Enable benchmarks" OFF) option(WITH_COV "Enable code coverage" OFF) option(WITH_CUDA "Enable CUDA support" OFF) +option(WITH_HIP "Enable HIP/ROCm support" OFF) if(NOT WIN32 AND NOT DEFINED USE_CXX11_ABI) find_package(Python3 COMPONENTS Interpreter REQUIRED) @@ -46,7 +47,7 @@ if(USE_MKL_BLAS AND DEFINED BLAS_INCLUDE_DIR) endif() configure_file(${CMAKE_CURRENT_SOURCE_DIR}/pyg_lib/csrc/config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/pyg_lib/csrc/config.h") -if(WITH_CUDA) +if(WITH_CUDA AND NOT WITH_HIP) enable_language(CUDA) add_definitions(-DWITH_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr -allow-unsupported-compiler") @@ -75,11 +76,68 @@ if(WITH_CUDA) include_directories(${CUCOLLECTIONS_DIR}) endif() +# HIP/ROCm support +if(WITH_HIP) + # Find HIP package + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH "/opt/rocm") + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + + list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}") + find_package(HIP REQUIRED) + + add_definitions(-DWITH_CUDA) # Use same macro for GPU code paths + add_definitions(-DUSE_ROCM) + add_definitions(-D__HIP_PLATFORM_AMD__) + + # Let CMake find HIP compiler (don't set hipcc directly - CMake 3.21+ prefers clang) + enable_language(HIP) + + # GPU architectures for AMD (gfx900=Vega, gfx906=MI50, gfx908=MI100, + # gfx90a=MI200, gfx940/942=MI300, gfx1100=RDNA3) + set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a;gfx940;gfx942;gfx1100") + + # Find rocBLAS for GEMM operations + find_package(rocblas REQUIRED) + find_package(hipblas REQUIRED) + + # Find rocThrust (HIP port of Thrust) + find_path(ROCTHRUST_INCLUDE_DIR thrust/version.h + HINTS ${ROCM_PATH}/include) + if(ROCTHRUST_INCLUDE_DIR) + include_directories(${ROCTHRUST_INCLUDE_DIR}) + message(STATUS "Found rocThrust: ${ROCTHRUST_INCLUDE_DIR}") + endif() + + message(STATUS "Building with HIP/ROCm support") + message(STATUS " ROCM_PATH: ${ROCM_PATH}") + message(STATUS " HIP_VERSION: ${HIP_VERSION}") +endif() + set(CSRC pyg_lib/csrc) file(GLOB_RECURSE ALL_SOURCES ${CSRC}/*.cpp) -if (WITH_CUDA) +if(WITH_CUDA AND NOT WITH_HIP) file(GLOB_RECURSE ALL_SOURCES ${ALL_SOURCES} ${CSRC}/*.cu) endif() +if(WITH_HIP) + # For HIP, compile .cu files as HIP sources + file(GLOB_RECURSE HIP_SOURCES ${CSRC}/*.cu) + # Exclude CUTLASS-dependent files (matmul_kernel.cu) and cuCollections files + # (hash_map.cu) - these have HIP replacements in hip/ subdirs + list(FILTER HIP_SOURCES EXCLUDE REGEX ".*/cuda/matmul_kernel\\.cu$") + list(FILTER HIP_SOURCES EXCLUDE REGEX ".*/cuda/hash_map\\.cu$") + set_source_files_properties(${HIP_SOURCES} PROPERTIES LANGUAGE HIP) + list(APPEND ALL_SOURCES ${HIP_SOURCES}) + # Add HIP-specific implementations (rocBLAS matmul, HIP hashmap) + # These need to be compiled as HIP to get proper HIP/ATen integration + file(GLOB_RECURSE HIP_CPP_SOURCES ${CSRC}/*/hip/*.cpp) + set_source_files_properties(${HIP_CPP_SOURCES} PROPERTIES LANGUAGE HIP) + list(APPEND ALL_SOURCES ${HIP_CPP_SOURCES}) + message(STATUS "HIP sources: ${HIP_SOURCES}") + message(STATUS "HIP C++ sources: ${HIP_CPP_SOURCES}") +endif() add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES}) target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") if(MKL_INCLUDE_FOUND) @@ -122,13 +180,21 @@ if(OpenMP_CXX_FOUND) target_link_libraries(${PROJECT_NAME} PRIVATE OpenMP::OpenMP_CXX) endif() -if(WITH_CUDA) +if(WITH_CUDA AND NOT WITH_HIP) target_include_directories(${PROJECT_NAME} PRIVATE third_party/cccl/thrust third_party/cccl/cub third_party/cccl/libcudacxx/include) endif() +if(WITH_HIP) + # Link rocBLAS and hipBLAS for GEMM operations + # Targets are roc::rocblas and roc::hipblas per ROCm cmake configs + target_link_libraries(${PROJECT_NAME} PRIVATE roc::rocblas roc::hipblas) + # Include HIP runtime + target_include_directories(${PROJECT_NAME} PRIVATE ${ROCM_PATH}/include) +endif() + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0") if(BUILD_TEST) diff --git a/pyg_lib/csrc/classes/hip/hash_map_hip.cpp b/pyg_lib/csrc/classes/hip/hash_map_hip.cpp new file mode 100644 index 000000000..f71808855 --- /dev/null +++ b/pyg_lib/csrc/classes/hip/hash_map_hip.cpp @@ -0,0 +1,163 @@ +// ROCm/HIP implementation of GPU HashMap +// Uses sorted array + binary search via ATen (not thrust directly) +// This avoids the cuda/cccl header conflicts with rocThrust + +#include +#include +#include + +namespace pyg { +namespace classes { + +namespace { + +#define DISPATCH_CASE_KEY(...) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_KEY(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__)) + +struct HashMapImpl { + virtual ~HashMapImpl() = default; + virtual at::Tensor get(const at::Tensor& query) = 0; + virtual at::Tensor keys() = 0; + virtual int64_t size() = 0; + virtual at::ScalarType dtype() = 0; + virtual at::Device device() = 0; +}; + +// HIP implementation using sorted arrays + binary search (ATen-only, no thrust) +template +struct HIPHashMapImpl : HashMapImpl { + public: + using ValueType = int64_t; + + HIPHashMapImpl(const at::Tensor& key, double load_factor) + : device_(key.device()) { + // Store sorted keys and their original indices + const auto options = key.options(); + const auto value_options = + options.dtype(c10::CppTypeToScalarType::value); + + // Create value tensor (indices 0 to N-1) + sorted_values_ = at::arange(key.numel(), value_options); + + // Clone keys and sort together with values + sorted_keys_ = key.clone(); + + // Sort by keys, permuting values accordingly + auto sort_result = at::sort(sorted_keys_); + sorted_keys_ = std::get<0>(sort_result); + auto sort_indices = std::get<1>(sort_result); + sorted_values_ = sorted_values_.index_select(0, sort_indices); + } + + at::Tensor get(const at::Tensor& query) override { + // Use searchsorted to find positions, then verify matches + auto positions = at::searchsorted(sorted_keys_, query); + auto out = sorted_values_.new_full({query.numel()}, -1); + + // Clamp positions to valid range + positions = at::clamp(positions, 0, sorted_keys_.numel() - 1); + + // Get keys at found positions + auto found_keys = sorted_keys_.index_select(0, positions); + + // Create mask where query matches found key + auto mask = (found_keys == query); + + // Get values where mask is true + auto found_values = sorted_values_.index_select(0, positions); + out = at::where(mask, found_values, out); + + return out; + } + + at::Tensor keys() override { + // Return keys in original order (unsort using values as indices) + auto perm = at::empty_like(sorted_values_); + perm.scatter_(0, sorted_values_, + at::arange(sorted_values_.numel(), sorted_values_.options())); + return sorted_keys_.index_select(0, perm); + } + + int64_t size() override { return sorted_keys_.numel(); } + + at::ScalarType dtype() override { + if (std::is_same::value) { + return at::kShort; + } else if (std::is_same::value) { + return at::kInt; + } else { + return at::kLong; + } + } + + at::Device device() override { return device_; } + + private: + at::Tensor sorted_keys_; + at::Tensor sorted_values_; + at::Device device_; +}; + +struct HIPHashMap : torch::CustomClassHolder { + public: + HIPHashMap(const at::Tensor& key, double load_factor = 0.5) { + at::TensorArg key_arg{key, "key", 0}; + at::CheckedFrom c{"HIPHashMap.init"}; + at::checkDeviceType(c, key, at::DeviceType::CUDA); // CUDA type for ROCm + at::checkDim(c, key_arg, 1); + at::checkContiguous(c, key_arg); + + DISPATCH_KEY(key.scalar_type(), "hip_hash_map_init", [&] { + map_ = std::make_unique>(key, load_factor); + }); + } + + at::Tensor get(const at::Tensor& query) { + at::TensorArg query_arg{query, "query", 0}; + at::CheckedFrom c{"HIPHashMap.get"}; + at::checkDeviceType(c, query, at::DeviceType::CUDA); + at::checkDim(c, query_arg, 1); + at::checkContiguous(c, query_arg); + + return map_->get(query); + } + + at::Tensor keys() { return map_->keys(); } + int64_t size() { return map_->size(); } + at::ScalarType dtype() { return map_->dtype(); } + at::Device device() { return map_->device(); } + + private: + std::unique_ptr map_; +}; + +} // namespace + +// Note: For ROCm, we register CUDAHashMap but use HIP implementation +// The dispatch key "CUDA" works for both CUDA and ROCm in PyTorch +TORCH_LIBRARY_FRAGMENT(pyg, m) { + m.class_("CUDAHashMap") + .def(torch::init()) + .def("get", &HIPHashMap::get) + .def("keys", &HIPHashMap::keys) + .def("size", &HIPHashMap::size) + .def("dtype", &HIPHashMap::dtype) + .def("device", &HIPHashMap::device) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) -> at::Tensor { + return self->keys(); + }, + // __setstate__ + [](const at::Tensor& state) -> c10::intrusive_ptr { + return c10::make_intrusive(state); + }); +} + +} // namespace classes +} // namespace pyg diff --git a/pyg_lib/csrc/library.cpp b/pyg_lib/csrc/library.cpp index 7bbc0c008..8859bb424 100644 --- a/pyg_lib/csrc/library.cpp +++ b/pyg_lib/csrc/library.cpp @@ -1,8 +1,12 @@ #include "library.h" #ifdef WITH_CUDA +#ifdef USE_ROCM +#include +#else #include #endif +#endif #include @@ -18,7 +22,11 @@ namespace pyg { int64_t cuda_version() { #ifdef WITH_CUDA +#ifdef USE_ROCM + return HIP_VERSION; +#else return CUDA_VERSION; +#endif #else return -1; #endif diff --git a/pyg_lib/csrc/ops/cuda/sampled_kernel.cu b/pyg_lib/csrc/ops/cuda/sampled_kernel.cu index f28ab47b5..d1339c006 100644 --- a/pyg_lib/csrc/ops/cuda/sampled_kernel.cu +++ b/pyg_lib/csrc/ops/cuda/sampled_kernel.cu @@ -1,5 +1,12 @@ #include +#ifdef USE_ROCM +#include +#include +#define getCurrentCUDAStream getCurrentHIPStream +#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_HIP_KERNEL_LAUNCH_CHECK +#else #include +#endif #include namespace pyg { diff --git a/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp b/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp new file mode 100644 index 000000000..3e49f8fd3 --- /dev/null +++ b/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp @@ -0,0 +1,128 @@ +// ROCm/HIP implementation of grouped matmul using rocBLAS +// Replaces CUTLASS-based CUDA implementation for AMD GPUs + +#include +#include +#include +#include +#include +#include + +#include "pyg_lib/csrc/utils/convert.h" + +namespace pyg { +namespace ops { + +namespace { + +// Helper to check rocBLAS status +#define ROCBLAS_CHECK(status) \ + do { \ + rocblas_status err = (status); \ + TORCH_CHECK(err == rocblas_status_success, \ + "rocBLAS error: ", rocblas_status_to_string(err)); \ + } while (0) + +// Get or create rocBLAS handle for current stream +rocblas_handle get_rocblas_handle() { + static thread_local rocblas_handle handle = nullptr; + if (handle == nullptr) { + ROCBLAS_CHECK(rocblas_create_handle(&handle)); + } + // Set stream to current HIP stream + ROCBLAS_CHECK(rocblas_set_stream(handle, at::hip::getCurrentHIPStream())); + return handle; +} + +void grouped_matmul_out_kernel(const at::TensorList input, + const at::TensorList other, + const at::TensorList out) { + const int64_t num_matrices = input.size(); + if (num_matrices == 0) + return; + + rocblas_handle handle = get_rocblas_handle(); + const float alpha = 1.0f; + const float beta = 0.0f; + + // For small number of matrices, use individual GEMM calls + // For larger batches, could use rocblas_gemm_batched_ex + for (int64_t i = 0; i < num_matrices; ++i) { + const auto& A = input[i]; + const auto& B = other[i]; + const auto& C = out[i]; + + int64_t m = A.size(0); + int64_t k = A.size(1); + int64_t n = B.size(1); + + // rocBLAS uses column-major, but our tensors are row-major + // C = A @ B in row-major is equivalent to C^T = B^T @ A^T in col-major + // So we compute: C(m,n) = A(m,k) @ B(k,n) + // In col-major: C^T(n,m) = B^T(n,k) @ A^T(k,m) + ROCBLAS_CHECK(rocblas_sgemm( + handle, + rocblas_operation_none, // B is not transposed (but we read B^T) + rocblas_operation_none, // A is not transposed (but we read A^T) + n, // rows of op(B^T) = cols of B = n + m, // cols of op(A^T) = rows of A = m + k, // inner dimension + &alpha, + B.data_ptr(), // B in row-major = B^T in col-major + n, // leading dim of B (row-major stride) + A.data_ptr(), // A in row-major = A^T in col-major + k, // leading dim of A (row-major stride) + &beta, + C.data_ptr(), // C in row-major = C^T in col-major + n // leading dim of C (row-major stride) + )); + } +} + +std::vector grouped_matmul_kernel(const at::TensorList input, + const at::TensorList other) { + std::vector out(input.size()); + std::vector input_contiguous(input.size()); + std::vector other_contiguous(other.size()); + + for (size_t i = 0; i < input.size(); ++i) { + input_contiguous[i] = input[i].contiguous(); + other_contiguous[i] = other[i].contiguous(); + out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)}); + } + + grouped_matmul_out_kernel(input_contiguous, other_contiguous, out); + return out; +} + +at::Tensor segment_matmul_kernel(const at::Tensor& input, + const at::Tensor& ptr, + const at::Tensor& other) { + const auto size = pyg::utils::size_from_ptr(ptr).cpu(); + const auto sizes = at::IntArrayRef(size.data_ptr(), size.numel()); + const auto out = input.new_empty({input.size(0), other.size(-1)}); + + auto input_splits = input.contiguous().split_with_sizes(sizes, 0); + auto other_splits = other.contiguous().split(1, 0); + auto out_splits = out.split_with_sizes(sizes, 0); + + std::vector input_vec(input_splits.begin(), input_splits.end()); + std::vector other_vec(other_splits.begin(), other_splits.end()); + std::vector out_vec(out_splits.begin(), out_splits.end()); + + grouped_matmul_out_kernel(input_vec, other_vec, out_vec); + return out; +} + +} // namespace + +// Register for HIP backend (uses same "CUDA" dispatch key on ROCm) +TORCH_LIBRARY_IMPL(pyg, CUDA, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::grouped_matmul"), + TORCH_FN(grouped_matmul_kernel)); + m.impl(TORCH_SELECTIVE_NAME("pyg::segment_matmul"), + TORCH_FN(segment_matmul_kernel)); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/cuda/random_walk_kernel.cu b/pyg_lib/csrc/sampler/cuda/random_walk_kernel.cu index b43fe07cb..fdf024e90 100644 --- a/pyg_lib/csrc/sampler/cuda/random_walk_kernel.cu +++ b/pyg_lib/csrc/sampler/cuda/random_walk_kernel.cu @@ -1,5 +1,12 @@ #include +#ifdef USE_ROCM +#include +#include +#define getCurrentCUDAStream getCurrentHIPStream +#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_HIP_KERNEL_LAUNCH_CHECK +#else #include +#endif #include namespace pyg { diff --git a/setup.py b/setup.py index ab9d61d46..aa8f947e3 100644 --- a/setup.py +++ b/setup.py @@ -55,10 +55,18 @@ def build_extension(self, ext): WITH_CUDA = torch.cuda.is_available() WITH_CUDA = bool(int(os.getenv('FORCE_CUDA', WITH_CUDA))) + # Detect HIP/ROCm (PyTorch built with ROCm has torch.version.hip set) + WITH_HIP = hasattr(torch.version, 'hip') and torch.version.hip is not None + WITH_HIP = bool(int(os.getenv('FORCE_HIP', WITH_HIP))) + + if WITH_HIP: + print(f"Building with HIP/ROCm support (torch.version.hip={torch.version.hip})") + cmake_args = [ '-DBUILD_TEST=OFF', '-DBUILD_BENCHMARK=OFF', f'-DWITH_CUDA={"ON" if WITH_CUDA else "OFF"}', + f'-DWITH_HIP={"ON" if WITH_HIP else "OFF"}', f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}', f'-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}', f'-DCMAKE_BUILD_TYPE={self.build_type}', From d9f2365834268f9ede326bad97a7374c002c90ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 06:36:20 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp | 24 +++++++++++----------- setup.py | 7 +++++-- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp b/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp index 3e49f8fd3..1a74a3179 100644 --- a/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp +++ b/pyg_lib/csrc/ops/hip/matmul_kernel_hip.cpp @@ -16,11 +16,11 @@ namespace ops { namespace { // Helper to check rocBLAS status -#define ROCBLAS_CHECK(status) \ - do { \ - rocblas_status err = (status); \ - TORCH_CHECK(err == rocblas_status_success, \ - "rocBLAS error: ", rocblas_status_to_string(err)); \ +#define ROCBLAS_CHECK(status) \ + do { \ + rocblas_status err = (status); \ + TORCH_CHECK(err == rocblas_status_success, \ + "rocBLAS error: ", rocblas_status_to_string(err)); \ } while (0) // Get or create rocBLAS handle for current stream @@ -68,14 +68,14 @@ void grouped_matmul_out_kernel(const at::TensorList input, m, // cols of op(A^T) = rows of A = m k, // inner dimension &alpha, - B.data_ptr(), // B in row-major = B^T in col-major - n, // leading dim of B (row-major stride) - A.data_ptr(), // A in row-major = A^T in col-major - k, // leading dim of A (row-major stride) + B.data_ptr(), // B in row-major = B^T in col-major + n, // leading dim of B (row-major stride) + A.data_ptr(), // A in row-major = A^T in col-major + k, // leading dim of A (row-major stride) &beta, - C.data_ptr(), // C in row-major = C^T in col-major - n // leading dim of C (row-major stride) - )); + C.data_ptr(), // C in row-major = C^T in col-major + n // leading dim of C (row-major stride) + )); } } diff --git a/setup.py b/setup.py index aa8f947e3..dc652024a 100644 --- a/setup.py +++ b/setup.py @@ -56,11 +56,14 @@ def build_extension(self, ext): WITH_CUDA = bool(int(os.getenv('FORCE_CUDA', WITH_CUDA))) # Detect HIP/ROCm (PyTorch built with ROCm has torch.version.hip set) - WITH_HIP = hasattr(torch.version, 'hip') and torch.version.hip is not None + WITH_HIP = hasattr(torch.version, + 'hip') and torch.version.hip is not None WITH_HIP = bool(int(os.getenv('FORCE_HIP', WITH_HIP))) if WITH_HIP: - print(f"Building with HIP/ROCm support (torch.version.hip={torch.version.hip})") + print( + f"Building with HIP/ROCm support (torch.version.hip={torch.version.hip})" + ) cmake_args = [ '-DBUILD_TEST=OFF',