mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-02-25 18:55:30 -06:00
Replaced some macro size checks with function calls.
This commit is contained in:
@@ -26,16 +26,6 @@
|
||||
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
|
||||
#include <opm/simulators/linalg/cuistl/detail/vector_operations.hpp>
|
||||
|
||||
#define CHECKSIZE(x) \
|
||||
if (x.m_numberOfElements != m_numberOfElements) { \
|
||||
OPM_THROW(std::invalid_argument, \
|
||||
fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements)); \
|
||||
}
|
||||
#define CHECKPOSITIVESIZE \
|
||||
if (m_numberOfElements <= 0) { \
|
||||
OPM_THROW(std::invalid_argument, "We have 0 elements"); \
|
||||
}
|
||||
|
||||
namespace Opm::cuistl
|
||||
{
|
||||
|
||||
@@ -66,7 +56,7 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::operator=(T scalar)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
assertHasElements();
|
||||
detail::setVectorValue(data(), detail::to_size_t(m_numberOfElements), scalar);
|
||||
return *this;
|
||||
}
|
||||
@@ -75,11 +65,9 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::operator=(const CuVector<T>& other)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(other)
|
||||
if (other.m_numberOfElements != m_numberOfElements) {
|
||||
OPM_THROW(std::invalid_argument, "Can only copy from vector of same size.");
|
||||
}
|
||||
assertHasElements();
|
||||
assertSameSize(other);
|
||||
|
||||
OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice,
|
||||
other.m_dataOnDevice,
|
||||
detail::to_size_t(m_numberOfElements) * sizeof(T),
|
||||
@@ -91,8 +79,8 @@ template <class T>
|
||||
CuVector<T>::CuVector(const CuVector<T>& other)
|
||||
: CuVector(other.m_numberOfElements)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(other)
|
||||
assertHasElements();
|
||||
assertSameSize(other);
|
||||
OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice,
|
||||
other.m_dataOnDevice,
|
||||
detail::to_size_t(m_numberOfElements) * sizeof(T),
|
||||
@@ -140,6 +128,25 @@ CuVector<T>::setZeroAtIndexSet(const CuVector<int>& indexSet)
|
||||
detail::setZeroAtIndexSet(m_dataOnDevice, indexSet.dim(), indexSet.data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
CuVector<T>::assertSameSize(const CuVector<T>& x) const
|
||||
{
|
||||
if (x.m_numberOfElements != m_numberOfElements) {
|
||||
OPM_THROW(std::invalid_argument,
|
||||
fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
CuVector<T>::assertHasElements() const
|
||||
{
|
||||
if (m_numberOfElements <= 0) {
|
||||
OPM_THROW(std::invalid_argument, "We have 0 elements");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T*
|
||||
CuVector<T>::data()
|
||||
@@ -151,7 +158,7 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::operator*=(const T& scalar)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
assertHasElements();
|
||||
OPM_CUBLAS_SAFE_CALL(detail::cublasScal(m_cuBlasHandle.get(), m_numberOfElements, &scalar, data(), 1));
|
||||
return *this;
|
||||
}
|
||||
@@ -160,8 +167,8 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::axpy(T alpha, const CuVector<T>& y)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(y)
|
||||
assertHasElements();
|
||||
assertSameSize(y);
|
||||
OPM_CUBLAS_SAFE_CALL(detail::cublasAxpy(m_cuBlasHandle.get(), m_numberOfElements, &alpha, y.data(), 1, data(), 1));
|
||||
return *this;
|
||||
}
|
||||
@@ -170,8 +177,8 @@ template <class T>
|
||||
T
|
||||
CuVector<T>::dot(const CuVector<T>& other) const
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(other)
|
||||
assertHasElements();
|
||||
assertSameSize(other);
|
||||
T result = T(0);
|
||||
OPM_CUBLAS_SAFE_CALL(
|
||||
detail::cublasDot(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, other.data(), 1, &result));
|
||||
@@ -181,7 +188,7 @@ template <class T>
|
||||
T
|
||||
CuVector<T>::two_norm() const
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
assertHasElements();
|
||||
T result = T(0);
|
||||
OPM_CUBLAS_SAFE_CALL(detail::cublasNrm2(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, &result));
|
||||
return result;
|
||||
@@ -222,8 +229,8 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::operator+=(const CuVector<T>& other)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(other)
|
||||
assertHasElements();
|
||||
assertSameSize(other);
|
||||
// TODO: [perf] Make a specialized version of this
|
||||
return axpy(1.0, other);
|
||||
}
|
||||
@@ -232,8 +239,8 @@ template <class T>
|
||||
CuVector<T>&
|
||||
CuVector<T>::operator-=(const CuVector<T>& other)
|
||||
{
|
||||
CHECKPOSITIVESIZE
|
||||
CHECKSIZE(other)
|
||||
assertHasElements();
|
||||
assertSameSize(other);
|
||||
// TODO: [perf] Make a specialized version of this
|
||||
return axpy(-1.0, other);
|
||||
}
|
||||
|
||||
@@ -341,6 +341,9 @@ private:
|
||||
// This gives the added benifit that a size_t to int conversion error occurs during construction.
|
||||
const int m_numberOfElements;
|
||||
detail::CuBlasHandle& m_cuBlasHandle;
|
||||
|
||||
void assertSameSize(const CuVector<T>& other) const;
|
||||
void assertHasElements() const;
|
||||
};
|
||||
} // namespace Opm::cuistl
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user