Skip to content

Commit 43c00e5

Browse files
authored
Use one cusparse handle per thread to avoid race condition on cuspars… (#544)
## Summary by CodeRabbit * **Refactor** * Improved GPU pointer and stream configuration to make GPU-accelerated solvers more consistent and reliable. * **Tests** * Re-enabled a previously-skipped barrier solver test so it now runs as part of test suite. * **Chores** * Restored default threading behavior in CI test runs by removing a forced thread-count override. Authors: - Hugo Linsenmaier (https://github.com/hlinsen) - Ramakrishnap (https://github.com/rgsl888prabhu) Approvers: - Rajesh Gandham (https://github.com/rg20) URL: #544
1 parent 5879493 commit 43c00e5

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

cpp/src/dual_simplex/cusparse_view.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ cusparse_view_t<i_t, f_t>::cusparse_view_t(raft::handle_t const* handle_ptr,
138138
d_minus_one_(f_t(-1), handle_ptr->get_stream()),
139139
d_zero_(f_t(0), handle_ptr->get_stream())
140140
{
141+
RAFT_CUBLAS_TRY(raft::linalg::detail::cublassetpointermode(
142+
handle_ptr->get_cublas_handle(), CUBLAS_POINTER_MODE_DEVICE, handle_ptr->get_stream()));
143+
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesetpointermode(
144+
handle_ptr->get_cusparse_handle(), CUSPARSE_POINTER_MODE_DEVICE, handle_ptr->get_stream()));
141145
// TMP matrix data should already be on the GPU
142146
constexpr bool debug = false;
143147
if (debug) { printf("A hash: %zu\n", A.hash()); }

cpp/src/linear_programming/solve.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,8 @@ optimization_problem_solution_t<i_t, f_t> run_concurrent(
672672
// Initialize the dual simplex structures before we run PDLP.
673673
// Otherwise, CUDA API calls to the problem stream may occur in both threads and throw graph
674674
// capture off
675-
auto barrier_handle = raft::handle_t(*op_problem.get_handle_ptr());
676675
rmm::cuda_stream_view barrier_stream = rmm::cuda_stream_per_thread;
677-
raft::resource::set_cuda_stream(barrier_handle, barrier_stream);
676+
auto barrier_handle = raft::handle_t(barrier_stream);
678677
// Make sure allocations are done on the original stream
679678
problem.handle_ptr->sync_stream();
680679

0 commit comments

Comments
 (0)