-
Notifications
You must be signed in to change notification settings - Fork 449
Add DeviceMergeSort::StableSortKeysCopy API #565
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! A few minor suggestions below.
cub/device/device_merge_sort.cuh
Outdated
template <typename KeyInputIteratorT, | ||
typename KeyIteratorT, | ||
typename OffsetT, | ||
typename CompareOpT> | ||
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED | ||
CUB_RUNTIME_FUNCTION static cudaError_t | ||
StableSortKeysCopy(void *d_temp_storage, | ||
std::size_t &temp_storage_bytes, | ||
KeyInputIteratorT d_input_keys, | ||
KeyIteratorT d_output_keys, | ||
OffsetT num_items, | ||
CompareOpT compare_op, | ||
cudaStream_t stream, | ||
bool debug_synchronous) | ||
{ | ||
CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG | ||
|
||
return StableSortKeysCopy<KeyInputIteratorT, KeyIteratorT, OffsetT, CompareOpT>(d_temp_storage, | ||
temp_storage_bytes, | ||
d_input_keys, | ||
d_output_keys, | ||
num_items, | ||
compare_op, | ||
stream); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <typename KeyInputIteratorT, | |
typename KeyIteratorT, | |
typename OffsetT, | |
typename CompareOpT> | |
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED | |
CUB_RUNTIME_FUNCTION static cudaError_t | |
StableSortKeysCopy(void *d_temp_storage, | |
std::size_t &temp_storage_bytes, | |
KeyInputIteratorT d_input_keys, | |
KeyIteratorT d_output_keys, | |
OffsetT num_items, | |
CompareOpT compare_op, | |
cudaStream_t stream, | |
bool debug_synchronous) | |
{ | |
CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG | |
return StableSortKeysCopy<KeyInputIteratorT, KeyIteratorT, OffsetT, CompareOpT>(d_temp_storage, | |
temp_storage_bytes, | |
d_input_keys, | |
d_output_keys, | |
num_items, | |
compare_op, | |
stream); | |
} |
This is not required. The debug_synchronous
overloads exist so as not to break API. For the new functions, we don't have anything to break, so let's remove this overload.
@@ -234,6 +234,18 @@ void TestKeys(std::int64_t num_items, | |||
CustomLess())); | |||
|
|||
AssertTrue(CheckResult(d_keys)); | |||
|
|||
CubDebugExit(cub::DeviceMergeSort::StableSortKeysCopy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that there might be a value in something like:
CubDebugExit(cub::DeviceMergeSort::StableSortKeysCopy( | |
thrust::fill(d_keys.begin(), d_keys.end(), KeyType{}); | |
CubDebugExit(cub::DeviceMergeSort::StableSortKeysCopy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution and the excellent documentation
CUB has a
SortKeysCopy
and aStableSortKeys
. This PR addsStableSortKeysCopy
so a sort in-place is not required. The in-placeStableSortKeys
requires a temporary vector if a copy is needed.Since
StableSortKeys
callsSortKeys
andSortKeysCopy
uses the same kernels asSortKeys
(which appear to already be stable-sort enabled), the newStableSortKeysCopy
simply callsSortKeysCopy
.