Reduced code duplication in BdaBridge

This commit is contained in:
Tong Dong Qiu 2021-03-04 11:49:29 +01:00
parent 81c0a3d9f9
commit 8ea19c66aa

View File

@ -148,7 +148,7 @@ int checkZeroDiagonal(BridgeMatrix& mat) {
// iterate sparsity pattern from Matrix and put colIndices and rowPointers in arrays // iterate sparsity pattern from Matrix and put colIndices and rowPointers in arrays
// sparsity pattern should stay the same due to matrix-add-well-contributions // sparsity pattern should stay the same
// this could be removed if Dune::BCRSMatrix features an API call that returns colIndices and rowPointers // this could be removed if Dune::BCRSMatrix features an API call that returns colIndices and rowPointers
template <class BridgeMatrix> template <class BridgeMatrix>
void getSparsityPattern(BridgeMatrix& mat, std::vector<int> &h_rows, std::vector<int> &h_cols) { void getSparsityPattern(BridgeMatrix& mat, std::vector<int> &h_rows, std::vector<int> &h_cols) {
@ -185,8 +185,10 @@ void BdaBridge<BridgeMatrix, BridgeVector, block_size>::solve_system(BridgeMatri
static std::vector<int> h_rows; static std::vector<int> h_rows;
static std::vector<int> h_cols; static std::vector<int> h_cols;
const int dim = (*mat)[0][0].N(); const int dim = (*mat)[0][0].N();
const int N = mat->N()*dim; const int Nb = mat->N();
const int nnz = (h_rows.empty()) ? mat->nonzeroes()*dim*dim : h_rows.back()*dim*dim; const int N = Nb * dim;
const int nnzb = (h_rows.empty()) ? mat->nonzeroes() : h_rows.back();
const int nnz = nnzb * dim * dim;
if (dim != 3) { if (dim != 3) {
OpmLog::warning("cusparseSolver only accepts blocksize = 3 at this time, will use Dune for the remainder of the program"); OpmLog::warning("cusparseSolver only accepts blocksize = 3 at this time, will use Dune for the remainder of the program");
@ -195,8 +197,8 @@ void BdaBridge<BridgeMatrix, BridgeVector, block_size>::solve_system(BridgeMatri
} }
if (h_rows.capacity() == 0) { if (h_rows.capacity() == 0) {
h_rows.reserve(N+1); h_rows.reserve(Nb+1);
h_cols.reserve(nnz); h_cols.reserve(nnzb);
#if PRINT_TIMERS_BRIDGE #if PRINT_TIMERS_BRIDGE
Dune::Timer t; Dune::Timer t;
#endif #endif