Skip to content

Conversation

@chrisHuxi
Copy link

Bug Fix

Which component has the problem?

CuTe C++

Describe the bug

When printing a tensor of subbyte <cutlass::float_e2m1_t> created by make_fragment_like, a compilation error occurs.

Steps/Code to reproduce bug

With repo at commit: a2439551c765c5393aebe557ee75d3a0412d2211

#include <cuda.h>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>  // cutlass::float_e4m3_t

using namespace cute;

__global__ void print_nvfp4_kernel(const __half *Aptr) {

  auto A_tensor = make_tensor(make_gmem_ptr((__half *)Aptr), make_shape(Int<16>{}, Int<16>{}), make_stride(16, Int<1>{}));

  Tensor glm_A_tensor_fp4 = make_tensor(make_gmem_ptr((cutlass::float_e2m1_t *)Aptr), make_shape(Int<16>{}, Int<16>{}), make_stride(16, Int<1>{}));

  Tensor A_tensor_fp4 = make_fragment_like<cutlass::float_e2m1_t>(A_tensor);
  
  if (cute::thread0()) {
    print("nvfp4: \n");
    print(glm_A_tensor_fp4(0)); print("\n"); // pass, result correct
    print(A_tensor_fp4(0)); print("\n"); // compilation error
  }
}

Outputs

Compilation error:

error: more than one instance of overloaded function "cuda_kernel::print" matches the argument list:
function template "void cute::print(const cute::subbyte_reference &)" (declared at line 370 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
function template "void cute::print(cute::subbyte_reference)" (declared at line 198 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
argument types are: (cute::subbyte_referencecutlass::float_e2m1_t)
print(A_tensor_fp4(0)); print("\n");

Expected behavior

compile pass & result correct.

Environment details

  • Compiler: g++ (Debian 12.2.0-14+deb12u1) 12.2.0
  • CUDA: Cuda compilation tools, release 12.8, V12.8.93
  • Build: cuda_12.8.r12.8/compiler.35583870_0

Additional context

The two overloads are indistinguishable at the "pass-by-value / pass-by-const-reference" level. Changing one of them to "accept only rvalues" allows the compiler to make a unique distinction: "passing an lvalue invokes the const& version, while passing an rvalue invokes the && version."

With the rvalue overload added, the code now compiles successfully and produces the expected results.
Could you please take a look and let me know your thoughts?

Thanks!

@ccecka @thakkarV

CUTE_HOST_DEVICE
void
print(subbyte_reference<T> ref) {
print(subbyte_reference<T>&& ref) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but I believe this should be it's own overload?

Suggested change
print(subbyte_reference<T>&& ref) {
print(subbyte_reference<T> ref) {
cute::print(ref.get());
}
template <class T>
CUTE_HOST_DEVICE
void
print(subbyte_reference<T>&& ref) {
cute::print(ref.get());
}

Copy link
Author

@chrisHuxi chrisHuxi Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Cris,

Thank you for reviewing. But I believe we cannot retain the print(subbyte_reference<T> ref) overload because there is another overload function in the same file, line 376:

template <class T>
CUTE_HOST_DEVICE void
print(subbyte_reference<T> const& x) {
  print(x.get());
}

This will conflict with print(subbyte_reference ref), leading to a compilation error. As you can see from the error message I posted:

error: more than one instance of overloaded function "cuda_kernel::print" matches the argument list:
function template "void cute::print(const cute::subbyte_reference &)" (declared at line 370 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
function template "void cute::print(cute::subbyte_reference)" (declared at line 198 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
argument types are: (cute::subbyte_referencecutlass::float_e2m1_t)
print(A_tensor_fp4(0)); print("\n");

Therefore, I changed print(subbyte_reference<T> ref)to print(subbyte_reference<T>&& ref).

Alternatively, we have another approach to resolve this issue: removing the print(subbyte_reference<T> const& x)overload would also fix the problem. Could you please evaluate whether deleting this overload might impact other functionalities? If it would cause any issues, I would recommend keeping my original modification.

Thank you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer deleting the const& version and leaving only the print(subbyte_reference<T> ref) interface. Then there is only one interface with no overloads.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, I agree with your suggestions. I just pushed an update. Please take another look and let me know if it's good to merge now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants