Select AMGX mode based on template scalar type

This commit is contained in:
jakobtorben 2024-12-17 15:39:37 +01:00
parent 0290fd0e9f
commit ac5b6b53c5

View File

@ -103,12 +103,14 @@ class AmgxPreconditioner : public Dune::PreconditionerWithUpdate<X,Y>
public:
//! \brief The matrix type the preconditioner is for
using matrix_type = M;
//! \brief The field type of the matrix
using matrix_field_type = typename M::field_type;
//! \brief The domain type of the preconditioner
using domain_type = X;
//! \brief The range type of the preconditioner
using range_type = Y;
//! \brief The field type of the preconditioner
using field_type = typename X::field_type;
//! \brief The field type of the vectors
using vector_field_type = typename X::field_type;
static constexpr int block_size = 1;
@ -135,11 +137,23 @@ public:
// Setup frequency is set in the property tree
setup_frequency_ = prm.get<int>("setup_frequency", 30);
// Create solver and matrix/vector handles
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, AMGX_mode_dDDI, cfg_));
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, AMGX_mode_dDDI));
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, AMGX_mode_dDDI));
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, AMGX_mode_dDDI));
// Select appropriate AMGX mode based on matrix and vector scalar types
AMGX_Mode amgx_mode;
if constexpr (std::is_same_v<matrix_field_type, double> && std::is_same_v<vector_field_type, double>) {
amgx_mode = AMGX_mode_dDDI;
} else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, double>) {
amgx_mode = AMGX_mode_dDFI;
} else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, float>) {
amgx_mode = AMGX_mode_dFFI;
} else {
OPM_THROW(std::runtime_error, "Unsupported combination of matrix and vector types in AmgxPreconditioner");
}
// Create solver and matrix/vector handles with selected mode
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, amgx_mode, cfg_));
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, amgx_mode));
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, amgx_mode));
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, amgx_mode));
// Setup matrix structure
std::vector<int> row_ptrs(N_ + 1);
@ -147,8 +161,8 @@ public:
setupSparsityPattern(row_ptrs, col_indices);
// initialize matrix with values
const field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<field_type*>(values), sizeof(field_type) * nnz_ * block_size * block_size));
const matrix_field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<matrix_field_type*>(values), sizeof(matrix_field_type) * nnz_ * block_size * block_size));
AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size,
row_ptrs.data(), col_indices.data(),
values, nullptr));
@ -162,8 +176,8 @@ public:
*/
~AmgxPreconditioner()
{
const field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<field_type*>(values)));
const matrix_field_type* values = &(A_[0][0][0][0]);
AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<matrix_field_type*>(values)));
if (solver_) {
AMGX_SAFE_CALL(AMGX_solver_destroy(solver_));
}
@ -304,7 +318,7 @@ private:
void copyMatrixToAmgx()
{
// Get direct pointer to matrix values
const field_type* values = &(A_[0][0][0][0]);
const matrix_field_type* values = &(A_[0][0][0][0]);
// Indexing explanation:
// A_[0] - First row of the matrix
// [0] - First block in that row