Reverted some changes

This commit is contained in:
babrodtk 2015-08-28 15:41:51 +02:00
parent d57adc6ed4
commit 8d82d9f89e
2 changed files with 49 additions and 47 deletions

View File

@ -99,29 +99,8 @@ namespace Opm
AutoDiffMatrix(const AutoDiffMatrix& other) AutoDiffMatrix(const AutoDiffMatrix& other) = default;
{ AutoDiffMatrix& operator=(const AutoDiffMatrix& other) = default;
*this = other;
}
AutoDiffMatrix& operator=(const AutoDiffMatrix& other)
{
type_ = other.type_;
rows_ = other.rows_;
cols_ = other.cols_;
switch(type_) {
case D:
diag_ = other.diag_;
break;
case S:
sparse_ = other.sparse_;
break;
default:
break;
}
return *this;
}
@ -418,8 +397,7 @@ namespace Opm
retval.type_ = S; retval.type_ = S;
retval.rows_ = lhs.rows_; retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_; retval.cols_ = rhs.cols_;
retval.sparse_ = lhs.sparse_; retval.sparse_ = lhs.sparse_ + spdiag(Eigen::VectorXd::Ones(lhs.rows_));
retval.sparse_ += spdiag(Eigen::VectorXd::Ones(lhs.rows_));
return retval; return retval;
} }
@ -431,8 +409,7 @@ namespace Opm
retval.type_ = S; retval.type_ = S;
retval.rows_ = lhs.rows_; retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_; retval.cols_ = rhs.cols_;
retval.sparse_ = lhs.sparse_; retval.sparse_ = lhs.sparse_ + spdiag(rhs.diag_);
retval.sparse_ += spdiag(rhs.diag_);
return retval; return retval;
} }
@ -440,8 +417,16 @@ namespace Opm
{ {
assert(lhs.type_ == S); assert(lhs.type_ == S);
assert(rhs.type_ == S); assert(rhs.type_ == S);
#if 1
AutoDiffMatrix retval;
retval.type_ = S;
retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_;
retval.sparse_ = lhs.sparse_ + rhs.sparse_;
#else
AutoDiffMatrix retval = lhs; AutoDiffMatrix retval = lhs;
retval.sparse_ += rhs.sparse_; retval.sparse_ += rhs.sparse_;
#endif
return retval; return retval;
} }
@ -453,10 +438,21 @@ namespace Opm
{ {
assert(lhs.type_ == D); assert(lhs.type_ == D);
assert(rhs.type_ == D); assert(rhs.type_ == D);
#if 1
AutoDiffMatrix retval;
retval.type_ = D;
retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_;
retval.diag_.resize(lhs.rows_);
for (int r = 0; r < lhs.rows_; ++r) {
retval.diag_[r] = lhs.diag_[r] * rhs.diag_[r];
}
#else
AutoDiffMatrix retval = lhs; AutoDiffMatrix retval = lhs;
for (int r = 0; r < lhs.rows_; ++r) { for (int r = 0; r < lhs.rows_; ++r) {
retval.diag_[r] *= rhs.diag_[r]; retval.diag_[r] *= rhs.diag_[r];
} }
#endif
return retval; return retval;
} }
@ -468,7 +464,11 @@ namespace Opm
retval.type_ = S; retval.type_ = S;
retval.rows_ = lhs.rows_; retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_; retval.cols_ = rhs.cols_;
#if 1
fastDiagSparseProduct(lhs.diag_, rhs.sparse_, retval.sparse_);
#else
retval.sparse_ = std::move(fastDiagSparseProduct(lhs.diag_, rhs.sparse_)); retval.sparse_ = std::move(fastDiagSparseProduct(lhs.diag_, rhs.sparse_));
#endif
return retval; return retval;
} }
@ -480,7 +480,11 @@ namespace Opm
retval.type_ = S; retval.type_ = S;
retval.rows_ = lhs.rows_; retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_; retval.cols_ = rhs.cols_;
#if 1
fastSparseDiagProduct(lhs.sparse_, rhs.diag_, retval.sparse_);
#else
retval.sparse_ = std::move(fastSparseDiagProduct(lhs.sparse_, rhs.diag_)); retval.sparse_ = std::move(fastSparseDiagProduct(lhs.sparse_, rhs.diag_));
#endif
return retval; return retval;
} }
@ -492,7 +496,11 @@ namespace Opm
retval.type_ = S; retval.type_ = S;
retval.rows_ = lhs.rows_; retval.rows_ = lhs.rows_;
retval.cols_ = rhs.cols_; retval.cols_ = rhs.cols_;
#if 1
fastSparseProduct(lhs.sparse_, rhs.sparse_, retval.sparse_);
#else
retval.sparse_ = std::move(fastSparseProduct<Sparse>(lhs.sparse_, rhs.sparse_)); retval.sparse_ = std::move(fastSparseProduct<Sparse>(lhs.sparse_, rhs.sparse_));
#endif
return retval; return retval;
} }

