diff --git a/cudax/include/cuda/experimental/__memory_resource/any_resource.cuh b/cudax/include/cuda/experimental/__memory_resource/any_resource.cuh index 3de89efd038..440e0a13bb2 100644 --- a/cudax/include/cuda/experimental/__memory_resource/any_resource.cuh +++ b/cudax/include/cuda/experimental/__memory_resource/any_resource.cuh @@ -365,6 +365,21 @@ struct _CCCL_DECLSPEC_EMPTY_BASES synchronous_resource_ref template synchronous_resource_ref(::cuda::mr::resource_ref<_OtherProperties...>) = delete; + _CCCL_TEMPLATE(class... _OtherProperties) + _CCCL_REQUIRES((::cuda::std::__type_set_contains_v<::cuda::std::__type_set<_OtherProperties...>, _Properties...>) ) + synchronous_resource_ref& operator=(const synchronous_resource_ref<_OtherProperties...>& __other) noexcept + { + __basic_any_access::__cast_to( + const_cast&>(__other).__get_base(), __get_base()); + return *this; + } + + synchronous_resource_ref& operator=(const synchronous_resource_ref& __other) noexcept + { + __basic_any_access::__cast_to(const_cast(__other).__get_base(), __get_base()); + return *this; + } + using default_queries = ::cuda::mr::properties_list<_Properties...>; private: @@ -372,7 +387,16 @@ private: "The properties of cuda::experimental::synchronous_resource_ref must contain at least one execution " "space " "property!"); + + template + friend struct synchronous_resource_ref; + using __base::interface; + + __base& __get_base() noexcept + { + return *this; + } }; //! @brief Type erased wrapper around a `synchronous_resource` that satisfies \tparam _Properties @@ -389,6 +413,20 @@ struct _CCCL_DECLSPEC_EMPTY_BASES resource_ref // Inherit other constructors from __basic_any _LIBCUDACXX_DELEGATE_CONSTRUCTORS(resource_ref, ::cuda::__basic_any, experimental::__iasync_resource<_Properties...>&); + _CCCL_TEMPLATE(class... _OtherProperties) + _CCCL_REQUIRES((::cuda::std::__type_set_contains_v<::cuda::std::__type_set<_OtherProperties...>, _Properties...>) ) + resource_ref& operator=(const resource_ref<_OtherProperties...>& __other) noexcept + { + __basic_any_access::__cast_to(const_cast&>(__other).__get_base(), __get_base()); + return *this; + } + + resource_ref& operator=(const resource_ref& __other) noexcept + { + __basic_any_access::__cast_to(const_cast(__other).__get_base(), __get_base()); + return *this; + } + using default_queries = ::cuda::mr::properties_list<_Properties...>; private: @@ -398,6 +436,8 @@ private: template friend struct synchronous_resource_ref; + template + friend struct resource_ref; using __base::interface; diff --git a/cudax/test/memory_resource/any_async_resource.cu b/cudax/test/memory_resource/any_async_resource.cu index 503f2a0587e..091f1e13cf4 100644 --- a/cudax/test/memory_resource/any_async_resource.cu +++ b/cudax/test/memory_resource/any_async_resource.cu @@ -181,4 +181,24 @@ TEMPLATE_TEST_CASE_METHOD(test_fixture, "any_resource", "[container][resource]", this->counts = Counts(); } +TEMPLATE_TEST_CASE_METHOD( + test_fixture, "ref assignment operators", "[container][resource]", big_resource, small_resource) +{ + big_resource mr{42, this}; + cudax::resource_ref<::cuda::mr::host_accessible, get_data> ref{mr}; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 42); + + big_resource mr2{43, this}; + cudax::resource_ref<::cuda::mr::host_accessible, get_data> ref2{mr2}; + ref = ref2; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 43); + + cudax::resource_ref<::cuda::mr::host_accessible, get_data, extra_property> ref3{mr}; + ref = ref3; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 42); +} + #endif // __CUDA_ARCH__ diff --git a/cudax/test/memory_resource/any_resource.cu b/cudax/test/memory_resource/any_resource.cu index c380ab02b8a..5d67ef0a257 100644 --- a/cudax/test/memory_resource/any_resource.cu +++ b/cudax/test/memory_resource/any_resource.cu @@ -287,3 +287,23 @@ TEMPLATE_TEST_CASE_METHOD( // Reset the counters: this->counts = Counts(); } + +TEMPLATE_TEST_CASE_METHOD( + test_fixture, "synchronous ref assignment operators", "[container][resource]", big_resource, small_resource) +{ + big_resource mr{42, this}; + cudax::synchronous_resource_ref<::cuda::mr::host_accessible, get_data> ref{mr}; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 42); + + big_resource mr2{43, this}; + cudax::synchronous_resource_ref<::cuda::mr::host_accessible, get_data> ref2{mr2}; + ref = ref2; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 43); + + cudax::synchronous_resource_ref<::cuda::mr::host_accessible, get_data, extra_property> ref3{mr}; + ref = ref3; + CHECK(ref.allocate_sync(bytes(100), align(8)) == this); + CHECK(get_property(ref, get_data{}) == 42); +} diff --git a/cudax/test/memory_resource/test_resource.cuh b/cudax/test/memory_resource/test_resource.cuh index b4beaa4fc64..8ea352cec8b 100644 --- a/cudax/test/memory_resource/test_resource.cuh +++ b/cudax/test/memory_resource/test_resource.cuh @@ -96,6 +96,9 @@ struct get_data using value_type = int; }; +struct extra_property +{}; + template struct test_resource { @@ -213,6 +216,7 @@ struct test_resource } friend constexpr void get_property(const test_resource&, ::cuda::mr::host_accessible) noexcept {} + friend constexpr void get_property(const test_resource&, extra_property) noexcept {} friend constexpr int get_property(const test_resource& self, get_data) noexcept { return self.data;