Merge pull request #5433 from multitalentloes/useRecomendedBlockSize

Autotune thread block size
This commit is contained in:
Kjetil Olsen Lye 2024-07-08 14:17:39 +02:00 committed by GitHub
commit f3b5e0d14d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 227 additions and 75 deletions

View File

@ -347,10 +347,12 @@ struct StandardPreconditioners {
});
F::addCreator("CUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t, const C& comm) {
const bool split_matrix = prm.get<double>("split_matrix", true);
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
// const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
using field_type = typename V::field_type;
using CuDILU = typename cuistl::CuDILU<M, cuistl::CuVector<field_type>, cuistl::CuVector<field_type>>;
auto cuDILU = std::make_shared<CuDILU>(op.getmat(), split_matrix);
auto cuDILU = std::make_shared<CuDILU>(op.getmat(), split_matrix, tune_gpu_kernels);
auto adapted = std::make_shared<cuistl::PreconditionerAdapter<V, V, CuDILU>>(cuDILU);
auto wrapped = std::make_shared<cuistl::CuBlockPreconditioner<V, V, Comm>>(adapted, comm);
@ -605,12 +607,15 @@ struct StandardPreconditioners<Operator, Dune::Amg::SequentialInformation> {
F::addCreator("CUDILU", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
using field_type = typename V::field_type;
using CUDILU = typename cuistl::CuDILU<M, cuistl::CuVector<field_type>, cuistl::CuVector<field_type>>;
return std::make_shared<cuistl::PreconditionerAdapter<V, V, CUDILU>>(std::make_shared<CUDILU>(op.getmat(), split_matrix));
return std::make_shared<cuistl::PreconditionerAdapter<V, V, CUDILU>>(std::make_shared<CUDILU>(op.getmat(), split_matrix, tune_gpu_kernels));
});
F::addCreator("CUDILUFloat", [](const O& op, [[maybe_unused]] const P& prm, const std::function<V()>&, std::size_t) {
const bool split_matrix = prm.get<bool>("split_matrix", true);
const bool tune_gpu_kernels = prm.get<bool>("tune_gpu_kernels", true);
using block_type = typename V::block_type;
using VTo = Dune::BlockVector<Dune::FieldVector<float, block_type::dimension>>;
using matrix_type_to = typename Dune::BCRSMatrix<Dune::FieldMatrix<float, block_type::dimension, block_type::dimension>>;
@ -618,7 +623,7 @@ struct StandardPreconditioners<Operator, Dune::Amg::SequentialInformation> {
using Adapter = typename cuistl::PreconditionerAdapter<VTo, VTo, CuDILU>;
using Converter = typename cuistl::PreconditionerConvertFieldTypeAdapter<Adapter, M, V, V>;
auto converted = std::make_shared<Converter>(op.getmat());
auto adapted = std::make_shared<Adapter>(std::make_shared<CuDILU>(converted->getConvertedMatrix()));
auto adapted = std::make_shared<Adapter>(std::make_shared<CuDILU>(converted->getConvertedMatrix(), split_matrix, tune_gpu_kernels));
converted->setUnderlyingPreconditioner(adapted);
return converted;
});

View File

@ -30,11 +30,16 @@
#include <opm/simulators/linalg/cuistl/detail/safe_conversion.hpp>
#include <opm/simulators/linalg/matrixblock.hh>
#include <vector>
#include <config.h>
#include <chrono>
#include <limits>
#include <tuple>
namespace
{
std::vector<int>
createReorderedToNatural(Opm::SparseTable<size_t> levelSets)
createReorderedToNatural(Opm::SparseTable<size_t>& levelSets)
{
auto res = std::vector<int>(Opm::cuistl::detail::to_size_t(levelSets.dataSize()));
int globCnt = 0;
@ -49,7 +54,7 @@ createReorderedToNatural(Opm::SparseTable<size_t> levelSets)
}
std::vector<int>
createNaturalToReordered(Opm::SparseTable<size_t> levelSets)
createNaturalToReordered(Opm::SparseTable<size_t>& levelSets)
{
auto res = std::vector<int>(Opm::cuistl::detail::to_size_t(levelSets.dataSize()));
int globCnt = 0;
@ -66,7 +71,7 @@ createNaturalToReordered(Opm::SparseTable<size_t> levelSets)
template <class M, class field_type, class GPUM>
void
createReorderedMatrix(const M& naturalMatrix,
std::vector<int> reorderedToNatural,
std::vector<int>& reorderedToNatural,
std::unique_ptr<GPUM>& reorderedGpuMat)
{
M reorderedMatrix(naturalMatrix.N(), naturalMatrix.N(), naturalMatrix.nonzeroes(), M::row_wise);
@ -84,7 +89,7 @@ createReorderedMatrix(const M& naturalMatrix,
template <class M, class field_type, class GPUM>
void
extractLowerAndUpperMatrices(const M& naturalMatrix,
std::vector<int> reorderedToNatural,
std::vector<int>& reorderedToNatural,
std::unique_ptr<GPUM>& lower,
std::unique_ptr<GPUM>& upper)
{
@ -119,7 +124,7 @@ namespace Opm::cuistl
{
template <class M, class X, class Y, int l>
CuDILU<M, X, Y, l>::CuDILU(const M& A, bool split_matrix)
CuDILU<M, X, Y, l>::CuDILU(const M& A, bool splitMatrix, bool tuneKernels)
: m_cpuMatrix(A)
, m_levelSets(Opm::getMatrixRowColoring(m_cpuMatrix, Opm::ColoringType::LOWER))
, m_reorderedToNatural(createReorderedToNatural(m_levelSets))
@ -128,7 +133,8 @@ CuDILU<M, X, Y, l>::CuDILU(const M& A, bool split_matrix)
, m_gpuNaturalToReorder(m_naturalToReordered)
, m_gpuReorderToNatural(m_reorderedToNatural)
, m_gpuDInv(m_gpuMatrix.N() * m_gpuMatrix.blockSize() * m_gpuMatrix.blockSize())
, m_splitMatrix(split_matrix)
, m_splitMatrix(splitMatrix)
, m_tuneThreadBlockSizes(tuneKernels)
{
// TODO: Should in some way verify that this matrix is symmetric, only do it debug mode?
@ -156,6 +162,14 @@ CuDILU<M, X, Y, l>::CuDILU(const M& A, bool split_matrix)
m_cpuMatrix, m_reorderedToNatural, m_gpuMatrixReordered);
}
computeDiagAndMoveReorderedData();
// HIP does currently not support automtically picking thread block sizes as well as CUDA
// So only when tuning and using hip should we do our own manual tuning
#ifdef USE_HIP
if (m_tuneThreadBlockSizes){
tuneThreadBlockSizes();
}
#endif
}
template <class M, class X, class Y, int l>
@ -183,7 +197,8 @@ CuDILU<M, X, Y, l>::apply(X& v, const Y& d)
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data());
v.data(),
m_applyThreadBlockSize);
} else {
detail::computeLowerSolveLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
@ -194,7 +209,8 @@ CuDILU<M, X, Y, l>::apply(X& v, const Y& d)
numOfRowsInLevel,
m_gpuDInv.data(),
d.data(),
v.data());
v.data(),
m_applyThreadBlockSize);
}
levelStartIdx += numOfRowsInLevel;
}
@ -213,7 +229,8 @@ CuDILU<M, X, Y, l>::apply(X& v, const Y& d)
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data());
v.data(),
m_applyThreadBlockSize);
} else {
detail::computeUpperSolveLevelSet<field_type, blocksize_>(
m_gpuMatrixReordered->getNonZeroValues().data(),
@ -223,7 +240,8 @@ CuDILU<M, X, Y, l>::apply(X& v, const Y& d)
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data(),
v.data());
v.data(),
m_applyThreadBlockSize);
}
}
}
@ -270,14 +288,16 @@ CuDILU<M, X, Y, l>::computeDiagAndMoveReorderedData()
m_gpuMatrixReorderedUpper->getRowIndices().data(),
m_gpuMatrixReorderedDiag->data(),
m_gpuNaturalToReorder.data(),
m_gpuMatrixReorderedLower->N());
m_gpuMatrixReorderedLower->N(),
m_updateThreadBlockSize);
} else {
detail::copyMatDataToReordered<field_type, blocksize_>(m_gpuMatrix.getNonZeroValues().data(),
m_gpuMatrix.getRowIndices().data(),
m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
m_gpuNaturalToReorder.data(),
m_gpuMatrixReordered->N());
m_gpuMatrixReordered->N(),
m_updateThreadBlockSize);
}
int levelStartIdx = 0;
@ -296,7 +316,8 @@ CuDILU<M, X, Y, l>::computeDiagAndMoveReorderedData()
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data());
m_gpuDInv.data(),
m_updateThreadBlockSize);
} else {
detail::computeDiluDiagonal<field_type, blocksize_>(m_gpuMatrixReordered->getNonZeroValues().data(),
m_gpuMatrixReordered->getRowIndices().data(),
@ -305,13 +326,66 @@ CuDILU<M, X, Y, l>::computeDiagAndMoveReorderedData()
m_gpuNaturalToReorder.data(),
levelStartIdx,
numOfRowsInLevel,
m_gpuDInv.data());
m_gpuDInv.data(),
m_updateThreadBlockSize);
}
levelStartIdx += numOfRowsInLevel;
}
}
}
template <class M, class X, class Y, int l>
void
CuDILU<M, X, Y, l>::tuneThreadBlockSizes()
{
// TODO: generalize this code and put it somewhere outside of this class
long long bestApplyTime = std::numeric_limits<long long>::max();
long long bestUpdateTime = std::numeric_limits<long long>::max();
int bestApplyBlockSize = -1;
int bestUpdateBlockSize = -1;
int interval = 64;
//temporary buffers for the apply
CuVector<field_type> tmpV(m_gpuMatrix.N() * m_gpuMatrix.blockSize());
CuVector<field_type> tmpD(m_gpuMatrix.N() * m_gpuMatrix.blockSize());
tmpD = 1;
for (int thrBlockSize = interval; thrBlockSize <= 1024; thrBlockSize += interval){
// sometimes the first kernel launch kan be slower, so take the time twice
for (int i = 0; i < 2; ++i){
auto beforeUpdate = std::chrono::high_resolution_clock::now();
m_updateThreadBlockSize = thrBlockSize;
update();
std::ignore = cudaDeviceSynchronize();
auto afterUpdate = std::chrono::high_resolution_clock::now();
if (cudaSuccess == cudaGetLastError()){ // kernel launch was valid
long long durationInMicroSec = std::chrono::duration_cast<std::chrono::microseconds>(afterUpdate - beforeUpdate).count();
if (durationInMicroSec < bestUpdateTime){
bestUpdateTime = durationInMicroSec;
bestUpdateBlockSize = thrBlockSize;
}
}
auto beforeApply = std::chrono::high_resolution_clock::now();
m_applyThreadBlockSize = thrBlockSize;
apply(tmpV, tmpD);
std::ignore = cudaDeviceSynchronize();
auto afterApply = std::chrono::high_resolution_clock::now();
if (cudaSuccess == cudaGetLastError()){ // kernel launch was valid
long long durationInMicroSec = std::chrono::duration_cast<std::chrono::microseconds>(afterApply - beforeApply).count();
if (durationInMicroSec < bestApplyTime){
bestApplyTime = durationInMicroSec;
bestApplyBlockSize = thrBlockSize;
}
}
}
}
m_applyThreadBlockSize = bestApplyBlockSize;
m_updateThreadBlockSize = bestUpdateBlockSize;
}
} // namespace Opm::cuistl
#define INSTANTIATE_CUDILU_DUNE(realtype, blockdim) \
template class ::Opm::cuistl::CuDILU<Dune::BCRSMatrix<Dune::FieldMatrix<realtype, blockdim, blockdim>>, \

