Merge pull request #5332 from multitalentloes/remove_thrust_dependency

remove usage of thrust
This commit is contained in:
Arne Morten Kvarving 2024-05-03 14:14:14 +02:00 committed by GitHub
commit 0cafaf92cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 13 deletions

View File

@ -205,7 +205,7 @@ template <typename T>
T
CuVector<T>::dot(const CuVector<T>& other, const CuVector<int>& indexSet, CuVector<T>& buffer) const
{
return detail::innerProductAtIndices(m_dataOnDevice, other.data(), buffer.data(), indexSet.dim(), indexSet.data());
return detail::innerProductAtIndices(m_cuBlasHandle.get(), m_dataOnDevice, other.data(), buffer.data(), indexSet.dim(), indexSet.data());
}
template <typename T>
@ -221,7 +221,7 @@ T
CuVector<T>::dot(const CuVector<T>& other, const CuVector<int>& indexSet) const
{
CuVector<T> buffer(indexSet.dim());
return detail::innerProductAtIndices(m_dataOnDevice, other.data(), buffer.data(), indexSet.dim(), indexSet.data());
return detail::innerProductAtIndices(m_cuBlasHandle.get(), m_dataOnDevice, other.data(), buffer.data(), indexSet.dim(), indexSet.data());
}
template <typename T>

View File

@ -18,10 +18,10 @@
*/
#include <opm/common/ErrorMacros.hpp>
#include <opm/simulators/linalg/cuistl/detail/vector_operations.hpp>
// TODO: [perf] Get rid of thrust.
#include <opm/simulators/linalg/cuistl/detail/cublas_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/detail/cublas_wrapper.hpp>
#include <opm/simulators/linalg/cuistl/CuVector.hpp>
#include <stdexcept>
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
namespace Opm::cuistl::detail
{
@ -139,19 +139,22 @@ template void setZeroAtIndexSet(int*, size_t, const int*);
template <class T>
T
innerProductAtIndices(const T* deviceA, const T* deviceB, T* buffer, size_t numberOfElements, const int* indices)
innerProductAtIndices(cublasHandle_t cublasHandle, const T* deviceA, const T* deviceB, T* buffer, size_t numberOfElements, const int* indices)
{
elementWiseMultiplyKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
deviceA, deviceB, buffer, numberOfElements, indices);
// TODO: [perf] Get rid of thrust and use a more direct reduction here.
auto bufferAsDevicePointer = thrust::device_pointer_cast(buffer);
return thrust::reduce(bufferAsDevicePointer, bufferAsDevicePointer + numberOfElements);
// TODO: [perf] Get rid of the allocation here.
CuVector<T> oneVector(numberOfElements);
oneVector = 1.0;
T result = 0.0;
OPM_CUBLAS_SAFE_CALL(cublasDot(cublasHandle, numberOfElements, oneVector.data(), 1, buffer, 1, &result));
return result;
}
template double innerProductAtIndices(const double*, const double*, double* buffer, size_t, const int*);
template float innerProductAtIndices(const float*, const float*, float* buffer, size_t, const int*);
template int innerProductAtIndices(const int*, const int*, int* buffer, size_t, const int*);
template double innerProductAtIndices(cublasHandle_t, const double*, const double*, double* buffer, size_t, const int*);
template float innerProductAtIndices(cublasHandle_t, const float*, const float*, float* buffer, size_t, const int*);
template int innerProductAtIndices(cublasHandle_t, const int*, const int*, int* buffer, size_t, const int*);
template <class T>
void prepareSendBuf(const T* deviceA, T* buffer, size_t numberOfElements, const int* indices)

View File

@ -19,6 +19,7 @@
#ifndef OPM_CUISTL_VECTOR_OPERATIONS_HPP
#define OPM_CUISTL_VECTOR_OPERATIONS_HPP
#include <cstddef>
#include <cublas_v2.h>
namespace Opm::cuistl::detail
{
@ -42,6 +43,7 @@ void setZeroAtIndexSet(T* deviceData, size_t numberOfElements, const int* indice
/**
* @brief innerProductAtIndices computes the inner product between deviceA[indices] and deviceB[indices]
* @param cublasHandle a valid (initialized) cublas handle
* @param deviceA data A (device memory)
* @param deviceB data B (device memory)
* @param buffer a buffer with number of elements equal to numberOfElements (device memory)
@ -53,7 +55,7 @@ void setZeroAtIndexSet(T* deviceData, size_t numberOfElements, const int* indice
* of those projected vectors.
*/
template <class T>
T innerProductAtIndices(const T* deviceA, const T* deviceB, T* buffer, size_t numberOfElements, const int* indices);
T innerProductAtIndices(cublasHandle_t cublasHandle, const T* deviceA, const T* deviceB, T* buffer, size_t numberOfElements, const int* indices);
template <class T>
void prepareSendBuf(const T* deviceA, T* buffer, size_t numberOfElements, const int* indices);