Skip to content

Conversation

bernhardmgruber
Copy link
Contributor

@bernhardmgruber bernhardmgruber commented Oct 15, 2025

Due to lack of a good example for a SASS test, I used this simple example:

#include <cuda/barrier>
#include <cuda/ptx>

// selects a single leader thread from the block
__device__ bool elect_one() {
  // elect_sync is important to help the optimizer generate a uniform datapath
  return cuda::ptx::elect_sync(~0) && threadIdx.x < 32;
}

__global__ void example_kernel(int* gmem1, double* gmem2) {
  constexpr int tile_size = 1024;
  __shared__ alignas(16)    int smem1[tile_size];
  __shared__ alignas(16) double smem2[tile_size];
  #pragma nv_diag_suppress static_var_with_dynamic_init
  using barrier_t = cuda::barrier<cuda::thread_scope_block>;
  __shared__  barrier_t bar;
  // setup the barrier where only the leader thread arrives
  if (elect_one()) {
    init(&bar, 1);
    // issue two TMA bulk copy operations
    cuda::device::memcpy_async_tx(smem1, gmem1, cuda::aligned_size_t<16>(tile_size * sizeof(int)   ), bar);
    cuda::device::memcpy_async_tx(smem2, gmem2, cuda::aligned_size_t<16>(tile_size * sizeof(double)), bar);
    // arrive and update the barrier's expect_tx with the **total** number of loaded bytes
    (void)cuda::device::barrier_arrive_tx(bar, 1, tile_size * (sizeof(int) + sizeof(double)));
  }
  __syncthreads(); // need to sync so the barrier is set up when the other threads arrive and wait
  // wait for the current barrier phase to complete
  bar.wait_parity(0);
  // process data in smem ...
}

Compiled for sm100, it does differ in SASS a little bit, but the compiler just flipped a branch:
image

@bernhardmgruber bernhardmgruber requested a review from a team as a code owner October 15, 2025 15:12
@github-project-automation github-project-automation bot moved this to Todo in CCCL Oct 15, 2025
@cccl-authenticator-app cccl-authenticator-app bot moved this from Todo to In Review in CCCL Oct 15, 2025
Copy link
Contributor

@miscco miscco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks technically correct, but the formatting is atrocious. Could we add an else to the conditions, as all early branches are returning?

Comment on lines +132 to +139
if (!::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::cluster_shared))
{
return __barrier.arrive(__update);
}
if (!::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::shared))
{
::__trap();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is strange, because the first condition takes anything but cluster_shared, so the second one seems wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything that is shared is also cluster_shared. Because the shared memory address space is part of the cluster shared memory space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we trap for any barrier that is in cluster shared memory, but not in the shared memory of the current CTA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we turn that into an else if

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could not be a precondition instead?

Copy link
Contributor

😬 CI Workflow Results

🟥 Finished in 4h 41m: Pass: 42%/84 | Total: 7h 57m | Max: 39m 16s | Hits: 99%/25847

See results here.

::__cvta_generic_to_shared(&__barrier)))
: "memory");
}))
(if (::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::cluster_shared)) { ::__trap(); }))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda::std::terminate() ?

}
if (!::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::shared))
{
::__trap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be a _CCCL_ASSERT instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should probably assert here. And check the documentation whether we make it clear that barriers ought not live in cluster shared memory.

unsigned int __activeA = ::__match_any_sync(__mask, __update);
unsigned int __activeB = ::__match_any_sync(__mask, reinterpret_cast<::cuda::std::uintptr_t>(&__barrier));
unsigned int __active = __activeA & __activeB;
int __inc = ::__popc(__active) * __update;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth to move to their C++ versions

Comment on lines +132 to +139
if (!::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::cluster_shared))
{
return __barrier.arrive(__update);
}
if (!::cuda::device::is_object_from(__barrier, ::cuda::device::address_space::shared))
{
::__trap();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could not be a precondition instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

3 participants