View File

@ -66,7 +66,7 @@ public:
//! \param A The matrix to operate on.
//! \param w The relaxation factor.
//!
explicit CuDILU(const M& A, bool split_matrix = true);
explicit CuDILU(const M& A, bool splitMatrix, bool tuneKernels);
//! \brief Prepare the preconditioner.
//! \note Does nothing at the time being.
@ -88,6 +88,9 @@ public:
//! \brief Compute the diagonal of the DILU, and update the data of the reordered matrix
void computeDiagAndMoveReorderedData();
//! \brief function that will experimentally tune the thread block sizes of the important cuda kernels
void tuneThreadBlockSizes();
//! \returns false
static constexpr bool shouldCallPre()
@ -130,6 +133,12 @@ private:
CuVector<field_type> m_gpuDInv;
//! \brief Bool storing whether or not we should store matrices in a split format
bool m_splitMatrix;
//! \brief Bool storing whether or not we will tune the threadblock sizes. Only used for AMD cards
bool m_tuneThreadBlockSizes;
//! \brief variables storing the threadblocksizes to use if using the tuned sizes and AMD cards
//! The default value of -1 indicates that we have not calibrated and selected a value yet
int m_applyThreadBlockSize = -1;
int m_updateThreadBlockSize = -1;
};
} // end namespace Opm::cuistl

