Added: matrix3d::multiplyMat and matrix3d::fillColumn.
Added: More matrix-vector multiplication operators. Added: Empty virtual destructor in class matrix since it has a parent class.
This commit is contained in:
parent
094fa4a689
commit
86265d9e50
@ -6,7 +6,7 @@
|
|||||||
//!
|
//!
|
||||||
//! \author Eivind Fonn / SINTEF
|
//! \author Eivind Fonn / SINTEF
|
||||||
//!
|
//!
|
||||||
//! \brief Unit tests for matrix
|
//! \brief Unit tests for matrix and matrix3d.
|
||||||
//!
|
//!
|
||||||
//==============================================================================
|
//==============================================================================
|
||||||
|
|
||||||
@ -122,3 +122,22 @@ TEST(TestMatrix3D, DumpRead)
|
|||||||
B -= A;
|
B -= A;
|
||||||
ASSERT_NEAR(B.norm2(), 0.0, 1.0e-13);
|
ASSERT_NEAR(B.norm2(), 0.0, 1.0e-13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(TestMatrix3D, Multiply)
|
||||||
|
{
|
||||||
|
std::vector<double> a(10);
|
||||||
|
utl::matrix<double> A(2,5);
|
||||||
|
utl::matrix3d<double> B(5,4,3), C, D;
|
||||||
|
|
||||||
|
std::iota(a.begin(),a.end(),1.0);
|
||||||
|
std::iota(A.begin(),A.end(),1.0);
|
||||||
|
std::iota(B.begin(),B.end(),1.0);
|
||||||
|
|
||||||
|
C.multiply(A,B);
|
||||||
|
ASSERT_TRUE(D.multiplyMat(a,B));
|
||||||
|
|
||||||
|
std::vector<double>::const_iterator c = C.begin();
|
||||||
|
for (double d : D)
|
||||||
|
EXPECT_EQ(d,*(c++));
|
||||||
|
}
|
||||||
|
@ -370,6 +370,8 @@ namespace utl //! General utility classes and functions.
|
|||||||
else
|
else
|
||||||
this->elem.fill(mat.elem.ptr());
|
this->elem.fill(mat.elem.ptr());
|
||||||
}
|
}
|
||||||
|
//! \brief Empty destructor.
|
||||||
|
virtual ~matrix() {}
|
||||||
|
|
||||||
//! \brief Resize the matrix to dimension \f$r \times c\f$.
|
//! \brief Resize the matrix to dimension \f$r \times c\f$.
|
||||||
//! \details Will erase the previous content, but only if both
|
//! \details Will erase the previous content, but only if both
|
||||||
@ -480,7 +482,7 @@ namespace utl //! General utility classes and functions.
|
|||||||
//! \brief Fill a column of the matrix.
|
//! \brief Fill a column of the matrix.
|
||||||
void fillColumn(size_t c, const std::vector<T>& data)
|
void fillColumn(size_t c, const std::vector<T>& data)
|
||||||
{
|
{
|
||||||
CHECK_INDEX("matrix::getColumn: Column-index ",c,ncol);
|
CHECK_INDEX("matrix::fillColumn: Column-index ",c,ncol);
|
||||||
size_t ndata = nrow > data.size() ? data.size() : nrow;
|
size_t ndata = nrow > data.size() ? data.size() : nrow;
|
||||||
memcpy(this->ptr(c-1),&data.front(),ndata*sizeof(T));
|
memcpy(this->ptr(c-1),&data.front(),ndata*sizeof(T));
|
||||||
}
|
}
|
||||||
@ -912,13 +914,22 @@ namespace utl //! General utility classes and functions.
|
|||||||
vector<T> getColumn(size_t i2, size_t i3) const
|
vector<T> getColumn(size_t i2, size_t i3) const
|
||||||
{
|
{
|
||||||
CHECK_INDEX("matrix3d::getColumn: Second index ",i2,this->n[1]);
|
CHECK_INDEX("matrix3d::getColumn: Second index ",i2,this->n[1]);
|
||||||
CHECK_INDEX("matrix3d::getColumn: Third index ",i3,this->n[2]);
|
CHECK_INDEX("matrix3d::getColumn: Third index " ,i3,this->n[2]);
|
||||||
if (this->n[1] < 2 && this->n[2] < 2) return this->elem;
|
if (this->n[1] < 2 && this->n[2] < 2) return this->elem;
|
||||||
vector<T> col(this->n[0]);
|
vector<T> col(this->n[0]);
|
||||||
memcpy(col.ptr(),this->ptr(i2-1+this->n[1]*(i3-1)),col.size()*sizeof(T));
|
memcpy(col.ptr(),this->ptr(i2-1+this->n[1]*(i3-1)),col.size()*sizeof(T));
|
||||||
return col;
|
return col;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//! \brief Fill a column of the matrix.
|
||||||
|
void fillColumn(size_t i2, size_t i3, const std::vector<T>& data)
|
||||||
|
{
|
||||||
|
CHECK_INDEX("matrix3d::fillColumn: Second index ",i2,this->n[1]);
|
||||||
|
CHECK_INDEX("matrix3d::fillColumn: Third index " ,i3,this->n[2]);
|
||||||
|
size_t ndata = this->n[0] > data.size() ? data.size() : this->n[0];
|
||||||
|
memcpy(this->ptr(i2-1+this->n[1]*(i3-1)),&data.front(),ndata*sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
//! \brief Return the trace of the \a i1'th sub-matrix.
|
//! \brief Return the trace of the \a i1'th sub-matrix.
|
||||||
T trace(size_t i1) const
|
T trace(size_t i1) const
|
||||||
{
|
{
|
||||||
@ -952,6 +963,19 @@ namespace utl //! General utility classes and functions.
|
|||||||
bool multiply(const matrix<T>& A, const matrix3d<T>& B,
|
bool multiply(const matrix<T>& A, const matrix3d<T>& B,
|
||||||
bool transA = false, bool addTo = false);
|
bool transA = false, bool addTo = false);
|
||||||
|
|
||||||
|
/*! \brief Matrix-matrix multiplication.
|
||||||
|
\details Performs one of the following operations (\b C = \a *this):
|
||||||
|
-# \f$ {\bf C} = {\bf A} {\bf B} \f$
|
||||||
|
-# \f$ {\bf C} = {\bf C} + {\bf A} {\bf B} \f$
|
||||||
|
|
||||||
|
The matrix \b A is here represented by a one-dimensional vector, and its
|
||||||
|
number of columns is assumed to match the first dimension of \b B and its
|
||||||
|
number of rows is then the total vector length divided by
|
||||||
|
the number of columns.
|
||||||
|
*/
|
||||||
|
bool multiplyMat(const std::vector<T>& A, const matrix3d<T>& B,
|
||||||
|
bool addTo = false);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//! \brief Check dimension compatibility for matrix-matrix multiplication.
|
//! \brief Check dimension compatibility for matrix-matrix multiplication.
|
||||||
bool compatible(const matrix<T>& A, const matrix3d<T>& B,
|
bool compatible(const matrix<T>& A, const matrix3d<T>& B,
|
||||||
@ -971,6 +995,23 @@ namespace utl //! General utility classes and functions.
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//! \brief Check dimension compatibility for matrix-matrix multiplication,
|
||||||
|
//! when the matrix A is represented by a one-dimensional vector.
|
||||||
|
bool compatible(const std::vector<T>& A, const matrix3d<T>& B,
|
||||||
|
size_t& M, size_t& N, size_t& K)
|
||||||
|
{
|
||||||
|
N = B.n[1]*B.n[2];
|
||||||
|
K = B.n[0];
|
||||||
|
M = K > 0 ? A.size() / K : 0;
|
||||||
|
if (M*K == A.size() && !A.empty()) return true;
|
||||||
|
|
||||||
|
std::cerr <<"matrix3d::multiply: Incompatible matrices: A(r*c="<< A.size()
|
||||||
|
<<"), B("<< B.n[0] <<","<< B.n[1] <<","<< B.n[2] <<")\n"
|
||||||
|
<<" when computing C = A * B"<< std::endl;
|
||||||
|
ABORT_ON_INDEX_CHECK;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
//! \brief Clears the content if any of the first two dimensions changed.
|
//! \brief Clears the content if any of the first two dimensions changed.
|
||||||
virtual void clearIfNrowChanged(size_t n1, size_t n2)
|
virtual void clearIfNrowChanged(size_t n1, size_t n2)
|
||||||
@ -1335,6 +1376,62 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix3d<double>::multiply(const matrix<double>& A,
|
||||||
|
const matrix3d<double>& B,
|
||||||
|
bool transA, bool addTo)
|
||||||
|
{
|
||||||
|
size_t M, N, K;
|
||||||
|
if (!this->compatible(A,B,transA,M,N,K)) return false;
|
||||||
|
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||||
|
|
||||||
|
cblas_dgemm(CblasColMajor,
|
||||||
|
transA ? CblasTrans : CblasNoTrans, CblasNoTrans,
|
||||||
|
M, N, K, 1.0,
|
||||||
|
A.ptr(), A.rows(),
|
||||||
|
B.ptr(), B.n[0],
|
||||||
|
addTo ? 1.0 : 0.0,
|
||||||
|
this->ptr(), n[0]);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix3d<float>::multiplyMat(const std::vector<float>& A,
|
||||||
|
const matrix3d<float>& B, bool addTo)
|
||||||
|
{
|
||||||
|
size_t M, N, K;
|
||||||
|
if (!this->compatible(A,B,M,N,K)) return false;
|
||||||
|
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||||
|
|
||||||
|
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
|
||||||
|
M, N, K, 1.0f,
|
||||||
|
A.data(), M,
|
||||||
|
B.ptr(), B.n[0],
|
||||||
|
addTo ? 1.0f : 0.0f,
|
||||||
|
this->ptr(), n[0]);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix3d<double>::multiplyMat(const std::vector<double>& A,
|
||||||
|
const matrix3d<double>& B, bool addTo)
|
||||||
|
{
|
||||||
|
size_t M, N, K;
|
||||||
|
if (!this->compatible(A,B,M,N,K)) return false;
|
||||||
|
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||||
|
|
||||||
|
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
|
||||||
|
M, N, K, 1.0,
|
||||||
|
A.data(), M,
|
||||||
|
B.ptr(), B.n[0],
|
||||||
|
addTo ? 1.0 : 0.0,
|
||||||
|
this->ptr(), n[0]);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<> inline
|
template<> inline
|
||||||
bool matrix<float>::outer_product(const std::vector<float>& X,
|
bool matrix<float>::outer_product(const std::vector<float>& X,
|
||||||
const std::vector<float>& Y,
|
const std::vector<float>& Y,
|
||||||
@ -1377,26 +1474,6 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> inline
|
|
||||||
bool matrix3d<double>::multiply(const matrix<double>& A,
|
|
||||||
const matrix3d<double>& B,
|
|
||||||
bool transA, bool addTo)
|
|
||||||
{
|
|
||||||
size_t M, N, K;
|
|
||||||
if (!this->compatible(A,B,transA,M,N,K)) return false;
|
|
||||||
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
|
||||||
|
|
||||||
cblas_dgemm(CblasColMajor,
|
|
||||||
transA ? CblasTrans : CblasNoTrans, CblasNoTrans,
|
|
||||||
M, N, K, 1.0,
|
|
||||||
A.ptr(), A.rows(),
|
|
||||||
B.ptr(), B.n[0],
|
|
||||||
addTo ? 1.0 : 0.0,
|
|
||||||
this->ptr(), n[0]);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
//============================================================================
|
//============================================================================
|
||||||
//=== Non-BLAS inlined implementations (slow...) =========================
|
//=== Non-BLAS inlined implementations (slow...) =========================
|
||||||
@ -1641,6 +1718,23 @@ namespace utl //! General utility classes and functions.
|
|||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T> inline
|
||||||
|
bool matrix3d<T>::multiplyMat(const std::vector<T>& A, const matrix3d<T>& B,
|
||||||
|
bool addTo)
|
||||||
|
{
|
||||||
|
size_t M, N, K;
|
||||||
|
if (!this->compatible(A,B,M,N,K)) return false;
|
||||||
|
if (!addTo) this->resize(M,B.n[1],B.n[2],true);
|
||||||
|
|
||||||
|
for (size_t i = 1; i <= this->n[0]; i++)
|
||||||
|
for (size_t j = 1; j <= this->n[1]; j++)
|
||||||
|
for (size_t k = 1; k <= K; k++)
|
||||||
|
for (size_t l = 1; l <= this->n[2]; l++)
|
||||||
|
this->operator()(i,j,l) += A[i-1+M*(k-1)]*B(k,j,l);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//============================================================================
|
//============================================================================
|
||||||
|
@ -35,6 +35,35 @@ Vec3 operator* (const utl::matrix<Real>& A, const std::vector<Real>& x)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Vec3 operator* (const utl::matrix<Real>& A, const Vec3& x)
|
||||||
|
{
|
||||||
|
return A * x.vec();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\brief Multiplication of a vector and a matrix,
|
||||||
|
\return \f$ {\bf y} = {\bf A}^T {\bf x} \f$
|
||||||
|
|
||||||
|
Special version for computation of 3D point vectors.
|
||||||
|
The number of columns in \b A must be (at least) 3.
|
||||||
|
Extra columns (if any) in \b A are ignored.
|
||||||
|
*/
|
||||||
|
|
||||||
|
Vec3 operator* (const std::vector<Real>& x, const utl::matrix<Real>& A)
|
||||||
|
{
|
||||||
|
std::vector<Real> y(A.cols(),0.0);
|
||||||
|
A.multiply(x,y,true);
|
||||||
|
return Vec3(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Vec3 operator* (const Vec3& x, const utl::matrix<Real>& A)
|
||||||
|
{
|
||||||
|
return x.vec() * A;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Vec3 operator* (const Vec3& a, Real value)
|
Vec3 operator* (const Vec3& a, Real value)
|
||||||
{
|
{
|
||||||
return Vec3(a.x*value, a.y*value, a.z*value);
|
return Vec3(a.x*value, a.y*value, a.z*value);
|
||||||
|
@ -21,6 +21,13 @@ class Vec3;
|
|||||||
|
|
||||||
//! \brief Multiplication of a matrix and a vector.
|
//! \brief Multiplication of a matrix and a vector.
|
||||||
Vec3 operator*(const utl::matrix<Real>& A, const std::vector<Real>& x);
|
Vec3 operator*(const utl::matrix<Real>& A, const std::vector<Real>& x);
|
||||||
|
//! \brief Multiplication of a matrix and a vector.
|
||||||
|
Vec3 operator*(const utl::matrix<Real>& A, const Vec3& x);
|
||||||
|
|
||||||
|
//! \brief Multiplication of a vector and a matrix.
|
||||||
|
Vec3 operator*(const std::vector<Real>& x, const utl::matrix<Real>& A);
|
||||||
|
//! \brief Multiplication of a vector and a matrix.
|
||||||
|
Vec3 operator*(const Vec3& x, const utl::matrix<Real>& A);
|
||||||
|
|
||||||
//! \brief Multiplication of a vector and a scalar.
|
//! \brief Multiplication of a vector and a scalar.
|
||||||
Vec3 operator*(const Vec3& a, Real value);
|
Vec3 operator*(const Vec3& a, Real value);
|
||||||
|
Loading…
Reference in New Issue
Block a user