Replace diagfinding with std::find

This commit is contained in:
Tong Dong Qiu 2021-03-03 09:50:33 +01:00
parent 5b4deab7e4
commit 18bf7c4b19

View File

@ -437,18 +437,10 @@ namespace bda
for (int row = 0; row < Nb; ++row) { for (int row = 0; row < Nb; ++row) {
int rowStart = LUmat->rowPointers[row]; int rowStart = LUmat->rowPointers[row];
int rowEnd = LUmat->rowPointers[row+1]; int rowEnd = LUmat->rowPointers[row+1];
bool diagFound = false;
for (int ij = rowStart; ij < rowEnd; ++ij) { auto candidate = std::find(LUmat->colIndices + rowStart, LUmat->colIndices + rowEnd, row);
int col = LUmat->colIndices[ij]; assert(candidate != LUmat->colIndices + rowEnd);
if (row == col) { diagIndex[row] = candidate - LUmat->colIndices;
diagIndex[row] = ij;
diagFound = true;
break;
}
}
if (!diagFound) {
OPM_THROW(std::logic_error, "Error did not find diagonal block in reordered matrix");
}
} }
queue->enqueueWriteBuffer(s.diagIndex, CL_TRUE, 0, Nb * sizeof(int), diagIndex); queue->enqueueWriteBuffer(s.diagIndex, CL_TRUE, 0, Nb * sizeof(int), diagIndex);
queue->enqueueWriteBuffer(s.Lcols, CL_TRUE, 0, Lmat->nnzbs * sizeof(int), Lmat->colIndices); queue->enqueueWriteBuffer(s.Lcols, CL_TRUE, 0, Lmat->nnzbs * sizeof(int), Lmat->colIndices);
@ -508,18 +500,10 @@ namespace bda
for (int row = 0; row < Nb; ++row) { for (int row = 0; row < Nb; ++row) {
int rowStart = LUmat->rowPointers[row]; int rowStart = LUmat->rowPointers[row];
int rowEnd = LUmat->rowPointers[row+1]; int rowEnd = LUmat->rowPointers[row+1];
bool diagFound = false;
for (int ij = rowStart; ij < rowEnd; ++ij) { auto candidate = std::find(LUmat->colIndices + rowStart, LUmat->colIndices + rowEnd, row);
int col = LUmat->colIndices[ij]; assert(candidate != LUmat->colIndices + rowEnd);
if (row == col) { diagIndex[row] = candidate - LUmat->colIndices;
diagIndex[row] = ij;
diagFound = true;
break;
}
}
if (!diagFound) {
OPM_THROW(std::logic_error, "Error did not find diagonal block in reordered matrix");
}
} }
queue->enqueueWriteBuffer(s.diagIndex, CL_TRUE, 0, Nb * sizeof(int), diagIndex); queue->enqueueWriteBuffer(s.diagIndex, CL_TRUE, 0, Nb * sizeof(int), diagIndex);
queue->enqueueWriteBuffer(s.LUcols, CL_TRUE, 0, LUmat->nnzbs * sizeof(int), LUmat->colIndices); queue->enqueueWriteBuffer(s.LUcols, CL_TRUE, 0, LUmat->nnzbs * sizeof(int), LUmat->colIndices);