diff --git a/include/malloc_unmanaged.hpp b/include/malloc_unmanaged.hpp index 7700be2..83ba80c 100644 --- a/include/malloc_unmanaged.hpp +++ b/include/malloc_unmanaged.hpp @@ -21,6 +21,8 @@ namespace safe_cuda { template struct destroyType { + cudaStream_t stream{ nullptr }; + void operator()(T *ptr) const noexcept { (void) cudaFree(ptr); } @@ -28,13 +30,17 @@ namespace safe_cuda { template struct destroyType { + cudaStream_t stream{ nullptr }; + void operator()(T *ptr) const noexcept { - (void) cudaFree(ptr); + (void) cudaFreeAsync(ptr, stream); } }; template struct destroyType { + cudaStream_t stream{ nullptr }; + void operator()(T *ptr) const noexcept { (void) cudaFreeHost(ptr); } @@ -51,17 +57,18 @@ namespace safe_cuda { * \tparam T bare and built-in type. * \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. * \param byteDataSize + * \param stream * \return */ template - std::pair, cudaError_t> - cuda_malloc(const std::size_t byteDataSize) noexcept { + std::pair, cudaError_t> cuda_malloc(const std::size_t byteDataSize, + cudaStream_t stream = nullptr) noexcept { T *ptr_tmp = nullptr; cudaError_t error = cudaSuccess; switch (alloc_type) { case allocType::Unmanaged: - error = cudaMalloc(reinterpret_cast(&ptr_tmp), byteDataSize); - return { safePtrType{ ptr_tmp, destroyType{} }, error }; + error = cudaMallocAsync(reinterpret_cast(&ptr_tmp), byteDataSize, stream); + return { safePtrType{ ptr_tmp, destroyType{ stream } }, error }; case allocType::Host: error = cudaMallocHost(reinterpret_cast(&ptr_tmp), byteDataSize); return { safePtrType{ ptr_tmp, destroyType{} }, error };