Finished doing real safe pointer type for built_in types
This commit is contained in:
parent
1ebcaf6f97
commit
81d04d6332
1 changed files with 46 additions and 21 deletions
|
|
@ -6,18 +6,42 @@
|
|||
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|
||||
|
||||
#include <concepts>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <variant>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace safe_cuda::unmanaged {
|
||||
using deviceMallocDestroyType = decltype(&cudaFree);
|
||||
using hostMallocDestroyType = decltype(&cudaFreeHost);
|
||||
|
||||
template<typename D>
|
||||
concept CudaArrayDestroyer = std::is_same_v<D, deviceMallocDestroyType> || std::is_same_v<D, hostMallocDestroyType>;
|
||||
namespace safe_cuda {
|
||||
enum class allocType : std::uint8_t {
|
||||
Unmanaged = 0,
|
||||
Managed = 1,
|
||||
Host = 2,
|
||||
};
|
||||
|
||||
template<typename T, CudaArrayDestroyer D> requires std::integral<T> || std::floating_point<T>
|
||||
template<typename T, allocType alloc_type = allocType::Managed>
|
||||
struct destroyType {
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFree(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct destroyType<T, allocType::Unmanaged> {
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFree(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct destroyType<T, allocType::Host> {
|
||||
void operator()(T *ptr) const noexcept {
|
||||
(void) cudaFreeHost(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, allocType alloc_type, typename D = destroyType<T, alloc_type> >
|
||||
requires std::integral<T> || std::floating_point<T>
|
||||
using safePtrType = std::unique_ptr<T, D>;
|
||||
|
||||
/**
|
||||
|
|
@ -25,25 +49,26 @@ namespace safe_cuda::unmanaged {
|
|||
*
|
||||
* It can allocate unmanaged memory on device and on Host for pinned memory.
|
||||
* \tparam T bare and built-in type.
|
||||
* \tparam D destroyer type. It determines if it allocates on device or pinned host.
|
||||
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
||||
* \param byteDataSize
|
||||
* \return
|
||||
*/
|
||||
template<typename T, CudaArrayDestroyer D>
|
||||
std::variant<safePtrType<T, D>, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept {
|
||||
template<typename T, allocType alloc_type>
|
||||
std::pair<safePtrType<T, alloc_type>, cudaError_t>
|
||||
cuda_malloc(const std::size_t byteDataSize) noexcept {
|
||||
T *ptr_tmp = nullptr;
|
||||
if constexpr (std::is_same_v<D, deviceMallocDestroyType>) {
|
||||
const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize);
|
||||
if (error != cudaSuccess) {
|
||||
return error;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<D, hostMallocDestroyType>) {
|
||||
const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize);
|
||||
if (error != cudaSuccess) {
|
||||
return error;
|
||||
}
|
||||
cudaError_t error = cudaSuccess;
|
||||
switch (alloc_type) {
|
||||
case allocType::Unmanaged:
|
||||
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||
case allocType::Host:
|
||||
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||
case allocType::Managed:
|
||||
error = cudaMallocManaged(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||
}
|
||||
return safePtrType<T, D>{ ptr_tmp, cudaFree };
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue