diff --git a/cub/cub/device/dispatch/tuning/tuning_segmented_sort.cuh b/cub/cub/device/dispatch/tuning/tuning_segmented_sort.cuh index c19350147ab..d3399013ae1 100644 --- a/cub/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_segmented_sort.cuh @@ -391,7 +391,70 @@ struct policy_hub LOAD_LDG>; }; - using MaxPolicy = Policy860; + struct Policy900 : ChainedPolicy<900, Policy900, Policy860> + { + static constexpr int BLOCK_THREADS = 256; + static constexpr int PARTITIONING_THRESHOLD = 500; + using LargeSegmentPolicy = AgentRadixSortDownsweepPolicy< + BLOCK_THREADS, + 23, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + (sizeof(KeyT) > 1) ? 6 : 4>; + + static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems(9); + static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems(KEYS_ONLY ? 7 : 11); + + using SmallSegmentPolicy = + AgentSubWarpMergeSortPolicy; + using MediumSegmentPolicy = + AgentSubWarpMergeSortPolicy; + }; + + struct Policy1200 : ChainedPolicy<1200, Policy1200, Policy900> + { + static constexpr int BLOCK_THREADS = 256; + static constexpr int PARTITIONING_THRESHOLD = 500; + using LargeSegmentPolicy = AgentRadixSortDownsweepPolicy< + BLOCK_THREADS, + 23, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + (sizeof(KeyT) > 1) ? 6 : 4>; + + static constexpr bool LARGE_ITEMS = sizeof(DominantT) > 4; + static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems(LARGE_ITEMS ? 7 : 9); + static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems(LARGE_ITEMS ? 9 : 7); + + using SmallSegmentPolicy = + AgentSubWarpMergeSortPolicy; + using MediumSegmentPolicy = + AgentSubWarpMergeSortPolicy; + }; + + using MaxPolicy = Policy1200; }; } // namespace detail::segmented_sort