mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-01-14 04:31:56 -06:00
rocsparseSolverBackend: add support for float Scalars
This commit is contained in:
parent
452a0a0baa
commit
644aeb582f
@ -51,6 +51,7 @@
|
||||
#endif
|
||||
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
|
||||
namespace Opm::Accelerator {
|
||||
|
||||
@ -151,26 +152,55 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
|
||||
// HIP_VERSION is defined as (HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR * 100000 + HIP_VERSION_PATCH)
|
||||
#if HIP_VERSION >= 60000000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
}
|
||||
#elif HIP_VERSION >= 50400000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_x, &zero, d_r));
|
||||
}
|
||||
#else
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_x, &zero, d_r));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_x, &zero, d_r));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_x, &zero, d_r));
|
||||
}
|
||||
#endif
|
||||
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &mone, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_rw, 1));
|
||||
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm_0));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_sscal(blas_handle, N, &mone, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_scopy(blas_handle, N, d_r, 1, d_rw, 1));
|
||||
ROCBLAS_CHECK(rocblas_scopy(blas_handle, N, d_r, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm_0));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &mone, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_b, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_rw, 1));
|
||||
ROCBLAS_CHECK(rocblas_dcopy(blas_handle, N, d_r, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm_0));
|
||||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
std::ostringstream out;
|
||||
@ -183,14 +213,24 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
}
|
||||
for (it = 0.5; it < maxit; it += 0.5) {
|
||||
rhop = rho;
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_r, 1, &rho));
|
||||
}
|
||||
|
||||
if (it > 1) {
|
||||
beta = (rho / rhop) * (alpha / omega);
|
||||
nomega = -omega;
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &beta, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_sscal(blas_handle, N, &beta, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_v, 1, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &beta, d_p, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_r, 1, d_p, 1));
|
||||
}
|
||||
}
|
||||
if (verbosity >= 3) {
|
||||
HIP_CHECK(hipStreamSynchronize(stream));
|
||||
@ -209,20 +249,41 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
|
||||
// spmv
|
||||
#if HIP_VERSION >= 60000000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
}
|
||||
#elif HIP_VERSION >= 50400000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_pw, &zero, d_v));
|
||||
}
|
||||
#else
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_pw, &zero, d_v));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_pw, &zero, d_v));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_pw, &zero, d_v));
|
||||
}
|
||||
#endif
|
||||
if (verbosity >= 3) {
|
||||
HIP_CHECK(hipStreamSynchronize(stream));
|
||||
@ -240,12 +301,22 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
t_rest.start();
|
||||
}
|
||||
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_rw, 1, d_v, 1, &tmp1));
|
||||
}
|
||||
alpha = rho / tmp1;
|
||||
nalpha = -alpha;
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nalpha, d_v, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &alpha, d_pw, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
|
||||
}
|
||||
if (verbosity >= 3) {
|
||||
HIP_CHECK(hipStreamSynchronize(stream));
|
||||
t_rest.stop();
|
||||
@ -272,20 +343,41 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
|
||||
// spmv
|
||||
#if HIP_VERSION >= 60000000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
}
|
||||
#elif HIP_VERSION >= 50400000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
spmv_info, d_s, &zero, d_t));
|
||||
}
|
||||
#else
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_s, &zero, d_t));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_s, &zero, d_t));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation,
|
||||
Nb, Nb, nnzb, &one, descr_A,
|
||||
d_Avals, d_Arows, d_Acols, block_size,
|
||||
d_s, &zero, d_t));
|
||||
}
|
||||
#endif
|
||||
if (verbosity >= 3) {
|
||||
HIP_CHECK(hipStreamSynchronize(stream));
|
||||
@ -303,14 +395,25 @@ gpu_pbicgstab([[maybe_unused]] WellContributions<Scalar>& wellContribs,
|
||||
t_rest.start();
|
||||
}
|
||||
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
|
||||
ROCBLAS_CHECK(rocblas_sdot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_r, 1, &tmp1));
|
||||
ROCBLAS_CHECK(rocblas_ddot(blas_handle, N, d_t, 1, d_t, 1, &tmp2));
|
||||
|
||||
}
|
||||
omega = tmp1 / tmp2;
|
||||
nomega = -omega;
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
|
||||
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_saxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_snrm2(blas_handle, N, d_r, 1, &norm));
|
||||
} else {
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &omega, d_s, 1, d_x, 1));
|
||||
ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &nomega, d_t, 1, d_r, 1));
|
||||
ROCBLAS_CHECK(rocblas_dnrm2(blas_handle, N, d_r, 1, &norm));
|
||||
}
|
||||
if (verbosity >= 3) {
|
||||
HIP_CHECK(hipStreamSynchronize(stream));
|
||||
t_rest.stop();
|
||||
@ -480,15 +583,31 @@ analyze_matrix()
|
||||
ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_A));
|
||||
|
||||
#if HIP_VERSION >= 60000000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals, d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals, d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals, d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
}
|
||||
#elif HIP_VERSION >= 50400000
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals, d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
if constexpr (std::is_same_v<Scalar,float>) {
|
||||
ROCSPARSE_CHECK(rocsparse_dbsrmv_ex_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals,
|
||||
d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
} else {
|
||||
ROCSPARSE_CHECK(rocsparse_sbsrmv_ex_analysis(handle, dir, operation,
|
||||
Nb, Nb, nnzb,
|
||||
descr_A, d_Avals,
|
||||
d_Arows, d_Acols,
|
||||
block_size, spmv_info));
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!prec->analyze_matrix(&*mat)) {
|
||||
|
Loading…
Reference in New Issue
Block a user