Adding stream for Async malloc and free
This commit is contained in:
parent
a53e5f87be
commit
d9cbb14913
1 changed files with 12 additions and 5 deletions
|
|
@ -21,6 +21,8 @@ namespace safe_cuda {
|
||||||
|
|
||||||
template<typename T, allocType alloc_type = allocType::Managed>
|
template<typename T, allocType alloc_type = allocType::Managed>
|
||||||
struct destroyType {
|
struct destroyType {
|
||||||
|
cudaStream_t stream{ nullptr };
|
||||||
|
|
||||||
void operator()(T *ptr) const noexcept {
|
void operator()(T *ptr) const noexcept {
|
||||||
(void) cudaFree(ptr);
|
(void) cudaFree(ptr);
|
||||||
}
|
}
|
||||||
|
|
@ -28,13 +30,17 @@ namespace safe_cuda {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct destroyType<T, allocType::Unmanaged> {
|
struct destroyType<T, allocType::Unmanaged> {
|
||||||
|
cudaStream_t stream{ nullptr };
|
||||||
|
|
||||||
void operator()(T *ptr) const noexcept {
|
void operator()(T *ptr) const noexcept {
|
||||||
(void) cudaFree(ptr);
|
(void) cudaFreeAsync(ptr, stream);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct destroyType<T, allocType::Host> {
|
struct destroyType<T, allocType::Host> {
|
||||||
|
cudaStream_t stream{ nullptr };
|
||||||
|
|
||||||
void operator()(T *ptr) const noexcept {
|
void operator()(T *ptr) const noexcept {
|
||||||
(void) cudaFreeHost(ptr);
|
(void) cudaFreeHost(ptr);
|
||||||
}
|
}
|
||||||
|
|
@ -51,17 +57,18 @@ namespace safe_cuda {
|
||||||
* \tparam T bare and built-in type.
|
* \tparam T bare and built-in type.
|
||||||
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
* \tparam alloc_type Type of allocation: Managed (default), Unmanage, Host.
|
||||||
* \param byteDataSize
|
* \param byteDataSize
|
||||||
|
* \param stream
|
||||||
* \return
|
* \return
|
||||||
*/
|
*/
|
||||||
template<typename T, allocType alloc_type>
|
template<typename T, allocType alloc_type>
|
||||||
std::pair<safePtrType<T, alloc_type>, cudaError_t>
|
std::pair<safePtrType<T, alloc_type>, cudaError_t> cuda_malloc(const std::size_t byteDataSize,
|
||||||
cuda_malloc(const std::size_t byteDataSize) noexcept {
|
cudaStream_t stream = nullptr) noexcept {
|
||||||
T *ptr_tmp = nullptr;
|
T *ptr_tmp = nullptr;
|
||||||
cudaError_t error = cudaSuccess;
|
cudaError_t error = cudaSuccess;
|
||||||
switch (alloc_type) {
|
switch (alloc_type) {
|
||||||
case allocType::Unmanaged:
|
case allocType::Unmanaged:
|
||||||
error = cudaMalloc(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
error = cudaMallocAsync(reinterpret_cast<void **>(&ptr_tmp), byteDataSize, stream);
|
||||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{ stream } }, error };
|
||||||
case allocType::Host:
|
case allocType::Host:
|
||||||
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
error = cudaMallocHost(reinterpret_cast<void **>(&ptr_tmp), byteDataSize);
|
||||||
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
return { safePtrType<T, alloc_type>{ ptr_tmp, destroyType<T, alloc_type>{} }, error };
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue