Merge pull request #5404 from multitalentloes/add_dilu_LU_splitting

Add cudilu lu splitting
This commit is contained in:
Kjetil Olsen Lye 2024-06-27 14:30:45 +02:00 committed by GitHub
commit 9b414419e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 929 additions and 71 deletions

View File

@ -762,6 +762,7 @@ if(CUDA_FOUND)
cuda_check_last_error
cublas_handle
cujac
cudilu
cusparse_handle
cuSparse_matrix_operations
cuVector_operations

View File

@ -362,6 +362,7 @@ if (HAVE_CUDA)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cusparse_safe_call.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cuda_safe_call.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cuda_check_last_error.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cudilu.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cujac.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cuowneroverlapcopy.cpp)
ADD_CUDA_OR_HIP_FILE(TEST_SOURCE_FILES tests test_cuseqilu0.cpp)

View File

@ -28,6 +28,7 @@
#include <opm/simulators/linalg/hipistl/CuSeqILU0.hpp>
#include <opm/simulators/linalg/hipistl/PreconditionerAdapter.hpp>
#include <opm/simulators/linalg/hipistl/PreconditionerConvertFieldTypeAdapter.hpp>
#include <opm/simulators/linalg/hipistl/detail/cuda_safe_call.hpp>
#else
#include <opm/simulators/linalg/cuistl/CuBlockPreconditioner.hpp>
#include <opm/simulators/linalg/cuistl/CuDILU.hpp>
@ -35,5 +36,6 @@
#include <opm/simulators/linalg/cuistl/CuSeqILU0.hpp>
#include <opm/simulators/linalg/cuistl/PreconditionerAdapter.hpp>
#include <opm/simulators/linalg/cuistl/PreconditionerConvertFieldTypeAdapter.hpp>
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
#endif
#endif

View File

@ -347,9 +347,10 @@ struct StandardPreconditioners {
});
F::addCreator("CUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t, const C& comm) {
const bool split_matrix = prm.get<double>("split_matrix", true);
using field_type = typename V::field_type;
using CuDILU = typename cuistl::CuDILU<M, cuistl::CuVector<field_type>, cuistl::CuVector<field_type>>;
auto cuDILU = std::make_shared<CuDILU>(op.getmat());
auto cuDILU = std::make_shared<CuDILU>(op.getmat(), split_matrix);
auto adapted = std::make_shared<cuistl::PreconditionerAdapter<V, V, CuDILU>>(cuDILU);
auto wrapped = std::make_shared<cuistl::CuBlockPreconditioner<V, V, Comm>>(adapted, comm);
@ -603,9 +604,10 @@ struct StandardPreconditioners<Operator, Dune::Amg::SequentialInformation> {
});
F::addCreator("CUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
using field_type = typename V::field_type;
using CUDILU = typename cuistl::CuDILU<M, cuistl::CuVector<field_type>, cuistl::CuVector<field_type>>;
return std::make_shared<cuistl::PreconditionerAdapter<V, V, CUDILU>>(std::make_shared<CUDILU>(op.getmat()));
return std::make_shared<cuistl::PreconditionerAdapter<V, V, CUDILU>>(std::make_shared<CUDILU>(op.getmat(), split_matrix));
});
F::addCreator("CUDILUFloat", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {

View File

@ -25,6 +25,7 @@
#include <opm/simulators/linalg/cuistl/CuDILU.hpp>
#include <opm/simulators/linalg/cuistl/CuSparseMatrix.hpp>
#include <opm/simulators/linalg/cuistl/CuVector.hpp>
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/detail/cusparse_matrix_operations.hpp>
#include <opm/simulators/linalg/cuistl/detail/safe_conversion.hpp>
#include <opm/simulators/linalg/matrixblock.hh>
@ -62,11 +63,11 @@ createNaturalToReordered(Opm::SparseTable<size_t> levelSets)
return res;
}
// TODO: When this function is called we already have the natural ordered matrix on the GPU
// TODO: could it be possible to create the reordered one in a kernel to speed up the constructor?
template <class M, class field_type>
Opm::cuistl::CuSparseMatrix<field_type>
createReorderedMatrix(const M& naturalMatrix, std::vector<int> reorderedToNatural)
template <class M, class field_type, class GPUM>
void
createReorderedMatrix(const M& naturalMatrix,
std::vector<int> reorderedToNatural,
std::unique_ptr<GPUM>& reorderedGpuMat)
{
M reorderedMatrix(naturalMatrix.N(), naturalMatrix.N(), naturalMatrix.nonzeroes(), M::row_wise);
for (auto dstRowIt = reorderedMatrix.createbegin(); dstRowIt != reorderedMatrix.createend(); ++dstRowIt) {
@ -77,15 +78,39 @@ createReorderedMatrix(const M& naturalMatrix, std::vector<int> reorderedToNatura
}
}
// TODO: There is probably a faster way to copy by copying whole rows at a time
for (auto dstRowIt = reorderedMatrix.begin(); dstRowIt != reorderedMatrix.end(); ++dstRowIt) {
auto srcRow = naturalMatrix.begin() + reorderedToNatural[dstRowIt.index()];
for (auto elem = srcRow->begin(); elem != srcRow->end(); elem++) {
reorderedMatrix[dstRowIt.index()][elem.index()] = *elem;
reorderedGpuMat.reset(new auto (GPUM::fromMatrix(reorderedMatrix, true)));
}
template <class M, class field_type, class GPUM>
void
extractLowerAndUpperMatrices(const M& naturalMatrix,
std::vector<int> reorderedToNatural,
std::unique_ptr<GPUM>& lower,
std::unique_ptr<GPUM>& upper)
{
const size_t new_nnz = (naturalMatrix.nonzeroes() - naturalMatrix.N()) / 2;
M reorderedLower(naturalMatrix.N(), naturalMatrix.N(), new_nnz, M::row_wise);
M reorderedUpper(naturalMatrix.N(), naturalMatrix.N(), new_nnz, M::row_wise);
for (auto lowerIt = reorderedLower.createbegin(), upperIt = reorderedUpper.createbegin();
lowerIt != reorderedLower.createend();
++lowerIt, ++upperIt) {
auto srcRow = naturalMatrix.begin() + reorderedToNatural[lowerIt.index()];
for (auto elem = srcRow->begin(); elem != srcRow->end(); ++elem) {
if (elem.index() < srcRow.index()) { // add index to lower matrix if under the diagonal
lowerIt.insert(elem.index());
} else if (elem.index() > srcRow.index()) { // add element to upper matrix if above the diagonal
upperIt.insert(elem.index());
}
}
}
return Opm::cuistl::CuSparseMatrix<field_type>::fromMatrix(reorderedMatrix, true);
lower.reset(new auto (GPUM::fromMatrix(reorderedLower, true)));
upper.reset(new auto (GPUM::fromMatrix(reorderedUpper, true)));
return;
}
} // NAMESPACE
@ -94,16 +119,16 @@ namespace Opm::cuistl
{
template <class M, class X, class Y, int l>
CuDILU<M, X, Y, l>::CuDILU(const M& A)
CuDILU<M, X, Y, l>::CuDILU(const M& A, bool split_matrix)
: m_cpuMatrix(A)
, m_levelSets(Opm::getMatrixRowColoring(m_cpuMatrix, Opm::ColoringType::LOWER))
, m_reorderedToNatural(createReorderedToNatural(m_levelSets))
, m_naturalToReordered(createNaturalToReordered(m_levelSets))
, m_gpuMatrix(CuSparseMatrix<field_type>::fromMatrix(m_cpuMatrix, true))
, m_gpuMatrixReordered(createReorderedMatrix<M, field_type>(m_cpuMatrix, m_reorderedToNatural))
, m_gpuNaturalToReorder(m_naturalToReordered)
, m_gpuReorderToNatural(m_reorderedToNatural)
, m_gpuDInv(m_gpuMatrix.N() * m_gpuMatrix.blockSize() * m_gpuMatrix.blockSize())
, m_splitMatrix(split_matrix)
{
// TODO: Should in some way verify that this matrix is symmetric, only do it debug mode?
@ -122,7 +147,15 @@ CuDILU<M, X, Y, l>::CuDILU(const M& A)
fmt::format("CuSparse matrix not same number of non zeroes as DUNE matrix. {} vs {}. ",
m_gpuMatrix.nonzeroes(),
A.nonzeroes()));
update();
if (m_splitMatrix) {
m_gpuMatrixReorderedDiag.reset(new auto(CuVector<field_type>(blocksize_ * blocksize_ * m_cpuMatrix.N())));
extractLowerAndUpperMatrices<M, field_type, CuSparseMatrix<field_type>>(
m_cpuMatrix, m_reorderedToNatural, m_gpuMatrixReorderedLower, m_gpuMatrixReorderedUpper);
} else {
createReorderedMatrix<M, field_type, CuSparseMatrix<field_type>>(
m_cpuMatrix, m_reorderedToNatural, m_gpuMatrixReordered);
}
computeDiagAndMoveReorderedData();
}
template <class M, class X, class Y, int l>
@ -136,34 +169,63 @@ void
CuDILU<M, X, Y, l>::apply(X& v, const Y& d)
{
OPM_TIMEBLOCK(prec_apply);
int levelStartIdx = 0;
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
detail::computeLowerSolveLevelSet<field_type, blocksize_>(m_gpuMatrixReordered.getNonZeroValues().data(),
m_gpuMatrixReordered.getRowIndices().data(),
m_gpuMatrixReordered.getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data());
levelStartIdx += numOfRowsInLevel;
}
{
int levelStartIdx = 0;
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
if (m_splitMatrix) {
detail::computeLowerSolveLevelSetSplit<field_type, blocksize_>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data());
} else {
detail::computeLowerSolveLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
m_gpuMatrixReordered->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data());
}
levelStartIdx += numOfRowsInLevel;
}
levelStartIdx = m_cpuMatrix.N();
// upper triangular solve: (D + U_A) v = Dy
for (int level = m_levelSets.size() - 1; level >= 0; --level) {
const int numOfRowsInLevel = m_levelSets[level].size();
levelStartIdx -= numOfRowsInLevel;
detail::computeUpperSolveLevelSet<field_type, blocksize_>(m_gpuMatrixReordered.getNonZeroValues().data(),
m_gpuMatrixReordered.getRowIndices().data(),
m_gpuMatrixReordered.getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data());
levelStartIdx = m_cpuMatrix.N();
// upper triangular solve: (D + U_A) v = Dy
for (int level = m_levelSets.size() - 1; level >= 0; --level) {
const int numOfRowsInLevel = m_levelSets[level].size();
levelStartIdx -= numOfRowsInLevel;
if (m_splitMatrix) {
detail::computeUpperSolveLevelSetSplit<field_type, blocksize_>(
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data());
} else {
detail::computeUpperSolveLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
m_gpuMatrixReordered->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data());
}
}
}
}
@ -185,29 +247,68 @@ void
CuDILU<M, X, Y, l>::update()
{
OPM_TIMEBLOCK(prec_update);
{
m_gpuMatrix.updateNonzeroValues(m_cpuMatrix, true); // send updated matrix to the gpu
computeDiagAndMoveReorderedData();
}
}
m_gpuMatrix.updateNonzeroValues(m_cpuMatrix, true); // send updated matrix to the gpu
template <class M, class X, class Y, int l>
void
CuDILU<M, X, Y, l>::computeDiagAndMoveReorderedData()
{
OPM_TIMEBLOCK(prec_update);
{
if (m_splitMatrix) {
detail::copyMatDataToReorderedSplit<field_type, blocksize_>(
m_gpuMatrix.getNonZeroValues().data(),
m_gpuMatrix.getRowIndices().data(),
m_gpuMatrix.getColumnIndices().data(),
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuNaturalToReorder.data(),
m_gpuMatrixReorderedLower->N());
} else {
detail::copyMatDataToReordered<field_type, blocksize_>(m_gpuMatrix.getNonZeroValues().data(),
m_gpuMatrix.getRowIndices().data(),
m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
m_gpuNaturalToReorder.data(),
m_gpuMatrixReordered->N());
}
detail::copyMatDataToReordered<field_type, blocksize_>(m_gpuMatrix.getNonZeroValues().data(),
m_gpuMatrix.getRowIndices().data(),
m_gpuMatrixReordered.getNonZeroValues().data(),
m_gpuMatrixReordered.getRowIndices().data(),
m_gpuNaturalToReorder.data(),
m_gpuMatrixReordered.N());
int levelStartIdx = 0;
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
detail::computeDiluDiagonal<field_type, blocksize_>(m_gpuMatrixReordered.getNonZeroValues().data(),
m_gpuMatrixReordered.getRowIndices().data(),
m_gpuMatrixReordered.getColumnIndices().data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data());
levelStartIdx += numOfRowsInLevel;
int levelStartIdx = 0;
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
if (m_splitMatrix) {
detail::computeDiluDiagonalSplit<field_type, blocksize_>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data());
} else {
detail::computeDiluDiagonal<field_type, blocksize_>(m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
m_gpuMatrixReordered->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data());
}
levelStartIdx += numOfRowsInLevel;
}
}
}

View File

@ -27,7 +27,9 @@
#include <opm/simulators/linalg/cuistl/detail/CuMatrixDescription.hpp>
#include <opm/simulators/linalg/cuistl/detail/CuSparseHandle.hpp>
#include <opm/simulators/linalg/cuistl/detail/CuSparseResource.hpp>
#include <optional>
#include <vector>
#include <memory>
@ -55,6 +57,8 @@ public:
using range_type = Y;
//! \brief The field type of the preconditioner.
using field_type = typename X::field_type;
//! \brief The GPU matrix type
using CuMat = CuSparseMatrix<field_type>;
//! \brief Constructor.
//!
@ -62,7 +66,7 @@ public:
//! \param A The matrix to operate on.
//! \param w The relaxation factor.
//!
explicit CuDILU(const M& A);
explicit CuDILU(const M& A, bool split_matrix = true);
//! \brief Prepare the preconditioner.
//! \note Does nothing at the time being.
@ -81,6 +85,9 @@ public:
//! \brief Updates the matrix data.
void update() final;
//! \brief Compute the diagonal of the DILU, and update the data of the reordered matrix
void computeDiagAndMoveReorderedData();
//! \returns false
static constexpr bool shouldCallPre()
@ -107,14 +114,22 @@ private:
//! \brief converts from index in natural ordered structure to index reordered strucutre
std::vector<int> m_naturalToReordered;
//! \brief The A matrix stored on the gpu, and its reordred version
CuSparseMatrix<field_type> m_gpuMatrix;
CuSparseMatrix<field_type> m_gpuMatrixReordered;
CuMat m_gpuMatrix;
//! \brief Stores the matrix in its entirety reordered. Optional in case splitting is used
std::unique_ptr<CuMat> m_gpuMatrixReordered;
//! \brief If matrix splitting is enabled, then we store the lower and upper part separately
std::unique_ptr<CuMat> m_gpuMatrixReorderedLower;
std::unique_ptr<CuMat> m_gpuMatrixReorderedUpper;
//! \brief If matrix splitting is enabled, we also store the diagonal separately
std::unique_ptr<CuVector<field_type>> m_gpuMatrixReorderedDiag;
//! row conversion from natural to reordered matrix indices stored on the GPU
CuVector<int> m_gpuNaturalToReorder;
//! row conversion from reordered to natural matrix indices stored on the GPU
CuVector<int> m_gpuReorderToNatural;
//! \brief Stores the inverted diagonal that we use in DILU
CuVector<field_type> m_gpuDInv;
//! \brief Bool storing whether or not we should store matrices in a split format
bool m_splitMatrix;
};
} // end namespace Opm::cuistl

View File

