rocsparseCPR: add support for float Scalars

This commit is contained in:
Arne Morten Kvarving 2024-05-31 08:47:38 +02:00
parent 6bcdad6ceb
commit 3ff678b58a

View File

@ -35,6 +35,8 @@
#include <opm/simulators/linalg/bda/Misc.hpp>
#include <type_traits>
namespace Opm::Accelerator {
using Opm::OpmLog;
@ -235,8 +237,13 @@ amg_cycle_gpu(const int level,
HIP_CHECK(hipMemcpyAsync(h_y.data(), &y, sizeof(Scalar) * Ncur, hipMemcpyDeviceToHost, this->stream));
// solve coarsest level using umfpack
this->umfpack.apply(h_x.data(), h_y.data());
// The if constexpr is needed to make the code compile
// since the umfpack member is an 'int' with float Scalar.
// We will never get here with float Scalar as we throw earlier.
// Solve coarsest level using umfpack
if constexpr (std::is_same_v<Scalar,double>) {
this->umfpack.apply(h_x.data(), h_y.data());
}
HIP_CHECK(hipMemcpyAsync(&x, h_x.data(), sizeof(Scalar) * Ncur, hipMemcpyHostToDevice, this->stream));