View File

@ -19,6 +19,7 @@
#include <opm/common/ErrorMacros.hpp>
#include <opm/simulators/linalg/cuistl/detail/cusparse_matrix_operations.hpp>
#include <stdexcept>
#include <config.h>
namespace Opm::cuistl::detail
{
@ -505,25 +506,22 @@ namespace
}
}
constexpr inline size_t getThreads([[maybe_unused]] size_t numberOfRows)
{
return 1024;
}
inline size_t getBlocks(size_t numberOfRows)
{
const auto threads = getThreads(numberOfRows);
return (numberOfRows + threads - 1) / threads;
}
// Kernel here is the function object of the cuda kernel
template <class Kernel>
inline int getCudaRecomendedThreadBlockSize(Kernel k)
inline int getCudaRecomendedThreadBlockSize(Kernel k, int suggestedThrBlockSize=-1)
{
int blockSize;
if (suggestedThrBlockSize != -1){
return suggestedThrBlockSize;
}
// Use cuda API to maximize occupancy, otherwise we just pick a thread block size if it is not tuned
#if USE_HIP
return 512;
#else
int blockSize = 1024;
int tmpGridSize;
cudaOccupancyMaxPotentialBlockSize(&tmpGridSize, &blockSize, k, 0, 0);
std::ignore = cudaOccupancyMaxPotentialBlockSize(&tmpGridSize, &blockSize, k, 0, 0);
return blockSize;
#endif
}
inline int getNumberOfBlocks(int wantedThreads, int threadBlockSize)
@ -538,8 +536,10 @@ void
invertDiagonalAndFlatten(T* mat, int* rowIndices, int* colIndices, size_t numberOfRows, T* vec)
{
if (blocksize <= 3) {
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuInvertDiagonalAndFlatten<T, blocksize>);
int nThreadBlocks = getNumberOfBlocks(numberOfRows, threadBlockSize);
cuInvertDiagonalAndFlatten<T, blocksize>
<<<getBlocks(numberOfRows), getThreads(numberOfRows)>>>(mat, rowIndices, colIndices, numberOfRows, vec);
<<<nThreadBlocks, threadBlockSize>>>(mat, rowIndices, colIndices, numberOfRows, vec);
} else {
OPM_THROW(std::invalid_argument, "Inverting diagonal is not implemented for blocksizes > 3");
}
@ -556,9 +556,12 @@ computeLowerSolveLevelSet(T* reorderedMat,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v)
T* v,
int thrBlockSize)
{
cuComputeLowerSolveLevelSet<T, blocksize><<<getBlocks(rowsInLevelSet), getThreads(rowsInLevelSet)>>>(
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSet<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeLowerSolveLevelSet<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, d, v);
}
@ -573,9 +576,10 @@ computeLowerSolveLevelSetSplit(T* reorderedMat,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v)
T* v,
int thrBlockSize)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeLowerSolveLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, d, v);
@ -590,9 +594,12 @@ computeUpperSolveLevelSet(T* reorderedMat,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v)
T* v,
int thrBlockSize)
{
cuComputeUpperSolveLevelSet<T, blocksize><<<getBlocks(rowsInLevelSet), getThreads(rowsInLevelSet)>>>(
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeUpperSolveLevelSet<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeUpperSolveLevelSet<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
}
@ -605,9 +612,10 @@ computeUpperSolveLevelSetSplit(T* reorderedMat,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v)
T* v,
int thrBlockSize)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeUpperSolveLevelSetSplit<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeUpperSolveLevelSetSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
reorderedMat, rowIndices, colIndices, indexConversion, startIdx, rowsInLevelSet, dInv, v);
@ -622,11 +630,14 @@ computeDiluDiagonal(T* reorderedMat,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv)
T* dInv,
int thrBlockSize)
{
if (blocksize <= 3) {
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeDiluDiagonal<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeDiluDiagonal<T, blocksize>
<<<getBlocks(rowsInLevelSet), getThreads(rowsInLevelSet)>>>(reorderedMat,
<<<nThreadBlocks, threadBlockSize>>>(reorderedMat,
rowIndices,
colIndices,
reorderedToNatural,
@ -652,10 +663,11 @@ computeDiluDiagonalSplit(T* reorderedLowerMat,
int* naturalToReordered,
const int startIdx,
int rowsInLevelSet,
T* dInv)
T* dInv,
int thrBlockSize)
{
if (blocksize <= 3) {
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeDiluDiagonalSplit<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(rowsInLevelSet, threadBlockSize);
cuComputeDiluDiagonalSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(reorderedLowerMat,
lowerRowIndices,
@ -677,9 +689,12 @@ computeDiluDiagonalSplit(T* reorderedLowerMat,
template <class T, int blocksize>
void
copyMatDataToReordered(
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* naturalToReordered, size_t numberOfRows)
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* naturalToReordered, size_t numberOfRows,
int thrBlockSize)
{
cuMoveDataToReordered<T, blocksize><<<getBlocks(numberOfRows), getThreads(numberOfRows)>>>(
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuMoveDataToReordered<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(numberOfRows, threadBlockSize);
cuMoveDataToReordered<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(
srcMatrix, srcRowIndices, dstMatrix, dstRowIndices, naturalToReordered, numberOfRows);
}
@ -694,9 +709,10 @@ copyMatDataToReorderedSplit(T* srcMatrix,
int* dstUpperRowIndices,
T* dstDiag,
int* naturalToReordered,
size_t numberOfRows)
size_t numberOfRows,
int thrBlockSize)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuComputeLowerSolveLevelSetSplit<T, blocksize>);
int threadBlockSize = getCudaRecomendedThreadBlockSize(cuMoveDataToReorderedSplit<T, blocksize>, thrBlockSize);
int nThreadBlocks = getNumberOfBlocks(numberOfRows, threadBlockSize);
cuMoveDataToReorderedSplit<T, blocksize><<<nThreadBlocks, threadBlockSize>>>(srcMatrix,
srcRowIndices,
@ -712,15 +728,15 @@ copyMatDataToReorderedSplit(T* srcMatrix,
#define INSTANTIATE_KERNEL_WRAPPERS(T, blocksize) \
template void invertDiagonalAndFlatten<T, blocksize>(T*, int*, int*, size_t, T*); \
template void copyMatDataToReordered<T, blocksize>(T*, int*, T*, int*, int*, size_t); \
template void copyMatDataToReorderedSplit<T, blocksize>(T*, int*, int*, T*, int*, T*, int*, T*, int*, size_t); \
template void computeDiluDiagonal<T, blocksize>(T*, int*, int*, int*, int*, const int, int, T*); \
template void copyMatDataToReordered<T, blocksize>(T*, int*, T*, int*, int*, size_t, int); \
template void copyMatDataToReorderedSplit<T, blocksize>(T*, int*, int*, T*, int*, T*, int*, T*, int*, size_t, int); \
template void computeDiluDiagonal<T, blocksize>(T*, int*, int*, int*, int*, const int, int, T*, int); \
template void computeDiluDiagonalSplit<T, blocksize>( \
T*, int*, int*, T*, int*, int*, T*, int*, int*, const int, int, T*); \
template void computeUpperSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*); \
template void computeLowerSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*); \
template void computeUpperSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*); \
template void computeLowerSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*);
T*, int*, int*, T*, int*, int*, T*, int*, int*, const int, int, T*, int); \
template void computeUpperSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*, int); \
template void computeLowerSolveLevelSet<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*, int); \
template void computeUpperSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, T*, int); \
template void computeLowerSolveLevelSetSplit<T, blocksize>(T*, int*, int*, int*, int, int, const T*, const T*, T*, int);
INSTANTIATE_KERNEL_WRAPPERS(float, 1);
INSTANTIATE_KERNEL_WRAPPERS(float, 2);

View File

@ -59,7 +59,8 @@ void computeLowerSolveLevelSet(T* reorderedMat,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v);
T* v,
int threadBlockSize);
/**
* @brief Perform a lower solve on certain rows in a matrix that can safely be computed in parallel
@ -86,7 +87,8 @@ void computeLowerSolveLevelSetSplit(T* reorderedUpperMat,
int rowsInLevelSet,
const T* dInv,
const T* d,
T* v);
T* v,
int threadBlockSize);
/**
* @brief Perform an upper solve on certain rows in a matrix that can safely be computed in parallel
@ -111,7 +113,8 @@ void computeUpperSolveLevelSet(T* reorderedMat,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v);
T* v,
int threadBlockSize);
template <class T, int blocksize>
/**
@ -136,7 +139,8 @@ void computeUpperSolveLevelSetSplit(T* reorderedUpperMat,
int startIdx,
int rowsInLevelSet,
const T* dInv,
T* v);
T* v,
int threadBlockSize);
/**
* @brief Computes the ILU0 of the diagonal elements of the reordered matrix and stores it in a reordered vector
@ -162,7 +166,8 @@ void computeDiluDiagonal(T* reorderedMat,
int* naturalToReordered,
int startIdx,
int rowsInLevelSet,
T* dInv);
T* dInv,
int threadBlockSize);
template <class T, int blocksize>
/**
@ -197,7 +202,8 @@ void computeDiluDiagonalSplit(T* reorderedLowerMat,
int* naturalToReordered,
int startIdx,
int rowsInLevelSet,
T* dInv);
T* dInv,
int threadBlockSize);
/**
* @brief Reorders the elements of a matrix by copying them from one matrix to another using a permutation list
@ -211,7 +217,7 @@ void computeDiluDiagonalSplit(T* reorderedLowerMat,
*/
template <class T, int blocksize>
void copyMatDataToReordered(
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* naturalToReordered, size_t numberOfRows);
T* srcMatrix, int* srcRowIndices, T* dstMatrix, int* dstRowIndices, int* naturalToReordered, size_t numberOfRows, int threadBlockSize);
/**
* @brief Reorders the elements of a matrix by copying them from one matrix to a split matrix using a permutation list
@ -229,7 +235,7 @@ void copyMatDataToReordered(
*/
template <class T, int blocksize>
void copyMatDataToReorderedSplit(
T* srcMatrix, int* srcRowIndices, int* srcColumnIndices, T* dstLowerMatrix, int* dstLowerRowIndices, T* dstUpperMatrix, int* dstUpperRowIndices, T* dstDiag, int* naturalToReordered, size_t numberOfRows);
T* srcMatrix, int* srcRowIndices, int* srcColumnIndices, T* dstLowerMatrix, int* dstLowerRowIndices, T* dstUpperMatrix, int* dstUpperRowIndices, T* dstDiag, int* naturalToReordered, size_t numberOfRows, int threadBlockSize);
} // namespace Opm::cuistl::detail
#endif

View File

@ -23,6 +23,7 @@
#include <opm/simulators/linalg/cuistl/detail/cuda_safe_call.hpp>
#include <opm/simulators/linalg/cuistl/CuVector.hpp>
#include <stdexcept>
#include <config.h>
namespace Opm::cuistl::detail
{
@ -115,10 +116,31 @@ namespace
}
} // namespace
// Kernel here is the function object of the cuda kernel
template <class Kernel>
inline int getCudaRecomendedThreadBlockSize(Kernel k)
{
#if USE_HIP
return 512;
#else
int blockSize;
int tmpGridSize;
std::ignore = cudaOccupancyMaxPotentialBlockSize(&tmpGridSize, &blockSize, k, 0, 0);
return blockSize;
#endif
}
inline int getNumberOfBlocks(int wantedThreads, int threadBlockSize)
{
return (wantedThreads + threadBlockSize - 1) / threadBlockSize;
}
template <class T>
void
setVectorValue(T* deviceData, size_t numberOfElements, const T& value)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(setVectorValueKernel<T>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
setVectorValueKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
deviceData, numberOfElements, value);
}
@ -131,6 +153,8 @@ template <class T>
void
setZeroAtIndexSet(T* deviceData, size_t numberOfElements, const int* indices)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(setZeroAtIndexSetKernel<T>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
setZeroAtIndexSetKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
deviceData, numberOfElements, indices);
}
@ -142,6 +166,8 @@ template <class T>
T
innerProductAtIndices(cublasHandle_t cublasHandle, const T* deviceA, const T* deviceB, T* buffer, size_t numberOfElements, const int* indices)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(elementWiseMultiplyKernel<T>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
elementWiseMultiplyKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
deviceA, deviceB, buffer, numberOfElements, indices);
@ -160,6 +186,8 @@ template int innerProductAtIndices(cublasHandle_t, const int*, const int*, int*
template <class T>
void prepareSendBuf(const T* deviceA, T* buffer, size_t numberOfElements, const int* indices)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(prepareSendBufKernel<T>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
prepareSendBufKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(deviceA, buffer, numberOfElements, indices);
OPM_CUDA_SAFE_CALL(cudaDeviceSynchronize()); // The buffers are prepared for MPI. Wait for them to finish.
}
@ -170,6 +198,8 @@ template void prepareSendBuf(const int* deviceA, int* buffer, size_t numberOfEle
template <class T>
void syncFromRecvBuf(T* deviceA, T* buffer, size_t numberOfElements, const int* indices)
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(syncFromRecvBufKernel<T>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
syncFromRecvBufKernel<<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(deviceA, buffer, numberOfElements, indices);
//cudaDeviceSynchronize(); // Not needed, I guess...
}
@ -188,16 +218,28 @@ weightedDiagMV(const T* squareBlockVector,
{
switch (blocksize) {
case 1:
weightedDiagMV<T, 1><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(weightedDiagMV<T, 1>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
weightedDiagMV<T, 1><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
}
break;
case 2:
weightedDiagMV<T, 2><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(weightedDiagMV<T, 2>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
weightedDiagMV<T, 2><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
}
break;
case 3:
weightedDiagMV<T, 3><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
{
int threadBlockSize = getCudaRecomendedThreadBlockSize(weightedDiagMV<T, 3>);
int nThreadBlocks = getNumberOfBlocks(numberOfElements, threadBlockSize);
weightedDiagMV<T, 3><<<getBlocks(numberOfElements), getThreads(numberOfElements)>>>(
squareBlockVector, numberOfElements, relaxationFactor, srcVec, dstVec);
}
break;
default:
OPM_THROW(std::invalid_argument, "blockvector Hadamard product not implemented for blocksize>3");

View File

@ -211,7 +211,7 @@ BOOST_AUTO_TEST_CASE(TestDiluApply)
// Initialize preconditioner objects
Dune::MultithreadDILU<Sp1x1BlockMatrix, B1x1Vec, B1x1Vec> cpudilu(matA);
auto gpudilu = CuDilu1x1(matA);
auto gpudilu = CuDilu1x1(matA, true, true);
// Use the apply
gpudilu.apply(d_output, d_input);
@ -235,7 +235,7 @@ BOOST_AUTO_TEST_CASE(TestDiluApplyBlocked)
// init matrix with 2x2 blocks
Sp2x2BlockMatrix matA = get2x2BlockTestMatrix();
auto gpudilu = CuDilu2x2(matA);
auto gpudilu = CuDilu2x2(matA, true, true);
Dune::MultithreadDILU<Sp2x2BlockMatrix, B2x2Vec, B2x2Vec> cpudilu(matA);
// create input/output buffers for the apply
@ -275,7 +275,7 @@ BOOST_AUTO_TEST_CASE(TestDiluInitAndUpdateLarge)
{
// create gpu dilu preconditioner
Sp1x1BlockMatrix matA = get1x1BlockTestMatrix();
auto gpudilu = CuDilu1x1(matA);
auto gpudilu = CuDilu1x1(matA, true, true);
matA[0][0][0][0] = 11.0;
matA[0][1][0][0] = 12.0;