From 046ef6cdc08d876f8b86ed961c03fdf19b7b4870 Mon Sep 17 00:00:00 2001 From: Kjetil Olsen Lye Date: Thu, 30 Mar 2023 11:02:32 +0200 Subject: [PATCH] Replaced some macro size checks with function calls. --- opm/simulators/linalg/cuistl/CuVector.cpp | 63 +++++++++++++---------- opm/simulators/linalg/cuistl/CuVector.hpp | 3 ++ 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/opm/simulators/linalg/cuistl/CuVector.cpp b/opm/simulators/linalg/cuistl/CuVector.cpp index f5d48cbf5..ca246ebd5 100644 --- a/opm/simulators/linalg/cuistl/CuVector.cpp +++ b/opm/simulators/linalg/cuistl/CuVector.cpp @@ -26,16 +26,6 @@ #include #include -#define CHECKSIZE(x) \ - if (x.m_numberOfElements != m_numberOfElements) { \ - OPM_THROW(std::invalid_argument, \ - fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements)); \ - } -#define CHECKPOSITIVESIZE \ - if (m_numberOfElements <= 0) { \ - OPM_THROW(std::invalid_argument, "We have 0 elements"); \ - } - namespace Opm::cuistl { @@ -66,7 +56,7 @@ template CuVector& CuVector::operator=(T scalar) { - CHECKPOSITIVESIZE + assertHasElements(); detail::setVectorValue(data(), detail::to_size_t(m_numberOfElements), scalar); return *this; } @@ -75,11 +65,9 @@ template CuVector& CuVector::operator=(const CuVector& other) { - CHECKPOSITIVESIZE - CHECKSIZE(other) - if (other.m_numberOfElements != m_numberOfElements) { - OPM_THROW(std::invalid_argument, "Can only copy from vector of same size."); - } + assertHasElements(); + assertSameSize(other); + OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice, other.m_dataOnDevice, detail::to_size_t(m_numberOfElements) * sizeof(T), @@ -91,8 +79,8 @@ template CuVector::CuVector(const CuVector& other) : CuVector(other.m_numberOfElements) { - CHECKPOSITIVESIZE - CHECKSIZE(other) + assertHasElements(); + assertSameSize(other); OPM_CUDA_SAFE_CALL(cudaMemcpy(m_dataOnDevice, other.m_dataOnDevice, detail::to_size_t(m_numberOfElements) * sizeof(T), @@ -140,6 +128,25 @@ CuVector::setZeroAtIndexSet(const CuVector& indexSet) detail::setZeroAtIndexSet(m_dataOnDevice, indexSet.dim(), indexSet.data()); } +template +void +CuVector::assertSameSize(const CuVector& x) const +{ + if (x.m_numberOfElements != m_numberOfElements) { + OPM_THROW(std::invalid_argument, + fmt::format("Given vector has {}, while we have {}.", x.m_numberOfElements, m_numberOfElements)); + } +} + +template +void +CuVector::assertHasElements() const +{ + if (m_numberOfElements <= 0) { + OPM_THROW(std::invalid_argument, "We have 0 elements"); + } +} + template T* CuVector::data() @@ -151,7 +158,7 @@ template CuVector& CuVector::operator*=(const T& scalar) { - CHECKPOSITIVESIZE + assertHasElements(); OPM_CUBLAS_SAFE_CALL(detail::cublasScal(m_cuBlasHandle.get(), m_numberOfElements, &scalar, data(), 1)); return *this; } @@ -160,8 +167,8 @@ template CuVector& CuVector::axpy(T alpha, const CuVector& y) { - CHECKPOSITIVESIZE - CHECKSIZE(y) + assertHasElements(); + assertSameSize(y); OPM_CUBLAS_SAFE_CALL(detail::cublasAxpy(m_cuBlasHandle.get(), m_numberOfElements, &alpha, y.data(), 1, data(), 1)); return *this; } @@ -170,8 +177,8 @@ template T CuVector::dot(const CuVector& other) const { - CHECKPOSITIVESIZE - CHECKSIZE(other) + assertHasElements(); + assertSameSize(other); T result = T(0); OPM_CUBLAS_SAFE_CALL( detail::cublasDot(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, other.data(), 1, &result)); @@ -181,7 +188,7 @@ template T CuVector::two_norm() const { - CHECKPOSITIVESIZE + assertHasElements(); T result = T(0); OPM_CUBLAS_SAFE_CALL(detail::cublasNrm2(m_cuBlasHandle.get(), m_numberOfElements, data(), 1, &result)); return result; @@ -222,8 +229,8 @@ template CuVector& CuVector::operator+=(const CuVector& other) { - CHECKPOSITIVESIZE - CHECKSIZE(other) + assertHasElements(); + assertSameSize(other); // TODO: [perf] Make a specialized version of this return axpy(1.0, other); } @@ -232,8 +239,8 @@ template CuVector& CuVector::operator-=(const CuVector& other) { - CHECKPOSITIVESIZE - CHECKSIZE(other) + assertHasElements(); + assertSameSize(other); // TODO: [perf] Make a specialized version of this return axpy(-1.0, other); } diff --git a/opm/simulators/linalg/cuistl/CuVector.hpp b/opm/simulators/linalg/cuistl/CuVector.hpp index 3aa1f6d07..de09f42e3 100644 --- a/opm/simulators/linalg/cuistl/CuVector.hpp +++ b/opm/simulators/linalg/cuistl/CuVector.hpp @@ -341,6 +341,9 @@ private: // This gives the added benifit that a size_t to int conversion error occurs during construction. const int m_numberOfElements; detail::CuBlasHandle& m_cuBlasHandle; + + void assertSameSize(const CuVector& other) const; + void assertHasElements() const; }; } // namespace Opm::cuistl #endif