This commit is contained in:
Robert Kloefkorn 2016-02-13 16:02:42 +01:00
parent d5b6566e06
commit fadf5528e3

View File

@ -177,98 +177,87 @@ inline void fastSparseDiagProduct(const Eigen::SparseMatrix<double>& lhs,
} }
} }
namespace detail {
struct AddTo
{
template <class T>
inline void operator()( T& a, const T& b ) const
{
a += b;
}
};
struct SubstractFrom
{
template <class T>
inline void operator()( T& a, const T& b ) const
{
a -= b;
}
};
} // end namespace detail
// this function substracts two sparse matrices
// if the sparsity pattern is the same a faster add/substract is performed
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
void fastSparseBinaryOp(Lhs& lhs, const Rhs& rhs, const bool add ) inline bool
equalSparsityPattern(const Lhs& lhs, const Rhs& rhs)
{ {
// if non zeros don't match, matrices don't have the same sparsity pattern // if non zeros don't match, matrices don't have the same sparsity pattern
bool equal = lhs.nonZeros() == rhs.nonZeros(); bool equal = (lhs.nonZeros() == rhs.nonZeros());
typedef typename Eigen::internal::remove_all<Lhs>::type::Scalar Scalar;
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros();
const Index* rhsJ = rhs.innerIndexPtr();
const Index* lhsJ = lhs.innerIndexPtr();
// check complete sparsity pattern // check complete sparsity pattern
if( equal ) if( equal )
{ {
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros();
const Index* rhsJ = rhs.innerIndexPtr();
const Index* lhsJ = lhs.innerIndexPtr();
for(Index i=0; i<nnz; ++i ) for(Index i=0; i<nnz; ++i )
{ {
equal &= ( lhsJ[ i ] == rhsJ[ i ] ); //equal &= ( lhsJ[ i ] == rhsJ[ i ] );
if( lhsJ[ i ] != rhsJ[ i ] )
return false;
} }
} }
// if sparsity pattern does not match, return equal;
// fallback to default implementation
if( ! equal )
{
if( add ) {
lhs += rhs;
}
else {
lhs -= rhs;
}
return;
}
// fast add/substract using only the data pointers
const Scalar* rhsV = rhs.valuePtr();
Scalar* lhsV = lhs.valuePtr();
if( add )
{
for(Index i=0; i<nnz; ++i )
{
lhsV[ i ] += rhsV[ i ];
}
}
else
{
for(Index i=0; i<nnz; ++i )
{
lhsV[ i ] -= rhsV[ i ];
}
}
} }
// this function substracts two sparse matrices
// if the sparsity pattern is the same a faster add/substract is performed
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
void fastSparseAdd(Lhs& lhs, const Rhs& rhs ) inline void
fastSparseAdd(Lhs& lhs, const Rhs& rhs)
{ {
fastSparseBinaryOp( lhs, rhs, true ); //detail::AddTo() ); if( equalSparsityPattern( lhs, rhs ) )
{
typedef typename Eigen::internal::remove_all<Lhs>::type::Scalar Scalar;
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros();
// fast add using only the data pointers
const Scalar* rhsV = rhs.valuePtr();
Scalar* lhsV = lhs.valuePtr();
for(Index i=0; i<nnz; ++i )
{
lhsV[ i ] += rhsV[ i ];
}
}
else
{
lhs += rhs;
}
} }
// this function substracts two sparse matrices
// if the sparsity pattern is the same a faster add/substract is performed
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
void fastSparseSubstract(Lhs& lhs, const Rhs& rhs ) inline void
fastSparseSubstract(Lhs& lhs, const Rhs& rhs)
{ {
fastSparseBinaryOp( lhs, rhs, false ); //detail::SubstractFrom() ); if( equalSparsityPattern( lhs, rhs ) )
{
typedef typename Eigen::internal::remove_all<Lhs>::type::Scalar Scalar;
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
const Index nnz = lhs.nonZeros();
// fast add using only the data pointers
const Scalar* rhsV = rhs.valuePtr();
Scalar* lhsV = lhs.valuePtr();
for(Index i=0; i<nnz; ++i )
{
lhsV[ i ] -= rhsV[ i ];
}
}
else
{
lhs -= rhs;
}
} }
} // end namespace Opm } // end namespace Opm