@ -27,6 +27,7 @@
#include <opm/simulators/linalg/cuistl/detail/cusparse_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/detail/cusparse_wrapper.hpp>
#include <opm/simulators/linalg/matrixblock.hh>
#include <type_traits>
namespace Opm::cuistl
{
@ -97,7 +98,22 @@ CuSparseMatrix<T>::fromMatrix(const MatrixType& matrix, bool copyNonZeroElements
rowIndices.push_back(0);
const size_t blockSize = matrix[0][0].N();
// We must find the pointer to the first element in the matrix
// Iterate until we find an element, we can get the blocksize from the element
// TODO: Can this be done more cleanly in the DUNE api to access the raw data more directly?
constexpr size_t blockSizeTmp = MatrixType::block_type::rows;
T* nonZeroElementsTmp = nullptr;
for (auto rowIt = matrix.begin(); rowIt != matrix.end(); ++rowIt){
auto colIt = rowIt->begin();
if (colIt != rowIt->end()){
nonZeroElementsTmp = const_cast<T*>(&((*colIt)[0][0]));
break;
}
}
OPM_ERROR_IF(nonZeroElementsTmp == nullptr, "error converting DUNE matrix to CuSparse matrix");
const size_t blockSize = blockSizeTmp;
const size_t numberOfRows = matrix.N();
const size_t numberOfNonzeroBlocks = matrix.nonzeroes();
@ -119,7 +135,7 @@ CuSparseMatrix<T>::fromMatrix(const MatrixType& matrix, bool copyNonZeroElements
if (copyNonZeroElementsDirectly) {
const T* nonZeroElements = static_cast<const T*>(&((matrix[0][0][0][0])));
const T* nonZeroElements = nonZeroElementsTmp;
return CuSparseMatrix<T>(
nonZeroElements, rowIndices.data(), columnIndices.data(), numberOfNonzeroBlocks, blockSize, numberOfRows);
} else {

View File

@ -19,6 +19,7 @@
#include <opm/common/ErrorMacros.hpp>
#include <opm/simulators/linalg/cuistl/detail/cusparse_matrix_operations.hpp>
#include <stdexcept>
namespace Opm::cuistl::detail
{
namespace
@ -226,6 +227,40 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuComputeLowerSolveLevelSetSplit(T* mat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v)
{
const auto reorderedRowIdx = startIdx + (blockDim.x * blockIdx.x + threadIdx.x);
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
const size_t nnzIdx = rowIndices[reorderedRowIdx];
const size_t nnzIdxLim = rowIndices[reorderedRowIdx + 1];
const int naturalRowIdx = indexConversion[reorderedRowIdx];
T rhs[blocksize];
for (int i = 0; i < blocksize; i++) {
rhs[i] = d[naturalRowIdx * blocksize + i];
}
// TODO: removce the first condition in the for loop
for (int block = nnzIdx; block < nnzIdxLim; ++block) {
const int col = colIndices[block];
mmv<T, blocksize>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
}
mv<T, blocksize>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
}
}
template <class T, int blocksize>
__global__ void cuComputeUpperSolveLevelSet(T* mat,
int* rowIndices,
@ -242,7 +277,6 @@ namespace
const int naturalRowIdx = indexConversion[reorderedRowIdx];
T rhs[blocksize] = {0};
for (int block = nnzIdxLim - 1; colIndices[block] > naturalRowIdx; --block) {
const int col = colIndices[block];
umv<T, blocksize>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
@ -252,6 +286,32 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuComputeUpperSolveLevelSetSplit(T* mat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v)
{
const auto reorderedRowIdx = startIdx + (blockDim.x * blockIdx.x + threadIdx.x);
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
const size_t nnzIdx = rowIndices[reorderedRowIdx];
const size_t nnzIdxLim = rowIndices[reorderedRowIdx + 1];
const int naturalRowIdx = indexConversion[reorderedRowIdx];
T rhs[blocksize] = {0};
for (int block = nnzIdx; block < nnzIdxLim; ++block) {
const int col = colIndices[block];
umv<T, blocksize>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
}
mmv<T, blocksize>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
}
}
template <class T, int blocksize>
__global__ void cuComputeDiluDiagonal(T* mat,
int* rowIndices,
@ -319,6 +379,61 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuComputeDiluDiagonalSplit(T* reorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
int* reorderedToNatural,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv)
{
const auto reorderedRowIdx = startIdx + blockDim.x * blockIdx.x + threadIdx.x;
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
const int naturalRowIdx = reorderedToNatural[reorderedRowIdx];
const size_t lowerRowStart = lowerRowIndices[reorderedRowIdx];
const size_t lowerRowEnd = lowerRowIndices[reorderedRowIdx + 1];
T dInvTmp[blocksize * blocksize];
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
dInvTmp[i * blocksize + j] = diagonal[reorderedRowIdx * blocksize * blocksize + i * blocksize + j];
}
}
for (int block = lowerRowStart; block < lowerRowEnd; ++block) {
const int col = naturalToReordered[lowerColIndices[block]];
int symOppositeIdx = upperRowIndices[col];
for (; symOppositeIdx < upperRowIndices[col + 1]; ++symOppositeIdx) {
if (naturalRowIdx == upperColIndices[symOppositeIdx]) {
break;
}
}
const int symOppositeBlock = symOppositeIdx;
mmx2Subtraction<T, blocksize>(&reorderedLowerMat[block * blocksize * blocksize],
&dInv[col * blocksize * blocksize],
&reorderedUpperMat[symOppositeBlock * blocksize * blocksize],
dInvTmp);
}
invBlockInPlace<T, blocksize>(dInvTmp);
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
dInv[reorderedRowIdx * blocksize * blocksize + i * blocksize + j] = dInvTmp[i * blocksize + j];
}
}
}
}
template <class T, int blocksize>
__global__ void cuMoveDataToReordered(
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* indexConversion, size_t numberOfRows)
@ -341,6 +456,55 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuMoveDataToReorderedSplit(T* srcMatrix,
int* srcRowIndices,
int* srcColumnIndices,
T* dstLowerMatrix,
int* dstLowerRowIndices,
T* dstUpperMatrix,
int* dstUpperRowIndices,
T* dstDiag,
int* naturalToReordered,
size_t numberOfRows)
{
const auto srcRow = blockDim.x * blockIdx.x + threadIdx.x;
if (srcRow < numberOfRows) {
const auto dstRow = naturalToReordered[srcRow];
const auto rowStart = srcRowIndices[srcRow];
const auto rowEnd = srcRowIndices[srcRow + 1];
auto lowerBlock = dstLowerRowIndices[dstRow];
auto upperBlock = dstUpperRowIndices[dstRow];
for (int srcBlock = rowStart; srcBlock < rowEnd; srcBlock++) {
int dstBlock;
T* dstBuffer;
if (srcColumnIndices[srcBlock] < srcRow) { // we are writing a value to the lower triangular matrix
dstBlock = lowerBlock;
++lowerBlock;
dstBuffer = dstLowerMatrix;
} else if (srcColumnIndices[srcBlock]
> srcRow) { // we are writing a value to the upper triangular matrix
dstBlock = upperBlock;
++upperBlock;
dstBuffer = dstUpperMatrix;
} else { // we are writing a value to the diagonal
dstBlock = dstRow;
dstBuffer = dstDiag;
}
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
dstBuffer[dstBlock * blocksize * blocksize + i * blocksize + j]
= srcMatrix[srcBlock * blocksize * blocksize + i * blocksize + j];
}
}
}
}
}
constexpr inline size_t getThreads([[maybe_unused]] size_t numberOfRows)
{
return 1024;
@ -351,6 +515,22 @@ namespace
const auto threads = getThreads(numberOfRows);
return (numberOfRows + threads - 1) / threads;
}
// Kernel here is the function object of the cuda kernel
template <class Kernel>
inline int getCudaRecomendedThreadBlockSize(Kernel k)
{
int blockSize;
int tmpGridSize;
cudaOccupancyMaxPotentialBlockSize(&tmpGridSize, &blockSize, k, 0, 0);
return blockSize;
}
inline int getNumberOfBlocks(int wantedThreads, int threadBlockSize)
{
return (wantedThreads + threadBlockSize - 1) / threadBlockSize;
}
} // namespace
template <class T, int blocksize>
@ -382,6 +562,24 @@ computeLowerSolveLevelSet(T* reorderedMat,
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, d, v);
}
template <class T, int blocksize>
void
computeLowerSolveLevelSetSplit(T* reorderedMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeLowerSolveLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, d, v);
}
// perform the upper solve for all rows in the same level set
template <class T, int blocksize>
void
@ -398,6 +596,23 @@ computeUpperSolveLevelSet(T* reorderedMat,
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
}
template <class T, int blocksize>
void
computeUpperSolveLevelSetSplit(T* reorderedMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeUpperSolveLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
}
template <class T, int blocksize>
void
computeDiluDiagonal(T* reorderedMat,
@ -424,6 +639,41 @@ computeDiluDiagonal(T* reorderedMat,
}
}
template <class T, int blocksize>
void
computeDiluDiagonalSplit(T* reorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
int* reorderedToNatural,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv)
{
if (blocksize <= 3) {
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeDiluDiagonalSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(reorderedLowerMat,
lowerRowIndices,
lowerColIndices,
reorderedUpperMat,
upperRowIndices,
upperColIndices,
diagonal,
reorderedToNatural,
naturalToReordered,
startIdx,
rowsInLevelSet,
dInv);
} else {
OPM_THROW(std::invalid_argument, "Inverting diagonal is not implemented for blocksizes > 3");
}
}
template <class T, int blocksize>
void
copyMatDataToReordered(
@ -433,12 +683,44 @@ copyMatDataToReordered(
srcMatrix, srcRowIndices, dstMatrix, dstRowIndices, naturalToReordered, numberOfRows);
}
template <class T, int blocksize>
void
copyMatDataToReorderedSplit(T* srcMatrix,
int* srcRowIndices,
int* srcColumnIndices,
T* dstLowerMatrix,
int* dstLowerRowIndices,
T* dstUpperMatrix,
int* dstUpperRowIndices,
T* dstDiag,
int* naturalToReordered,
size_t numberOfRows)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int nThreadBlocks = getNumberOfBlocks(numberOfRows, threadBlockSize);
cuMoveDataToReorderedSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(srcMatrix,
srcRowIndices,
srcColumnIndices,
dstLowerMatrix,
dstLowerRowIndices,
dstUpperMatrix,
dstUpperRowIndices,
dstDiag,
naturalToReordered,
numberOfRows);
}
#define INSTANTIATE_KERNEL_WRAPPERS(T, blocksize) \
template void invertDiagonalAndFlatten<T, blocksize>(T*, int*, int*, size_t, T*); \
template void copyMatDataToReordered<T, blocksize>(T*, int*, T*, int*, int*, size_t); \
template void copyMatDataToReorderedSplit<T, blocksize>(T*, int*, int*, T*, int*, T*, int*, T*, int*, size_t); \
template void computeDiluDiagonal<T, blocksize>(T*, int*, int*, int*, int*, const int, int, T*); \
template void computeDiluDiagonalSplit<T, blocksize>( \
T*, int*, int*, T*, int*, int*, T*, int*, int*, const int, int, T*); \
template void computeUpperSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*); \
template void computeLowerSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*);
template void computeLowerSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*); \
template void computeUpperSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*); \
template void computeLowerSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*);
INSTANTIATE_KERNEL_WRAPPERS(float, 1);
INSTANTIATE_KERNEL_WRAPPERS(float, 2);

