use macro to make device code valid in debug mode

This commit is contained in:
Tobias Meyer Andersen 2024-08-07 13:31:48 +02:00
parent 51e8bb7191
commit fe09d147b0

View File

@ -27,6 +27,14 @@
#include <vector>
#include <string>
// TODO: remove this line and instead include gpuDecorators.hpp from OPM common when it gets added
#if defined(__CUDA_ARCH__) || (defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ > 0)
#define OPM_IS_INSIDE_DEVICE_FUNCTION_TEMPORARY 1
#else
#define OPM_IS_INSIDE_DEVICE_FUNCTION_TEMPORARY 0
#endif
namespace Opm::cuistl
{
@ -362,27 +370,42 @@ private:
/// @param size The value to compare with the size of this view
__host__ __device__ void assertSameSize(size_t size) const
{
#if OPM_IS_INSIDE_DEVICE_FUNCTION_TEMPORARY
// TODO: find a better way to handle exceptions in kernels, this will possibly be printed many times
assert(size == m_numberOfElements && "Views did not have the same size");
#else
if (size != m_numberOfElements) {
OPM_THROW(std::invalid_argument,
fmt::format("Given view has {}, while we have {}.", size, m_numberOfElements));
fmt::format("Given view has {}, while this View has {}.", size, m_numberOfElements));
}
#endif
}
/// @brief Helper function to assert that the view has at least one element
__host__ __device__ void assertHasElements() const
{
#if OPM_IS_INSIDE_DEVICE_FUNCTION_TEMPORARY
// TODO: find a better way to handle exceptions in kernels, this will possibly be printed many times
assert(m_numberOfElements > 0 && "View have 0 elements");
#else
if (m_numberOfElements <= 0) {
OPM_THROW(std::invalid_argument, "We have 0 elements");
OPM_THROW(std::invalid_argument, "View have 0 elements");
}
#endif
}
/// @brief Helper function to determine if an index is within the range of valid indexes in the view
__host__ __device__ void assertInRange(size_t idx) const
{
#if OPM_IS_INSIDE_DEVICE_FUNCTION_TEMPORARY
// TODO: find a better way to handle exceptions in kernels, this will possibly be printed many times
assert(idx < m_numberOfElements && "The index provided was not in the range [0, buffersize-1]");
#else
if (idx >= m_numberOfElements) {
OPM_THROW(std::invalid_argument,
fmt::format("The index provided was not in the range [0, buffersize-1]"));
}
#endif
}
};