diff --git a/include/malloc_unmanaged.hpp b/include/malloc_unmanaged.hpp index 7700be2..59b8c54 100644 --- a/include/malloc_unmanaged.hpp +++ b/include/malloc_unmanaged.hpp @@ -6,42 +6,18 @@ #define RAIISAFECUDA_MALLOC_UNMANAGED_HPP #include -#include #include - +#include #include +namespace safe_cuda::unmanaged { + using deviceMallocDestroyType = decltype(&cudaFree); + using hostMallocDestroyType = decltype(&cudaFreeHost); -namespace safe_cuda { - enum class allocType : std::uint8_t { - Unmanaged = 0, - Managed = 1, - Host = 2, - }; + template + concept CudaArrayDestroyer = std::is_same_v || std::is_same_v; - template - struct destroyType { - void operator()(T *ptr) const noexcept { - (void) cudaFree(ptr); - } - }; - - template - struct destroyType { - void operator()(T *ptr) const noexcept { - (void) cudaFree(ptr); - } - }; - - template - struct destroyType { - void operator()(T *ptr) const noexcept { - (void) cudaFreeHost(ptr); - } - }; - - template > - requires std::integral || std::floating_point + template requires std::integral || std::floating_point using safePtrType = std::unique_ptr; /** @@ -49,26 +25,25 @@ namespace safe_cuda { * * It can allocate unmanaged memory on device and on Host for pinned memory. * \tparam T bare and built-in type. - * \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. + * \tparam D destroyer type. It determines if it allocates on device or pinned host. * \param byteDataSize * \return */ - template - std::pair, cudaError_t> - cuda_malloc(const std::size_t byteDataSize) noexcept { + template + std::variant, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept { T *ptr_tmp = nullptr; - cudaError_t error = cudaSuccess; - switch (alloc_type) { - case allocType::Unmanaged: - error = cudaMalloc(reinterpret_cast(&ptr_tmp), byteDataSize); - return { safePtrType{ ptr_tmp, destroyType{} }, error }; - case allocType::Host: - error = cudaMallocHost(reinterpret_cast(&ptr_tmp), byteDataSize); - return { safePtrType{ ptr_tmp, destroyType{} }, error }; - case allocType::Managed: - error = cudaMallocManaged(reinterpret_cast(&ptr_tmp), byteDataSize); - return { safePtrType{ ptr_tmp, destroyType{} }, error }; + if constexpr (std::is_same_v) { + const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize); + if (error != cudaSuccess) { + return error; + } + } else if constexpr (std::is_same_v) { + const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize); + if (error != cudaSuccess) { + return error; + } } + return safePtrType{ ptr_tmp, cudaFree }; } } diff --git a/src/stream_related.cpp b/src/stream_related.cpp index 7fcac3c..a863685 100644 --- a/src/stream_related.cpp +++ b/src/stream_related.cpp @@ -3,7 +3,8 @@ namespace safe_cuda { std::variant, cudaError_t> create_stream() noexcept { cudaStream_t stream = nullptr; - if (const cudaError_t error = cudaStreamCreate(&stream); error != cudaSuccess) { + const cudaError_t error = cudaStreamCreate(&stream); + if (error != cudaSuccess) { return error; } return std::unique_ptr{ stream, cudaStreamDestroy }; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a1f8d7..2bc2b61 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,7 +7,7 @@ if (NOT Catch2_FOUND) fetchcontent_declare( Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2 - GIT_TAG v3.11.0 + GIT_TAG v3.5.3 EXCLUDE_FROM_ALL OVERRIDE_FIND_PACKAGE ) @@ -20,15 +20,7 @@ enable_testing() include(Catch) include(CatchAddTests) -add_executable(tests - tests_stream.cpp - tests_safe_allocation.cpp -) -target_link_libraries(tests Catch2::Catch2WithMain raiiSafeCuda) -set_target_properties(tests PROPERTIES - CXX_STANDARD 20 - CXX_EXTENSIONS OFF - INTERPROCEDURAL_OPTIMIZATION ON -) +add_executable(tests tests_stream.cpp) +target_link_libraries(tests Catch2::Catch2WithMain) catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/tests/tests_safe_allocation.cpp b/tests/tests_safe_allocation.cpp deleted file mode 100644 index 8ba3e89..0000000 --- a/tests/tests_safe_allocation.cpp +++ /dev/null @@ -1,16 +0,0 @@ -// -// Created by postaron on 13/12/2025. -// - -#include -#include - -#include "../include/malloc_unmanaged.hpp" - -TEST_CASE("Managed allocation", "[safe_alloc][0]") { - std::cout << "Safely allocates memory in CUDA" << std::endl; - const auto [safe_ptr, error] = safe_cuda::cuda_malloc(sizeof(int)); - REQUIRE(safe_ptr != nullptr); - REQUIRE(error == cudaSuccess); - std::cout << "Safely deallocates memory in CUDA" << std::endl; -}