From c712d0070df61bc91fb3202b365d21d7400b8179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Atgeirr=20Fl=C3=B8=20Rasmussen?= Date: Tue, 25 Aug 2015 16:15:13 +0200 Subject: [PATCH] Bugfixes for AutoDiffMatrix. --- opm/autodiff/AutoDiffMatrix.hpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/opm/autodiff/AutoDiffMatrix.hpp b/opm/autodiff/AutoDiffMatrix.hpp index 314d7bdc4..330e22e6c 100644 --- a/opm/autodiff/AutoDiffMatrix.hpp +++ b/opm/autodiff/AutoDiffMatrix.hpp @@ -87,6 +87,8 @@ namespace Opm AutoDiffMatrix operator+(const AutoDiffMatrix& rhs) const { + assert(rows_ == rhs.rows_); + assert(cols_ == rhs.cols_); switch (type_) { case Z: return rhs; @@ -128,9 +130,10 @@ namespace Opm AutoDiffMatrix operator*(const AutoDiffMatrix& rhs) const { + assert(cols_ == rhs.rows_); switch (type_) { case Z: - return *this; + return AutoDiffMatrix(rows_, rhs.cols_); case I: switch (rhs.type_) { case Z: @@ -145,7 +148,7 @@ namespace Opm case D: switch (rhs.type_) { case Z: - return rhs; + return AutoDiffMatrix(rows_, rhs.cols_); case I: return *this; case D: @@ -156,7 +159,7 @@ namespace Opm case S: switch (rhs.type_) { case Z: - return rhs; + return AutoDiffMatrix(rows_, rhs.cols_); case I: return *this; case D: @@ -336,14 +339,14 @@ namespace Opm static AutoDiffMatrix prodDS(const AutoDiffMatrix& lhs, const AutoDiffMatrix& rhs) { - assert(lhs.type_ == S); - assert(rhs.type_ == D); + assert(lhs.type_ == D); + assert(rhs.type_ == S); AutoDiffMatrix retval; - Eigen::SparseMatrix diag = spdiag(rhs.d_.diagonal()); + Eigen::SparseMatrix diag = spdiag(lhs.d_.diagonal()); retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.s_ = lhs.s_ * diag; + fastSparseProduct(diag, rhs.s_, retval.s_); return retval; } @@ -356,7 +359,7 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - retval.s_ = diag * lhs.s_; + fastSparseProduct(lhs.s_, diag, retval.s_); return retval; } @@ -368,7 +371,6 @@ namespace Opm retval.type_ = S; retval.rows_ = lhs.rows_; retval.cols_ = rhs.cols_; - // retval.s_ = lhs.s_ * rhs.s_; fastSparseProduct(lhs.s_, rhs.s_, retval.s_); return retval; }