diff --git a/opm/simulators/linalg/bda/BdaBridge.cpp b/opm/simulators/linalg/bda/BdaBridge.cpp index 12ff2c0a7..3f6850276 100644 --- a/opm/simulators/linalg/bda/BdaBridge.cpp +++ b/opm/simulators/linalg/bda/BdaBridge.cpp @@ -90,16 +90,11 @@ int checkZeroDiagonal(BridgeMatrix& mat) { } -// convert matrix to blocked csr (bsr) arrays -// if only_vals, do not convert rowPointers and colIndices +// iterate sparsity pattern from Matrix and put colIndices and rowPointers in arrays // sparsity pattern should stay the same due to matrix-add-well-contributions template -void convertMatrixBsr(BridgeMatrix& mat, std::vector &h_vals, std::vector &h_rows, std::vector &h_cols, int dim) { +void getSparsityPattern(BridgeMatrix& mat, std::vector &h_rows, std::vector &h_cols, int dim) { int sum_nnzs = 0; - int nnz = mat.nonzeroes()*dim*dim; - - // copy nonzeros - memcpy(h_vals.data(), &(mat[0][0][0][0]), sizeof(double)*nnz); // convert colIndices and rowPointers if(h_rows.size() == 0){ @@ -116,13 +111,8 @@ void convertMatrixBsr(BridgeMatrix& mat, std::vector &h_vals, std::vecto // set last rowpointer h_rows[mat.N()] = mat.nonzeroes(); } -} // end convertMatrixBsr() +} // end getSparsityPattern() -// converts the BlockVector b to a flat array -template -void convertBlockVectorToArray(BridgeVector& b, std::vector &h_b) { - memcpy(h_b.data(), &(b[0]), sizeof(double) * b.N() * b[0].dim()); -} #endif template @@ -133,8 +123,6 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U if(use_gpu){ BdaResult result; result.converged = false; - static std::vector h_vals; - static std::vector h_b; static std::vector h_rows; static std::vector h_cols; int dim = (*mat)[0][0].N(); @@ -146,13 +134,16 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U exit(1); } - if(h_vals.capacity() == 0){ - h_vals.reserve(nnz); - h_vals.resize(nnz); - h_b.reserve(N); - h_b.resize(N); + if(h_rows.capacity() == 0){ h_rows.reserve(N+1); - h_cols.reserve(nnz); + h_cols.reserve(nnz); +#if PRINT_TIMERS_BRIDGE + Dune::Timer t; +#endif + getSparsityPattern(*mat, h_rows, h_cols, dim); +#if PRINT_TIMERS_BRIDGE + printf("getSparsityPattern(): %.4f s\n", t.stop()); +#endif } #if PRINT_TIMERS_BRIDGE @@ -163,23 +154,13 @@ void BdaBridge::solve_system(BridgeMatrix *mat OPM_UNUSED, BridgeVector &b OPM_U checkZeroDiagonal(*mat); #endif -#if PRINT_TIMERS_BRIDGE - Dune::Timer t; -#endif - - convertMatrixBsr(*mat, h_vals, h_rows, h_cols, dim); - convertBlockVectorToArray(b, h_b); - -#if PRINT_TIMERS_BRIDGE - printf("Conversion to flat arrays: %.4f s\n", t.stop()); -#endif ///////////////////////// // actually solve typedef cusparseSolverBackend::cusparseSolverStatus cusparseSolverStatus; - - cusparseSolverStatus status = backend->solve_system(N, nnz, dim, h_vals.data(), h_rows.data(), h_cols.data(), h_b.data(), result); + // assume that underlying data (nonzeroes) from mat (Dune::BCRSMatrix) are contiguous, if this is not the case, cusparseSolver is expected to produce garbage + cusparseSolverStatus status = backend->solve_system(N, nnz, dim, static_cast(&(((*mat)[0][0][0][0]))), h_rows.data(), h_cols.data(), static_cast(&(b[0][0])), result); switch(status){ case cusparseSolverStatus::CUSPARSE_SOLVER_SUCCESS: //OpmLog::info("cusparseSolver converged"); diff --git a/opm/simulators/linalg/bda/cusparseSolverBackend.cu b/opm/simulators/linalg/bda/cusparseSolverBackend.cu index 43e5629ab..51d0caf96 100644 --- a/opm/simulators/linalg/bda/cusparseSolverBackend.cu +++ b/opm/simulators/linalg/bda/cusparseSolverBackend.cu @@ -342,10 +342,6 @@ namespace Opm cusparseCreateBsrsv2Info(&info_U); cudaCheckLastError("Could not create analysis info"); - cudaMemcpyAsync(d_bRows, rows, sizeof(int)*(Nb+1), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(d_bCols, cols, sizeof(int)*nnz, cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(d_bVals, vals, sizeof(double)*nnz, cudaMemcpyHostToDevice, stream); - cusparseDbsrilu02_bufferSize(cusparseHandle, order, Nb, nnzb, descr_M, d_bVals, d_bRows, d_bCols, BLOCK_SIZE, info_M, &d_bufferSize_M); cusparseDbsrsv2_bufferSize(cusparseHandle, order, operation, Nb, nnzb,