diff --git a/opm/simulators/linalg/bda/rocsparseSolverBackend.cpp b/opm/simulators/linalg/bda/rocsparseSolverBackend.cpp index 45284061a..13273d493 100644 --- a/opm/simulators/linalg/bda/rocsparseSolverBackend.cpp +++ b/opm/simulators/linalg/bda/rocsparseSolverBackend.cpp @@ -36,6 +36,7 @@ #include #include +#include #define HIP_CHECK(stat) \ { \ @@ -117,10 +118,18 @@ void rocsparseSolverBackend::gpu_pbicgstab([[maybe_unused]] WellCont Timer t_total, t_prec(false), t_spmv(false), t_rest(false); +// HIP_VERSION is defined as (HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR * 100000 + HIP_VERSION_PATCH) +#if HIP_VERSION >= 50400000 + ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation, + Nb, Nb, nnzb, &one, descr_M, + d_Avals, d_Arows, d_Acols, block_size, + spmv_info, d_x, &zero, d_r)); +#else ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation, Nb, Nb, nnzb, &one, descr_M, d_Avals, d_Arows, d_Acols, block_size, d_x, &zero, d_r)); +#endif ROCBLAS_CHECK(rocblas_dscal(blas_handle, N, &mone, d_r, 1)); ROCBLAS_CHECK(rocblas_daxpy(blas_handle, N, &one, d_b, 1, d_r, 1)); @@ -168,10 +177,17 @@ void rocsparseSolverBackend::gpu_pbicgstab([[maybe_unused]] WellCont } // spmv +#if HIP_VERSION >= 50400000 + ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation, + Nb, Nb, nnzb, &one, descr_M, + d_Avals, d_Arows, d_Acols, block_size, + spmv_info, d_pw, &zero, d_v)); +#else ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation, Nb, Nb, nnzb, &one, descr_M, d_Avals, d_Arows, d_Acols, block_size, d_pw, &zero, d_v)); +#endif if (verbosity >= 3) { HIP_CHECK(hipStreamSynchronize(stream)); t_spmv.stop(); @@ -214,10 +230,17 @@ void rocsparseSolverBackend::gpu_pbicgstab([[maybe_unused]] WellCont } // spmv +#if HIP_VERSION >= 50400000 + ROCSPARSE_CHECK(rocsparse_dbsrmv_ex(handle, dir, operation, + Nb, Nb, nnzb, &one, descr_M, + d_Avals, d_Arows, d_Acols, block_size, + spmv_info, d_s, &zero, d_t)); +#else ROCSPARSE_CHECK(rocsparse_dbsrmv(handle, dir, operation, Nb, Nb, nnzb, &one, descr_M, d_Avals, d_Arows, d_Acols, block_size, d_s, &zero, d_t)); +#endif if(verbosity >= 3){ HIP_CHECK(hipStreamSynchronize(stream)); t_spmv.stop(); @@ -382,7 +405,11 @@ bool rocsparseSolverBackend::analyze_matrix() { ROCSPARSE_CHECK(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); ROCSPARSE_CHECK(rocsparse_create_mat_info(&ilu_info)); +#if HIP_VERSION >= 50400000 + ROCSPARSE_CHECK(rocsparse_create_mat_info(&spmv_info)); +#endif + ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_A)); ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_M)); ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr_L)); @@ -424,6 +451,13 @@ bool rocsparseSolverBackend::analyze_matrix() { Nb, nnzbs_prec, descr_U, d_Mvals, d_Mrows, d_Mcols, \ block_size, ilu_info, rocsparse_analysis_policy_reuse, rocsparse_solve_policy_auto, d_buffer)); +#if HIP_VERSION >= 50400000 + ROCSPARSE_CHECK(rocsparse_dbsrmv_ex_analysis(handle, dir, operation, + Nb, Nb, nnzb, + descr_A, d_Avals, d_Arows, d_Acols, + block_size, spmv_info)); +#endif + if (verbosity >= 3) { HIP_CHECK(hipStreamSynchronize(stream)); std::ostringstream out; diff --git a/opm/simulators/linalg/bda/rocsparseSolverBackend.hpp b/opm/simulators/linalg/bda/rocsparseSolverBackend.hpp index 824e90116..4fa470774 100644 --- a/opm/simulators/linalg/bda/rocsparseSolverBackend.hpp +++ b/opm/simulators/linalg/bda/rocsparseSolverBackend.hpp @@ -30,6 +30,8 @@ #include #include +#include + namespace Opm { namespace Accelerator @@ -65,8 +67,11 @@ private: rocsparse_operation operation = rocsparse_operation_none; rocsparse_handle handle; rocblas_handle blas_handle; - rocsparse_mat_descr descr_M, descr_L, descr_U; + rocsparse_mat_descr descr_A, descr_M, descr_L, descr_U; rocsparse_mat_info ilu_info; +#if HIP_VERSION >= 50400000 + rocsparse_mat_info spmv_info; +#endif hipStream_t stream; rocsparse_int *d_Arows, *d_Mrows;