rocsparseBILU0: add support for float Scalars

This commit is contained in:
Arne Morten Kvarving 2024-05-31 08:47:38 +02:00
parent 644aeb582f
commit 6bcdad6ceb

View File

@ -29,8 +29,9 @@
#include <opm/simulators/linalg/bda/Misc.hpp>
#include <sstream>
#include <thread>
#include <type_traits>
extern std::shared_ptr<std::thread> copyThread;
#if HAVE_OPENMP
@ -112,23 +113,63 @@ analyze_matrix(BlockedMatrix<Scalar>*,
ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_U));
ROCSPARSE_CHECK(rocsparse_set_mat_fill_mode(descr_U, rocsparse_fill_mode_upper));
ROCSPARSE_CHECK(rocsparse_set_mat_diag_type(descr_U, rocsparse_diag_type_non_unit));
ROCSPARSE_CHECK(rocsparse_dbsrilu0_buffer_size(this->handle, this->dir, Nb, this->nnzbs_prec, descr_M, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, &d_bufferSize_M));
ROCSPARSE_CHECK(rocsparse_dbsrsv_buffer_size(this->handle, this->dir, this->operation, Nb, this->nnzbs_prec,
descr_L, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, &d_bufferSize_L));
ROCSPARSE_CHECK(rocsparse_dbsrsv_buffer_size(this->handle, this->dir, this->operation, Nb, this->nnzbs_prec,
descr_U, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, &d_bufferSize_U));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrilu0_buffer_size(this->handle, this->dir, Nb,
this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_M));
ROCSPARSE_CHECK(rocsparse_sbsrsv_buffer_size(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, descr_L,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_L));
ROCSPARSE_CHECK(rocsparse_sbsrsv_buffer_size(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, descr_U,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_U));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrilu0_buffer_size(this->handle, this->dir, Nb,
this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_M));
ROCSPARSE_CHECK(rocsparse_dbsrsv_buffer_size(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, descr_L,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_L));
ROCSPARSE_CHECK(rocsparse_dbsrsv_buffer_size(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, descr_U,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info, &d_bufferSize_U));
}
d_bufferSize = std::max(d_bufferSize_M, std::max(d_bufferSize_L, d_bufferSize_U));
HIP_CHECK(hipMalloc((void**)&d_buffer, d_bufferSize));
// analysis of ilu LU decomposition
ROCSPARSE_CHECK(rocsparse_dbsrilu0_analysis(this->handle, this->dir, \
Nb, this->nnzbs_prec, descr_M, d_Mvals, d_Mrows, d_Mcols, \
block_size, ilu_info, rocsparse_analysis_policy_reuse, rocsparse_solve_policy_auto, d_buffer));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrilu0_analysis(this->handle, this->dir,
Nb, this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrilu0_analysis(this->handle, this->dir,
Nb, this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
}
int zero_position = 0;
rocsparse_status status = rocsparse_bsrilu0_zero_pivot(this->handle, ilu_info, &zero_position);
@ -138,12 +179,33 @@ analyze_matrix(BlockedMatrix<Scalar>*,
}
// analysis of ilu apply
ROCSPARSE_CHECK(rocsparse_dbsrsv_analysis(this->handle, this->dir, this->operation, \
Nb, this->nnzbs_prec, descr_L, d_Mvals, d_Mrows, d_Mcols, \
block_size, ilu_info, rocsparse_analysis_policy_reuse, rocsparse_solve_policy_auto, d_buffer));
ROCSPARSE_CHECK(rocsparse_dbsrsv_analysis(this->handle, this->dir, this->operation, \
Nb, this->nnzbs_prec, descr_U, d_Mvals, d_Mrows, d_Mcols, \
block_size, ilu_info, rocsparse_analysis_policy_reuse, rocsparse_solve_policy_auto, d_buffer));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrsv_analysis(this->handle, this->dir, this->operation,
Nb, this->nnzbs_prec, descr_L,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
ROCSPARSE_CHECK(rocsparse_sbsrsv_analysis(this->handle, this->dir, this->operation,
Nb, this->nnzbs_prec, descr_U, d_Mvals,
d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrsv_analysis(this->handle, this->dir, this->operation,
Nb, this->nnzbs_prec, descr_L,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
ROCSPARSE_CHECK(rocsparse_dbsrsv_analysis(this->handle, this->dir, this->operation,
Nb, this->nnzbs_prec, descr_U, d_Mvals,
d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_analysis_policy_reuse,
rocsparse_solve_policy_auto, d_buffer));
}
if (verbosity >= 3) {
HIP_CHECK(hipStreamSynchronize(this->stream));
@ -168,13 +230,25 @@ create_preconditioner(BlockedMatrix<Scalar>*,
{
Timer t;
bool result = true;
ROCSPARSE_CHECK(rocsparse_dbsrilu0(this->handle, this->dir, Nb, this->nnzbs_prec, descr_M, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, rocsparse_solve_policy_auto, d_buffer));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrilu0(this->handle, this->dir, Nb,
this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_solve_policy_auto, d_buffer));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrilu0(this->handle, this->dir, Nb,
this->nnzbs_prec, descr_M,
d_Mvals, d_Mrows, d_Mcols,
block_size, ilu_info,
rocsparse_solve_policy_auto, d_buffer));
}
// Check for zero pivot
int zero_position = 0;
rocsparse_status status = rocsparse_bsrilu0_zero_pivot(this->handle, ilu_info, &zero_position);
if(rocsparse_status_success != status)
if (rocsparse_status_success != status)
{
printf("L has structural and/or numerical zero at L(%d,%d)\n", zero_position, zero_position);
return false;
@ -257,13 +331,39 @@ apply(Scalar& y, Scalar& x) {
Timer t_apply;
ROCSPARSE_CHECK(rocsparse_dbsrsv_solve(this->handle, this->dir, \
this->operation, Nb, this->nnzbs_prec, &one, \
descr_L, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, &y, d_t, rocsparse_solve_policy_auto, d_buffer));
if constexpr (std::is_same_v<Scalar,float>) {
ROCSPARSE_CHECK(rocsparse_sbsrsv_solve(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, &one,
descr_L, d_Mvals, d_Mrows,
d_Mcols, block_size, ilu_info,
&y, d_t, rocsparse_solve_policy_auto,
d_buffer));
ROCSPARSE_CHECK(rocsparse_dbsrsv_solve(this->handle, this->dir, \
this->operation, Nb, this->nnzbs_prec, &one, \
descr_U, d_Mvals, d_Mrows, d_Mcols, block_size, ilu_info, d_t, &x, rocsparse_solve_policy_auto, d_buffer));
ROCSPARSE_CHECK(rocsparse_sbsrsv_solve(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, &one,
descr_U, d_Mvals, d_Mrows,
d_Mcols, block_size, ilu_info,
d_t, &x, rocsparse_solve_policy_auto,
d_buffer));
} else {
ROCSPARSE_CHECK(rocsparse_dbsrsv_solve(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, &one,
descr_L, d_Mvals, d_Mrows,
d_Mcols, block_size, ilu_info,
&y, d_t, rocsparse_solve_policy_auto,
d_buffer));
ROCSPARSE_CHECK(rocsparse_dbsrsv_solve(this->handle, this->dir,
this->operation, Nb,
this->nnzbs_prec, &one,
descr_U, d_Mvals, d_Mrows,
d_Mcols, block_size, ilu_info,
d_t, &x, rocsparse_solve_policy_auto,
d_buffer));
}
if (verbosity >= 3) {
std::ostringstream out;