Template implementation to allocate memory with CUDA.
This commit is contained in:
		
					parent
					
						
							
								9d305db3d1
							
						
					
				
			
			
				commit
				
					
						24a0a78ec6
					
				
			
		
					 1 changed files with 39 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -5,8 +5,46 @@
 | 
			
		|||
#ifndef RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
			
		||||
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
			
		||||
 | 
			
		||||
namespace safe_cuda {
 | 
			
		||||
#include <concepts>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include <cuda_runtime_api.h>
 | 
			
		||||
 | 
			
		||||
namespace safe_cuda::unmanaged {
 | 
			
		||||
    using deviceMallocDestroyType = decltype(&cudaFree);
 | 
			
		||||
    using hostMallocDestroyType = decltype(&cudaFreeHost);
 | 
			
		||||
 | 
			
		||||
    template<typename D>
 | 
			
		||||
    concept CudaArrayDestroyer = std::is_same_v<D, deviceMallocDestroyType> || std::is_same_v<D, hostMallocDestroyType>;
 | 
			
		||||
 | 
			
		||||
    template<typename T, CudaArrayDestroyer D> requires std::integral<T> || std::floating_point<T>
 | 
			
		||||
    using safePtrType = std::unique_ptr<T, D>;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * \brief It allocates unmanaged memory with cuda runtime API.
 | 
			
		||||
     *
 | 
			
		||||
     * It can allocate unmanaged memory on device and on Host for pinned memory.
 | 
			
		||||
     * \tparam T bare and built-in type.
 | 
			
		||||
     * \tparam D destroyer type. It determines if it allocates on device or pinned host.
 | 
			
		||||
     * \param byteDataSize
 | 
			
		||||
     * \return
 | 
			
		||||
     */
 | 
			
		||||
    template<typename T, CudaArrayDestroyer D>
 | 
			
		||||
    std::variant<safePtrType<T, D>, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept {
 | 
			
		||||
        T *ptr_tmp = nullptr;
 | 
			
		||||
        if constexpr (std::is_same_v<D, deviceMallocDestroyType>) {
 | 
			
		||||
            const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize);
 | 
			
		||||
            if (error != cudaSuccess) {
 | 
			
		||||
                return error;
 | 
			
		||||
            }
 | 
			
		||||
        } else if constexpr (std::is_same_v<D, hostMallocDestroyType>) {
 | 
			
		||||
            const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize);
 | 
			
		||||
            if (error != cudaSuccess) {
 | 
			
		||||
                return error;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return safePtrType<T, D>{ ptr_tmp, cudaFree };
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif //RAIISAFECUDA_MALLOC_UNMANAGED_HPP
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue