From 5dd70eb8bb95ea41682f668f07003d492a8938e2 Mon Sep 17 00:00:00 2001 From: jakobtorben Date: Mon, 25 Nov 2024 21:46:07 +0100 Subject: [PATCH] Add AMGX preconditioner --- CMakeLists.txt | 13 + CMakeLists_files.cmake | 6 + opm-simulators-prereqs.cmake | 1 + opm/simulators/flow/Main.cpp | 12 + opm/simulators/linalg/AmgxPreconditioner.hpp | 318 ++++++++++++++++++ .../linalg/PreconditionerFactory_impl.hpp | 13 + 6 files changed, 363 insertions(+) create mode 100644 opm/simulators/linalg/AmgxPreconditioner.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e6d3b728a..640de229e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,7 @@ option(USE_DAMARIS_LIB "Use the Damaris library for asynchronous I/O?" OFF) option(USE_GPU_BRIDGE "Enable the GPU bridge (GPU/AMGCL solvers)" ON) option(USE_TRACY_PROFILER "Enable tracy profiling" OFF) option(CONVERT_CUDA_TO_HIP "Convert CUDA code to HIP (to run on AMD cards)" OFF) +option(USE_AMGX "Enable AMGX support?" OFF) option(USE_HYPRE "Use the Hypre library for linear solvers?" OFF) set(OPM_COMPILE_COMPONENTS "2;3;4;5;6;7" CACHE STRING "The components to compile support for") option(USE_OPENCL "Enable OpenCL support?" ON) @@ -307,6 +308,18 @@ elseif(USE_HYPRE) set(HYPRE_FOUND OFF) endif() +# Find AMGX +if(USE_AMGX) + find_package(AMGX) + if(AMGX_FOUND) + set(HAVE_AMGX 1) + list(APPEND opm-simulators_LIBRARIES AMGX::AMGX) + else() + message(WARNING "AMGX requested but not found. Continuing without AMGX support.") + set(USE_AMGX OFF) + endif() +endif() + macro (config_hook) opm_need_version_of ("dune-common") opm_need_version_of ("dune-istl") diff --git a/CMakeLists_files.cmake b/CMakeLists_files.cmake index e3ece638b..05c83b1f6 100644 --- a/CMakeLists_files.cmake +++ b/CMakeLists_files.cmake @@ -1162,3 +1162,9 @@ if(HYPRE_FOUND) opm/simulators/linalg/HyprePreconditioner.hpp ) endif() + +if(AMGX_FOUND) + list(APPEND PUBLIC_HEADER_FILES + opm/simulators/linalg/AmgxPreconditioner.hpp + ) +endif() diff --git a/opm-simulators-prereqs.cmake b/opm-simulators-prereqs.cmake index 453a80102..204a6bc69 100644 --- a/opm-simulators-prereqs.cmake +++ b/opm-simulators-prereqs.cmake @@ -10,6 +10,7 @@ set (opm-simulators_CONFIG_VAR HAVE_OPENCL HAVE_OPENCL_HPP HAVE_AMGCL + HAVE_AMGX HAVE_VEXCL HAVE_ROCALUTION HAVE_ROCSPARSE diff --git a/opm/simulators/flow/Main.cpp b/opm/simulators/flow/Main.cpp index 91864fba9..e5edfa04f 100644 --- a/opm/simulators/flow/Main.cpp +++ b/opm/simulators/flow/Main.cpp @@ -40,6 +40,10 @@ #include #endif +#if HAVE_AMGX +#include +#endif + namespace Opm { Main::Main(int argc, char** argv, bool ownMPI) @@ -97,6 +101,10 @@ Main::~Main() HYPRE_Finalize(); #endif +#if HAVE_AMGX + AMGX_SAFE_CALL(AMGX_finalize()); +#endif + if (ownMPI_) { FlowGenericVanguard::setCommunication(nullptr); } @@ -180,6 +188,10 @@ void Main::initMPI() HYPRE_Init(); #endif #endif + +#if HAVE_AMGX + AMGX_SAFE_CALL(AMGX_initialize()); +#endif } void Main::handleVersionCmdLine_(int argc, char** argv, diff --git a/opm/simulators/linalg/AmgxPreconditioner.hpp b/opm/simulators/linalg/AmgxPreconditioner.hpp new file mode 100644 index 000000000..d2c679f69 --- /dev/null +++ b/opm/simulators/linalg/AmgxPreconditioner.hpp @@ -0,0 +1,318 @@ +/* + Copyright 2024 SINTEF AS + Copyright 2024 Equinor ASA + + This file is part of the Open Porous Media project (OPM). + + OPM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OPM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OPM. If not, see . +*/ + +#ifndef OPM_AMGX_PRECONDITIONER_HEADER_INCLUDED +#define OPM_AMGX_PRECONDITIONER_HEADER_INCLUDED + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace Amgx { + +/** + * @brief Configuration structure for AMGX parameters. + * + * This structure holds the configuration parameters for the AMGX solver. + */ +struct AmgxConfig { + int determinism_flag = 0; + int print_grid_stats = 0; + int print_solve_stats = 0; + std::string solver = "AMG"; + std::string algorithm = "CLASSICAL"; + std::string interpolator = "D2"; + std::string selector = "PMIS"; + std::string smoother = "BLOCK_JACOBI"; + int presweeps = 3; + int postsweeps = 3; + double strength_threshold = 0.5; + int max_iters = 1; + + explicit AmgxConfig(const Opm::PropertyTree& prm) { + determinism_flag = prm.get("determinism_flag", determinism_flag); + print_grid_stats = prm.get("print_grid_stats", print_grid_stats); + print_solve_stats = prm.get("print_solve_stats", print_solve_stats); + solver = prm.get("solver", solver); + algorithm = prm.get("algorithm", algorithm); + interpolator = prm.get("interpolator", interpolator); + selector = prm.get("selector", selector); + smoother = prm.get("smoother", smoother); + presweeps = prm.get("presweeps", presweeps); + postsweeps = prm.get("postsweeps", postsweeps); + strength_threshold = prm.get("strength_threshold", strength_threshold); + max_iters = prm.get("max_iters", max_iters); + } + + std::string toString() const { + return "config_version=2, " + "determinism_flag=" + std::to_string(determinism_flag) + ", " + "print_grid_stats=" + std::to_string(print_grid_stats) + ", " + "print_solve_stats=" + std::to_string(print_solve_stats) + ", " + "solver=" + solver + ", " + "algorithm=" + algorithm + ", " + "interpolator=" + interpolator + ", " + "selector=" + selector + ", " + "smoother=" + smoother + ", " + "presweeps=" + std::to_string(presweeps) + ", " + "postsweeps=" + std::to_string(postsweeps) + ", " + "strength_threshold=" + std::to_string(strength_threshold) + ", " + "max_iters=" + std::to_string(max_iters); + } +}; + +/** + * @brief Wrapper for AMGX's AMG preconditioner. + * + * This class provides an interface to the AMG preconditioner from the AMGX library. + * It is designed to work with matrices, update vectors, and defect vectors specified + * by the template parameters. + * + * @tparam M The matrix type the preconditioner is for. + * @tparam X The type of the update vector. + * @tparam Y The type of the defect vector. + */ +template +class AmgxPreconditioner : public Dune::PreconditionerWithUpdate +{ +public: + //! \brief The matrix type the preconditioner is for + using matrix_type = M; + //! \brief The domain type of the preconditioner + using domain_type = X; + //! \brief The range type of the preconditioner + using range_type = Y; + //! \brief The field type of the preconditioner + using field_type = typename X::field_type; + + static constexpr int block_size = 1; + + /** + * @brief Constructor for the AmgxPreconditioner class. + * + * Initializes the preconditioner with the given matrix and property tree. + * + * @param A The matrix for which the preconditioner is constructed. + * @param prm The property tree containing configuration parameters. + */ + AmgxPreconditioner(const M& A, const Opm::PropertyTree prm) + : A_(A) + , N_(A.N()) + , nnz_(A.nonzeroes()) + { + OPM_TIMEBLOCK(prec_construct); + + // Create configuration + AmgxConfig config(prm); + AMGX_SAFE_CALL(AMGX_config_create(&cfg_, config.toString().c_str())); + AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc_, cfg_)); + + // Create solver and matrix/vector handles + AMGX_SAFE_CALL(AMGX_solver_create(&solver_, rsrc_, AMGX_mode_dDDI, cfg_)); + AMGX_SAFE_CALL(AMGX_matrix_create(&A_amgx_, rsrc_, AMGX_mode_dDDI)); + AMGX_SAFE_CALL(AMGX_vector_create(&x_amgx_, rsrc_, AMGX_mode_dDDI)); + AMGX_SAFE_CALL(AMGX_vector_create(&b_amgx_, rsrc_, AMGX_mode_dDDI)); + + // Setup matrix structure + std::vector row_ptrs(N_ + 1); + std::vector col_indices(nnz_); + setupSparsityPattern(row_ptrs, col_indices); + + // initialize matrix with values + const field_type* values = &(A_[0][0][0][0]); + AMGX_SAFE_CALL(AMGX_pin_memory(const_cast(values), sizeof(field_type) * nnz_ * block_size * block_size)); + AMGX_SAFE_CALL(AMGX_matrix_upload_all(A_amgx_, N_, nnz_, block_size, block_size, + row_ptrs.data(), col_indices.data(), + values, nullptr)); + update(); + } + + /** + * @brief Destructor for the AmgxPreconditioner class. + * + * Cleans up resources allocated by the preconditioner. + */ + ~AmgxPreconditioner() + { + const field_type* values = &(A_[0][0][0][0]); + AMGX_SAFE_CALL(AMGX_unpin_memory(const_cast(values))); + if (solver_) { + AMGX_SAFE_CALL(AMGX_solver_destroy(solver_)); + } + if (x_amgx_) { + AMGX_SAFE_CALL(AMGX_vector_destroy(x_amgx_)); + } + if (b_amgx_) { + AMGX_SAFE_CALL(AMGX_vector_destroy(b_amgx_)); + } + if (A_amgx_) { + AMGX_SAFE_CALL(AMGX_matrix_destroy(A_amgx_)); + } + // Destroying resources and config crashes when reinitializing + //if (rsrc_) { + // AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc_)); + //} + //if (cfg_) { + // AMGX_SAFE_CALL(AMGX_config_destroy(cfg_)); + //} + } + + /** + * @brief Pre-processing step before applying the preconditioner. + * + * This method is currently a no-op. + * + * @param v The update vector. + * @param d The defect vector. + */ + void pre(X& /*v*/, Y& /*d*/) override { + } + + /** + * @brief Applies the preconditioner to a vector. + * + * Performs one AMG cycle to solve the system. + * Involves uploading vectors to AMGX, applying the preconditioner, and downloading the result. + * + * @param v The update vector. + * @param d The defect vector. + */ + void apply(X& v, const Y& d) override + { + OPM_TIMEBLOCK(prec_apply); + + // Upload vectors to AMGX + AMGX_SAFE_CALL(AMGX_vector_upload(x_amgx_, N_, block_size, &v[0][0])); + AMGX_SAFE_CALL(AMGX_vector_upload(b_amgx_, N_, block_size, &d[0][0])); + + // Apply preconditioner + AMGX_SAFE_CALL(AMGX_solver_solve(solver_, b_amgx_, x_amgx_)); + + // Download result + AMGX_SAFE_CALL(AMGX_vector_download(x_amgx_, &v[0][0])); + } + + /** + * @brief Post-processing step after applying the preconditioner. + * + * This method is currently a no-op. + * + * @param v The update vector. + */ + void post(X& /*v*/) override { + } + + /** + * @brief Updates the preconditioner with the current matrix values. + * + * This method should be called whenever the matrix values change. + */ + void update() override + { + OPM_TIMEBLOCK(prec_update); + copyMatrixToAmgx(); + AMGX_SAFE_CALL(AMGX_solver_setup(solver_, A_amgx_)); + } + + /** + * @brief Returns the solver category. + * + * @return The solver category, which is sequential. + */ + Dune::SolverCategory::Category category() const override + { + return Dune::SolverCategory::sequential; + } + + /** + * @brief Checks if the preconditioner has a perfect update. + * + * @return True, indicating that the preconditioner can be perfectly updated. + */ + bool hasPerfectUpdate() const override + { + // The Amgx preconditioner can depend on the values of the matrix, so it must be recreated + return false; + } + +private: + /** + * @brief Sets up the sparsity pattern for the AMGX matrix. + * + * This method initializes the row pointers and column indices for the AMGX matrix. + * + * @param row_ptrs The row pointers for the AMGX matrix. + * @param col_indices The column indices for the AMGX matrix. + */ + void setupSparsityPattern(std::vector& row_ptrs, std::vector& col_indices) + { + int pos = 0; + row_ptrs[0] = 0; + for (auto row = A_.begin(); row != A_.end(); ++row) { + for (auto col = row->begin(); col != row->end(); ++col) { + col_indices[pos++] = col.index(); + } + row_ptrs[row.index() + 1] = pos; + } + } + + /** + * @brief Copies the matrix values to the AMGX matrix. + * + * This method updates the AMGX matrix with the current matrix values. + * The method assumes that the sparsity structure is the same and that + * the values are stored in a contiguous array. + */ + void copyMatrixToAmgx() + { + // Get direct pointer to matrix values + const field_type* values = &(A_[0][0][0][0]); + // Indexing explanation: + // A_[0] - First row of the matrix + // [0] - First block in that row + // [0] - First row within the 1x1 block + // [0] - First column within the 1x1 block + // update matrix with new values, assuming the sparsity structure is the same + AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(A_amgx_, N_, nnz_, values, nullptr)); + } + + const M& A_; //!< The matrix for which the preconditioner is constructed. + const int N_; //!< Number of rows in the matrix. + const int nnz_; //!< Number of non-zero elements in the matrix. + + AMGX_config_handle cfg_ = nullptr; //!< The AMGX configuration handle. + AMGX_resources_handle rsrc_ = nullptr; //!< The AMGX resources handle. + AMGX_solver_handle solver_ = nullptr; //!< The AMGX solver handle. + AMGX_matrix_handle A_amgx_ = nullptr; //!< The AMGX matrix handle. + AMGX_vector_handle x_amgx_ = nullptr; //!< The AMGX solution vector handle. + AMGX_vector_handle b_amgx_ = nullptr; //!< The AMGX right-hand side vector handle. +}; + +} // namespace Amgx + +#endif // OPM_AMGX_PRECONDITIONER_HEADER_INCLUDED diff --git a/opm/simulators/linalg/PreconditionerFactory_impl.hpp b/opm/simulators/linalg/PreconditionerFactory_impl.hpp index 04bedf0e2..25fea74a0 100644 --- a/opm/simulators/linalg/PreconditionerFactory_impl.hpp +++ b/opm/simulators/linalg/PreconditionerFactory_impl.hpp @@ -53,6 +53,9 @@ #include #endif +#if HAVE_AMGX +#include +#endif namespace Opm { @@ -547,6 +550,16 @@ struct StandardPreconditioners { return getRebuildOnUpdateWrapper>(op, crit, parms); } }); + +#if HAVE_AMGX + // Only add AMGX for scalar matrices + if constexpr (M::block_type::rows == 1 && M::block_type::cols == 1) { + F::addCreator("amgx", [](const O& op, const P& prm, const std::function&, std::size_t) { + return std::make_shared>(op.getmat(), prm); + }); + } +#endif + #if HAVE_HYPRE // Only add Hypre for scalar matrices if constexpr (M::block_type::rows == 1 && M::block_type::cols == 1 &&