rocsparseSolverBackend: add support for float Scalars

This commit is contained in:
Arne Morten Kvarving 2024-05-31 08:47:38 +02:00
parent 452a0a0baa
commit 644aeb582f

View File

@ -51,6 +51,7 @@
#endif
#include <cstddef>
#include <type_traits>
namespace Opm::Accelerator {
@ -151,26 +152,55 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
// HIP_VERSION is defined as (HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR * 100000 + HIP_VERSION_PATCH)
#if HIP_VERSION >= 60000000
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
}
#elif HIP_VERSION >= 50400000
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_x, &zero, d_r));
}
#else
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_x, &zero, d_r));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_x, &zero, d_r));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_x, &zero, d_r));
}
#endif
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &mone, d_r, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_rw, 1));
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm_0));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_sscal(blas_handle, N, &mone, d_r, 1));
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_scopy(blas_handle, N, d_r, 1, d_rw, 1));
ROCBLAS_CHECK(rocblas_scopy(blas_handle, N, d_r, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm_0));
} else {
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &mone, d_r, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_rw, 1));
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm_0));
}
if (verbosity >= 2) {
std::ostringstream out;
@ -183,14 +213,24 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
}
for (it = 0.5; it < maxit; it += 0.5) {
rhop = rho;
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
} else {
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
}
if (it > 1) {
beta = (rho / rhop) * (alpha / omega);
nomega = -omega;
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &beta, d_p, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_sscal(blas_handle, N, &beta, d_p, 1));
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
} else {
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &beta, d_p, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
}
}
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(stream));
@ -209,20 +249,41 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
// spmv
#if HIP_VERSION >= 60000000
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
}
#elif HIP_VERSION >= 50400000
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_pw, &zero, d_v));
}
#else
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_pw, &zero, d_v));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_pw, &zero, d_v));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_pw, &zero, d_v));
}
#endif
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(stream));
@ -240,12 +301,22 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
t_rest.start();
}
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
} else {
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
}
alpha = rho / tmp1;
nalpha = -alpha;
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm));
} else {
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
}
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(stream));
t_rest.stop();
@ -272,20 +343,41 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
// spmv
#if HIP_VERSION >= 60000000
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
}
#elif HIP_VERSION >= 50400000
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
spmv_info, d_s, &zero, d_t));
}
#else
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_s, &zero, d_t));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_s, &zero, d_t));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
Nb, Nb, nnzb, &one, descr_A,
d_Avals, d_Arows, d_Acols, block_size,
d_s, &zero, d_t));
}
#endif
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(stream));
@ -303,14 +395,25 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
t_rest.start();
}
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
} else {
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
}
omega = tmp1 / tmp2;
nomega = -omega;
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
if constexpr (std::is_same_v<Scalar,float>) {
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm));
} else {
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
}
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(stream));
t_rest.stop();
@ -480,15 +583,31 @@ analyze_matrix()
ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_A));
#if HIP_VERSION >= 60000000
ROCSPARSE_CHECK(rocsparse_dbsrmv_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals, d_Arows, d_Acols,
block_size, spmv_info));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrmv_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals, d_Arows, d_Acols,
block_size, spmv_info));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrmv_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals, d_Arows, d_Acols,
block_size, spmv_info));
}
#elif HIP_VERSION >= 50400000
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals, d_Arows, d_Acols,
block_size, spmv_info));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals,
d_Arows, d_Acols,
block_size, spmv_info));
} else {
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex_analysis(handle, dir, operation,
Nb, Nb, nnzb,
descr_A, d_Avals,
d_Arows, d_Acols,
block_size, spmv_info));
}
#endif
if(!prec->analyze_matrix(&*mat)) {