Also check outer ptr.

This commit is contained in:
Robert Kloefkorn 2016-02-16 15:24:52 +01:00
parent 4df4c9147a
commit 9661973387

View File

@ -190,14 +190,58 @@ equalSparsityPattern(const Lhs& lhs, const Rhs& rhs)
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index; typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros(); const Index nnz = lhs.nonZeros();
// outer indices
const Index* rhsO = rhs.outerIndexPtr();
const Index* lhsO = lhs.outerIndexPtr();
// inner indices
const Index* rhsJ = rhs.innerIndexPtr(); const Index* rhsJ = rhs.innerIndexPtr();
const Index* lhsJ = lhs.innerIndexPtr(); const Index* lhsJ = lhs.innerIndexPtr();
for(Index i=0; i<nnz; ++i ) const Index outerSize = lhs.outerSize();
if( outerSize != rhs.outerSize() )
{ {
if( lhsJ[ i ] != rhsJ[ i ] )
return false; return false;
} }
bool equalOuter = true;
bool equalInner = true;
const Index size = std::min( outerSize+1, nnz );
for( Index i=0; i<size; ++i)
{
equalOuter &= (lhsO[ i ] == rhsO[ i ]);
equalInner &= (lhsJ[ i ] == rhsJ[ i ]);
}
if( ! equalOuter || ! equalInner ) return false ;
if( outerSize+1 < nnz )
{
for(Index i=outerSize+1; i<nnz; ++i)
{
if( lhsJ[ i ] != rhsJ[ i ] ) return false;
}
}
else if( outerSize+1 > nnz )
{
for(Index o=nnz; o<=outerSize; ++o )
{
if( lhsO[ o ] != rhsO[ o ] )
return false;
}
}
else
{
return equalOuter && equalInner;
}
/*
for(Index o=0; o<=outerSize; ++o )
{
if( lhsO[ o ] != rhsO[ o ] )
return false;
}
*/
} }
return equal; return equal;
@ -228,7 +272,7 @@ fastSparseAdd(Lhs& lhs, const Rhs& rhs)
else else
{ {
// default Eigen operator+= // default Eigen operator+=
lhs += rhs; lhs = lhs + rhs;
} }
} }
@ -257,7 +301,7 @@ fastSparseSubstract(Lhs& lhs, const Rhs& rhs)
else else
{ {
// default Eigen operator-= // default Eigen operator-=
lhs -= rhs; lhs = lhs - rhs;
} }
} }