Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ option(BUILD_TEST "Enable testing" OFF)
option(BUILD_BENCHMARK "Enable benchmarks" OFF)
option(WITH_COV "Enable code coverage" OFF)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_HIP "Enable HIP/ROCm support" OFF)

if(NOT WIN32 AND NOT DEFINED USE_CXX11_ABI)
find_package(Python3 COMPONENTS Interpreter REQUIRED)
Expand Down Expand Up @@ -46,7 +47,7 @@ if(USE_MKL_BLAS AND DEFINED BLAS_INCLUDE_DIR)
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/pyg_lib/csrc/config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/pyg_lib/csrc/config.h")

if(WITH_CUDA)
if(WITH_CUDA AND NOT WITH_HIP)
enable_language(CUDA)
add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr -allow-unsupported-compiler")
Expand Down Expand Up @@ -75,11 +76,68 @@ if(WITH_CUDA)
include_directories(${CUCOLLECTIONS_DIR})
endif()

# HIP/ROCm support
if(WITH_HIP)
# Find HIP package
if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH "/opt/rocm")
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()

list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}")
find_package(HIP REQUIRED)

add_definitions(-DWITH_CUDA) # Use same macro for GPU code paths
add_definitions(-DUSE_ROCM)
add_definitions(-D__HIP_PLATFORM_AMD__)

# Let CMake find HIP compiler (don't set hipcc directly - CMake 3.21+ prefers clang)
enable_language(HIP)

# GPU architectures for AMD (gfx900=Vega, gfx906=MI50, gfx908=MI100,
# gfx90a=MI200, gfx940/942=MI300, gfx1100=RDNA3)
set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a;gfx940;gfx942;gfx1100")

# Find rocBLAS for GEMM operations
find_package(rocblas REQUIRED)
find_package(hipblas REQUIRED)

# Find rocThrust (HIP port of Thrust)
find_path(ROCTHRUST_INCLUDE_DIR thrust/version.h
HINTS ${ROCM_PATH}/include)
if(ROCTHRUST_INCLUDE_DIR)
include_directories(${ROCTHRUST_INCLUDE_DIR})
message(STATUS "Found rocThrust: ${ROCTHRUST_INCLUDE_DIR}")
endif()

message(STATUS "Building with HIP/ROCm support")
message(STATUS " ROCM_PATH: ${ROCM_PATH}")
message(STATUS " HIP_VERSION: ${HIP_VERSION}")
endif()

