From ac5b6b53c5cc4bc258b83f55961ee4ea81ff7eb2 Mon Sep 17 00:00:00 2001 From: jakobtorben Date: Tue, 17 Dec 2024 15:39:37 +0100 Subject: [PATCH] Select AMGX mode based on template scalar type --- opm/simulators/linalg/AmgxPreconditioner.hpp | 38 +++++++++++++------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/opm/simulators/linalg/AmgxPreconditioner.hpp b/opm/simulators/linalg/AmgxPreconditioner.hpp index f582e265f..e65b2f6c6 100644 --- a/opm/simulators/linalg/AmgxPreconditioner.hpp +++ b/opm/simulators/linalg/AmgxPreconditioner.hpp @@ -103,12 +103,14 @@ class AmgxPreconditioner : public Dune::PreconditionerWithUpdate public: //! \brief The matrix type the preconditioner is for using matrix_type = M; + //! \brief The field type of the matrix + using matrix_field_type = typename M::field_type; //! \brief The domain type of the preconditioner using domain_type = X; //! \brief The range type of the preconditioner using range_type = Y; - //! \brief The field type of the preconditioner - using field_type = typename X::field_type; + //! \brief The field type of the vectors + using vector_field_type = typename X::field_type; static constexpr int block_size = 1; @@ -135,11 +137,23 @@ public: // Setup frequency is set in the property tree setup_frequency_ = prm.get("setup_frequency", 30); - // Create solver and matrix/vector handles - AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, AMGX_mode_dDDI, cfg_)); - AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, AMGX_mode_dDDI)); - AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, AMGX_mode_dDDI)); - AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, AMGX_mode_dDDI)); + // Select appropriate AMGX mode based on matrix and vector scalar types + AMGX_Mode amgx_mode; + if constexpr (std::is_same_v && std::is_same_v) { + amgx_mode = AMGX_mode_dDDI; + } else if constexpr (std::is_same_v && std::is_same_v) { + amgx_mode = AMGX_mode_dDFI; + } else if constexpr (std::is_same_v && std::is_same_v) { + amgx_mode = AMGX_mode_dFFI; + } else { + OPM_THROW(std::runtime_error, "Unsupported combination of matrix and vector types in AmgxPreconditioner"); + } + + // Create solver and matrix/vector handles with selected mode + AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, amgx_mode, cfg_)); + AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, amgx_mode)); + AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, amgx_mode)); + AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, amgx_mode)); // Setup matrix structure std::vector row_ptrs(N_ + 1); @@ -147,8 +161,8 @@ public: setupSparsityPattern(row_ptrs, col_indices); // initialize matrix with values - const field_type* values = &(A_[0][0][0][0]); - AMGX_SAFE_CALL(AMGX_pin_memory(const_cast(values), sizeof(field_type) * nnz_ * block_size * block_size)); + const matrix_field_type* values = &(A_[0][0][0][0]); + AMGX_SAFE_CALL(AMGX_pin_memory(const_cast(values), sizeof(matrix_field_type) * nnz_ * block_size * block_size)); AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size, row_ptrs.data(), col_indices.data(), values, nullptr)); @@ -162,8 +176,8 @@ public: */ ~AmgxPreconditioner() { - const field_type* values = &(A_[0][0][0][0]); - AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast(values))); + const matrix_field_type* values = &(A_[0][0][0][0]); + AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast(values))); if (solver_) { AMGX_SAFE_CALL(AMGX_solver_destroy(solver_)); } @@ -304,7 +318,7 @@ private: void copyMatrixToAmgx() { // Get direct pointer to matrix values - const field_type* values = &(A_[0][0][0][0]); + const matrix_field_type* values = &(A_[0][0][0][0]); // Indexing explanation: // A_[0] - First row of the matrix // [0] - First block in that row