Merge pull request #5674 from multitalentloes/add_gpudilu_mixed_precision

Add gpudilu mixed precision
This commit is contained in:
Kjetil Olsen Lye
2025-01-07 14:57:17 +01:00
committed by GitHub
7 changed files with 258 additions and 115 deletions

View File

@@ -357,9 +357,10 @@ struct StandardPreconditioners {
F::addCreator("GPUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t, const C& comm) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
const bool store_factorization_as_float = prm.get<bool>("store_factorization_as_float", false);
using field_type = typename V::field_type;
using GpuDILU = typename gpuistl::GpuDILU<M, gpuistl::GpuVector<field_type>, gpuistl::GpuVector<field_type>>;
auto gpuDILU = std::make_shared<GpuDILU>(op.getmat(), split_matrix, tune_gpu_kernels);
auto gpuDILU = std::make_shared<GpuDILU>(op.getmat(), split_matrix, tune_gpu_kernels, store_factorization_as_float);
auto adapted = std::make_shared<gpuistl::PreconditionerAdapter<V, V, GpuDILU>>(gpuDILU);
auto wrapped = std::make_shared<gpuistl::GpuBlockPreconditioner<V, V, Comm>>(adapted, comm);
@@ -660,14 +661,16 @@ struct StandardPreconditioners<Operator, Dune::Amg::SequentialInformation> {
F::addCreator("GPUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
const bool store_factorization_as_float = prm.get<bool>("store_factorization_as_float", false);
using field_type = typename V::field_type;
using GPUDILU = typename gpuistl::GpuDILU<M, gpuistl::GpuVector<field_type>, gpuistl::GpuVector<field_type>>;
return std::make_shared<gpuistl::PreconditionerAdapter<V, V, GPUDILU>>(std::make_shared<GPUDILU>(op.getmat(), split_matrix, tune_gpu_kernels));
return std::make_shared<gpuistl::PreconditionerAdapter<V, V, GPUDILU>>(std::make_shared<GPUDILU>(op.getmat(), split_matrix, tune_gpu_kernels, store_factorization_as_float));
});
F::addCreator("GPUDILUFloat", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
const bool store_factorization_as_float = prm.get<bool>("store_factorization_as_float", false);
using block_type = typename V::block_type;
using VTo = Dune::BlockVector<Dune::FieldVector<float, block_type::dimension>>;
@@ -676,7 +679,7 @@ struct StandardPreconditioners<Operator, Dune::Amg::SequentialInformation> {
using Adapter = typename gpuistl::PreconditionerAdapter<VTo, VTo, GpuDILU>;
using Converter = typename gpuistl::PreconditionerConvertFieldTypeAdapter<Adapter, M, V, V>;
auto converted = std::make_shared<Converter>(op.getmat());
auto adapted = std::make_shared<Adapter>(std::make_shared<GpuDILU>(converted->getConvertedMatrix(), split_matrix, tune_gpu_kernels));
auto adapted = std::make_shared<Adapter>(std::make_shared<GpuDILU>(converted->getConvertedMatrix(), split_matrix, tune_gpu_kernels, store_factorization_as_float));
converted->setUnderlyingPreconditioner(adapted);
return converted;
});

View File

@@ -41,7 +41,7 @@ namespace Opm::gpuistl
{
template <class M, class X, class Y, int l>
GpuDILU<M, X, Y, l>::GpuDILU(const M& A, bool splitMatrix, bool tuneKernels)
GpuDILU<M, X, Y, l>::GpuDILU(const M& A, bool splitMatrix, bool tuneKernels, bool storeFactorizationAsFloat)
: m_cpuMatrix(A)
, m_levelSets(Opm::getMatrixRowColoring(m_cpuMatrix, Opm::ColoringType::LOWER))
, m_reorderedToNatural(detail::createReorderedToNatural(m_levelSets))
@@ -52,6 +52,7 @@ GpuDILU<M, X, Y, l>::GpuDILU(const M& A, bool splitMatrix, bool tuneKernels)
, m_gpuDInv(m_gpuMatrix.N() * m_gpuMatrix.blockSize() * m_gpuMatrix.blockSize())
, m_splitMatrix(splitMatrix)
, m_tuneThreadBlockSizes(tuneKernels)
, m_storeFactorizationAsFloat(storeFactorizationAsFloat)
{
// TODO: Should in some way verify that this matrix is symmetric, only do it debug mode?
@@ -80,6 +81,17 @@ GpuDILU<M, X, Y, l>::GpuDILU(const M& A, bool splitMatrix, bool tuneKernels)
m_gpuMatrixReordered = detail::createReorderedMatrix<M, field_type, GpuSparseMatrix<field_type>>(
m_cpuMatrix, m_reorderedToNatural);
}
if (m_storeFactorizationAsFloat) {
if (!m_splitMatrix){
OPM_THROW(std::runtime_error, "Matrix must be split when storing as float.");
}
m_gpuMatrixReorderedLowerFloat = std::make_unique<FloatMat>(m_gpuMatrixReorderedLower->getRowIndices(), m_gpuMatrixReorderedLower->getColumnIndices(), blocksize_);
m_gpuMatrixReorderedUpperFloat = std::make_unique<FloatMat>(m_gpuMatrixReorderedUpper->getRowIndices(), m_gpuMatrixReorderedUpper->getColumnIndices(), blocksize_);
m_gpuMatrixReorderedDiagFloat = std::make_unique<FloatVec>(m_gpuMatrix.N() * m_gpuMatrix.blockSize() * m_gpuMatrix.blockSize());
m_gpuDInvFloat = std::make_unique<FloatVec>(m_gpuMatrix.N() * m_gpuMatrix.blockSize() * m_gpuMatrix.blockSize());
}
computeDiagAndMoveReorderedData(m_moveThreadBlockSize, m_DILUFactorizationThreadBlockSize);
if (m_tuneThreadBlockSizes) {
@@ -111,17 +123,31 @@ GpuDILU<M, X, Y, l>::apply(X& v, const Y& d, int lowerSolveThreadBlockSize, int
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
if (m_splitMatrix) {
detail::DILU::solveLowerLevelSetSplit<field_type, blocksize_>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data(),
lowerSolveThreadBlockSize);
if (m_storeFactorizationAsFloat) {
detail::DILU::solveLowerLevelSetSplit<blocksize_, field_type, float>(
m_gpuMatrixReorderedLowerFloat->getNonZeroValues().data(),
m_gpuMatrixReorderedLowerFloat->getRowIndices().data(),
m_gpuMatrixReorderedLowerFloat->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInvFloat->data(),
d.data(),
v.data(),
lowerSolveThreadBlockSize);
} else {
detail::DILU::solveLowerLevelSetSplit<blocksize_, field_type, field_type>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data(),
lowerSolveThreadBlockSize);
}
} else {
detail::DILU::solveLowerLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
@@ -144,16 +170,29 @@ GpuDILU<M, X, Y, l>::apply(X& v, const Y& d, int lowerSolveThreadBlockSize, int
const int numOfRowsInLevel = m_levelSets[level].size();
levelStartIdx -= numOfRowsInLevel;
if (m_splitMatrix) {
detail::DILU::solveUpperLevelSetSplit<field_type, blocksize_>(
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data(),
upperSolveThreadBlockSize);
if (m_storeFactorizationAsFloat){
detail::DILU::solveUpperLevelSetSplit<blocksize_, field_type, float>(
m_gpuMatrixReorderedUpperFloat->getNonZeroValues().data(),
m_gpuMatrixReorderedUpperFloat->getRowIndices().data(),
m_gpuMatrixReorderedUpperFloat->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInvFloat->data(),
v.data(),
upperSolveThreadBlockSize);
} else {
detail::DILU::solveUpperLevelSetSplit<blocksize_, field_type, field_type>(
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuReorderToNatural.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data(),
upperSolveThreadBlockSize);
}
} else {
detail::DILU::solveUpperLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
@@ -232,20 +271,44 @@ GpuDILU<M, X, Y, l>::computeDiagAndMoveReorderedData(int moveThreadBlockSize, in
for (int level = 0; level < m_levelSets.size(); ++level) {
const int numOfRowsInLevel = m_levelSets[level].size();
if (m_splitMatrix) {
detail::DILU::computeDiluDiagonalSplit<field_type, blocksize_>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
factorizationBlockSize);
if (m_storeFactorizationAsFloat) {
detail::DILU::computeDiluDiagonalSplit<blocksize_, field_type, float, true>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
m_gpuDInvFloat->data(),
m_gpuMatrixReorderedLowerFloat->getNonZeroValues().data(),
m_gpuMatrixReorderedUpperFloat->getNonZeroValues().data(),
factorizationBlockSize);
} else {
// TODO: should this be field type twice or field type then float in the template?
detail::DILU::computeDiluDiagonalSplit<blocksize_, field_type, float, false>(
m_gpuMatrixReorderedLower->getNonZeroValues().data(),
m_gpuMatrixReorderedLower->getRowIndices().data(),
m_gpuMatrixReorderedLower->getColumnIndices().data(),
m_gpuMatrixReorderedUpper->getNonZeroValues().data(),
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedUpper->getColumnIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuReorderToNatural.data(),
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
nullptr,
nullptr,
nullptr,
factorizationBlockSize);
}
} else {
detail::DILU::computeDiluDiagonal<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),

View File

@@ -53,6 +53,8 @@ public:
using field_type = typename X::field_type;
//! \brief The GPU matrix type
using CuMat = GpuSparseMatrix<field_type>;
using FloatMat = GpuSparseMatrix<float>;
using FloatVec = GpuVector<float>;
//! \brief Constructor.
//!
@@ -60,7 +62,7 @@ public:
//! \param A The matrix to operate on.
//! \param w The relaxation factor.
//!
explicit GpuDILU(const M& A, bool splitMatrix, bool tuneKernels);
explicit GpuDILU(const M& A, bool splitMatrix, bool tuneKernels, bool storeFactorizationAsFloat);
//! \brief Prepare the preconditioner.
//! \note Does nothing at the time being.
@@ -127,6 +129,11 @@ private:
std::unique_ptr<CuMat> m_gpuMatrixReorderedUpper;
//! \brief If matrix splitting is enabled, we also store the diagonal separately
std::unique_ptr<GpuVector<field_type>> m_gpuMatrixReorderedDiag;
//! \brief If mixed precision is enabled, store a float matrix
std::unique_ptr<FloatMat> m_gpuMatrixReorderedLowerFloat;
std::unique_ptr<FloatMat> m_gpuMatrixReorderedUpperFloat;
std::unique_ptr<FloatVec> m_gpuMatrixReorderedDiagFloat;
std::unique_ptr<FloatVec> m_gpuDInvFloat;
//! row conversion from natural to reordered matrix indices stored on the GPU
GpuVector<int> m_gpuNaturalToReorder;
//! row conversion from reordered to natural matrix indices stored on the GPU
@@ -137,6 +144,8 @@ private:
bool m_splitMatrix;
//! \brief Bool storing whether or not we will tune the threadblock sizes. Only used for AMD cards
bool m_tuneThreadBlockSizes;
//! \brief Bool storing whether or not we will store the factorization as float. Only used for mixed precision
bool m_storeFactorizationAsFloat;
//! \brief variables storing the threadblocksizes to use if using the tuned sizes and AMD cards
//! The default value of -1 indicates that we have not calibrated and selected a value yet
int m_upperSolveThreadBlockSize = -1;

View File

@@ -159,7 +159,7 @@ mmv(const T* a, const T* b, T* c)
// dst -= A*B*C
template <class T, int blocksize>
__device__ __forceinline__ void
mmx2Subtraction(T* A, T* B, T* C, T* dst)
mmx2Subtraction(const T* A, const T* B, const T* C, T* dst)
{
T tmp[blocksize * blocksize] = {0};
@@ -256,6 +256,19 @@ mvMixedGeneral(const MatrixScalar* A, const VectorScalar* b, ResultScalar* c)
}
}
// TODO: consider merging with existing block operations
// mixed precision general version of c += Ab
template <int blocksize, class MatrixScalar, class VectorScalar, class ResultScalar, class ComputeScalar>
__device__ __forceinline__ void
umvMixedGeneral(const MatrixScalar* A, const VectorScalar* b, ResultScalar* c)
{
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
c[i] += ResultScalar(ComputeScalar(A[i * blocksize + j]) * ComputeScalar(b[j]));
}
}
}
// TODO: consider merging with existing block operations
// Mixed precision general version of c -= Ab
template <int blocksize, class MatrixScalar, class VectorScalar, class ResultScalar, class ComputeScalar>

View File

@@ -59,16 +59,16 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuSolveLowerLevelSetSplit(T* mat,
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
__global__ void cuSolveLowerLevelSetSplit(MatrixScalar* mat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v)
const MatrixScalar* dInv,
const LinearSolverScalar* d,
LinearSolverScalar* v)
{
const auto reorderedRowIdx = startIdx + (blockDim.x * blockIdx.x + threadIdx.x);
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
@@ -77,7 +77,7 @@ namespace
const size_t nnzIdxLim = rowIndices[reorderedRowIdx + 1];
const int naturalRowIdx = indexConversion[reorderedRowIdx];
T rhs[blocksize];
LinearSolverScalar rhs[blocksize];
for (int i = 0; i < blocksize; i++) {
rhs[i] = d[naturalRowIdx * blocksize + i];
}
@@ -85,10 +85,10 @@ namespace
// TODO: removce the first condition in the for loop
for (int block = nnzIdx; block < nnzIdxLim; ++block) {
const int col = colIndices[block];
mmv<T, blocksize>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
mmvMixedGeneral<blocksize, MatrixScalar, LinearSolverScalar, LinearSolverScalar, LinearSolverScalar>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
}
mv<T, blocksize>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
mvMixedGeneral<blocksize, MatrixScalar, LinearSolverScalar, LinearSolverScalar, LinearSolverScalar>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
}
}
@@ -118,15 +118,15 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuSolveUpperLevelSetSplit(T* mat,
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
__global__ void cuSolveUpperLevelSetSplit(MatrixScalar* mat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v)
const MatrixScalar* dInv,
LinearSolverScalar* v)
{
const auto reorderedRowIdx = startIdx + (blockDim.x * blockIdx.x + threadIdx.x);
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
@@ -134,13 +134,13 @@ namespace
const size_t nnzIdxLim = rowIndices[reorderedRowIdx + 1];
const int naturalRowIdx = indexConversion[reorderedRowIdx];
T rhs[blocksize] = {0};
LinearSolverScalar rhs[blocksize] = {0};
for (int block = nnzIdx; block < nnzIdxLim; ++block) {
const int col = colIndices[block];
umv<T, blocksize>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
umvMixedGeneral<blocksize, MatrixScalar, LinearSolverScalar, LinearSolverScalar, LinearSolverScalar>(&mat[block * blocksize * blocksize], &v[col * blocksize], rhs);
}
mmv<T, blocksize>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
mmvMixedGeneral<blocksize, MatrixScalar, LinearSolverScalar, LinearSolverScalar, LinearSolverScalar>(&dInv[reorderedRowIdx * blocksize * blocksize], rhs, &v[naturalRowIdx * blocksize]);
}
}
@@ -211,19 +211,24 @@ namespace
}
}
template <class T, int blocksize>
__global__ void cuComputeDiluDiagonalSplit(T* reorderedLowerMat,
// TODO: rewrite such that during the factorization there is a dInv of InputScalar type that stores intermediate results
// TOOD: The important part is to only cast after that is fully computed
template <int blocksize, class InputScalar, class OutputScalar, bool copyResultToOtherMatrix>
__global__ void cuComputeDiluDiagonalSplit(const InputScalar* srcReorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
const InputScalar* srcReorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
const InputScalar* srcDiagonal,
int* reorderedToNatural,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv)
InputScalar* dInv,
OutputScalar* dstDiag, // TODO: should this be diag or dInv?
OutputScalar* dstLowerMat,
OutputScalar* dstUpperMat)
{
const auto reorderedRowIdx = startIdx + blockDim.x * blockIdx.x + threadIdx.x;
if (reorderedRowIdx < rowsInLevelSet + startIdx) {
@@ -231,10 +236,10 @@ namespace
const size_t lowerRowStart = lowerRowIndices[reorderedRowIdx];
const size_t lowerRowEnd = lowerRowIndices[reorderedRowIdx + 1];
T dInvTmp[blocksize * blocksize];
InputScalar dInvTmp[blocksize * blocksize];
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
dInvTmp[i * blocksize + j] = diagonal[reorderedRowIdx * blocksize * blocksize + i * blocksize + j];
dInvTmp[i * blocksize + j] = srcDiagonal[reorderedRowIdx * blocksize * blocksize + i * blocksize + j];
}
}
@@ -250,18 +255,28 @@ namespace
const int symOppositeBlock = symOppositeIdx;
mmx2Subtraction<T, blocksize>(&reorderedLowerMat[block * blocksize * blocksize],
if constexpr (copyResultToOtherMatrix) {
// TODO: think long and hard about whether this performs only the wanted memory transfers
moveBlock<blocksize, InputScalar, OutputScalar>(&srcReorderedLowerMat[block * blocksize * blocksize], &dstLowerMat[block * blocksize * blocksize]);
moveBlock<blocksize, InputScalar, OutputScalar>(&srcReorderedUpperMat[symOppositeBlock * blocksize * blocksize], &dstUpperMat[symOppositeBlock * blocksize * blocksize]);
}
mmx2Subtraction<InputScalar, blocksize>(&srcReorderedLowerMat[block * blocksize * blocksize],
&dInv[col * blocksize * blocksize],
&reorderedUpperMat[symOppositeBlock * blocksize * blocksize],
&srcReorderedUpperMat[symOppositeBlock * blocksize * blocksize],
dInvTmp);
}
invBlockInPlace<T, blocksize>(dInvTmp);
invBlockInPlace<InputScalar, blocksize>(dInvTmp);
for (int i = 0; i < blocksize; ++i) {
for (int j = 0; j < blocksize; ++j) {
dInv[reorderedRowIdx * blocksize * blocksize + i * blocksize + j] = dInvTmp[i * blocksize + j];
}
// for (int i = 0; i < blocksize; ++i) {
// for (int j = 0; j < blocksize; ++j) {
// dInv[reorderedRowIdx * blocksize * blocksize + i * blocksize + j] = dInvTmp[i * blocksize + j];
// }
// }
moveBlock<blocksize, InputScalar, InputScalar>(dInvTmp, &dInv[reorderedRowIdx * blocksize * blocksize]);
if constexpr (copyResultToOtherMatrix) {
moveBlock<blocksize, InputScalar, OutputScalar>(dInvTmp, &dstDiag[reorderedRowIdx * blocksize * blocksize]); // important!
}
}
}
@@ -289,23 +304,23 @@ solveLowerLevelSet(T* reorderedMat,
}
template <class T, int blocksize>
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
void
solveLowerLevelSetSplit(T* reorderedMat,
solveLowerLevelSetSplit(MatrixScalar* reorderedMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v,
const MatrixScalar* dInv,
const LinearSolverScalar* d,
LinearSolverScalar* v,
int thrBlockSize)
{
int threadBlockSize = ::Opm::gpuistl::detail::getCudaRecomendedThreadBlockSize(
cuSolveLowerLevelSetSplit<T, blocksize>, thrBlockSize);
cuSolveLowerLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar>, thrBlockSize);
int nThreadBlocks = ::Opm::gpuistl::detail::getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuSolveLowerLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
cuSolveLowerLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, d, v);
}
// perform the upper solve for all rows in the same level set
@@ -328,22 +343,22 @@ solveUpperLevelSet(T* reorderedMat,
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
}
template <class T, int blocksize>
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
void
solveUpperLevelSetSplit(T* reorderedMat,
solveUpperLevelSetSplit(MatrixScalar* reorderedMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v,
const MatrixScalar* dInv,
LinearSolverScalar* v,
int thrBlockSize)
{
int threadBlockSize = ::Opm::gpuistl::detail::getCudaRecomendedThreadBlockSize(
cuSolveUpperLevelSetSplit<T, blocksize>, thrBlockSize);
cuSolveUpperLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar>, thrBlockSize);
int nThreadBlocks = ::Opm::gpuistl::detail::getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuSolveUpperLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
cuSolveUpperLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
}
@@ -376,51 +391,62 @@ computeDiluDiagonal(T* reorderedMat,
}
}
template <class T, int blocksize>
template <int blocksize, class InputScalar, class OutputScalar, bool copyResultToOtherMatrix>
void
computeDiluDiagonalSplit(T* reorderedLowerMat,
computeDiluDiagonalSplit(const InputScalar* srcReorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
const InputScalar* srcReorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
const InputScalar* srcDiagonal,
int* reorderedToNatural,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv,
InputScalar* dInv,
OutputScalar* dstDiag,
OutputScalar* dstLowerMat,
OutputScalar* dstUpperMat,
int thrBlockSize)
{
if (blocksize <= 3) {
int threadBlockSize = ::Opm::gpuistl::detail::getCudaRecomendedThreadBlockSize(
cuComputeDiluDiagonalSplit<T, blocksize>, thrBlockSize);
cuComputeDiluDiagonalSplit<blocksize, InputScalar, OutputScalar, copyResultToOtherMatrix>, thrBlockSize);
int nThreadBlocks = ::Opm::gpuistl::detail::getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeDiluDiagonalSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(reorderedLowerMat,
cuComputeDiluDiagonalSplit<blocksize, InputScalar, OutputScalar, copyResultToOtherMatrix><<<nThreadBlocks, threadBlockSize>>>(srcReorderedLowerMat,
lowerRowIndices,
lowerColIndices,
reorderedUpperMat,
srcReorderedUpperMat,
upperRowIndices,
upperColIndices,
diagonal,
srcDiagonal,
reorderedToNatural,
naturalToReordered,
startIdx,
rowsInLevelSet,
dInv);
dInv,
dstDiag,
dstLowerMat,
dstUpperMat);
} else {
OPM_THROW(std::invalid_argument, "Inverting diagonal is not implemented for blocksizes > 3");
}
}
// TODO: format
#define INSTANTIATE_KERNEL_WRAPPERS(T, blocksize) \
template void computeDiluDiagonal<T, blocksize>(T*, int*, int*, int*, int*, const int, int, T*, int); \
template void computeDiluDiagonalSplit<T, blocksize>( \
T*, int*, int*, T*, int*, int*, T*, int*, int*, const int, int, T*, int); \
template void computeDiluDiagonalSplit<blocksize, T, double, false>( \
const T*, int*, int*, const T*, int*, int*, const T*, int*, int*, const int, int, T*, double*, double*, double*, int); \
template void computeDiluDiagonalSplit<blocksize, T, float, false>( \
const T*, int*, int*, const T*, int*, int*, const T*, int*, int*, const int, int, T*, float*, float*, float*, int); \
template void computeDiluDiagonalSplit<blocksize, T, float, true>( \
const T*, int*, int*, const T*, int*, int*, const T*, int*, int*, const int, int, T*, float*, float*, float*, int); \
template void computeDiluDiagonalSplit<blocksize, T, double, true>( \
const T*, int*, int*, const T*, int*, int*, const T*, int*, int*, const int, int, T*, double*, double*, double*, int); \
template void solveUpperLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*, int); \
template void solveLowerLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*, int); \
template void solveUpperLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*, int); \
template void solveLowerLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*, int);
template void solveLowerLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*, int);
INSTANTIATE_KERNEL_WRAPPERS(float, 1);
INSTANTIATE_KERNEL_WRAPPERS(float, 2);
@@ -434,4 +460,30 @@ INSTANTIATE_KERNEL_WRAPPERS(double, 3);
INSTANTIATE_KERNEL_WRAPPERS(double, 4);
INSTANTIATE_KERNEL_WRAPPERS(double, 5);
INSTANTIATE_KERNEL_WRAPPERS(double, 6);
#define INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(blocksize, LinearSolverScalar, MatrixScalar) \
template void solveUpperLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar>( \
MatrixScalar*, int*, int*, int*, int, int, const MatrixScalar*, LinearSolverScalar*, int); \
template void solveLowerLevelSetSplit<blocksize, LinearSolverScalar, MatrixScalar>( \
MatrixScalar*, int*, int*, int*, int, int, const MatrixScalar*, const LinearSolverScalar*, LinearSolverScalar*, int);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(1, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(2, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(3, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(4, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(5, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(6, float, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(1, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(2, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(3, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(4, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(5, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(6, double, double);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(1, double, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(2, double, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(3, double, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(4, double, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(5, double, float);
INSTANTIATE_SOLVE_LEVEL_SET_SPLIT(6, double, float);
} // namespace Opm::gpuistl::detail::DILU

View File

@@ -71,16 +71,16 @@ void solveLowerLevelSet(T* reorderedMat,
* @param d Stores the defect
* @param [out] v Will store the results of the lower solve
*/
template <class T, int blocksize>
void solveLowerLevelSetSplit(T* reorderedUpperMat,
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
void solveLowerLevelSetSplit(MatrixScalar* reorderedUpperMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v,
const MatrixScalar* dInv,
const LinearSolverScalar* d,
LinearSolverScalar* v,
int threadBlockSize);
/**
@@ -124,15 +124,15 @@ void solveUpperLevelSet(T* reorderedMat,
* @param [out] v Will store the results of the lower solve. To begin with it should store the output from the lower
* solve
*/
template <class T, int blocksize>
void solveUpperLevelSetSplit(T* reorderedUpperMat,
template <int blocksize, class LinearSolverScalar, class MatrixScalar>
void solveUpperLevelSetSplit(MatrixScalar* reorderedUpperMat,
int* rowIndices,
int* colIndices,
int* indexConversion,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v,
const MatrixScalar* dInv,
LinearSolverScalar* v,
int threadBlockSize);
/**
@@ -161,7 +161,6 @@ void computeDiluDiagonal(T* reorderedMat,
int rowsInLevelSet,
T* dInv,
int threadBlockSize);
template <class T, int blocksize>
/**
* @brief Computes the ILU0 of the diagonal elements of the split reordered matrix and stores it in a reordered vector
@@ -184,18 +183,22 @@ template <class T, int blocksize>
* function
* @param [out] dInv The diagonal matrix used by the Diagonal ILU preconditioner
*/
void computeDiluDiagonalSplit(T* reorderedLowerMat,
template <int blocksize, class InputScalar, class OutputScalar, bool copyResultToOtherMatrix>
void computeDiluDiagonalSplit(const InputScalar* srcReorderedLowerMat,
int* lowerRowIndices,
int* lowerColIndices,
T* reorderedUpperMat,
const InputScalar* srcReorderedUpperMat,
int* upperRowIndices,
int* upperColIndices,
T* diagonal,
const InputScalar* srcDiagonal,
int* reorderedToNatural,
int* naturalToReordered,
int startIdx,
int rowsInLevelSet,
T* dInv,
InputScalar* dInv,
OutputScalar* dstDiagonal,
OutputScalar* dstLowerMat,
OutputScalar* dstUpperMat,
int threadBlockSize);
} // namespace Opm::gpuistl::detail::DILU

View File

@@ -211,7 +211,7 @@ BOOST_AUTO_TEST_CASE(TestDiluApply)
// Initialize preconditioner objects
Dune::MultithreadDILU<Sp1x1BlockMatrix, B1x1Vec, B1x1Vec> cpudilu(matA);
auto gpudilu = GpuDilu1x1(matA, true, true);
auto gpudilu = GpuDilu1x1(matA, true, true, false);
// Use the apply
gpudilu.apply(d_output, d_input);
@@ -235,7 +235,7 @@ BOOST_AUTO_TEST_CASE(TestDiluApplyBlocked)
// init matrix with 2x2 blocks
Sp2x2BlockMatrix matA = get2x2BlockTestMatrix();
auto gpudilu = GpuDilu2x2(matA, true, true);
auto gpudilu = GpuDilu2x2(matA, true, true, false);
Dune::MultithreadDILU<Sp2x2BlockMatrix, B2x2Vec, B2x2Vec> cpudilu(matA);
// create input/output buffers for the apply
@@ -275,7 +275,7 @@ BOOST_AUTO_TEST_CASE(TestDiluInitAndUpdateLarge)
{
// create gpu dilu preconditioner
Sp1x1BlockMatrix matA = get1x1BlockTestMatrix();
auto gpudilu = GpuDilu1x1(matA, true, true);
auto gpudilu = GpuDilu1x1(matA, true, true, false);
matA[0][0][0][0] = 11.0;
matA[0][1][0][0] = 12.0;