Also check outer ptr.

This commit is contained in:
Robert Kloefkorn
2016-02-16 15:24:52 +01:00
parent 4df4c9147a
commit 9661973387

View File

@@ -181,26 +181,70 @@ template<typename Lhs, typename Rhs>
inline bool inline bool
equalSparsityPattern(const Lhs& lhs, const Rhs& rhs) equalSparsityPattern(const Lhs& lhs, const Rhs& rhs)
{ {
// if both matrices have equal storage and non zeros match, we can check sparsity pattern // if both matrices have equal storage and non zeros match, we can check sparsity pattern
bool equal = (Lhs::IsRowMajor == Rhs::IsRowMajor) && (lhs.nonZeros() == rhs.nonZeros()); bool equal = (Lhs::IsRowMajor == Rhs::IsRowMajor) && (lhs.nonZeros() == rhs.nonZeros());
// check complete sparsity pattern // check complete sparsity pattern
if( equal ) if( equal )
{
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros();
const Index* rhsJ = rhs.innerIndexPtr();
const Index* lhsJ = lhs.innerIndexPtr();
for(Index i=0; i<nnz; ++i )
{ {
if( lhsJ[ i ] != rhsJ[ i ] ) typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
return false; const Index nnz = lhs.nonZeros();
}
}
return equal; // outer indices
const Index* rhsO = rhs.outerIndexPtr();
const Index* lhsO = lhs.outerIndexPtr();
// inner indices
const Index* rhsJ = rhs.innerIndexPtr();
const Index* lhsJ = lhs.innerIndexPtr();
const Index outerSize = lhs.outerSize();
if( outerSize != rhs.outerSize() )
{
return false;
}
bool equalOuter = true;
bool equalInner = true;
const Index size = std::min( outerSize+1, nnz );
for( Index i=0; i<size; ++i)
{
equalOuter &= (lhsO[ i ] == rhsO[ i ]);
equalInner &= (lhsJ[ i ] == rhsJ[ i ]);
}
if( ! equalOuter || ! equalInner ) return false ;
if( outerSize+1 < nnz )
{
for(Index i=outerSize+1; i<nnz; ++i)
{
if( lhsJ[ i ] != rhsJ[ i ] ) return false;
}
}
else if( outerSize+1 > nnz )
{
for(Index o=nnz; o<=outerSize; ++o )
{
if( lhsO[ o ] != rhsO[ o ] )
return false;
}
}
else
{
return equalOuter && equalInner;
}
/*
for(Index o=0; o<=outerSize; ++o )
{
if( lhsO[ o ] != rhsO[ o ] )
return false;
}
*/
}
return equal;
} }
// this function substracts two sparse matrices // this function substracts two sparse matrices
@@ -228,7 +272,7 @@ fastSparseAdd(Lhs& lhs, const Rhs& rhs)
else else
{ {
// default Eigen operator+= // default Eigen operator+=
lhs += rhs; lhs = lhs + rhs;
} }
} }
@@ -257,7 +301,7 @@ fastSparseSubstract(Lhs& lhs, const Rhs& rhs)
else else
{ {
// default Eigen operator-= // default Eigen operator-=
lhs -= rhs; lhs = lhs - rhs;
} }
} }