mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-02-25 18:55:30 -06:00
Removed extra copy of nnzs, now sends pointer to start of Dune::BCRSMatrix data to cusparseSolver.
This commit is contained in:
parent
48900df882
commit
8b92c5dca6
@ -90,16 +90,11 @@ int checkZeroDiagonal(BridgeMatrix& mat) {
|
||||
}
|
||||
|
||||
|
||||
// convert matrix to blocked csr (bsr) arrays
|
||||
// if only_vals, do not convert rowPointers and colIndices
|
||||
// iterate sparsity pattern from Matrix and put colIndices and rowPointers in arrays
|
||||
// sparsity pattern should stay the same due to matrix-add-well-contributions
|
||||
template <class BridgeMatrix>
|
||||
void convertMatrixBsr(BridgeMatrix& mat, std::vector<double> &h_vals, std::vector<int> &h_rows, std::vector<int> &h_cols, int dim) {
|
||||
void getSparsityPattern(BridgeMatrix& mat, std::vector<int> &h_rows, std::vector<int> &h_cols, int dim) {
|
||||
int sum_nnzs = 0;
|
||||
int nnz = mat.nonzeroes()*dim*dim;
|
||||
|
||||
// copy nonzeros
|
||||
memcpy(h_vals.data(), &(mat[0][0][0][0]), sizeof(double)*nnz);
|
||||
|
||||
// convert colIndices and rowPointers
|
||||
if(h_rows.size() == 0){
|
||||
@ -116,13 +111,8 @@ void convertMatrixBsr(BridgeMatrix& mat, std::vector<double> &h_vals, std::vecto
|
||||
// set last rowpointer
|
||||
h_rows[mat.N()] = mat.nonzeroes();
|
||||
}
|
||||
} // end convertMatrixBsr()
|
||||
} // end getSparsityPattern()
|
||||
|
||||
// converts the BlockVector b to a flat array
|
||||
template <class BridgeVector>
|
||||
void convertBlockVectorToArray(BridgeVector& b, std::vector<double> &h_b) {
|
||||
memcpy(h_b.data(), &(b[0]), sizeof(double) * b.N() * b[0].dim());
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class BridgeMatrix, class BridgeVector>
|
||||
@ -133,8 +123,6 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
||||
if(use_gpu){
|
||||
BdaResult result;
|
||||
result.converged = false;
|
||||
static std::vector<double> h_vals;
|
||||
static std::vector<double> h_b;
|
||||
static std::vector<int> h_rows;
|
||||
static std::vector<int> h_cols;
|
||||
int dim = (*mat)[0][0].N();
|
||||
@ -146,13 +134,16 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(h_vals.capacity() == 0){
|
||||
h_vals.reserve(nnz);
|
||||
h_vals.resize(nnz);
|
||||
h_b.reserve(N);
|
||||
h_b.resize(N);
|
||||
if(h_rows.capacity() == 0){
|
||||
h_rows.reserve(N+1);
|
||||
h_cols.reserve(nnz);
|
||||
h_cols.reserve(nnz);
|
||||
#if PRINT_TIMERS_BRIDGE
|
||||
Dune::Timer t;
|
||||
#endif
|
||||
getSparsityPattern(*mat, h_rows, h_cols, dim);
|
||||
#if PRINT_TIMERS_BRIDGE
|
||||
printf("getSparsityPattern(): %.4f s\n", t.stop());
|
||||
#endif
|
||||
}
|
||||
|
||||
#if PRINT_TIMERS_BRIDGE
|
||||
@ -163,23 +154,13 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
||||
checkZeroDiagonal(*mat);
|
||||
#endif
|
||||
|
||||
#if PRINT_TIMERS_BRIDGE
|
||||
Dune::Timer t;
|
||||
#endif
|
||||
|
||||
convertMatrixBsr(*mat, h_vals, h_rows, h_cols, dim);
|
||||
convertBlockVectorToArray(b, h_b);
|
||||
|
||||
#if PRINT_TIMERS_BRIDGE
|
||||
printf("Conversion to flat arrays: %.4f s\n", t.stop());
|
||||
#endif
|
||||
|
||||
/////////////////////////
|
||||
// actually solve
|
||||
|
||||
typedef cusparseSolverBackend::cusparseSolverStatus cusparseSolverStatus;
|
||||
|
||||
cusparseSolverStatus status = backend->solve_system(N, nnz, dim, h_vals.data(), h_rows.data(), h_cols.data(), h_b.data(), result);
|
||||
// assume that underlying data (nonzeroes) from mat (Dune::BCRSMatrix) are contiguous, if this is not the case, cusparseSolver is expected to produce garbage
|
||||
cusparseSolverStatus status = backend->solve_system(N, nnz, dim, static_cast<double*>(&(((*mat)[0][0][0][0]))), h_rows.data(), h_cols.data(), static_cast<double*>(&(b[0][0])), result);
|
||||
switch(status){
|
||||
case cusparseSolverStatus::CUSPARSE_SOLVER_SUCCESS:
|
||||
//OpmLog::info("cusparseSolver converged");
|
||||
|
@ -342,10 +342,6 @@ namespace Opm
|
||||
cusparseCreateBsrsv2Info(&info_U);
|
||||
cudaCheckLastError("Could not create analysis info");
|
||||
|
||||
cudaMemcpyAsync(d_bRows, rows, sizeof(int)*(Nb+1), cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(d_bCols, cols, sizeof(int)*nnz, cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(d_bVals, vals, sizeof(double)*nnz, cudaMemcpyHostToDevice, stream);
|
||||
|
||||
cusparseDbsrilu02_bufferSize(cusparseHandle, order, Nb, nnzb,
|
||||
descr_M, d_bVals, d_bRows, d_bCols, BLOCK_SIZE, info_M, &d_bufferSize_M);
|
||||
cusparseDbsrsv2_bufferSize(cusparseHandle, order, operation, Nb, nnzb,
|
||||
|
Loading…
Reference in New Issue
Block a user