View File

@ -56,16 +56,16 @@ struct QuickSort< 0 >
}; };
template<typename ResultType, typename Lhs, typename Rhs> template<typename Lhs, typename Rhs, typename ResultType>
ResultType fastSparseProduct(const Lhs& lhs, const Rhs& rhs) void fastSparseProduct(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
// initialize result // initialize result
ResultType res(lhs.rows(), rhs.cols()); res = ResultType(lhs.rows(), rhs.cols());
// if one of the matrices does not contain non zero elements // if one of the matrices does not contain non zero elements
// the result will only contain an empty matrix // the result will only contain an empty matrix
if( lhs.nonZeros() == 0 || rhs.nonZeros() == 0 ) if( lhs.nonZeros() == 0 || rhs.nonZeros() == 0 )
return res; return;
typedef typename Eigen::internal::remove_all<Lhs>::type::Scalar Scalar; typedef typename Eigen::internal::remove_all<Lhs>::type::Scalar Scalar;
typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index; typedef typename Eigen::internal::remove_all<Lhs>::type::Index Index;
@ -137,19 +137,17 @@ ResultType fastSparseProduct(const Lhs& lhs, const Rhs& rhs)
} }
res.finalize(); res.finalize();
return res;
} }
inline inline void fastDiagSparseProduct(// const Eigen::DiagonalMatrix<double, Eigen::Dynamic>& lhs,
Eigen::SparseMatrix<double> const std::vector<double>& lhs,
fastDiagSparseProduct(const std::vector<double>& lhs, const Eigen::SparseMatrix<double>& rhs,
const Eigen::SparseMatrix<double>& rhs) Eigen::SparseMatrix<double>& res)
{ {
Eigen::SparseMatrix<double> res = rhs; res = rhs;
// Multiply rows by diagonal lhs. // Multiply rows by diagonal lhs.
int n = res.cols(); int n = res.cols();
@ -159,19 +157,17 @@ fastDiagSparseProduct(const std::vector<double>& lhs,
it.valueRef() *= lhs[it.row()]; // lhs.diagonal()(it.row()); it.valueRef() *= lhs[it.row()]; // lhs.diagonal()(it.row());
} }
} }
return res;
} }
inline inline void fastSparseDiagProduct(const Eigen::SparseMatrix<double>& lhs,
Eigen::SparseMatrix<double> // const Eigen::DiagonalMatrix<double, Eigen::Dynamic>& rhs,
fastSparseDiagProduct(const Eigen::SparseMatrix<double>& lhs, const std::vector<double>& rhs,
const std::vector<double>& rhs) Eigen::SparseMatrix<double>& res)
{ {
Eigen::SparseMatrix<double> res = lhs; res = lhs;
// Multiply columns by diagonal rhs. // Multiply columns by diagonal rhs.
int n = res.cols(); int n = res.cols();
@ -181,8 +177,6 @@ fastSparseDiagProduct(const Eigen::SparseMatrix<double>& lhs,
it.valueRef() *= rhs[col]; // rhs.diagonal()(col); it.valueRef() *= rhs[col]; // rhs.diagonal()(col);
} }
} }
return res;
} }