75 lines
2.4 KiB
C++
75 lines
2.4 KiB
C++
//
|
|
// Created by postaron on 25/03/24.
|
|
//
|
|
|
|
#ifndef RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|
|
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|
|
|
|
#include <concepts>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
|
|
namespace safe_cuda {
|
|
enum class allocType : std::uint8_t {
|
|
Unmanaged = 0,
|
|
Managed = 1,
|
|
Host = 2,
|
|
};
|
|
|
|
template<typename T, allocType alloc_type = allocType::Managed>
|
|
struct destroyType {
|
|
void operator()(T *ptr) const noexcept {
|
|
(void) cudaFree(ptr);
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct destroyType<T, allocType::Unmanaged> {
|
|
void operator()(T *ptr) const noexcept {
|
|
(void) cudaFree(ptr);
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct destroyType<T, allocType::Host> {
|
|
void operator()(T *ptr) const noexcept {
|
|
(void) cudaFreeHost(ptr);
|
|
}
|
|
};
|
|
|
|
template<typename T, allocType alloc_type, typename D = destroyType<T, alloc_type> >
|
|
requires std::integral<T> || std::floating_point<T>
|
|
using safePtrType = std::unique_ptr<T, D>;
|
|
|
|
/**
|
|
* \brief It allocates unmanaged memory with cuda runtime API.
|
|
*
|
|
* It can allocate unmanaged memory on device and on Host for pinned memory.
|
|
* \tparam T bare and built-in type.
|
|
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
|
* \param byteDataSize
|
|
* \return
|
|
*/
|
|
template<typename T, allocType alloc_type>
|
|
std::pair<safePtrType<T, alloc_type>, cudaError_t>
|
|
cuda_malloc(const std::size_t byteDataSize) noexcept {
|
|
T *ptr_tmp = nullptr;
|
|
cudaError_t error = cudaSuccess;
|
|
switch (alloc_type) {
|
|
case allocType::Unmanaged:
|
|
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
|
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
case allocType::Host:
|
|
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
|
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
case allocType::Managed:
|
|
error = cudaMallocManaged(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
|
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif //RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|