mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-02-25 18:55:30 -06:00
Select AMGX mode based on template scalar type
This commit is contained in:
parent
0290fd0e9f
commit
ac5b6b53c5
@ -103,12 +103,14 @@ class AmgxPreconditioner : public Dune::PreconditionerWithUpdate<X,Y>
|
|||||||
public:
|
public:
|
||||||
//! \brief The matrix type the preconditioner is for
|
//! \brief The matrix type the preconditioner is for
|
||||||
using matrix_type = M;
|
using matrix_type = M;
|
||||||
|
//! \brief The field type of the matrix
|
||||||
|
using matrix_field_type = typename M::field_type;
|
||||||
//! \brief The domain type of the preconditioner
|
//! \brief The domain type of the preconditioner
|
||||||
using domain_type = X;
|
using domain_type = X;
|
||||||
//! \brief The range type of the preconditioner
|
//! \brief The range type of the preconditioner
|
||||||
using range_type = Y;
|
using range_type = Y;
|
||||||
//! \brief The field type of the preconditioner
|
//! \brief The field type of the vectors
|
||||||
using field_type = typename X::field_type;
|
using vector_field_type = typename X::field_type;
|
||||||
|
|
||||||
static constexpr int block_size = 1;
|
static constexpr int block_size = 1;
|
||||||
|
|
||||||
@ -135,11 +137,23 @@ public:
|
|||||||
// Setup frequency is set in the property tree
|
// Setup frequency is set in the property tree
|
||||||
setup_frequency_ = prm.get<int>("setup_frequency", 30);
|
setup_frequency_ = prm.get<int>("setup_frequency", 30);
|
||||||
|
|
||||||
// Create solver and matrix/vector handles
|
// Select appropriate AMGX mode based on matrix and vector scalar types
|
||||||
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, AMGX_mode_dDDI, cfg_));
|
AMGX_Mode amgx_mode;
|
||||||
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, AMGX_mode_dDDI));
|
if constexpr (std::is_same_v<matrix_field_type, double> && std::is_same_v<vector_field_type, double>) {
|
||||||
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, AMGX_mode_dDDI));
|
amgx_mode = AMGX_mode_dDDI;
|
||||||
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, AMGX_mode_dDDI));
|
} else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, double>) {
|
||||||
|
amgx_mode = AMGX_mode_dDFI;
|
||||||
|
} else if constexpr (std::is_same_v<matrix_field_type, float> && std::is_same_v<vector_field_type, float>) {
|
||||||
|
amgx_mode = AMGX_mode_dFFI;
|
||||||
|
} else {
|
||||||
|
OPM_THROW(std::runtime_error, "Unsupported combination of matrix and vector types in AmgxPreconditioner");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create solver and matrix/vector handles with selected mode
|
||||||
|
AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, amgx_mode, cfg_));
|
||||||
|
AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, amgx_mode));
|
||||||
|
AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, amgx_mode));
|
||||||
|
AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, amgx_mode));
|
||||||
|
|
||||||
// Setup matrix structure
|
// Setup matrix structure
|
||||||
std::vector<int> row_ptrs(N_ + 1);
|
std::vector<int> row_ptrs(N_ + 1);
|
||||||
@ -147,8 +161,8 @@ public:
|
|||||||
setupSparsityPattern(row_ptrs, col_indices);
|
setupSparsityPattern(row_ptrs, col_indices);
|
||||||
|
|
||||||
// initialize matrix with values
|
// initialize matrix with values
|
||||||
const field_type* values = &(A_[0][0][0][0]);
|
const matrix_field_type* values = &(A_[0][0][0][0]);
|
||||||
AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<field_type*>(values), sizeof(field_type) * nnz_ * block_size * block_size));
|
AMGX_SAFE_CALL(AMGX_pin_memory(const_cast<matrix_field_type*>(values), sizeof(matrix_field_type) * nnz_ * block_size * block_size));
|
||||||
AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size,
|
AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size,
|
||||||
row_ptrs.data(), col_indices.data(),
|
row_ptrs.data(), col_indices.data(),
|
||||||
values, nullptr));
|
values, nullptr));
|
||||||
@ -162,8 +176,8 @@ public:
|
|||||||
*/
|
*/
|
||||||
~AmgxPreconditioner()
|
~AmgxPreconditioner()
|
||||||
{
|
{
|
||||||
const field_type* values = &(A_[0][0][0][0]);
|
const matrix_field_type* values = &(A_[0][0][0][0]);
|
||||||
AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<field_type*>(values)));
|
AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast<matrix_field_type*>(values)));
|
||||||
if (solver_) {
|
if (solver_) {
|
||||||
AMGX_SAFE_CALL(AMGX_solver_destroy(solver_));
|
AMGX_SAFE_CALL(AMGX_solver_destroy(solver_));
|
||||||
}
|
}
|
||||||
@ -304,7 +318,7 @@ private:
|
|||||||
void copyMatrixToAmgx()
|
void copyMatrixToAmgx()
|
||||||
{
|
{
|
||||||
// Get direct pointer to matrix values
|
// Get direct pointer to matrix values
|
||||||
const field_type* values = &(A_[0][0][0][0]);
|
const matrix_field_type* values = &(A_[0][0][0][0]);
|
||||||
// Indexing explanation:
|
// Indexing explanation:
|
||||||
// A_[0] - First row of the matrix
|
// A_[0] - First row of the matrix
|
||||||
// [0] - First block in that row
|
// [0] - First block in that row
|
||||||
|
Loading…
Reference in New Issue
Block a user