1
0
Fork 0

Compare commits

..

No commits in common. "a53e5f87be62bf17b5c7b9e6cd53fb4783afea1f" and "24a0a78ec6a6156dfb93969bedc61e68d5899a36" have entirely different histories.

4 changed files with 26 additions and 74 deletions

View file

@ -6,42 +6,18 @@
#define RAIISAFECUDA_MALLOC_UNMANAGED_HPP #define RAIISAFECUDA_MALLOC_UNMANAGED_HPP
#include <concepts> #include <concepts>
#include <cstdint>
#include <memory> #include <memory>
#include <variant>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
namespace safe_cuda::unmanaged {
using deviceMallocDestroyType = decltype(&cudaFree);
using hostMallocDestroyType = decltype(&cudaFreeHost);
namespace safe_cuda { template<typename D>
enum class allocType : std::uint8_t { concept CudaArrayDestroyer = std::is_same_v<D, deviceMallocDestroyType> || std::is_same_v<D, hostMallocDestroyType>;
Unmanaged = 0,
Managed = 1,
Host = 2,
};
template<typename T, allocType alloc_type = allocType::Managed> template<typename T, CudaArrayDestroyer D> requires std::integral<T> || std::floating_point<T>
struct destroyType {
void operator()(T *ptr) const noexcept {
(void) cudaFree(ptr);
}
};
template<typename T>
struct destroyType<T, allocType::Unmanaged> {
void operator()(T *ptr) const noexcept {
(void) cudaFree(ptr);
}
};
template<typename T>
struct destroyType<T, allocType::Host> {
void operator()(T *ptr) const noexcept {
(void) cudaFreeHost(ptr);
}
};
template<typename T, allocType alloc_type, typename D = destroyType<T, alloc_type> >
requires std::integral<T> || std::floating_point<T>
using safePtrType = std::unique_ptr<T, D>; using safePtrType = std::unique_ptr<T, D>;
/** /**
@ -49,26 +25,25 @@ namespace safe_cuda {
* *
* It can allocate unmanaged memory on device and on Host for pinned memory. * It can allocate unmanaged memory on device and on Host for pinned memory.
* \tparam T bare and built-in type. * \tparam T bare and built-in type.
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host. * \tparam D destroyer type. It determines if it allocates on device or pinned host.
* \param byteDataSize * \param byteDataSize
* \return * \return
*/ */
template<typename T, allocType alloc_type> template<typename T, CudaArrayDestroyer D>
std::pair<safePtrType<T, alloc_type>, cudaError_t> std::variant<safePtrType<T, D>, cudaError_t> cuda_malloc(const std::size_t byteDataSize) noexcept {
cuda_malloc(const std::size_t byteDataSize) noexcept {
T *ptr_tmp = nullptr; T *ptr_tmp = nullptr;
cudaError_t error = cudaSuccess; if constexpr (std::is_same_v<D, deviceMallocDestroyType>) {
switch (alloc_type) { const cudaError_t error = cudaMalloc(&ptr_tmp, byteDataSize);
case allocType::Unmanaged: if (error != cudaSuccess) {
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize); return error;
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error }; }
case allocType::Host: } else if constexpr (std::is_same_v<D, hostMallocDestroyType>) {
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize); const cudaError_t error = cudaMallocHost(&ptr_tmp, byteDataSize);
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error }; if (error != cudaSuccess) {
case allocType::Managed: return error;
error = cudaMallocManaged(reinterpret_cast<void **>(&ptr_tmp), byteDataSize); }
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
} }
return safePtrType<T, D>{ ptr_tmp, cudaFree };
} }
} }

View file

@ -3,7 +3,8 @@
namespace safe_cuda { namespace safe_cuda {
std::variant<std::unique_ptr<CUstream_st, streamDestroyType>, cudaError_t> create_stream() noexcept { std::variant<std::unique_ptr<CUstream_st, streamDestroyType>, cudaError_t> create_stream() noexcept {
cudaStream_t stream = nullptr; cudaStream_t stream = nullptr;
if (const cudaError_t error = cudaStreamCreate(&stream); error != cudaSuccess) { const cudaError_t error = cudaStreamCreate(&stream);
if (error != cudaSuccess) {
return error; return error;
} }
return std::unique_ptr<CUstream_st, streamDestroyType>{ stream, cudaStreamDestroy }; return std::unique_ptr<CUstream_st, streamDestroyType>{ stream, cudaStreamDestroy };

View file

@ -7,7 +7,7 @@ if (NOT Catch2_FOUND)
fetchcontent_declare( fetchcontent_declare(
Catch2 Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2
GIT_TAG v3.11.0 GIT_TAG v3.5.3
EXCLUDE_FROM_ALL EXCLUDE_FROM_ALL
OVERRIDE_FIND_PACKAGE OVERRIDE_FIND_PACKAGE
) )
@ -20,15 +20,7 @@ enable_testing()
include(Catch) include(Catch)
include(CatchAddTests) include(CatchAddTests)
add_executable(tests add_executable(tests tests_stream.cpp)
tests_stream.cpp target_link_libraries(tests Catch2::Catch2WithMain)
tests_safe_allocation.cpp
)
target_link_libraries(tests Catch2::Catch2WithMain raiiSafeCuda)
set_target_properties(tests PROPERTIES
CXX_STANDARD 20
CXX_EXTENSIONS OFF
INTERPROCEDURAL_OPTIMIZATION ON
)
catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) catch_discover_tests(tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})

View file

@ -1,16 +0,0 @@
//
// Created by postaron on 13/12/2025.
//
#include <iostream>
#include <catch2/catch_all.hpp>
#include "../include/malloc_unmanaged.hpp"
TEST_CASE("Managed allocation", "[safe_alloc][0]") {
std::cout << "Safely allocates memory in CUDA" << std::endl;
const auto [safe_ptr, error] = safe_cuda::cuda_malloc<int, safe_cuda::allocType::Managed>(sizeof(int));
REQUIRE(safe_ptr != nullptr);
REQUIRE(error == cudaSuccess);
std::cout << "Safely deallocates memory in CUDA" << std::endl;
}