1
0
Fork 0

Adding stream for Async malloc and free

This commit is contained in:
Pcornat 2026-01-19 22:50:34 +01:00
commit d9cbb14913
Signed by: Pcornat
GPG key ID: E0326CC678A00BDD

View file

@ -21,6 +21,8 @@ namespace safe_cuda {
template<typename T, allocType alloc_type = allocType::Managed> template<typename T, allocType alloc_type = allocType::Managed>
struct destroyType { struct destroyType {
cudaStream_t stream{ nullptr };
void operator()(T *ptr) const noexcept { void operator()(T *ptr) const noexcept {
(void) cudaFree(ptr); (void) cudaFree(ptr);
} }
@ -28,13 +30,17 @@ namespace safe_cuda {
template<typename T> template<typename T>
struct destroyType<T, allocType::Unmanaged> { struct destroyType<T, allocType::Unmanaged> {
cudaStream_t stream{ nullptr };
void operator()(T *ptr) const noexcept { void operator()(T *ptr) const noexcept {
(void) cudaFree(ptr); (void) cudaFreeAsync(ptr, stream);
} }
}; };
template<typename T> template<typename T>
struct destroyType<T, allocType::Host> { struct destroyType<T, allocType::Host> {
cudaStream_t stream{ nullptr };
void operator()(T *ptr) const noexcept { void operator()(T *ptr) const noexcept {
(void) cudaFreeHost(ptr); (void) cudaFreeHost(ptr);
} }
@ -51,17 +57,18 @@ namespace safe_cuda {
* \tparam T bare and built-in type. * \tparam T bare and built-in type.
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. * \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
* \param byteDataSize * \param byteDataSize
* \param stream
* \return * \return
*/ */
template<typename T, allocType alloc_type> template<typename T, allocType alloc_type>
std::pair<safePtrType<T, alloc_type>, cudaError_t> std::pair<safePtrType<T, alloc_type>, cudaError_t> cuda_malloc(const std::size_t byteDataSize,
cuda_malloc(const std::size_t byteDataSize) noexcept { cudaStream_t stream = nullptr) noexcept {
T *ptr_tmp = nullptr; T *ptr_tmp = nullptr;
cudaError_t error = cudaSuccess; cudaError_t error = cudaSuccess;
switch (alloc_type) { switch (alloc_type) {
case allocType::Unmanaged: case allocType::Unmanaged:
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize); error = cudaMallocAsync(reinterpret_cast<void **>(&ptr_tmp), byteDataSize, stream);
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error }; return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{ stream } }, error };
case allocType::Host: case allocType::Host:
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize); error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error }; return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };