Added: matrix3d::multiplyMat and matrix3d::fillColumn.
Added: More matrix-vector multiplication operators. Added: Empty virtual destructor in class matrix since it has a parent class.
This commit is contained in:
parent
094fa4a689
commit
86265d9e50
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! \author Eivind Fonn / SINTEF
|
||||
//!
|
||||
//! \brief Unit tests for matrix
|
||||
//! \brief Unit tests for matrix and matrix3d.
|
||||
//!
|
||||
//==============================================================================
|
||||
|
||||
@ -122,3 +122,22 @@ TEST(TestMatrix3D, DumpRead)
|
||||
B -= A;
|
||||
ASSERT_NEAR(B.norm2(), 0.0, 1.0e-13);
|
||||
}
|
||||
|
||||
|
||||
TEST(TestMatrix3D, Multiply)
|
||||
{
|
||||
std::vector<double> a(10);
|
||||
utl::matrix<double> A(2,5);
|
||||
utl::matrix3d<double> B(5,4,3), C, D;
|
||||
|
||||
std::iota(a.begin(),a.end(),1.0);
|
||||
std::iota(A.begin(),A.end(),1.0);
|
||||
std::iota(B.begin(),B.end(),1.0);
|
||||
|
||||
C.multiply(A,B);
|
||||
ASSERT_TRUE(D.multiplyMat(a,B));
|
||||
|
||||
std::vector<double>::const_iterator c = C.begin();
|
||||
for (double d : D)
|
||||
EXPECT_EQ(d,*(c++));
|
||||
}
|
||||
|
@ -370,6 +370,8 @@ namespace utl //! General utility classes and functions.
|
||||
else
|
||||
this->elem.fill(mat.elem.ptr());
|
||||
}
|
||||
//! \brief Empty destructor.
|
||||
virtual ~matrix() {}
|
||||
|
||||
//! \brief Resize the matrix to dimension \f$r \times c\f$.
|
||||
//! \details Will erase the previous content, but only if both
|
||||
@ -480,7 +482,7 @@ namespace utl //! General utility classes and functions.
|
||||
//! \brief Fill a column of the matrix.
|
||||
void fillColumn(size_t c, const std::vector<T>& data)
|
||||
{
|
||||
CHECK_INDEX("matrix::getColumn: Column-index ",c,ncol);
|
||||
CHECK_INDEX("matrix::fillColumn: Column-index ",c,ncol);
|
||||
size_t ndata = nrow > data.size() ? data.size() : nrow;
|
||||
memcpy(this->ptr(c-1),&data.front(),ndata*sizeof(T));
|
||||
}
|
||||
@ -919,6 +921,15 @@ namespace utl //! General utility classes and functions.
|
||||
return col;
|
||||
}
|
||||
|
||||
//! \brief Fill a column of the matrix.
|
||||
void fillColumn(size_t i2, size_t i3, const std::vector<T>& data)
|
||||
{
|
||||
CHECK_INDEX("matrix3d::fillColumn: Second index ",i2,this->n[1]);
|
||||
CHECK_INDEX("matrix3d::fillColumn: Third index " ,i3,this->n[2]);
|
||||
size_t ndata = this->n[0] > data.size() ? data.size() : this->n[0];
|
||||
memcpy(this->ptr(i2-1+this->n[1]*(i3-1)),&data.front(),ndata*sizeof(T));
|
||||
}
|
||||
|
||||
//! \brief Return the trace of the \a i1'th sub-matrix.
|
||||
T trace(size_t i1) const
|
||||
{
|
||||
@ -952,6 +963,19 @@ namespace utl //! General utility classes and functions.
|
||||
bool multiply(const matrix<T>& A, const matrix3d<T>& B,
|
||||
bool transA = false, bool addTo = false);
|
||||
|
||||
/*! \brief Matrix-matrix multiplication.
|
||||
\details Performs one of the following operations (\b C = \a *this):
|
||||
-# \f$ {\bf C} = {\bf A} {\bf B} \f$
|
||||
-# \f$ {\bf C} = {\bf C} + {\bf A} {\bf B} \f$
|
||||
|
||||
The matrix \b A is here represented by a one-dimensional vector, and its
|
||||
number of columns is assumed to match the first dimension of \b B and its
|
||||
number of rows is then the total vector length divided by
|
||||
the number of columns.
|
||||
*/
|
||||
bool multiplyMat(const std::vector<T>& A, const matrix3d<T>& B,
|
||||
bool addTo = false);
|
||||
|
||||
private:
|
||||
//! \brief Check dimension compatibility for matrix-matrix multiplication.
|
||||
bool compatible(const matrix<T>& A, const matrix3d<T>& B,
|
||||
@ -971,6 +995,23 @@ namespace utl //! General utility classes and functions.
|
||||
return false;
|
||||
}
|
||||
|
||||
//! \brief Check dimension compatibility for matrix-matrix multiplication,
|
||||
//! when the matrix A is represented by a one-dimensional vector.
|
||||
bool compatible(const std::vector<T>& A, const matrix3d<T>& B,
|
||||
size_t& M, size_t& N, size_t& K)
|
||||
{
|
||||
N = B.n[1]*B.n[2];
|
||||
K = B.n[0];
|
||||
M = K > 0 ? A.size() / K : 0;
|
||||
if (M*K == A.size() && !A.empty()) return true;
|
||||
|
||||
std::cerr <<"matrix3d::multiply: Incompatible matrices: A(r*c="<< A.size()
|
||||
<<"), B("<< B.n[0] <<","<< B.n[1] <<","<< B.n[2] <<")\n"
|
||||
<<" when computing C = A * B"<< std::endl;
|
||||
ABORT_ON_INDEX_CHECK;
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
//! \brief Clears the content if any of the first two dimensions changed.
|
||||
virtual void clearIfNrowChanged(size_t n1, size_t n2)
|
||||
@ -1335,6 +1376,62 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<double>::multiply(const matrix<double>& A,
|
||||
const matrix3d<double>& B,
|
||||
bool transA, bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,M,N,K)) return false;
|
||||
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||
|
||||
cblas_dgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans, CblasNoTrans,
|
||||
M, N, K, 1.0,
|
||||
A.ptr(), A.rows(),
|
||||
B.ptr(), B.n[0],
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), n[0]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<float>::multiplyMat(const std::vector<float>& A,
|
||||
const matrix3d<float>& B, bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,M,N,K)) return false;
|
||||
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
|
||||
M, N, K, 1.0f,
|
||||
A.data(), M,
|
||||
B.ptr(), B.n[0],
|
||||
addTo ? 1.0f : 0.0f,
|
||||
this->ptr(), n[0]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<double>::multiplyMat(const std::vector<double>& A,
|
||||
const matrix3d<double>& B, bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,M,N,K)) return false;
|
||||
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||
|
||||
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
|
||||
M, N, K, 1.0,
|
||||
A.data(), M,
|
||||
B.ptr(), B.n[0],
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), n[0]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<float>::outer_product(const std::vector<float>& X,
|
||||
const std::vector<float>& Y,
|
||||
@ -1377,26 +1474,6 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<double>::multiply(const matrix<double>& A,
|
||||
const matrix3d<double>& B,
|
||||
bool transA, bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,M,N,K)) return false;
|
||||
if (!addTo) this->resize(M,B.n[1],B.n[2]);
|
||||
|
||||
cblas_dgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans, CblasNoTrans,
|
||||
M, N, K, 1.0,
|
||||
A.ptr(), A.rows(),
|
||||
B.ptr(), B.n[0],
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), n[0]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
//============================================================================
|
||||
//=== Non-BLAS inlined implementations (slow...) =========================
|
||||
@ -1641,6 +1718,23 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
bool matrix3d<T>::multiplyMat(const std::vector<T>& A, const matrix3d<T>& B,
|
||||
bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,M,N,K)) return false;
|
||||
if (!addTo) this->resize(M,B.n[1],B.n[2],true);
|
||||
|
||||
for (size_t i = 1; i <= this->n[0]; i++)
|
||||
for (size_t j = 1; j <= this->n[1]; j++)
|
||||
for (size_t k = 1; k <= K; k++)
|
||||
for (size_t l = 1; l <= this->n[2]; l++)
|
||||
this->operator()(i,j,l) += A[i-1+M*(k-1)]*B(k,j,l);
|
||||
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
//============================================================================
|
||||
|
@ -35,6 +35,35 @@ Vec3 operator* (const utl::matrix<Real>& A, const std::vector<Real>& x)
|
||||
}
|
||||
|
||||
|
||||
Vec3 operator* (const utl::matrix<Real>& A, const Vec3& x)
|
||||
{
|
||||
return A * x.vec();
|
||||
}
|
||||
|
||||
|
||||
/*!
|
||||
\brief Multiplication of a vector and a matrix,
|
||||
\return \f$ {\bf y} = {\bf A}^T {\bf x} \f$
|
||||
|
||||
Special version for computation of 3D point vectors.
|
||||
The number of columns in \b A must be (at least) 3.
|
||||
Extra columns (if any) in \b A are ignored.
|
||||
*/
|
||||
|
||||
Vec3 operator* (const std::vector<Real>& x, const utl::matrix<Real>& A)
|
||||
{
|
||||
std::vector<Real> y(A.cols(),0.0);
|
||||
A.multiply(x,y,true);
|
||||
return Vec3(y);
|
||||
}
|
||||
|
||||
|
||||
Vec3 operator* (const Vec3& x, const utl::matrix<Real>& A)
|
||||
{
|
||||
return x.vec() * A;
|
||||
}
|
||||
|
||||
|
||||
Vec3 operator* (const Vec3& a, Real value)
|
||||
{
|
||||
return Vec3(a.x*value, a.y*value, a.z*value);
|
||||
|
@ -21,6 +21,13 @@ class Vec3;
|
||||
|
||||
//! \brief Multiplication of a matrix and a vector.
|
||||
Vec3 operator*(const utl::matrix<Real>& A, const std::vector<Real>& x);
|
||||
//! \brief Multiplication of a matrix and a vector.
|
||||
Vec3 operator*(const utl::matrix<Real>& A, const Vec3& x);
|
||||
|
||||
//! \brief Multiplication of a vector and a matrix.
|
||||
Vec3 operator*(const std::vector<Real>& x, const utl::matrix<Real>& A);
|
||||
//! \brief Multiplication of a vector and a matrix.
|
||||
Vec3 operator*(const Vec3& x, const utl::matrix<Real>& A);
|
||||
|
||||
//! \brief Multiplication of a vector and a scalar.
|
||||
Vec3 operator*(const Vec3& a, Real value);
|
||||
|
Loading…
Reference in New Issue
Block a user