diff --git a/src/LinAlg/Test/TestMatrix.C b/src/LinAlg/Test/TestMatrix.C index e9921da8..a4ca8e13 100644 --- a/src/LinAlg/Test/TestMatrix.C +++ b/src/LinAlg/Test/TestMatrix.C @@ -6,7 +6,7 @@ //! //! \author Eivind Fonn / SINTEF //! -//! \brief Unit tests for matrix +//! \brief Unit tests for matrix and matrix3d. //! //============================================================================== @@ -122,3 +122,22 @@ TEST(TestMatrix3D, DumpRead) B -= A; ASSERT_NEAR(B.norm2(), 0.0, 1.0e-13); } + + +TEST(TestMatrix3D, Multiply) +{ + std::vector a(10); + utl::matrix A(2,5); + utl::matrix3d B(5,4,3), C, D; + + std::iota(a.begin(),a.end(),1.0); + std::iota(A.begin(),A.end(),1.0); + std::iota(B.begin(),B.end(),1.0); + + C.multiply(A,B); + ASSERT_TRUE(D.multiplyMat(a,B)); + + std::vector::const_iterator c = C.begin(); + for (double d : D) + EXPECT_EQ(d,*(c++)); +} diff --git a/src/LinAlg/matrix.h b/src/LinAlg/matrix.h index 8ff71b37..ddba3922 100644 --- a/src/LinAlg/matrix.h +++ b/src/LinAlg/matrix.h @@ -370,6 +370,8 @@ namespace utl //! General utility classes and functions. else this->elem.fill(mat.elem.ptr()); } + //! \brief Empty destructor. + virtual ~matrix() {} //! \brief Resize the matrix to dimension \f$r \times c\f$. //! \details Will erase the previous content, but only if both @@ -480,7 +482,7 @@ namespace utl //! General utility classes and functions. //! \brief Fill a column of the matrix. void fillColumn(size_t c, const std::vector& data) { - CHECK_INDEX("matrix::getColumn: Column-index ",c,ncol); + CHECK_INDEX("matrix::fillColumn: Column-index ",c,ncol); size_t ndata = nrow > data.size() ? data.size() : nrow; memcpy(this->ptr(c-1),&data.front(),ndata*sizeof(T)); } @@ -912,13 +914,22 @@ namespace utl //! General utility classes and functions. vector getColumn(size_t i2, size_t i3) const { CHECK_INDEX("matrix3d::getColumn: Second index ",i2,this->n[1]); - CHECK_INDEX("matrix3d::getColumn: Third index ",i3,this->n[2]); + CHECK_INDEX("matrix3d::getColumn: Third index " ,i3,this->n[2]); if (this->n[1] < 2 && this->n[2] < 2) return this->elem; vector col(this->n[0]); memcpy(col.ptr(),this->ptr(i2-1+this->n[1]*(i3-1)),col.size()*sizeof(T)); return col; } + //! \brief Fill a column of the matrix. + void fillColumn(size_t i2, size_t i3, const std::vector& data) + { + CHECK_INDEX("matrix3d::fillColumn: Second index ",i2,this->n[1]); + CHECK_INDEX("matrix3d::fillColumn: Third index " ,i3,this->n[2]); + size_t ndata = this->n[0] > data.size() ? data.size() : this->n[0]; + memcpy(this->ptr(i2-1+this->n[1]*(i3-1)),&data.front(),ndata*sizeof(T)); + } + //! \brief Return the trace of the \a i1'th sub-matrix. T trace(size_t i1) const { @@ -952,6 +963,19 @@ namespace utl //! General utility classes and functions. bool multiply(const matrix& A, const matrix3d& B, bool transA = false, bool addTo = false); + /*! \brief Matrix-matrix multiplication. + \details Performs one of the following operations (\b C = \a *this): + -# \f$ {\bf C} = {\bf A} {\bf B} \f$ + -# \f$ {\bf C} = {\bf C} + {\bf A} {\bf B} \f$ + + The matrix \b A is here represented by a one-dimensional vector, and its + number of columns is assumed to match the first dimension of \b B and its + number of rows is then the total vector length divided by + the number of columns. + */ + bool multiplyMat(const std::vector& A, const matrix3d& B, + bool addTo = false); + private: //! \brief Check dimension compatibility for matrix-matrix multiplication. bool compatible(const matrix& A, const matrix3d& B, @@ -971,6 +995,23 @@ namespace utl //! General utility classes and functions. return false; } + //! \brief Check dimension compatibility for matrix-matrix multiplication, + //! when the matrix A is represented by a one-dimensional vector. + bool compatible(const std::vector& A, const matrix3d& B, + size_t& M, size_t& N, size_t& K) + { + N = B.n[1]*B.n[2]; + K = B.n[0]; + M = K > 0 ? A.size() / K : 0; + if (M*K == A.size() && !A.empty()) return true; + + std::cerr <<"matrix3d::multiply: Incompatible matrices: A(r*c="<< A.size() + <<"), B("<< B.n[0] <<","<< B.n[1] <<","<< B.n[2] <<")\n" + <<" when computing C = A * B"<< std::endl; + ABORT_ON_INDEX_CHECK; + return false; + } + protected: //! \brief Clears the content if any of the first two dimensions changed. virtual void clearIfNrowChanged(size_t n1, size_t n2) @@ -1335,6 +1376,62 @@ namespace utl //! General utility classes and functions. return true; } + template<> inline + bool matrix3d::multiply(const matrix& A, + const matrix3d& B, + bool transA, bool addTo) + { + size_t M, N, K; + if (!this->compatible(A,B,transA,M,N,K)) return false; + if (!addTo) this->resize(M,B.n[1],B.n[2]); + + cblas_dgemm(CblasColMajor, + transA ? CblasTrans : CblasNoTrans, CblasNoTrans, + M, N, K, 1.0, + A.ptr(), A.rows(), + B.ptr(), B.n[0], + addTo ? 1.0 : 0.0, + this->ptr(), n[0]); + + return true; + } + + template<> inline + bool matrix3d::multiplyMat(const std::vector& A, + const matrix3d& B, bool addTo) + { + size_t M, N, K; + if (!this->compatible(A,B,M,N,K)) return false; + if (!addTo) this->resize(M,B.n[1],B.n[2]); + + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, + M, N, K, 1.0f, + A.data(), M, + B.ptr(), B.n[0], + addTo ? 1.0f : 0.0f, + this->ptr(), n[0]); + + return true; + } + + template<> inline + bool matrix3d::multiplyMat(const std::vector& A, + const matrix3d& B, bool addTo) + { + size_t M, N, K; + if (!this->compatible(A,B,M,N,K)) return false; + if (!addTo) this->resize(M,B.n[1],B.n[2]); + + cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, + M, N, K, 1.0, + A.data(), M, + B.ptr(), B.n[0], + addTo ? 1.0 : 0.0, + this->ptr(), n[0]); + + return true; + } + template<> inline bool matrix::outer_product(const std::vector& X, const std::vector& Y, @@ -1377,26 +1474,6 @@ namespace utl //! General utility classes and functions. return true; } - template<> inline - bool matrix3d::multiply(const matrix& A, - const matrix3d& B, - bool transA, bool addTo) - { - size_t M, N, K; - if (!this->compatible(A,B,transA,M,N,K)) return false; - if (!addTo) this->resize(M,B.n[1],B.n[2]); - - cblas_dgemm(CblasColMajor, - transA ? CblasTrans : CblasNoTrans, CblasNoTrans, - M, N, K, 1.0, - A.ptr(), A.rows(), - B.ptr(), B.n[0], - addTo ? 1.0 : 0.0, - this->ptr(), n[0]); - - return true; - } - #else //============================================================================ //=== Non-BLAS inlined implementations (slow...) ========================= @@ -1641,6 +1718,23 @@ namespace utl //! General utility classes and functions. return true; } + + template inline + bool matrix3d::multiplyMat(const std::vector& A, const matrix3d& B, + bool addTo) + { + size_t M, N, K; + if (!this->compatible(A,B,M,N,K)) return false; + if (!addTo) this->resize(M,B.n[1],B.n[2],true); + + for (size_t i = 1; i <= this->n[0]; i++) + for (size_t j = 1; j <= this->n[1]; j++) + for (size_t k = 1; k <= K; k++) + for (size_t l = 1; l <= this->n[2]; l++) + this->operator()(i,j,l) += A[i-1+M*(k-1)]*B(k,j,l); + + return true; + } #endif //============================================================================ diff --git a/src/Utility/Vec3Oper.C b/src/Utility/Vec3Oper.C index 58f6d718..ce69fead 100644 --- a/src/Utility/Vec3Oper.C +++ b/src/Utility/Vec3Oper.C @@ -35,6 +35,35 @@ Vec3 operator* (const utl::matrix& A, const std::vector& x) } +Vec3 operator* (const utl::matrix& A, const Vec3& x) +{ + return A * x.vec(); +} + + +/*! + \brief Multiplication of a vector and a matrix, + \return \f$ {\bf y} = {\bf A}^T {\bf x} \f$ + + Special version for computation of 3D point vectors. + The number of columns in \b A must be (at least) 3. + Extra columns (if any) in \b A are ignored. +*/ + +Vec3 operator* (const std::vector& x, const utl::matrix& A) +{ + std::vector y(A.cols(),0.0); + A.multiply(x,y,true); + return Vec3(y); +} + + +Vec3 operator* (const Vec3& x, const utl::matrix& A) +{ + return x.vec() * A; +} + + Vec3 operator* (const Vec3& a, Real value) { return Vec3(a.x*value, a.y*value, a.z*value); diff --git a/src/Utility/Vec3Oper.h b/src/Utility/Vec3Oper.h index c360cfb9..382c6e64 100644 --- a/src/Utility/Vec3Oper.h +++ b/src/Utility/Vec3Oper.h @@ -21,6 +21,13 @@ class Vec3; //! \brief Multiplication of a matrix and a vector. Vec3 operator*(const utl::matrix& A, const std::vector& x); +//! \brief Multiplication of a matrix and a vector. +Vec3 operator*(const utl::matrix& A, const Vec3& x); + +//! \brief Multiplication of a vector and a matrix. +Vec3 operator*(const std::vector& x, const utl::matrix& A); +//! \brief Multiplication of a vector and a matrix. +Vec3 operator*(const Vec3& x, const utl::matrix& A); //! \brief Multiplication of a vector and a scalar. Vec3 operator*(const Vec3& a, Real value);