Adding stream for Async malloc and free
This commit is contained in:
parent
a53e5f87be
commit
d9cbb14913
1 changed files with 12 additions and 5 deletions
|
|
@ -21,6 +21,8 @@ namespace safe_cuda {
|
|||
|
||||
template<typename T, allocType alloc_type = allocType::Managed>
|
||||
struct destroyType {
|
||||
cudaStream_t stream{ nullptr };
|
||||
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFree(ptr);
|
||||
}
|
||||
|
|
@ -28,13 +30,17 @@ namespace safe_cuda {
|
|||
|
||||
template<typename T>
|
||||
struct destroyType<T, allocType::Unmanaged> {
|
||||
cudaStream_t stream{ nullptr };
|
||||
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFree(ptr);
|
||||
(void) cudaFreeAsync(ptr, stream);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct destroyType<T, allocType::Host> {
|
||||
cudaStream_t stream{ nullptr };
|
||||
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFreeHost(ptr);
|
||||
}
|
||||
|
|
@ -51,17 +57,18 @@ namespace safe_cuda {
|
|||
* \tparam T bare and built-in type.
|
||||
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
||||
* \param byteDataSize
|
||||
* \param stream
|
||||
* \return
|
||||
*/
|
||||
template<typename T, allocType alloc_type>
|
||||
std::pair<safePtrType<T, alloc_type>, cudaError_t>
|
||||
cuda_malloc(const std::size_t byteDataSize) noexcept {
|
||||
std::pair<safePtrType<T, alloc_type>, cudaError_t> cuda_malloc(const std::size_t byteDataSize,
|
||||
cudaStream_t stream = nullptr) noexcept {
|
||||
T *ptr_tmp = nullptr;
|
||||
cudaError_t error = cudaSuccess;
|
||||
switch (alloc_type) {
|
||||
case allocType::Unmanaged:
|
||||
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||
error = cudaMallocAsync(reinterpret_cast<void **>(&ptr_tmp), byteDataSize, stream);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{ stream } }, error };
|
||||
case allocType::Host:
|
||||
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue