Use std::find and added comments

This commit is contained in:
tqiu 2021-01-18 17:10:46 +01:00
parent a8e524fc9d
commit 123e3fa89e

View File

@ -238,80 +238,100 @@ namespace bda
double *Ltmp = new double[Lmat->nnzbs * block_size * block_size];
for (int sweep = 0; sweep < num_sweeps; ++sweep) {
// algorithm
// for every block in A (LUmat):
// if i > j:
// Lij = (Aij - sum k=1 to j-1 {Lik*Ukj}) / Ujj
// else:
// Uij = (Aij - sum k=1 to i-1 {Lik*Ukj})
// for every row
for (int row = 0; row < Nb; row++) {
// update U
// Uij = (Aij - sum k=1 to i-1 {Lik*Ukj})
int jColStart = Ut->rowPointers[row];
int jColEnd = Ut->rowPointers[row + 1];
int colU = row; // rename for clarity, next row in Ut means next col in U
// for every block in this row
for (int ij = jColStart; ij < jColEnd; ij++) {
int col = Ut->colIndices[ij];
int rowU1 = Ut->colIndices[ij]; // actually rowIndices for U
// refine Uij element (or diagonal)
int i1 = LUmat->rowPointers[col];
int i2 = LUmat->rowPointers[col+1];
int kk = 0;
for(kk = i1; kk < i2; ++kk) {
ptrdiff_t c = LUmat->colIndices[kk];
if (c >= row) {
break;
}
}
int i1 = LUmat->rowPointers[rowU1];
int i2 = LUmat->rowPointers[rowU1+1];
// search on row rowU1, find blockIndex in LUmat of block with same col (colU) as Uij
// LUmat->nnzValues[kk] is block Aij
auto candidate = std::find(LUmat->colIndices + i1, LUmat->colIndices + i2, colU);
assert(candidate != LUmat->colIndices + i2);
auto kk = candidate - LUmat->colIndices;
double aij[bs*bs];
// copy block to Aij so operations can be done on it without affecting LUmat
memcpy(&aij[0], LUmat->nnzValues + kk * bs * bs, sizeof(double) * bs * bs);
int jk = Lmat->rowPointers[col];
int ik = (jk < Lmat->rowPointers[col+1]) ? Lmat->colIndices[jk] : Nb;
for (int k = jColStart; k < ij; ++k) {
int ki = Ut->colIndices[k];
while (ik < ki) {
++jk;
ik = Lmat->colIndices[jk];
}
if (ik == ki) {
blockMultSub<bs>(&aij[0], Lmat->nnzValues + jk * bs * bs, Ut->nnzValues + k * bs * bs);
int jk = Lmat->rowPointers[rowU1]; // points to row rowU1 in L
// if row rowU1 is empty, skip row
if (jk < Lmat->rowPointers[rowU1+1]) {
int colL = Lmat->colIndices[jk];
for (int k = jColStart; k < ij; ++k) {
int rowU2 = Ut->colIndices[k];
while (colL < rowU2) {
++jk; // check next block on row rowU1 of L
colL = Lmat->colIndices[jk];
}
if (colL == rowU2) {
// Aij -= (Lik * Ukj)
blockMultSub<bs>(&aij[0], Lmat->nnzValues + jk * bs * bs, Ut->nnzValues + k * bs * bs);
}
}
}
// Uij_new = Aij - sum
memcpy(Utmp + ij * bs * bs, &aij[0], sizeof(double) * bs * bs);
}
// update L
// Lij = (Aij - sum k=1 to j-1 {Lik*Ukj}) / Ujj
int iRowStart = Lmat->rowPointers[row];
int iRowEnd = Lmat->rowPointers[row + 1];
for (int ij = iRowStart; ij < iRowEnd; ij++) {
int j = Lmat->colIndices[ij];
// refine Lij element
// search on row 'row', find blockIndex in LUmat of block with same col (j) as Lij
// LUmat->nnzValues[kk] is block Aij
int i1 = LUmat->rowPointers[row];
int i2 = LUmat->rowPointers[row+1];
int kk = 0;
for(kk = i1; kk < i2; ++kk) {
ptrdiff_t c = LUmat->colIndices[kk];
if (c >= j) {
break;
}
}
auto candidate = std::find(LUmat->colIndices + i1, LUmat->colIndices + i2, j);
assert(candidate != LUmat->colIndices + i2);
auto kk = candidate - LUmat->colIndices;
double aij[bs*bs];
// copy block to Aij so operations can be done on it without affecting LUmat
memcpy(&aij[0], LUmat->nnzValues + kk * bs * bs, sizeof(double) * bs * bs);
int jk = Ut->rowPointers[j];
int ik = Ut->colIndices[jk];
int jk = Ut->rowPointers[j]; // actually colPointers, jk points to col j in U
int rowU = Ut->colIndices[jk]; // actually rowIndices, rowU is the row of block jk
// check if L has a matching block where colL == rowU
for (int k = iRowStart; k < ij; ++k) {
int ki = Lmat->colIndices[k];
while(ik < ki) {
++jk;
ik = Ut->colIndices[jk];
int colL = Lmat->colIndices[k];
while (rowU < colL) {
++jk; // check next block on col j of U
rowU = Ut->colIndices[jk];
}
if(ik == ki) {
if (rowU == colL) {
// Aij -= (Lik * Ukj)
blockMultSub<bs>(&aij[0], Lmat->nnzValues + k * bs * bs , Ut->nnzValues + jk * bs * bs);
}
}
// calculate aij / ujj
// calculate (Aij - sum) / Ujj
double ujj[bs*bs];
inverter(Ut->nnzValues + (Ut->rowPointers[j+1] - 1) * bs * bs, &ujj[0]);
// lij = aij / ujj
// Lij_new = (Aij - sum) / Ujj
blockMult<bs>(&aij[0], &ujj[0], Ltmp + ij * bs * bs);
}
}
// 1st sweep writes to Ltmp