Skip to content

Allow iterators in cub::DeviceRadixSort #868

@zasdfgbnm

Description

@zasdfgbnm

Currently, cub::DeviceRadixSort only support operating on pointers

template<typename KeyT , typename ValueT >
static CUB_RUNTIME_FUNCTION
cudaError_t      SortPairs (void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT)*8, cudaStream_t stream=0, bool debug_synchronous=false)

It would be good if the d_values_in could be an iterator.

One use case is pytorch/pytorch#53841, in this PR, we are working on a sorting problem where the input keys are random numbers, and input values are 0, 1, 2, 3, ..., N. Currently, we have to generate a memory buffer to store these 0, 1, 2, ..., N, which is not optimal. It would be nice if we can do something like:

cub::CountingInputIterator iter(0);
cub::DeviceRadixSort::SortPairs(..., /*d_values_in=*/iter, /*d_values_out=*/buffer, ...);

Metadata

Metadata

Assignees

No one assigned

    Labels

    cubFor all items related to CUB

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions