From cab0efeec5fd677013124b95a6c48f1e1b146703 Mon Sep 17 00:00:00 2001 From: Kjetil Olsen Lye Date: Tue, 9 May 2023 14:02:49 +0200 Subject: [PATCH] Added cublasWarnIfError/CUBLAS_WARN_IF_ERROR. --- .../linalg/cuistl/detail/CuBlasHandle.cpp | 2 +- .../linalg/cuistl/detail/cublas_safe_call.hpp | 155 ++++++++++++++---- 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/opm/simulators/linalg/cuistl/detail/CuBlasHandle.cpp b/opm/simulators/linalg/cuistl/detail/CuBlasHandle.cpp index 3771e6139..5fa8a420f 100644 --- a/opm/simulators/linalg/cuistl/detail/CuBlasHandle.cpp +++ b/opm/simulators/linalg/cuistl/detail/CuBlasHandle.cpp @@ -30,7 +30,7 @@ CuBlasHandle::CuBlasHandle() CuBlasHandle::~CuBlasHandle() { - OPM_CUBLAS_SAFE_CALL(cublasDestroy(m_handle)); + OPM_CUBLAS_WARN_IF_ERROR(cublasDestroy(m_handle)); } cublasHandle_t diff --git a/opm/simulators/linalg/cuistl/detail/cublas_safe_call.hpp b/opm/simulators/linalg/cuistl/detail/cublas_safe_call.hpp index 8b72a6be6..7866c7337 100644 --- a/opm/simulators/linalg/cuistl/detail/cublas_safe_call.hpp +++ b/opm/simulators/linalg/cuistl/detail/cublas_safe_call.hpp @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -35,33 +36,76 @@ namespace Opm::cuistl::detail return #x; \ } -/** - * @brief getCublasErrorMessage Converts an error code returned from a cublas function a human readable string. - * @param code an error code from a cublas routine - * @return a human readable string. - */ -inline std::string -getCublasErrorMessage(int code) +namespace { - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_SUCCESS); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_NOT_INITIALIZED); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_ALLOC_FAILED); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_INVALID_VALUE); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_ARCH_MISMATCH); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_MAPPING_ERROR); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_EXECUTION_FAILED); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_INTERNAL_ERROR); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_NOT_SUPPORTED); - CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_LICENSE_ERROR); + /** + * @brief getCublasErrorMessage Converts an error code returned from a cublas function a human readable string. + * @param code an error code from a cublas routine + * @return a human readable string. + */ + inline std::string getCublasErrorCodeToString(int code) + { + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_SUCCESS); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_NOT_INITIALIZED); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_ALLOC_FAILED); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_INVALID_VALUE); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_ARCH_MISMATCH); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_MAPPING_ERROR); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_EXECUTION_FAILED); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_INTERNAL_ERROR); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_NOT_SUPPORTED); + CHECK_CUBLAS_ERROR_TYPE(code, CUBLAS_STATUS_LICENSE_ERROR); - return fmt::format("UNKNOWN CUBLAS ERROR {}.", code); -} + return fmt::format("UNKNOWN CUBLAS ERROR {}.", code); + } #undef CHECK_CUBLAS_ERROR_TYPE + +} // namespace + +/** + * @brief getCublasErrorMessage generates the error message to display for a given error. + * + * @param error the error code from cublas + * @param expression the expresison (say "cublasCreate(&handle)") + * @param filename the code file the error occured in (typically __FILE__) + * @param functionName name of the function the error occured in (typically __func__) + * @param lineNumber the line number the error occured in (typically __LINE__) + * + * @todo Refactor to use std::source_location once we shift to C++20 + * + * @return An error message to be displayed. + * + * @note This function is mostly for internal use. + */ +inline std::string +getCublasErrorMessage(cublasStatus_t error, + const std::string_view& expression, + const std::string_view& filename, + const std::string_view& functionName, + size_t lineNumber) +{ + return fmt::format("cuBLAS expression did not execute correctly. Expression was: \n\n" + " {}\n\n" + "in function {}, in {}, at line {}.\n" + "CuBLAS error code was: {}\n", + expression, + functionName, + filename, + lineNumber, + getCublasErrorCodeToString(error)); +} + /** * @brief cublasSafeCall checks the return type of the CUBLAS expression (function call) and throws an exception if it * does not equal CUBLAS_STATUS_SUCCESS. * + * @param error the error code from cublas + * @param expression the expresison (say "cublasCreate(&handle)") + * @param filename the code file the error occured in (typically __FILE__) + * @param functionName name of the function the error occured in (typically __func__) + * @param lineNumber the line number the error occured in (typically __LINE__) + * * Example usage: * @code{.cpp} * #include @@ -85,18 +129,51 @@ cublasSafeCall(cublasStatus_t error, size_t lineNumber) { if (error != CUBLAS_STATUS_SUCCESS) { - OPM_THROW(std::runtime_error, - fmt::format("cuBLAS expression did not execute correctly. Expression was: \n\n" - " {}\n\n" - "in function {}, in {}, at line {}.\n" - "CuBLAS error code was: {}\n", - expression, - functionName, - filename, - lineNumber, - getCublasErrorMessage(error))); + OPM_THROW(std::runtime_error, getCublasErrorMessage(error, expression, filename, functionName, lineNumber)); } } + +/** + * @brief cublasWarnIfError checks the return type of the CUBLAS expression (function call) and issues a warning if it + * does not equal CUBLAS_STATUS_SUCCESS. + * + * @param error the error code from cublas + * @param expression the expresison (say "cublasCreate(&handle)") + * @param filename the code file the error occured in (typically __FILE__) + * @param functionName name of the function the error occured in (typically __func__) + * @param lineNumber the line number the error occured in (typically __LINE__) + * + * @return the error sent in (for convenience). + * + * Example usage: + * @code{.cpp} + * #include + * #include + * + * void some_function() { + * cublasHandle_t cublasHandle; + * cublasWarnIfError(cublasCreate(&cublasHandle), "cublasCreate(&cublasHandle)", __FILE__, __func__, __LINE__); + * } + * @endcode + * + * @note It is probably easier to use the macro OPM_CUBLAS_WARN_IF_ERROR + * @note Prefer the cublasSafeCall/OPM_CUBLAS_SAFE_CALL counterpart unless you really don't want to throw an exception. + * + * @todo Refactor to use std::source_location once we shift to C++20 + */ +inline cublasStatus_t +cublasWarnIfError(cublasStatus_t error, + const std::string_view& expression, + const std::string_view& filename, + const std::string_view& functionName, + size_t lineNumber) +{ + if (error != CUBLAS_STATUS_SUCCESS) { + OpmLog::warning(getCublasErrorMessage(error, expression, filename, functionName, lineNumber)); + } + + return error; +} } // namespace Opm::cuistl::detail /** @@ -119,4 +196,24 @@ cublasSafeCall(cublasStatus_t error, #define OPM_CUBLAS_SAFE_CALL(expression) \ ::Opm::cuistl::detail::cublasSafeCall(expression, #expression, __FILE__, __func__, __LINE__) +/** + * @brief OPM_CUBLAS_WARN_IF_ERROR checks the return type of the cublas expression (function call) and issues a warning + * if it does not equal CUBLAS_STATUS_SUCCESS. + * + * Example usage: + * @code{.cpp} + * #include + * #include + * + * void some_function() { + * cublasHandle_t cublasHandle; + * OPM_CUBLAS_WARN_IF_ERROR(cublasCreate(&cublasHandle)); + * } + * @endcode + * + * @note Prefer the cublasSafeCall/OPM_CUBLAS_SAFE_CALL counterpart unless you really don't want to throw an exception. + */ +#define OPM_CUBLAS_WARN_IF_ERROR(expression) \ + ::Opm::cuistl::detail::cublasWarnIfError(expression, #expression, __FILE__, __func__, __LINE__) + #endif // OPM_CUBLAS_SAFE_CALL_HPP