Compare commits
No commits in common. "a53e5f87be62bf17b5c7b9e6cd53fb4783afea1f" and "24a0a78ec6a6156dfb93969bedc61e68d5899a36" have entirely different histories.
a53e5f87be
...
24a0a78ec6
4 changed files with 26 additions and 74 deletions
|
|
@ -6,42 +6,18 @@
|
||||||
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
|
||||||
|
|
||||||
#include <concepts>
|
#include <concepts>
|
||||||
#include <cstdint>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <variant>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
|
namespace safe_cuda::unmanaged {
|
||||||
|
using deviceMallocDestroyType = decltype(&cudaFree);
|
||||||
|
using hostMallocDestroyType = decltype(&cudaFreeHost);
|
||||||
|
|
||||||
namespace safe_cuda {
|
template<typename D>
|
||||||
enum class allocType : std::uint8_t {
|
concept CudaArrayDestroyer = std::is_same_v<D, deviceMallocDestroyType> || std::is_same_v<D, hostMallocDestroyType>;
|
||||||
Unmanaged = 0,
|
|
||||||
Managed = 1,
|
|
||||||
Host = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T, allocType alloc_type = allocType::Managed>
|
template<typename T, CudaArrayDestroyer D> requires std::integral<T> || std::floating_point<T>
|
||||||
struct destroyType {
|
|
||||||
void operator()(T *ptr) const noexcept {
|
|
||||||
(void) cudaFree(ptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct destroyType<T, allocType::Unmanaged> {
|
|
||||||
void operator()(T *ptr) const noexcept {
|
|
||||||
(void) cudaFree(ptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct destroyType<T, allocType::Host> {
|
|
||||||
void operator()(T *ptr) const noexcept {
|
|
||||||
(void) cudaFreeHost(ptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T, allocType alloc_type, typename D = destroyType<T, alloc_type> >
|
|
||||||
requires std::integral<T> || std::floating_point<T>
|
|
||||||
using safePtrType = std::unique_ptr<T, D>;
|
using safePtrType = std::unique_ptr<T, D>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -49,26 +25,25 @@ namespace safe_cuda {
|
||||||
*
|
*
|
||||||
* It can allocate unmanaged memory on device and on Host for pinned memory.
|
* It can allocate unmanaged memory on device and on Host for pinned memory.
|
||||||
* \tparam T bare and built-in type.
|
* \tparam T bare and built-in type.
|
||||||
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
* \tparam D destroyer type. It determines if it allocates on device or pinned host.
|
||||||
* \param byteDataSize
|
* \param byteDataSize
|
||||||
* \return
|
* \return
|
||||||
*/
|
*/
|
||||||
template<typename T, allocType alloc_type>
|
template<typename T, CudaArrayDestroyer D>
|
||||||
std::pair<safePtrType<T, alloc_type>, cudaError_t>
|
std::variant<safePtrType<T, D>, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept {
|
||||||
cuda_malloc(const std::size_t byteDataSize) noexcept {
|
|
||||||
T *ptr_tmp = nullptr;
|
T *ptr_tmp = nullptr;
|
||||||
cudaError_t error = cudaSuccess;
|
if constexpr (std::is_same_v<D, deviceMallocDestroyType>) {
|
||||||
switch (alloc_type) {
|
const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize);
|
||||||
case allocType::Unmanaged:
|
if (error != cudaSuccess) {
|
||||||
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
return error;
|
||||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
||||||
case allocType::Host:
|
|
||||||
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
|
||||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
||||||
case allocType::Managed:
|
|
||||||
error = cudaMallocManaged(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
|
||||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
|
||||||
}
|
}
|
||||||
|
} else if constexpr (std::is_same_v<D, hostMallocDestroyType>) {
|
||||||
|
const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize);
|
||||||
|
if (error != cudaSuccess) {
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return safePtrType<T, D>{ ptr_tmp, cudaFree };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@
|
||||||
namespace safe_cuda {
|
namespace safe_cuda {
|
||||||
std::variant<std::unique_ptr<CUstream_st, streamDestroyType>, cudaError_t> create_stream() noexcept {
|
std::variant<std::unique_ptr<CUstream_st, streamDestroyType>, cudaError_t> create_stream() noexcept {
|
||||||
cudaStream_t stream = nullptr;
|
cudaStream_t stream = nullptr;
|
||||||
if (const cudaError_t error = cudaStreamCreate(&stream); error != cudaSuccess) {
|
const cudaError_t error = cudaStreamCreate(&stream);
|
||||||
|
if (error != cudaSuccess) {
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
return std::unique_ptr<CUstream_st, streamDestroyType>{ stream, cudaStreamDestroy };
|
return std::unique_ptr<CUstream_st, streamDestroyType>{ stream, cudaStreamDestroy };
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ if (NOT Catch2_FOUND)
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
Catch2
|
Catch2
|
||||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||||
GIT_TAG v3.11.0
|
GIT_TAG v3.5.3
|
||||||
EXCLUDE_FROM_ALL
|
EXCLUDE_FROM_ALL
|
||||||
OVERRIDE_FIND_PACKAGE
|
OVERRIDE_FIND_PACKAGE
|
||||||
)
|
)
|
||||||
|
|
@ -20,15 +20,7 @@ enable_testing()
|
||||||
include(Catch)
|
include(Catch)
|
||||||
include(CatchAddTests)
|
include(CatchAddTests)
|
||||||
|
|
||||||
add_executable(tests
|
add_executable(tests tests_stream.cpp)
|
||||||
tests_stream.cpp
|
target_link_libraries(tests Catch2::Catch2WithMain)
|
||||||
tests_safe_allocation.cpp
|
|
||||||
)
|
|
||||||
target_link_libraries(tests Catch2::Catch2WithMain raiiSafeCuda)
|
|
||||||
set_target_properties(tests PROPERTIES
|
|
||||||
CXX_STANDARD 20
|
|
||||||
CXX_EXTENSIONS OFF
|
|
||||||
INTERPROCEDURAL_OPTIMIZATION ON
|
|
||||||
)
|
|
||||||
|
|
||||||
catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
//
|
|
||||||
// Created by postaron on 13/12/2025.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <catch2/catch_all.hpp>
|
|
||||||
|
|
||||||
#include "../include/malloc_unmanaged.hpp"
|
|
||||||
|
|
||||||
TEST_CASE("Managed allocation", "[safe_alloc][0]") {
|
|
||||||
std::cout << "Safely allocates memory in CUDA" << std::endl;
|
|
||||||
const auto [safe_ptr, error] = safe_cuda::cuda_malloc<int, safe_cuda::allocType::Managed>(sizeof(int));
|
|
||||||
REQUIRE(safe_ptr != nullptr);
|
|
||||||
REQUIRE(error == cudaSuccess);
|
|
||||||
std::cout << "Safely deallocates memory in CUDA" << std::endl;
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue