diff --git a/include/malloc_unmanaged.hpp b/include/malloc_unmanaged.hpp index 59b8c54..7700be2 100644 --- a/include/malloc_unmanaged.hpp +++ b/include/malloc_unmanaged.hpp @@ -6,18 +6,42 @@ #define RAIISAFECUDA_MALLOC_UNMANAGED_HPP #include +#include #include -#include + #include -namespace safe_cuda::unmanaged { - using deviceMallocDestroyType = decltype(&cudaFree); - using hostMallocDestroyType = decltype(&cudaFreeHost); - template - concept CudaArrayDestroyer = std::is_same_v || std::is_same_v; +namespace safe_cuda { + enum class allocType : std::uint8_t { + Unmanaged = 0, + Managed = 1, + Host = 2, + }; - template requires std::integral || std::floating_point + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFree(ptr); + } + }; + + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFree(ptr); + } + }; + + template + struct destroyType { + void operator()(T *ptr) const noexcept { + (void) cudaFreeHost(ptr); + } + }; + + template > + requires std::integral || std::floating_point using safePtrType = std::unique_ptr; /** @@ -25,25 +49,26 @@ namespace safe_cuda::unmanaged { * * It can allocate unmanaged memory on device and on Host for pinned memory. * \tparam T bare and built-in type. - * \tparam D destroyer type. It determines if it allocates on device or pinned host. + * \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. * \param byteDataSize * \return */ - template - std::variant, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept { + template + std::pair, cudaError_t> + cuda_malloc(const std::size_t byteDataSize) noexcept { T *ptr_tmp = nullptr; - if constexpr (std::is_same_v) { - const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize); - if (error != cudaSuccess) { - return error; - } - } else if constexpr (std::is_same_v) { - const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize); - if (error != cudaSuccess) { - return error; - } + cudaError_t error = cudaSuccess; + switch (alloc_type) { + case allocType::Unmanaged: + error = cudaMalloc(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; + case allocType::Host: + error = cudaMallocHost(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; + case allocType::Managed: + error = cudaMallocManaged(reinterpret_cast(&ptr_tmp), byteDataSize); + return { safePtrType{ ptr_tmp, destroyType{} }, error }; } - return safePtrType{ ptr_tmp, cudaFree }; } } diff --git a/src/stream_related.cpp b/src/stream_related.cpp index a863685..7fcac3c 100644 --- a/src/stream_related.cpp +++ b/src/stream_related.cpp @@ -3,8 +3,7 @@ namespace safe_cuda { std::variant, cudaError_t> create_stream() noexcept { cudaStream_t stream = nullptr; - const cudaError_t error = cudaStreamCreate(&stream); - if (error != cudaSuccess) { + if (const cudaError_t error = cudaStreamCreate(&stream); error != cudaSuccess) { return error; } return std::unique_ptr{ stream, cudaStreamDestroy }; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2bc2b61..6a1f8d7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,7 +7,7 @@ if (NOT Catch2_FOUND) fetchcontent_declare( Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2 - GIT_TAG v3.5.3 + GIT_TAG v3.11.0 EXCLUDE_FROM_ALL OVERRIDE_FIND_PACKAGE ) @@ -20,7 +20,15 @@ enable_testing() include(Catch) include(CatchAddTests) -add_executable(tests tests_stream.cpp) -target_link_libraries(tests Catch2::Catch2WithMain) +add_executable(tests + tests_stream.cpp + tests_safe_allocation.cpp +) +target_link_libraries(tests Catch2::Catch2WithMain raiiSafeCuda) +set_target_properties(tests PROPERTIES + CXX_STANDARD 20 + CXX_EXTENSIONS OFF + INTERPROCEDURAL_OPTIMIZATION ON +) catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/tests/tests_safe_allocation.cpp b/tests/tests_safe_allocation.cpp new file mode 100644 index 0000000..8ba3e89 --- /dev/null +++ b/tests/tests_safe_allocation.cpp @@ -0,0 +1,16 @@ +// +// Created by postaron on 13/12/2025. +// + +#include +#include + +#include "../include/malloc_unmanaged.hpp" + +TEST_CASE("Managed allocation", "[safe_alloc][0]") { + std::cout << "Safely allocates memory in CUDA" << std::endl; + const auto [safe_ptr, error] = safe_cuda::cuda_malloc(sizeof(int)); + REQUIRE(safe_ptr != nullptr); + REQUIRE(error == cudaSuccess); + std::cout << "Safely deallocates memory in CUDA" << std::endl; +}