Select AMGX mode based on template scalar type

This commit is contained in:
jakobtorben 2024-12-17 15:39:37 +01:00
parent 0290fd0e9f
commit ac5b6b53c5

View File

@ -103,12 +103,14 @@ class AmgxPreconditioner : public Dune::PreconditionerWithUpdate<X,Y>
public: public:
//! \brief The matrix type the preconditioner is for //! \brief The matrix type the preconditioner is for
using matrix_type = M; using matrix_type = M;
//! \brief The field type of the matrix
using matrix_field_type = typename M::field_type;
//! \brief The domain type of the preconditioner //! \brief The domain type of the preconditioner
using domain_type = X; using domain_type = X;
//! \brief The range type of the preconditioner //! \brief The range type of the preconditioner
using range_type = Y; using range_type = Y;
//! \brief The field type of the preconditioner //! \brief The field type of the vectors
using field_type = typename X::field_type; using vector_field_type = typename X::field_type;
static constexpr int block_size = 1; static constexpr int block_size = 1;
@ -135,11 +137,23 @@ public:
// Setup frequency is set in the property tree // Setup frequency is set in the property tree
setup_frequency_ = prm.get<int>("setup_frequency", 30); setup_frequency_ = prm.get<int>("setup_frequency", 30);
// Create solver and matrix/vector handles // Select appropriate AMGX mode based on matrix and vector scalar types
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, AMGX_mode_dDDI, cfg_)); AMGX_Mode amgx_mode;
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, AMGX_mode_dDDI)); if constexpr (std::is_same_v<matrix_field_type, double> && std::is_same_v<vector_field_type, double>) {
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, AMGX_mode_dDDI)); amgx_mode = AMGX_mode_dDDI;
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, AMGX_mode_dDDI)); } else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, double>) {
amgx_mode = AMGX_mode_dDFI;
} else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, float>) {
amgx_mode = AMGX_mode_dFFI;
} else {
OPM_THROW(std::runtime_error, "Unsupported combination of matrix and vector types in AmgxPreconditioner");
}
// Create solver and matrix/vector handles with selected mode
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, amgx_mode, cfg_));
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, amgx_mode));
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, amgx_mode));
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, amgx_mode));
// Setup matrix structure // Setup matrix structure
std::vector<int> row_ptrs(N_ + 1); std::vector<int> row_ptrs(N_ + 1);
@ -147,8 +161,8 @@ public:
setupSparsityPattern(row_ptrs, col_indices); setupSparsityPattern(row_ptrs, col_indices);
// initialize matrix with values // initialize matrix with values
const field_type* values = &(A_[0][0][0][0]); const matrix_field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<field_type*>(values), sizeof(field_type) * nnz_ * block_size * block_size)); AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<matrix_field_type*>(values), sizeof(matrix_field_type) * nnz_ * block_size * block_size));
AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size, AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size,
row_ptrs.data(), col_indices.data(), row_ptrs.data(), col_indices.data(),
values, nullptr)); values, nullptr));
@ -162,8 +176,8 @@ public:
*/ */
~AmgxPreconditioner() ~AmgxPreconditioner()
{ {
const field_type* values = &(A_[0][0][0][0]); const matrix_field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<field_type*>(values))); AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<matrix_field_type*>(values)));
if (solver_) { if (solver_) {
AMGX_SAFE_CALL(AMGX_solver_destroy(solver_)); AMGX_SAFE_CALL(AMGX_solver_destroy(solver_));
} }
@ -304,7 +318,7 @@ private:
void copyMatrixToAmgx() void copyMatrixToAmgx()
{ {
// Get direct pointer to matrix values // Get direct pointer to matrix values
const field_type* values = &(A_[0][0][0][0]); const matrix_field_type* values = &(A_[0][0][0][0]);
// Indexing explanation: // Indexing explanation:
// A_[0] - First row of the matrix // A_[0] - First row of the matrix
// [0] - First block in that row // [0] - First block in that row