Template implementation to allocate memory with CUDA.
This commit is contained in:
		
					parent
					
						
							
								9d305db3d1
							
						
					
				
			
			
				commit
				
					
						24a0a78ec6
					
				
			
		
					 1 changed files with 39 additions and 1 deletions
				
			
		| 
						 | 
					@ -5,8 +5,46 @@
 | 
				
			||||||
#ifndef RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
					#ifndef RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
				
			||||||
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
					#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace safe_cuda {
 | 
					#include <concepts>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <variant>
 | 
				
			||||||
 | 
					#include <cuda_runtime_api.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace safe_cuda::unmanaged {
 | 
				
			||||||
 | 
					    using deviceMallocDestroyType = decltype(&cudaFree);
 | 
				
			||||||
 | 
					    using hostMallocDestroyType = decltype(&cudaFreeHost);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<typename D>
 | 
				
			||||||
 | 
					    concept CudaArrayDestroyer = std::is_same_v<D, deviceMallocDestroyType> || std::is_same_v<D, hostMallocDestroyType>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<typename T, CudaArrayDestroyer D> requires std::integral<T> || std::floating_point<T>
 | 
				
			||||||
 | 
					    using safePtrType = std::unique_ptr<T, D>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * \brief It allocates unmanaged memory with cuda runtime API.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * It can allocate unmanaged memory on device and on Host for pinned memory.
 | 
				
			||||||
 | 
					     * \tparam T bare and built-in type.
 | 
				
			||||||
 | 
					     * \tparam D destroyer type. It determines if it allocates on device or pinned host.
 | 
				
			||||||
 | 
					     * \param byteDataSize
 | 
				
			||||||
 | 
					     * \return
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    template<typename T, CudaArrayDestroyer D>
 | 
				
			||||||
 | 
					    std::variant<safePtrType<T, D>, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept {
 | 
				
			||||||
 | 
					        T *ptr_tmp = nullptr;
 | 
				
			||||||
 | 
					        if constexpr (std::is_same_v<D, deviceMallocDestroyType>) {
 | 
				
			||||||
 | 
					            const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize);
 | 
				
			||||||
 | 
					            if (error != cudaSuccess) {
 | 
				
			||||||
 | 
					                return error;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        } else if constexpr (std::is_same_v<D, hostMallocDestroyType>) {
 | 
				
			||||||
 | 
					            const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize);
 | 
				
			||||||
 | 
					            if (error != cudaSuccess) {
 | 
				
			||||||
 | 
					                return error;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return safePtrType<T, D>{ ptr_tmp, cudaFree };
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif //RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
					#endif //RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue