mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-02-25 18:55:30 -06:00
Removed extra copy of nnzs, now sends pointer to start of Dune::BCRSMatrix data to cusparseSolver.
This commit is contained in:
parent
48900df882
commit
8b92c5dca6
@ -90,16 +90,11 @@ int checkZeroDiagonal(BridgeMatrix& mat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// convert matrix to blocked csr (bsr) arrays
|
// iterate sparsity pattern from Matrix and put colIndices and rowPointers in arrays
|
||||||
// if only_vals, do not convert rowPointers and colIndices
|
|
||||||
// sparsity pattern should stay the same due to matrix-add-well-contributions
|
// sparsity pattern should stay the same due to matrix-add-well-contributions
|
||||||
template <class BridgeMatrix>
|
template <class BridgeMatrix>
|
||||||
void convertMatrixBsr(BridgeMatrix& mat, std::vector<double> &h_vals, std::vector<int> &h_rows, std::vector<int> &h_cols, int dim) {
|
void getSparsityPattern(BridgeMatrix& mat, std::vector<int> &h_rows, std::vector<int> &h_cols, int dim) {
|
||||||
int sum_nnzs = 0;
|
int sum_nnzs = 0;
|
||||||
int nnz = mat.nonzeroes()*dim*dim;
|
|
||||||
|
|
||||||
// copy nonzeros
|
|
||||||
memcpy(h_vals.data(), &(mat[0][0][0][0]), sizeof(double)*nnz);
|
|
||||||
|
|
||||||
// convert colIndices and rowPointers
|
// convert colIndices and rowPointers
|
||||||
if(h_rows.size() == 0){
|
if(h_rows.size() == 0){
|
||||||
@ -116,13 +111,8 @@ void convertMatrixBsr(BridgeMatrix& mat, std::vector<double> &h_vals, std::vecto
|
|||||||
// set last rowpointer
|
// set last rowpointer
|
||||||
h_rows[mat.N()] = mat.nonzeroes();
|
h_rows[mat.N()] = mat.nonzeroes();
|
||||||
}
|
}
|
||||||
} // end convertMatrixBsr()
|
} // end getSparsityPattern()
|
||||||
|
|
||||||
// converts the BlockVector b to a flat array
|
|
||||||
template <class BridgeVector>
|
|
||||||
void convertBlockVectorToArray(BridgeVector& b, std::vector<double> &h_b) {
|
|
||||||
memcpy(h_b.data(), &(b[0]), sizeof(double) * b.N() * b[0].dim());
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <class BridgeMatrix, class BridgeVector>
|
template <class BridgeMatrix, class BridgeVector>
|
||||||
@ -133,8 +123,6 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
|||||||
if(use_gpu){
|
if(use_gpu){
|
||||||
BdaResult result;
|
BdaResult result;
|
||||||
result.converged = false;
|
result.converged = false;
|
||||||
static std::vector<double> h_vals;
|
|
||||||
static std::vector<double> h_b;
|
|
||||||
static std::vector<int> h_rows;
|
static std::vector<int> h_rows;
|
||||||
static std::vector<int> h_cols;
|
static std::vector<int> h_cols;
|
||||||
int dim = (*mat)[0][0].N();
|
int dim = (*mat)[0][0].N();
|
||||||
@ -146,13 +134,16 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(h_vals.capacity() == 0){
|
if(h_rows.capacity() == 0){
|
||||||
h_vals.reserve(nnz);
|
|
||||||
h_vals.resize(nnz);
|
|
||||||
h_b.reserve(N);
|
|
||||||
h_b.resize(N);
|
|
||||||
h_rows.reserve(N+1);
|
h_rows.reserve(N+1);
|
||||||
h_cols.reserve(nnz);
|
h_cols.reserve(nnz);
|
||||||
|
#if PRINT_TIMERS_BRIDGE
|
||||||
|
Dune::Timer t;
|
||||||
|
#endif
|
||||||
|
getSparsityPattern(*mat, h_rows, h_cols, dim);
|
||||||
|
#if PRINT_TIMERS_BRIDGE
|
||||||
|
printf("getSparsityPattern(): %.4f s\n", t.stop());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#if PRINT_TIMERS_BRIDGE
|
#if PRINT_TIMERS_BRIDGE
|
||||||
@ -163,23 +154,13 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U
|
|||||||
checkZeroDiagonal(*mat);
|
checkZeroDiagonal(*mat);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if PRINT_TIMERS_BRIDGE
|
|
||||||
Dune::Timer t;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
convertMatrixBsr(*mat, h_vals, h_rows, h_cols, dim);
|
|
||||||
convertBlockVectorToArray(b, h_b);
|
|
||||||
|
|
||||||
#if PRINT_TIMERS_BRIDGE
|
|
||||||
printf("Conversion to flat arrays: %.4f s\n", t.stop());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/////////////////////////
|
/////////////////////////
|
||||||
// actually solve
|
// actually solve
|
||||||
|
|
||||||
typedef cusparseSolverBackend::cusparseSolverStatus cusparseSolverStatus;
|
typedef cusparseSolverBackend::cusparseSolverStatus cusparseSolverStatus;
|
||||||
|
// assume that underlying data (nonzeroes) from mat (Dune::BCRSMatrix) are contiguous, if this is not the case, cusparseSolver is expected to produce garbage
|
||||||
cusparseSolverStatus status = backend->solve_system(N, nnz, dim, h_vals.data(), h_rows.data(), h_cols.data(), h_b.data(), result);
|
cusparseSolverStatus status = backend->solve_system(N, nnz, dim, static_cast<double*>(&(((*mat)[0][0][0][0]))), h_rows.data(), h_cols.data(), static_cast<double*>(&(b[0][0])), result);
|
||||||
switch(status){
|
switch(status){
|
||||||
case cusparseSolverStatus::CUSPARSE_SOLVER_SUCCESS:
|
case cusparseSolverStatus::CUSPARSE_SOLVER_SUCCESS:
|
||||||
//OpmLog::info("cusparseSolver converged");
|
//OpmLog::info("cusparseSolver converged");
|
||||||
|
@ -342,10 +342,6 @@ namespace Opm
|
|||||||
cusparseCreateBsrsv2Info(&info_U);
|
cusparseCreateBsrsv2Info(&info_U);
|
||||||
cudaCheckLastError("Could not create analysis info");
|
cudaCheckLastError("Could not create analysis info");
|
||||||
|
|
||||||
cudaMemcpyAsync(d_bRows, rows, sizeof(int)*(Nb+1), cudaMemcpyHostToDevice, stream);
|
|
||||||
cudaMemcpyAsync(d_bCols, cols, sizeof(int)*nnz, cudaMemcpyHostToDevice, stream);
|
|
||||||
cudaMemcpyAsync(d_bVals, vals, sizeof(double)*nnz, cudaMemcpyHostToDevice, stream);
|
|
||||||
|
|
||||||
cusparseDbsrilu02_bufferSize(cusparseHandle, order, Nb, nnzb,
|
cusparseDbsrilu02_bufferSize(cusparseHandle, order, Nb, nnzb,
|
||||||
descr_M, d_bVals, d_bRows, d_bCols, BLOCK_SIZE, info_M, &d_bufferSize_M);
|
descr_M, d_bVals, d_bRows, d_bCols, BLOCK_SIZE, info_M, &d_bufferSize_M);
|
||||||
cusparseDbsrsv2_bufferSize(cusparseHandle, order, operation, Nb, nnzb,
|
cusparseDbsrsv2_bufferSize(cusparseHandle, order, operation, Nb, nnzb,
|
||||||
|
Loading…
Reference in New Issue
Block a user