diff --git a/include/malloc_unmanaged.hpp b/include/malloc_unmanaged.hpp index 59b8c54..7700be2 100644 --- a/include/malloc_unmanaged.hpp +++ b/include/malloc_unmanaged.hpp @@ -6,18 +6,42 @@ #define RAIISAFECUDA_MALLOC_UNMANAGED_HPP #include +#include #include -#include + #include -namespace safe_cuda::unmanaged { - using deviceMallocDestroyType = decltype(&cudaFree); - using hostMallocDestroyType = decltype(&cudaFreeHost); - template - concept CudaArrayDestroyer = std::is_same_v || std::is_same_v; +namespace safe_cuda { + enum class allocType : std::uint8_t { + Unmanaged = 0, + Managed = 1, + Host = 2, + }; - template requires std::integral || std::floating_point + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFree(ptr); + } + }; + + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFree(ptr); + } + }; + + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFreeHost(ptr); + } + }; + + template > + requires std::integral || std::floating_point using safePtrType = std::unique_ptr; /** @@ -25,25 +49,26 @@ namespace safe_cuda::unmanaged { * * It can allocate unmanaged memory on device and on Host for pinned memory. * \tparam T bare and built-in type. - * \tparam D destroyer type. It determines if it allocates on device or pinned host. + * \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. * \param byteDataSize * \return */ - template - std::variant, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept { + template + std::pair, cudaError_t> + cuda_malloc(const std::size_t byteDataSize) noexcept { T *ptr_tmp = nullptr; - if constexpr (std::is_same_v) { - const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize); - if (error != cudaSuccess) { - return error; - } - } else if constexpr (std::is_same_v) { - const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize); - if (error != cudaSuccess) { - return error; - } + cudaError_t error = cudaSuccess; + switch (alloc_type) { + case allocType::Unmanaged: + error = cudaMalloc(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; + case allocType::Host: + error = cudaMallocHost(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; + case allocType::Managed: + error = cudaMallocManaged(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; } - return safePtrType{ ptr_tmp, cudaFree }; } }