-
Notifications
You must be signed in to change notification settings - Fork 112
Draft: Generic atomic operation #2005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
2fc602a
f69487a
f858624
2183864
95f82c7
ae67803
cd697e2
fc718fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,12 @@ | |
|
|
||
| #if defined(RAJA_ENABLE_DESUL_ATOMICS) | ||
|
|
||
| #include <cstdint> | ||
| #include <type_traits> | ||
| #include <utility> | ||
|
|
||
| #include "RAJA/util/macros.hpp" | ||
| #include "RAJA/util/TypeConvert.hpp" | ||
|
|
||
| #include "RAJA/policy/atomic_builtin.hpp" | ||
|
|
||
|
|
@@ -27,6 +32,32 @@ using raja_default_desul_scope = desul::MemoryScopeDevice; | |
| namespace RAJA | ||
| { | ||
|
|
||
| namespace detail | ||
| { | ||
|
|
||
| template<typename T> | ||
| RAJA_HOST_DEVICE RAJA_INLINE bool desul_atomicCAS_equal(const T& a, const T& b) | ||
| { | ||
| return a == b; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes me wonder whether it makes sense to also allow a used-defined comparison function for comparing
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you thinking about this just for the short-circuiting, or for determining whether the atomic CAS happened as well?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little worried about the interface getting too complicated. |
||
| } | ||
|
|
||
| template<typename T, | ||
| std::enable_if_t<std::is_floating_point<T>::value, bool> = true> | ||
| RAJA_HOST_DEVICE RAJA_INLINE bool desul_atomicCAS_equal(const T& a, const T& b) | ||
| { | ||
| using R = std::conditional_t<sizeof(T) == sizeof(std::uint32_t), | ||
| std::uint32_t, | ||
| std::uint64_t>; | ||
| static_assert(sizeof(T) == sizeof(std::uint32_t) || | ||
| sizeof(T) == sizeof(std::uint64_t), | ||
| "desul_atomicCAS_equal only supports 32/64-bit floating point"); | ||
|
|
||
| return RAJA::util::reinterp_A_as_B<T, R>(a) == | ||
| RAJA::util::reinterp_A_as_B<T, R>(b); | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| RAJA_SUPPRESS_HD_WARN | ||
| template<typename AtomicPolicy, typename T> | ||
| RAJA_HOST_DEVICE RAJA_INLINE T atomicLoad(AtomicPolicy, T* acc) | ||
|
|
@@ -153,6 +184,34 @@ atomicCAS(AtomicPolicy, T* acc, T compare, T value) | |
| raja_default_desul_scope {}); | ||
| } | ||
|
|
||
| RAJA_SUPPRESS_HD_WARN | ||
| template<typename AtomicPolicy, typename T, typename Operation> | ||
| RAJA_HOST_DEVICE RAJA_INLINE T | ||
| atomicOperation(AtomicPolicy, T* acc, Operation&& operation) | ||
| { | ||
| T expected = desul::atomic_load(acc, | ||
| raja_default_desul_order {}, | ||
| raja_default_desul_scope {}); | ||
|
|
||
| while (true) { | ||
| const T desired = operation(expected); | ||
|
|
||
| if (desul_atomicCAS_equal(desired, expected)) { | ||
| return expected; // no-op | ||
| } | ||
|
|
||
| const T old = desul::atomic_compare_exchange(acc, expected, desired, | ||
| raja_default_desul_order {}, | ||
| raja_default_desul_scope {}); | ||
|
|
||
| if (desul_atomicCAS_equal(old, expected)) { | ||
| return old; // success | ||
| } | ||
|
Comment on lines
+199
to
+209
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get away with one conditional per iteration?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we want short-circuiting, I don't think so.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll think about this some more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I don't think there's a way to do short-circuiting without two checks. If we are really worried about performance, we could expose a couple of different overloads of atomicOperation (with a functor like the different cas loop overloads we have for CUDA and HIP), or a single overload with a boolean parameter. There are definitely cases where shortcircuiting helps (like min/max or there's a chance you are multiplying by 1 or adding 0 to a non-builtin type). But there are plenty of cases where it would just be extra overhead.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. To summarize some operators like min/max make sense to short circuit as they will act like the identity function in some cases. Others like plus/mult make less sense to short circuit as they (almost) always return a different value.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could have a single implementation with a short circuit and "old_matches_expected" operator where the short circuit functor defaults to one that always returns false and the "old_matches_expected" operator does a bit-wise comparison. |
||
|
|
||
| expected = old; // CAS failed, old is the latest observed value | ||
| } | ||
| } | ||
|
|
||
| } // namespace RAJA | ||
|
|
||
| #endif // RAJA_ENABLE_DESUL_ATOMICS | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not already have something like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's duplicated in multiple places. Is there a centralized place you would suggest moving these implementations?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhere in pattern or maybe internal would be a good place.