View File

@ -61,6 +61,33 @@ void computeLowerSolveLevelSet(T* reorderedMat,
const T* d,
T* v);
/**
* @brief Perform a lower solve on certain rows in a matrix that can safely be computed in parallel
* @param reorderedUpperMat pointer to GPU memory containing nonzerovalues of the sparse matrix. The matrix reordered such
* that rows in the same level sets are contiguous. Thismatrix is assumed to be strictly lower triangular
* @param rowIndices Pointer to vector on GPU containing row indices compliant wiht bsr format
* @param colIndices Pointer to vector on GPU containing col indices compliant wiht bsr format
* @param indexConversion Integer array containing mapping an index in the reordered matrix to its corresponding index
* in the natural ordered matrix
* @param startIdx Index of the first row of the matrix to be solve
* @param rowsInLevelSet Number of rows in this level set, which number the amount of rows solved in parallel by this
* function
* @param dInv The diagonal matrix used by the Diagonal ILU preconditioner. Must be reordered in the same way as
* reorderedUpperMat
* @param d Stores the defect
* @param [out] v Will store the results of the lower solve
*/
template <class T, int blocksize>
void computeLowerSolveLevelSetSplit(T* reorderedUpperMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v);
/**
* @brief Perform an upper solve on certain rows in a matrix that can safely be computed in parallel
* @param reorderedMat pointer to GPU memory containing nonzerovalues of the sparse matrix. The matrix reordered such
@ -85,6 +112,31 @@ void computeUpperSolveLevelSet(T* reorderedMat,
int rowsInLevelSet,
const T* dInv,
T* v);
template <class T, int blocksize>
/**
* @brief Perform an upper solve on certain rows in a matrix that can safely be computed in parallel
* @param reorderedUpperMat pointer to GPU memory containing nonzerovalues of the sparse matrix. The matrix reordered such
* that rows in the same level sets are contiguous. This matrix is assumed to be strictly upper triangular
* @param rowIndices Pointer to vector on GPU containing row indices compliant wiht bsr format
* @param colIndices Pointer to vector on GPU containing col indices compliant wiht bsr format
* @param indexConversion Integer array containing mapping an index in the reordered matrix to its corresponding index
* in the natural ordered matrix
* @param startIdx Index of the first row of the matrix to be solve
* @param rowsInLevelSet Number of rows in this level set, which number the amount of rows solved in parallel by this
* function
* @param dInv The diagonal matrix used by the Diagonal ILU preconditioner
* @param [out] v Will store the results of the lower solve. To begin with it should store the output from the lower
* solve
*/
void computeUpperSolveLevelSetSplit(T* reorderedUpperMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v);
/**
* @brief Computes the ILU0 of the diagonal elements of the reordered matrix and stores it in a reordered vector
@ -111,6 +163,41 @@ void computeDiluDiagonal(T* reorderedMat,
int startIdx,
int rowsInLevelSet,
T* dInv);
template <class T, int blocksize>
/**
* @brief Computes the ILU0 of the diagonal elements of the split reordered matrix and stores it in a reordered vector
* containing the diagonal blocks
* @param reorderedLowerMat pointer to GPU memory containing nonzerovalues of the strictly lower triangular sparse matrix. The matrix reordered such
* that rows in the same level sets are contiguous
* @param lowerRowIndices Pointer to vector on GPU containing row indices of the lower matrix compliant wiht bsr format
* @param lowerColIndices Pointer to vector on GPU containing col indices of the lower matrix compliant wiht bsr format
* @param reorderedUpperMat pointer to GPU memory containing nonzerovalues of the strictly upper triangular sparse matrix. The matrix reordered such
* that rows in the same level sets are contiguous
* @param upperRowIndices Pointer to vector on GPU containing row indices of the upper matrix compliant wiht bsr format
* @param upperColIndices Pointer to vector on GPU containing col indices of the upper matrix compliant wiht bsr format
* @param reorderedToNatural Integer array containing mapping an index in the reordered matrix to its corresponding
* index in the natural ordered matrix
* @param diagonal The diagonal elements of the reordered matrix
* @param naturalToreordered Integer array containing mapping an index in the reordered matrix to its corresponding
* index in the natural ordered matrix
* @param startIdx Index of the first row of the matrix to be solve
* @param rowsInLevelSet Number of rows in this level set, which number the amount of rows solved in parallel by this
* function
* @param [out] dInv The diagonal matrix used by the Diagonal ILU preconditioner
*/
void computeDiluDiagonalSplit(T* reorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
int* reorderedToNatural,
int* naturalToReordered,
int startIdx,
int rowsInLevelSet,
T* dInv);
/**
* @brief Reorders the elements of a matrix by copying them from one matrix to another using a permutation list
@ -125,5 +212,24 @@ void computeDiluDiagonal(T* reorderedMat,
template <class T, int blocksize>
void copyMatDataToReordered(
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* naturalToReordered, size_t numberOfRows);
/**
* @brief Reorders the elements of a matrix by copying them from one matrix to a split matrix using a permutation list
* @param srcMatrix The source matrix we will copy data from
* @param srcRowIndices Pointer to vector on GPU containing row indices for the source matrix compliant wiht bsr format
* @param [out] dstLowerMatrix The destination of entries that originates from the strictly lower triangular matrix
* @param dstRowIndices Pointer to vector on GPU containing rww indices for the destination lower matrix compliant wiht bsr
* format
* @param [out] dstUpperMatrix The destination of entries that originates from the strictly upper triangular matrix
* @param dstRowIndices Pointer to vector on GPU containing riw indices for the destination upper matrix compliant wiht bsr
* format
* @param [out] dstDiag The destination buffer for the diagonal part of the matrix
* @param naturalToReordered Permuation list that converts indices in the src matrix to the indices in the dst matrix
* @param numberOfRows The number of rows in the matrices
*/
template <class T, int blocksize>
void copyMatDataToReorderedSplit(
T* srcMatrix, int* srcRowIndices, int* srcColumnIndices, T* dstLowerMatrix, int* dstLowerRowIndices, T* dstUpperMatrix, int* dstUpperRowIndices, T* dstDiag, int* naturalToReordered, size_t numberOfRows);
} // namespace Opm::cuistl::detail
#endif

View File

@ -0,0 +1,332 @@
/*
Copyright 2024 SINTEF AS
This file is part of the Open Porous Media project (OPM).
OPM is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
OPM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with OPM. If not, see <http://www.gnu.org/licenses/>.
*/
#include <config.h>
#define BOOST_TEST_MODULE TestCuDiluHelpers
#include <boost/test/unit_test.hpp>
#include <dune/common/fmatrix.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <memory>
#include <opm/simulators/linalg/DILU.hpp>
#include <opm/simulators/linalg/cuistl/CuDILU.hpp>
#include <opm/simulators/linalg/cuistl/CuSparseMatrix.hpp>
#include <opm/simulators/linalg/cuistl/CuVector.hpp>
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/detail/cusparse_matrix_operations.hpp>
#include <random>
#include <vector>
using T = double;
using FM1x1 = Dune::FieldMatrix<T, 1, 1>;
using FM2x2 = Dune::FieldMatrix<T, 2, 2>;
using B1x1Vec = Dune::BlockVector<Dune::FieldVector<double, 1>>;
using B2x2Vec = Dune::BlockVector<Dune::FieldVector<double, 2>>;
using Sp1x1BlockMatrix = Dune::BCRSMatrix<FM1x1>;
using Sp2x2BlockMatrix = Dune::BCRSMatrix<FM2x2>;
using CuMatrix = Opm::cuistl::CuSparseMatrix<T>;
using CuIntVec = Opm::cuistl::CuVector<int>;
using CuFloatingPointVec = Opm::cuistl::CuVector<T>;
using CuDilu1x1 = Opm::cuistl::CuDILU<Sp1x1BlockMatrix, CuFloatingPointVec, CuFloatingPointVec>;
using CuDilu2x2 = Opm::cuistl::CuDILU<Sp2x2BlockMatrix, CuFloatingPointVec, CuFloatingPointVec>;
Sp1x1BlockMatrix
get1x1BlockTestMatrix()
{
/*
matA:
1 2 0 3 0 0
4 5 0 6 0 7
0 0 8 0 0 0
9 10 0 11 12 0
0 0 0 13 14 0
0 15 0 0 0 16
Expected reordering:
1 2 0 3 0 0
0 0 8 0 0 0
4 5 0 6 0 7
9 10 0 11 12 0
0 15 0 0 0 16
0 0 0 13 14 0
Expected lowerTriangularReorderedMatrix:
0 0 0 0 0 0
0 0 0 0 0 0
4 0 0 0 0 0
9 10 0 0 0 0
0 15 0 0 0 0
0 0 0 13 0 0
Expected lowerTriangularReorderedMatrix:
0 2 0 3 0 0
0 0 0 0 0 0
0 0 0 6 0 7
0 0 0 0 12 0
0 0 0 0 0 0
*/
const int N = 6;
const int nonZeroes = 16;
// Create the Dune A matrix
Sp1x1BlockMatrix matA(N, N, nonZeroes, Sp1x1BlockMatrix::row_wise);
for (auto row = matA.createbegin(); row != matA.createend(); ++row) {
row.insert(row.index());
if (row.index() == 0) {
row.insert(row.index() + 1);
row.insert(row.index() + 3);
}
if (row.index() == 1) {
row.insert(row.index() - 1);
row.insert(row.index() + 2);
row.insert(row.index() + 4);
}
if (row.index() == 2) {
}
if (row.index() == 3) {
row.insert(row.index() - 3);
row.insert(row.index() - 2);
row.insert(row.index() + 1);
}
if (row.index() == 4) {
row.insert(row.index() - 1);
}
if (row.index() == 5) {
row.insert(row.index() - 4);
}
}
matA[0][0][0][0] = 1.0;
matA[0][1][0][0] = 2.0;
matA[0][3][0][0] = 3.0;
matA[1][0][0][0] = 4.0;
matA[1][1][0][0] = 5.0;
matA[1][3][0][0] = 6.0;
matA[1][5][0][0] = 7.0;
matA[2][2][0][0] = 8.0;
matA[3][0][0][0] = 9.0;
matA[3][1][0][0] = 10.0;
matA[3][3][0][0] = 11.0;
matA[3][4][0][0] = 12.0;
matA[4][3][0][0] = 13.0;
matA[4][4][0][0] = 14.0;
matA[5][1][0][0] = 15.0;
matA[5][5][0][0] = 16.0;
return matA;
}
Sp2x2BlockMatrix
get2x2BlockTestMatrix()
{
/*
matA:
1 2 0 3 0 0
4 5 0 6 0 7
0 0 1 0 0 0
9 10 0 1 12 0
0 0 0 13 14 0
0 15 0 0 0 16
*/
const int N = 3;
const int nonZeroes = 9;
// Create the Dune A matrix
Sp2x2BlockMatrix matA(N, N, nonZeroes, Sp2x2BlockMatrix::row_wise);
for (auto row = matA.createbegin(); row != matA.createend(); ++row) {
row.insert(row.index());
if (row.index() == 0) {
row.insert(row.index() + 1);
row.insert(row.index() + 2);
}
if (row.index() == 1) {
row.insert(row.index() - 1);
row.insert(row.index() + 1);
}
if (row.index() == 2) {
row.insert(row.index() - 1);
row.insert(row.index() - 2);
}
}
matA[0][0][0][0] = 1.0;
matA[0][0][0][1] = 2.0;
matA[0][0][1][0] = 4.0;
matA[0][0][1][1] = 5.0;
matA[0][1][0][1] = 3.0;
matA[0][1][1][1] = 6.0;
matA[0][2][1][1] = 7.0;
matA[1][0][1][0] = 9.0;
matA[1][0][1][1] = 10.0;
matA[1][1][0][0] = 1.0;
matA[1][1][1][1] = 1.0;
matA[1][2][1][0] = 12.0;
matA[2][0][1][1] = 15.0;
matA[2][1][0][1] = 13.0;
matA[2][2][0][0] = 14.0;
matA[2][2][1][1] = 16.0;
return matA;
}
BOOST_AUTO_TEST_CASE(TestDiluApply)
{
Sp1x1BlockMatrix matA = get1x1BlockTestMatrix();
std::vector<double> input = {1.1, 1.2, 1.3, 1.4, 1.5, 1.6};
std::vector<double> output(6);
CuFloatingPointVec d_input(input);
CuFloatingPointVec d_output(output);
B1x1Vec h_input(6);
h_input[0] = 1.1;
h_input[1] = 1.2;
h_input[2] = 1.3;
h_input[3] = 1.4;
h_input[4] = 1.5;
h_input[5] = 1.6;
B1x1Vec h_output(6);
// Initialize preconditioner objects
Dune::MultithreadDILU<Sp1x1BlockMatrix, B1x1Vec, B1x1Vec> cpudilu(matA);
auto gpudilu = CuDilu1x1(matA);
// Use the apply
gpudilu.apply(d_output, d_input);
cpudilu.apply(h_output, h_input);
// put results in std::vector
std::vector<T> cpudilures;
for (auto e : h_output) {
cpudilures.push_back(e);
}
auto cudilures = d_output.asStdVector();
// check that CuDilu results matches that of CPU dilu
for (size_t i = 0; i < cudilures.size(); ++i) {
BOOST_CHECK_CLOSE(cudilures[i], cpudilures[i], 1e-7);
}
}
BOOST_AUTO_TEST_CASE(TestDiluApplyBlocked)
{
// init matrix with 2x2 blocks
Sp2x2BlockMatrix matA = get2x2BlockTestMatrix();
auto gpudilu = CuDilu2x2(matA);
Dune::MultithreadDILU<Sp2x2BlockMatrix, B2x2Vec, B2x2Vec> cpudilu(matA);
// create input/output buffers for the apply
std::vector<double> input = {1.1, 1.2, 1.3, 1.4, 1.5, 1.6};
std::vector<double> output(6);
CuFloatingPointVec d_input(input);
CuFloatingPointVec d_output(output);
B2x2Vec h_input(3);
h_input[0][0] = 1.1;
h_input[0][1] = 1.2;
h_input[1][0] = 1.3;
h_input[1][1] = 1.4;
h_input[2][0] = 1.5;
h_input[2][1] = 1.6;
B2x2Vec h_output(3);
// call apply with cpu and gpu dilu
cpudilu.apply(h_output, h_input);
gpudilu.apply(d_output, d_input);
auto cudilures = d_output.asStdVector();
std::vector<T> cpudilures;
for (auto v : h_output) {
for (auto e : v) {
cpudilures.push_back(e);
}
}
// check that the values are close
for (size_t i = 0; i < cudilures.size(); ++i) {
BOOST_CHECK_CLOSE(cudilures[i], cpudilures[i], 1e-7);
}
}
BOOST_AUTO_TEST_CASE(TestDiluInitAndUpdateLarge)
{
// create gpu dilu preconditioner
Sp1x1BlockMatrix matA = get1x1BlockTestMatrix();
auto gpudilu = CuDilu1x1(matA);
matA[0][0][0][0] = 11.0;
matA[0][1][0][0] = 12.0;
matA[0][3][0][0] = 13.0;
matA[1][0][0][0] = 14.0;
matA[1][1][0][0] = 15.0;
matA[1][3][0][0] = 16.0;
matA[1][5][0][0] = 17.0;
matA[2][2][0][0] = 18.0;
matA[3][0][0][0] = 19.0;
matA[3][1][0][0] = 110.0;
matA[3][3][0][0] = 111.0;
matA[3][4][0][0] = 112.0;
matA[4][3][0][0] = 113.0;
matA[4][4][0][0] = 114.0;
matA[5][1][0][0] = 115.0;
matA[5][5][0][0] = 116.0;
// make sure the function is updated
gpudilu.update();
// create a cpu dilu preconditioner on the matrix that is definitely updated
Dune::MultithreadDILU<Sp1x1BlockMatrix, B1x1Vec, B1x1Vec> cpudilu(matA);
std::vector<double> input = {1.1, 1.2, 1.3, 1.4, 1.5, 1.6};
std::vector<double> output(6);
CuFloatingPointVec d_input(input);
CuFloatingPointVec d_output(output);
B1x1Vec h_input(6);
h_input[0] = 1.1;
h_input[1] = 1.2;
h_input[2] = 1.3;
h_input[3] = 1.4;
h_input[4] = 1.5;
h_input[5] = 1.6;
B1x1Vec h_output(6);
// run an apply to see effect of update
gpudilu.apply(d_output, d_input);
cpudilu.apply(h_output, h_input);
// put results in std::vector
std::vector<T> cpudilures;
for (auto e : h_output) {
cpudilures.push_back(e);
}
auto cudilures = d_output.asStdVector();
// check that CuDilu results matches that of CPU dilu
for (size_t i = 0; i < cudilures.size(); ++i) {
BOOST_CHECK_CLOSE(cudilures[i], cpudilures[i], 1e-7);
}
}