-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fix: print subbyte<T> compilation error #2783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| CUTE_HOST_DEVICE | ||
| void | ||
| print(subbyte_reference<T> ref) { | ||
| print(subbyte_reference<T>&& ref) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, but I believe this should be it's own overload?
| print(subbyte_reference<T>&& ref) { | |
| print(subbyte_reference<T> ref) { | |
| cute::print(ref.get()); | |
| } | |
| template <class T> | |
| CUTE_HOST_DEVICE | |
| void | |
| print(subbyte_reference<T>&& ref) { | |
| cute::print(ref.get()); | |
| } | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Cris,
Thank you for reviewing. But I believe we cannot retain the print(subbyte_reference<T> ref) overload because there is another overload function in the same file, line 376:
template <class T>
CUTE_HOST_DEVICE void
print(subbyte_reference<T> const& x) {
print(x.get());
}This will conflict with print(subbyte_reference ref), leading to a compilation error. As you can see from the error message I posted:
error: more than one instance of overloaded function "cuda_kernel::print" matches the argument list:
function template "void cute::print(const cute::subbyte_reference &)" (declared at line 370 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
function template "void cute::print(cute::subbyte_reference)" (declared at line 198 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
argument types are: (cute::subbyte_referencecutlass::float_e2m1_t)
print(A_tensor_fp4(0)); print("\n");
Therefore, I changed print(subbyte_reference<T> ref)to print(subbyte_reference<T>&& ref).
Alternatively, we have another approach to resolve this issue: removing the print(subbyte_reference<T> const& x)overload would also fix the problem. Could you please evaluate whether deleting this overload might impact other functionalities? If it would cause any issues, I would recommend keeping my original modification.
Thank you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer deleting the const& version and leaving only the print(subbyte_reference<T> ref) interface. Then there is only one interface with no overloads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely, I agree with your suggestions. I just pushed an update. Please take another look and let me know if it's good to merge now.
Bug Fix
Which component has the problem?
CuTe C++
Describe the bug
When printing a tensor of subbyte
<cutlass::float_e2m1_t>created bymake_fragment_like, a compilation error occurs.Steps/Code to reproduce bug
With repo at commit:
a2439551c765c5393aebe557ee75d3a0412d2211Outputs
Compilation error:
error: more than one instance of overloaded function "cuda_kernel::print" matches the argument list:
function template "void cute::print(const cute::subbyte_reference &)" (declared at line 370 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
function template "void cute::print(cute::subbyte_reference)" (declared at line 198 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
argument types are: (cute::subbyte_referencecutlass::float_e2m1_t)
print(A_tensor_fp4(0)); print("\n");
Expected behavior
compile pass & result correct.
Environment details
Additional context
The two overloads are indistinguishable at the "pass-by-value / pass-by-const-reference" level. Changing one of them to "accept only rvalues" allows the compiler to make a unique distinction: "passing an lvalue invokes the const& version, while passing an rvalue invokes the && version."
With the rvalue overload added, the code now compiles successfully and produces the expected results.
Could you please take a look and let me know your thoughts?
Thanks!
@ccecka @thakkarV