diff --git a/opm/autodiff/AutoDiffMatrix.hpp b/opm/autodiff/AutoDiffMatrix.hpp index 430f65317..01823bc03 100644 --- a/opm/autodiff/AutoDiffMatrix.hpp +++ b/opm/autodiff/AutoDiffMatrix.hpp @@ -24,6 +24,7 @@ #include #include +#include #include @@ -38,6 +39,9 @@ namespace Opm class AutoDiffMatrix { public: + typedef std::vector Diag; + typedef Eigen::SparseMatrix Sparse; + AutoDiffMatrix() : type_(Z), rows_(0), @@ -72,7 +76,7 @@ namespace Opm : type_(D), rows_(d.rows()), cols_(d.cols()), - d_(d.diagonal().array().data(), d.diagonal().array().data() + d.rows()) + data_(Diag(d.diagonal().array().data(), d.diagonal().array().data() + d.rows())) { } @@ -82,7 +86,7 @@ namespace Opm : type_(S), rows_(s.rows()), cols_(s.cols()), - s_(s) + data_(s) { } @@ -114,8 +118,7 @@ namespace Opm std::swap(type_, other.type_); std::swap(rows_, other.rows_); std::swap(cols_, other.cols_); - d_.swap(other.d_); - s_.swap(other.s_); + data_.swap(other.data_); } @@ -258,13 +261,14 @@ namespace Opm { AutoDiffMatrix retval(*this); retval.type_ = D; - retval.d_.assign(rows_, rhs); + retval.data_ = boost::any(Diag(rows_, rhs)); return retval; } case D: { AutoDiffMatrix retval(*this); - for (double& elem : retval.d_) { + Diag& d = boost::any_cast(retval.data_); + for (double& elem : d) { elem *= rhs; } return retval; @@ -272,7 +276,8 @@ namespace Opm case S: { AutoDiffMatrix retval(*this); - retval.s_ *= rhs; + Sparse& s = boost::any_cast(retval.data_); + s *= rhs; return retval; } default: @@ -294,13 +299,14 @@ namespace Opm { AutoDiffMatrix retval(*this); retval.type_ = D; - retval.d_.assign(rows_, 1.0/rhs); + boost::any_cast(retval.data_).assign(rows_, 1.0/rhs); return retval; } case D: { AutoDiffMatrix retval(*this); - for (double& elem : retval.d_) { + Diag& d = boost::any_cast(retval.data_); + for (double& elem : d) { elem /= rhs; } return retval; @@ -308,7 +314,8 @@ namespace Opm case S: { AutoDiffMatrix retval(*this); - retval.s_ /= rhs; + Sparse& s = boost::any_cast(retval.data_); + s /= rhs; return retval; } default: @@ -330,9 +337,15 @@ namespace Opm case I: return rhs; case D: - return Eigen::Map(d_.data(), rows_) * rhs; + { + const Diag& d = boost::any_cast(data_); + return Eigen::Map(d.data(), rows_) * rhs; + } case S: - return s_ * rhs; + { + const Sparse& s = boost::any_cast(data_); + return s * rhs; + } default: OPM_THROW(std::logic_error, "Invalid AutoDiffMatrix type encountered: " << type_); } @@ -351,7 +364,7 @@ namespace Opm retval.type_ = D; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.d_.assign(lhs.rows_, 2.0); + retval.data_ = boost::any(Diag(lhs.rows_, 2.0)); return retval; } @@ -361,8 +374,9 @@ namespace Opm assert(lhs.type_ == D); assert(rhs.type_ == I); AutoDiffMatrix retval = lhs; + Diag& d = boost::any_cast(retval.data_); for (int r = 0; r < lhs.rows_; ++r) { - retval.d_[r] += 1.0; + d[r] += 1.0; } return retval; } @@ -372,8 +386,10 @@ namespace Opm assert(lhs.type_ == D); assert(rhs.type_ == D); AutoDiffMatrix retval = lhs; + Diag& d_lhs = boost::any_cast(retval.data_); + const Diag& d_rhs = boost::any_cast(rhs.data_); for (int r = 0; r < lhs.rows_; ++r) { - retval.d_[r] += rhs.d_[r]; + d_lhs[r] += d_rhs[r]; } return retval; } @@ -387,7 +403,8 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.s_ = lhs.s_ + ident; + const Sparse& s = boost::any_cast(lhs.data_); + retval.data_ = static_cast(s + ident); return retval; } @@ -396,11 +413,12 @@ namespace Opm assert(lhs.type_ == S); assert(rhs.type_ == D); AutoDiffMatrix retval; - Eigen::SparseMatrix diag = spdiag(rhs.d_); + Sparse diag = spdiag(boost::any_cast(rhs.data_)); retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.s_ = lhs.s_ + diag; + const Sparse& s = boost::any_cast(lhs.data_); + retval.data_ = static_cast(s + diag); return retval; } @@ -412,7 +430,9 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.s_ = lhs.s_ + rhs.s_; + const Sparse& s_lhs = boost::any_cast(lhs.data_); + const Sparse& s_rhs = boost::any_cast(rhs.data_); + retval.data_ = static_cast(s_lhs + s_rhs); return retval; } @@ -428,9 +448,12 @@ namespace Opm retval.type_ = D; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.d_.resize(lhs.rows_); + retval.data_ = boost::any(Diag(lhs.rows_)); + Diag& d = boost::any_cast(retval.data_); + const Diag& d_lhs = boost::any_cast(lhs.data_); + const Diag& d_rhs = boost::any_cast(rhs.data_); for (int r = 0; r < lhs.rows_; ++r) { - retval.d_[r] = lhs.d_[r] * rhs.d_[r]; + d[r] = d_lhs[r] * d_rhs[r]; } return retval; } @@ -443,7 +466,11 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - fastDiagSparseProduct(lhs.d_, rhs.s_, retval.s_); + retval.data_ = boost::any(Sparse(retval.rows_, retval.cols_)); //FIXME: Superfluous? + const Diag& a = boost::any_cast(lhs.data_); + const Sparse& b = boost::any_cast(rhs.data_); + Sparse& c = boost::any_cast(retval.data_); + fastDiagSparseProduct(a, b, c); return retval; } @@ -455,7 +482,11 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - fastSparseDiagProduct(lhs.s_, rhs.d_, retval.s_); + retval.data_ = boost::any(Sparse(retval.rows_, retval.cols_)); //FIXME: Superfluous? + const Sparse& a = boost::any_cast(lhs.data_); + const Diag& b = boost::any_cast(rhs.data_); + Sparse& c = boost::any_cast(retval.data_); + fastSparseDiagProduct(a, b, c); return retval; } @@ -467,7 +498,11 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - fastSparseProduct(lhs.s_, rhs.s_, retval.s_); + retval.data_ = boost::any(Sparse(retval.rows_, retval.cols_)); //FIXME: Superfluous? + const Sparse& a = boost::any_cast(lhs.data_); + const Sparse& b = boost::any_cast(rhs.data_); + Sparse& c = boost::any_cast(retval.data_); + fastSparseProduct(a, b, c); return retval; } @@ -483,10 +518,10 @@ namespace Opm s = spdiag(Eigen::VectorXd::Ones(rows_)); return; case D: - s = spdiag(d_); + s = spdiag(boost::any_cast(data_)); return; case S: - s = s_; + s = boost::any_cast(data_); return; } } @@ -512,7 +547,7 @@ namespace Opm case D: return rows_; case S: - return s_.nonZeros(); + return boost::any_cast(data_).nonZeros(); default: OPM_THROW(std::logic_error, "Invalid AutoDiffMatrix type encountered: " << type_); } @@ -527,9 +562,9 @@ namespace Opm case I: return (row == col) ? 1.0 : 0.0; case D: - return (row == col) ? d_[row] : 0.0; + return (row == col) ? boost::any_cast(data_)[row] : 0.0; case S: - return s_.coeff(row, col); + return boost::any_cast(data_).coeff(row, col); default: OPM_THROW(std::logic_error, "Invalid AutoDiffMatrix type encountered: " << type_); } @@ -540,8 +575,11 @@ namespace Opm MatrixType type_; int rows_; int cols_; + boost::any data_; + /* std::vector d_; Eigen::SparseMatrix s_; + */ template static inline