set(CSRC pyg_lib/csrc)
file(GLOB_RECURSE ALL_SOURCES ${CSRC}/*.cpp)
if (WITH_CUDA)
if(WITH_CUDA AND NOT WITH_HIP)
file(GLOB_RECURSE ALL_SOURCES ${ALL_SOURCES} ${CSRC}/*.cu)
endif()
if(WITH_HIP)
# For HIP, compile .cu files as HIP sources
file(GLOB_RECURSE HIP_SOURCES ${CSRC}/*.cu)
# Exclude CUTLASS-dependent files (matmul_kernel.cu) and cuCollections files
# (hash_map.cu) - these have HIP replacements in hip/ subdirs
list(FILTER HIP_SOURCES EXCLUDE REGEX ".*/cuda/matmul_kernel\\.cu$")
list(FILTER HIP_SOURCES EXCLUDE REGEX ".*/cuda/hash_map\\.cu$")
set_source_files_properties(${HIP_SOURCES} PROPERTIES LANGUAGE HIP)
list(APPEND ALL_SOURCES ${HIP_SOURCES})
# Add HIP-specific implementations (rocBLAS matmul, HIP hashmap)
# These need to be compiled as HIP to get proper HIP/ATen integration
file(GLOB_RECURSE HIP_CPP_SOURCES ${CSRC}/*/hip/*.cpp)
set_source_files_properties(${HIP_CPP_SOURCES} PROPERTIES LANGUAGE HIP)
list(APPEND ALL_SOURCES ${HIP_CPP_SOURCES})
message(STATUS "HIP sources: ${HIP_SOURCES}")
message(STATUS "HIP C++ sources: ${HIP_CPP_SOURCES}")
endif()
add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES})
target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
if(MKL_INCLUDE_FOUND)
Expand Down Expand Up @@ -122,13 +180,21 @@ if(OpenMP_CXX_FOUND)
target_link_libraries(${PROJECT_NAME} PRIVATE OpenMP::OpenMP_CXX)
endif()

if(WITH_CUDA)
if(WITH_CUDA AND NOT WITH_HIP)
target_include_directories(${PROJECT_NAME} PRIVATE
third_party/cccl/thrust
third_party/cccl/cub
third_party/cccl/libcudacxx/include)
endif()

if(WITH_HIP)
# Link rocBLAS and hipBLAS for GEMM operations
# Targets are roc::rocblas and roc::hipblas per ROCm cmake configs
target_link_libraries(${PROJECT_NAME} PRIVATE roc::rocblas roc::hipblas)
# Include HIP runtime
target_include_directories(${PROJECT_NAME} PRIVATE ${ROCM_PATH}/include)
endif()

set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0")

if(BUILD_TEST)
Expand Down
163 changes: 163 additions & 0 deletions pyg_lib/csrc/classes/hip/hash_map_hip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// ROCm/HIP implementation of GPU HashMap
// Uses sorted array + binary search via ATen (not thrust directly)
// This avoids the cuda/cccl header conflicts with rocThrust

#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <torch/library.h>

namespace pyg {
namespace classes {

namespace {

#define DISPATCH_CASE_KEY(...) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define DISPATCH_KEY(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__))

struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
virtual at::Tensor keys() = 0;
virtual int64_t size() = 0;
virtual at::ScalarType dtype() = 0;
virtual at::Device device() = 0;
};

// HIP implementation using sorted arrays + binary search (ATen-only, no thrust)
template <typename KeyType>
struct HIPHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

HIPHashMapImpl(const at::Tensor& key, double load_factor)
: device_(key.device()) {
// Store sorted keys and their original indices
const auto options = key.options();
const auto value_options =
options.dtype(c10::CppTypeToScalarType<ValueType>::value);

// Create value tensor (indices 0 to N-1)
sorted_values_ = at::arange(key.numel(), value_options);

// Clone keys and sort together with values
sorted_keys_ = key.clone();

// Sort by keys, permuting values accordingly
auto sort_result = at::sort(sorted_keys_);
sorted_keys_ = std::get<0>(sort_result);
auto sort_indices = std::get<1>(sort_result);
sorted_values_ = sorted_values_.index_select(0, sort_indices);
}

at::Tensor get(const at::Tensor& query) override {
// Use searchsorted to find positions, then verify matches
auto positions = at::searchsorted(sorted_keys_, query);
auto out = sorted_values_.new_full({query.numel()}, -1);

// Clamp positions to valid range
positions = at::clamp(positions, 0, sorted_keys_.numel() - 1);

// Get keys at found positions
auto found_keys = sorted_keys_.index_select(0, positions);

// Create mask where query matches found key
auto mask = (found_keys == query);

// Get values where mask is true
auto found_values = sorted_values_.index_select(0, positions);
out = at::where(mask, found_values, out);

return out;
}

at::Tensor keys() override {
// Return keys in original order (unsort using values as indices)
auto perm = at::empty_like(sorted_values_);
perm.scatter_(0, sorted_values_,
at::arange(sorted_values_.numel(), sorted_values_.options()));
return sorted_keys_.index_select(0, perm);
}

int64_t size() override { return sorted_keys_.numel(); }

at::ScalarType dtype() override {
if (std::is_same<KeyType, int16_t>::value) {
return at::kShort;
} else if (std::is_same<KeyType, int32_t>::value) {
return at::kInt;
} else {
return at::kLong;
}
}

at::Device device() override { return device_; }

private:
at::Tensor sorted_keys_;
at::Tensor sorted_values_;
at::Device device_;
};

struct HIPHashMap : torch::CustomClassHolder {
public:
HIPHashMap(const at::Tensor& key, double load_factor = 0.5) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"HIPHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA); // CUDA type for ROCm
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

DISPATCH_KEY(key.scalar_type(), "hip_hash_map_init", [&] {
map_ = std::make_unique<HIPHashMapImpl<scalar_t>>(key, load_factor);
});
}

at::Tensor get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"HIPHashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CUDA);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

return map_->get(query);
}

at::Tensor keys() { return map_->keys(); }
int64_t size() { return map_->size(); }
at::ScalarType dtype() { return map_->dtype(); }
at::Device device() { return map_->device(); }

private:
std::unique_ptr<HashMapImpl> map_;
};

} // namespace

// Note: For ROCm, we register CUDAHashMap but use HIP implementation
// The dispatch key "CUDA" works for both CUDA and ROCm in PyTorch
TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.class_<HIPHashMap>("CUDAHashMap")
.def(torch::init<at::Tensor&, double>())
.def("get", &HIPHashMap::get)
.def("keys", &HIPHashMap::keys)
.def("size", &HIPHashMap::size)
.def("dtype", &HIPHashMap::dtype)
.def("device", &HIPHashMap::device)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<HIPHashMap>& self) -> at::Tensor {
return self->keys();
},
// __setstate__
[](const at::Tensor& state) -> c10::intrusive_ptr<HIPHashMap> {
return c10::make_intrusive<HIPHashMap>(state);
});
}

} // namespace classes
} // namespace pyg
8 changes: 8 additions & 0 deletions pyg_lib/csrc/library.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "library.h"

#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda.h>
#endif
#endif

#include <torch/library.h>

Expand All @@ -18,7 +22,11 @@ namespace pyg {

int64_t cuda_version() {
#ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION;
#endif
#else
return -1;
#endif
Expand Down
7 changes: 7 additions & 0 deletions pyg_lib/csrc/ops/cuda/sampled_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#define getCurrentCUDAStream getCurrentHIPStream
#define C10_CUDA_KERNEL_LAUNCH_CHECK C10_HIP_KERNEL_LAUNCH_CHECK
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#include <torch/library.h>

namespace pyg {
Expand Down
Loading
Loading