Replaced some macro size checks with function calls.

This commit is contained in:
Kjetil Olsen Lye
2023-03-30 11:02:32 +02:00
parent e32b6ac0a8
commit 046ef6cdc0
2 changed files with 38 additions and 28 deletions

View File

@@ -26,16 +26,6 @@
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/detail/vector_operations.hpp>
#define CHECKSIZE(x) \
if (x.m_numberOfElements != m_numberOfElements) { \
OPM_THROW(std::invalid_argument, \
fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements)); \
}
#define CHECKPOSITIVESIZE \
if (m_numberOfElements <= 0) { \
OPM_THROW(std::invalid_argument, "We have 0 elements"); \
}
namespace Opm::cuistl
{
@@ -66,7 +56,7 @@ template <class T>
CuVector<T>&
CuVector<T>::operator=(T scalar)
{
CHECKPOSITIVESIZE
assertHasElements();
detail::setVectorValue(data(), detail::to_size_t(m_numberOfElements), scalar);
return *this;
}
@@ -75,11 +65,9 @@ template <class T>
CuVector<T>&
CuVector<T>::operator=(const CuVector<T>& other)
{
CHECKPOSITIVESIZE
CHECKSIZE(other)
if (other.m_numberOfElements != m_numberOfElements) {
OPM_THROW(std::invalid_argument, "Can only copy from vector of same size.");
}
assertHasElements();
assertSameSize(other);
OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice,
other.m_dataOnDevice,
detail::to_size_t(m_numberOfElements) * sizeof(T),
@@ -91,8 +79,8 @@ template <class T>
CuVector<T>::CuVector(const CuVector<T>& other)
: CuVector(other.m_numberOfElements)
{
CHECKPOSITIVESIZE
CHECKSIZE(other)
assertHasElements();
assertSameSize(other);
OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice,
other.m_dataOnDevice,
detail::to_size_t(m_numberOfElements) * sizeof(T),
@@ -140,6 +128,25 @@ CuVector<T>::setZeroAtIndexSet(const CuVector<int>& indexSet)
detail::setZeroAtIndexSet(m_dataOnDevice, indexSet.dim(), indexSet.data());
}
template <typename T>
void
CuVector<T>::assertSameSize(const CuVector<T>& x) const
{
if (x.m_numberOfElements != m_numberOfElements) {
OPM_THROW(std::invalid_argument,
fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements));
}
}
template <typename T>
void
CuVector<T>::assertHasElements() const
{
if (m_numberOfElements <= 0) {
OPM_THROW(std::invalid_argument, "We have 0 elements");
}
}
template <typename T>
T*
CuVector<T>::data()
@@ -151,7 +158,7 @@ template <class T>
CuVector<T>&
CuVector<T>::operator*=(const T& scalar)
{
CHECKPOSITIVESIZE
assertHasElements();
OPM_CUBLAS_SAFE_CALL(detail::cublasScal(m_cuBlasHandle.get(), m_numberOfElements, &scalar, data(), 1));
return *this;
}
@@ -160,8 +167,8 @@ template <class T>
CuVector<T>&
CuVector<T>::axpy(T alpha, const CuVector<T>& y)
{
CHECKPOSITIVESIZE
CHECKSIZE(y)
assertHasElements();
assertSameSize(y);
OPM_CUBLAS_SAFE_CALL(detail::cublasAxpy(m_cuBlasHandle.get(), m_numberOfElements, &alpha, y.data(), 1, data(), 1));
return *this;
}
@@ -170,8 +177,8 @@ template <class T>
T
CuVector<T>::dot(const CuVector<T>& other) const
{
CHECKPOSITIVESIZE
CHECKSIZE(other)
assertHasElements();
assertSameSize(other);
T result = T(0);
OPM_CUBLAS_SAFE_CALL(
detail::cublasDot(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, other.data(), 1, &result));
@@ -181,7 +188,7 @@ template <class T>
T
CuVector<T>::two_norm() const
{
CHECKPOSITIVESIZE
assertHasElements();
T result = T(0);
OPM_CUBLAS_SAFE_CALL(detail::cublasNrm2(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, &result));
return result;
@@ -222,8 +229,8 @@ template <class T>
CuVector<T>&
CuVector<T>::operator+=(const CuVector<T>& other)
{
CHECKPOSITIVESIZE
CHECKSIZE(other)
assertHasElements();
assertSameSize(other);
// TODO: [perf] Make a specialized version of this
return axpy(1.0, other);
}
@@ -232,8 +239,8 @@ template <class T>
CuVector<T>&
CuVector<T>::operator-=(const CuVector<T>& other)
{
CHECKPOSITIVESIZE
CHECKSIZE(other)
assertHasElements();
assertSameSize(other);
// TODO: [perf] Make a specialized version of this
return axpy(-1.0, other);
}

View File

@@ -341,6 +341,9 @@ private:
// This gives the added benifit that a size_t to int conversion error occurs during construction.
const int m_numberOfElements;
detail::CuBlasHandle& m_cuBlasHandle;
void assertSameSize(const CuVector<T>& other) const;
void assertHasElements() const;
};
} // namespace Opm::cuistl
#endif