- 
                Notifications
    You must be signed in to change notification settings 
- Fork 284
Open
NVIDIA/cub
#374Labels
cubFor all items related to CUBFor all items related to CUB
Description
Currently, cub::DeviceRadixSort only support operating on pointers
template<typename KeyT , typename ValueT >
static CUB_RUNTIME_FUNCTION
cudaError_t      SortPairs (void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT)*8, cudaStream_t stream=0, bool debug_synchronous=false)It would be good if the d_values_in could be an iterator.
One use case is pytorch/pytorch#53841, in this PR, we are working on a sorting problem where the input keys are random numbers, and input values are 0, 1, 2, 3, ..., N. Currently, we have to generate a memory buffer to store these 0, 1, 2, ..., N, which is not optimal. It would be nice if we can do something like:
cub::CountingInputIterator iter(0);
cub::DeviceRadixSort::SortPairs(..., /*d_values_in=*/iter, /*d_values_out=*/buffer, ...);cliffburdick
Metadata
Metadata
Assignees
Labels
cubFor all items related to CUBFor all items related to CUB
Type
Projects
Status
Todo