Fixed: stride > 1 and ofs > 0 was not correctly accounted for.
Also the non-BLAS version of this multiply method was faulty. Changed: Use const reference only for those parameters that are passed into BLAS subroutines (consistency). Changed: Clear the matrix if dimension mismatch is detected in the multiply methods returning the reference to the matrix itself (for error handling).
This commit is contained in:
@@ -63,6 +63,35 @@ TEST(TestMatrix, AddRows)
|
||||
}
|
||||
|
||||
|
||||
TEST(TestMatrix, Multiply)
|
||||
{
|
||||
utl::vector<double> u(14), v(9), x, y;
|
||||
utl::matrix<double> A(3,5);
|
||||
|
||||
std::iota(u.begin(),u.end(),1.0);
|
||||
std::iota(v.begin(),v.end(),1.0);
|
||||
std::iota(A.begin(),A.end(),1.0);
|
||||
|
||||
ASSERT_TRUE(A.multiply(u,x,1.0,0.0,false,3,4,1,2));
|
||||
ASSERT_TRUE(A.multiply(v,y,1.0,0.0,true,4,2));
|
||||
|
||||
EXPECT_FLOAT_EQ(x(3),370.0);
|
||||
EXPECT_FLOAT_EQ(x(7),410.0);
|
||||
EXPECT_FLOAT_EQ(x(11),450.0);
|
||||
EXPECT_FLOAT_EQ(y(1),38.0);
|
||||
EXPECT_FLOAT_EQ(y(3),83.0);
|
||||
EXPECT_FLOAT_EQ(y(5),128.0);
|
||||
EXPECT_FLOAT_EQ(y(7),173.0);
|
||||
EXPECT_FLOAT_EQ(y(9),218.0);
|
||||
|
||||
ASSERT_TRUE(A.multiply(u,x,1.0,-1.0,false,3,4,1,2));
|
||||
ASSERT_TRUE(A.multiply(v,y,1.0,-1.0,true,4,2));
|
||||
|
||||
EXPECT_FLOAT_EQ(x.sum(),0.0);
|
||||
EXPECT_FLOAT_EQ(y.sum(),0.0);
|
||||
}
|
||||
|
||||
|
||||
TEST(TestMatrix, Norm)
|
||||
{
|
||||
utl::matrix<double> a(4,5);
|
||||
|
||||
@@ -92,7 +92,7 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
//! \brief Fill the vector with a scalar value.
|
||||
void fill(const T& s) { std::fill(this->begin(),this->end(),s); }
|
||||
void fill(T s) { std::fill(this->begin(),this->end(),s); }
|
||||
//! \brief Fill the vector with data from an array.
|
||||
void fill(const T* values, size_t n = 0)
|
||||
{
|
||||
@@ -125,7 +125,7 @@ namespace utl //! General utility classes and functions.
|
||||
//! \brief Subtract the given vector \b X from \a *this.
|
||||
vector<T>& operator-=(const vector<T>& X) { return this->add(X,T(-1)); }
|
||||
//! \brief Add the given vector \b X scaled by \a alfa to \a *this.
|
||||
vector<T>& add(const std::vector<T>& X, T alfa = T(1));
|
||||
vector<T>& add(const std::vector<T>& X, const T& alfa = T(1));
|
||||
|
||||
//! \brief Perform \f${\bf Y} = \alpha{\bf Y} + (1-\alpha){\bf X} \f$
|
||||
//! where \b Y = \a *this.
|
||||
@@ -317,7 +317,7 @@ namespace utl //! General utility classes and functions.
|
||||
void clear() { n[0] = n[1] = n[2] = 0; elem.clear(); }
|
||||
|
||||
//! \brief Fill the matrix with a scalar value.
|
||||
void fill(const T& s) { std::fill(elem.begin(),elem.end(),s); }
|
||||
void fill(T s) { std::fill(elem.begin(),elem.end(),s); }
|
||||
//! \brief Fill the matrix with data from an array.
|
||||
void fill(const T* values, size_t n = 0) { elem.fill(values,n); }
|
||||
|
||||
@@ -519,7 +519,7 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
//! \brief Add a scalar multiple of another matrix to a block of the matrix.
|
||||
void addBlock(const matrix<T>& block, const T& s, size_t r, size_t c,
|
||||
void addBlock(const matrix<T>& block, T s, size_t r, size_t c,
|
||||
bool transposed = false)
|
||||
{
|
||||
size_t nr = transposed ? block.cols() : block.rows();
|
||||
@@ -533,7 +533,7 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
//! \brief Create a diagonal matrix.
|
||||
matrix<T>& diag(const T& d, size_t dim = 0)
|
||||
matrix<T>& diag(T d, size_t dim = 0)
|
||||
{
|
||||
if (dim > 0)
|
||||
this->resize(dim,dim,true);
|
||||
@@ -586,7 +586,7 @@ namespace utl //! General utility classes and functions.
|
||||
//! \brief Compute the inverse of a square matrix.
|
||||
//! \param[in] tol Division by zero tolerance
|
||||
//! \return Determinant of the matrix
|
||||
T inverse(const T& tol = T(0))
|
||||
T inverse(T tol = T(0))
|
||||
{
|
||||
T det = this->det();
|
||||
if (det == T(-999))
|
||||
@@ -626,7 +626,7 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
//! \brief Check for symmetry.
|
||||
//! \param[in] tol Comparison tolerance
|
||||
bool isSymmetric(const T& tol = T(0)) const
|
||||
bool isSymmetric(T tol = T(0)) const
|
||||
{
|
||||
if (nrow != ncol) return false;
|
||||
|
||||
@@ -673,7 +673,7 @@ namespace utl //! General utility classes and functions.
|
||||
*/
|
||||
matrix<T>& multiply(const matrix<T>& A, const matrix<T>& B,
|
||||
bool transA = false, bool transB = false,
|
||||
bool addTo = false, T alpha = T(1));
|
||||
bool addTo = false, const T& alpha = T(1));
|
||||
|
||||
/*! \brief Matrix-matrix multiplication.
|
||||
\details Performs one of the following operations (\b C = \a *this):
|
||||
@@ -723,8 +723,8 @@ namespace utl //! General utility classes and functions.
|
||||
-# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$
|
||||
*/
|
||||
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||
T alpha, T beta, bool transA = false,
|
||||
int stridex = 1, int stridey = 1,
|
||||
const T& alpha, const T& beta = T(0),
|
||||
bool transA = false, int stridex = 1, int stridey = 1,
|
||||
int ofsx = 0, int ofsy = 0) const;
|
||||
|
||||
//! \brief Outer product between two vectors.
|
||||
@@ -1111,7 +1111,8 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
template<> inline
|
||||
vector<float>& vector<float>::add(const std::vector<float>& X, float alfa)
|
||||
vector<float>& vector<float>::add(const std::vector<float>& X,
|
||||
const float& alfa)
|
||||
{
|
||||
size_t n = this->size() < X.size() ? this->size() : X.size();
|
||||
cblas_saxpy(n,alfa,&X.front(),1,this->ptr(),1);
|
||||
@@ -1119,7 +1120,8 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
template<> inline
|
||||
vector<double>& vector<double>::add(const std::vector<double>& X, double alfa)
|
||||
vector<double>& vector<double>::add(const std::vector<double>& X,
|
||||
const double& alfa)
|
||||
{
|
||||
size_t n = this->size() < X.size() ? this->size() : X.size();
|
||||
cblas_daxpy(n,alfa,&X.front(),1,this->ptr(),1);
|
||||
@@ -1197,11 +1199,18 @@ namespace utl //! General utility classes and functions.
|
||||
template<> inline
|
||||
bool matrix<float>::multiply(const std::vector<float>& X,
|
||||
std::vector<float>& Y,
|
||||
float alpha, float beta, bool transA,
|
||||
int stridex, int stridey, int ofsx,
|
||||
int ofsy) const
|
||||
const float& alpha, const float& beta,
|
||||
bool transA, int stridex, int stridey,
|
||||
int ofsx, int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
if (ofsx == 0 && stridex == 1)
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
if (beta == 0.0f)
|
||||
{
|
||||
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
|
||||
if (stridey > 0) std::fill(Y.begin(),Y.end(),0.0f);
|
||||
}
|
||||
|
||||
cblas_sgemv(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
@@ -1216,11 +1225,18 @@ namespace utl //! General utility classes and functions.
|
||||
template<> inline
|
||||
bool matrix<double>::multiply(const std::vector<double>& X,
|
||||
std::vector<double>& Y,
|
||||
double alpha, double beta, bool transA,
|
||||
int stridex, int stridey,
|
||||
const double& alpha, const double& beta,
|
||||
bool transA, int stridex, int stridey,
|
||||
int ofsx, int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
if (ofsx == 0 && stridex == 1)
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
if (beta == 0.0)
|
||||
{
|
||||
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
|
||||
if (stridey > 0) std::fill(Y.begin(),Y.end(),0.0);
|
||||
}
|
||||
|
||||
cblas_dgemv(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
@@ -1236,20 +1252,23 @@ namespace utl //! General utility classes and functions.
|
||||
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
|
||||
const matrix<float>& B,
|
||||
bool transA, bool transB,
|
||||
bool addTo, float alpha)
|
||||
bool addTo, const float& alpha)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
|
||||
if (!addTo) this->resize(M,N);
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K))
|
||||
this->clear();
|
||||
else if (!addTo)
|
||||
this->resize(M,N);
|
||||
|
||||
cblas_sgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0f : 0.0f,
|
||||
this->ptr(), nrow);
|
||||
if (!this->empty())
|
||||
cblas_sgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0f : 0.0f,
|
||||
this->ptr(), nrow);
|
||||
|
||||
return *this;
|
||||
}
|
||||
@@ -1258,20 +1277,23 @@ namespace utl //! General utility classes and functions.
|
||||
matrix<double>& matrix<double>::multiply(const matrix<double>& A,
|
||||
const matrix<double>& B,
|
||||
bool transA, bool transB,
|
||||
bool addTo, double alpha)
|
||||
bool addTo, const double& alpha)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
|
||||
if (!addTo) this->resize(M,N);
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K))
|
||||
this->clear();
|
||||
else if (!addTo)
|
||||
this->resize(M,N);
|
||||
|
||||
cblas_dgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), nrow);
|
||||
if (!this->empty())
|
||||
cblas_dgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), nrow);
|
||||
|
||||
return *this;
|
||||
}
|
||||
@@ -1416,7 +1438,7 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<double>::multiplyMat(const std::vector<double>& A,
|
||||
const matrix3d<double>& B, bool addTo)
|
||||
const matrix3d<double>& B, bool addTo)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,M,N,K)) return false;
|
||||
@@ -1546,7 +1568,7 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
vector<T>& vector<T>::add(const std::vector<T>& X, T alfa)
|
||||
vector<T>& vector<T>::add(const std::vector<T>& X, const T& alfa)
|
||||
{
|
||||
T* p = this->ptr();
|
||||
const T* q = &X.front();
|
||||
@@ -1556,7 +1578,7 @@ namespace utl //! General utility classes and functions.
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
matrixBase<T>& matrixBase<T>::add(const matrixBase<T>& A, T alfa)
|
||||
matrixBase<T>& matrixBase<T>::add(const matrixBase<T>& A, const T& alfa)
|
||||
{
|
||||
T* p = this->ptr();
|
||||
const T* q = A.ptr();
|
||||
@@ -1596,20 +1618,28 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
template<class T> inline
|
||||
bool matrix<T>::multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||
T alpha, T beta, bool transA,
|
||||
int stridex, int stridey,
|
||||
const T& alpha, const T& beta,
|
||||
bool transA, int stridex, int stridey,
|
||||
int ofsx, int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
if (ofsx == 0 && stridex == 1)
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
for (size_t i = ofsy; i < Y.size(); i += stridey) {
|
||||
Y[i] *= beta;
|
||||
for (size_t j = ofsx; j < X.size(); j += stridex)
|
||||
if (transA)
|
||||
Y[i] += alpha * THIS(j+1,i+1) * X[j];
|
||||
else
|
||||
Y[i] += alpha * THIS(i+1,j+1) * X[j];
|
||||
if (beta == T(0))
|
||||
{
|
||||
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
|
||||
std::fill(Y.begin(),Y.end(),T(0));
|
||||
}
|
||||
else for (size_t i = ofsy; i < Y.size(); i += stridey)
|
||||
Y[i] *= beta;
|
||||
|
||||
size_t a, b, i, j;
|
||||
for (a = 1, i = ofsy; i < Y.size(); a++, i += stridey)
|
||||
for (b = 1, j = ofsx; j < X.size(); b++, j += stridex)
|
||||
if (transA)
|
||||
Y[i] += alpha * THIS(b,a) * X[j];
|
||||
else
|
||||
Y[i] += alpha * THIS(a,b) * X[j];
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1618,11 +1648,16 @@ namespace utl //! General utility classes and functions.
|
||||
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
|
||||
const matrix<T>& B,
|
||||
bool transA, bool transB, bool addTo,
|
||||
T alpha)
|
||||
const T& alpha)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
|
||||
if (!addTo) this->resize(M,N,true);
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K))
|
||||
{
|
||||
this->clear();
|
||||
M = 0;
|
||||
}
|
||||
else if (!addTo)
|
||||
this->resize(M,N,true);
|
||||
|
||||
for (size_t i = 1; i <= M; i++)
|
||||
for (size_t j = 1; j <= N; j++)
|
||||
@@ -1749,7 +1784,7 @@ namespace utl //! General utility classes and functions.
|
||||
//! matrices when they contain terms that are numerically zero, except for
|
||||
//! some round-off noise. The value of the global variable \a zero_print_tol
|
||||
//! is used as a tolerance in this method.
|
||||
template<class T> inline T trunc(const T& v)
|
||||
template<class T> inline T trunc(T v)
|
||||
{
|
||||
return v > T(zero_print_tol) || v < T(-zero_print_tol) || std::isnan(v) ?
|
||||
v : T(0);
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
|
||||
#ifdef HAS_BLAS
|
||||
TEST(TestLegendre, GL)
|
||||
{
|
||||
std::vector<Real> w, p;
|
||||
@@ -60,6 +61,7 @@ TEST(TestLegendre, GLL)
|
||||
ASSERT_FLOAT_EQ(p[2], 0.447213595499958);
|
||||
ASSERT_FLOAT_EQ(p[3], 1.0);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
TEST(TestLegendre, Eval)
|
||||
@@ -84,14 +86,15 @@ TEST(TestLegendre, Derivative)
|
||||
}
|
||||
|
||||
|
||||
#ifdef HAS_BLAS
|
||||
TEST(TestLegendre, BasisDerivatives)
|
||||
{
|
||||
utl::matrix<Real> result1, result2;
|
||||
|
||||
Legendre::basisDerivatives(3, result1);
|
||||
const double ref3[3][3] = {{-1.5, 2.0, -0.5},
|
||||
{-0.5, 0.0, 0.5},
|
||||
{ 0.5, -2.0, 1.5}};
|
||||
const Real ref3[3][3] = {{-1.5, 2.0, -0.5},
|
||||
{-0.5, 0.0, 0.5},
|
||||
{ 0.5, -2.0, 1.5}};
|
||||
|
||||
for (size_t i = 0; i < 3; ++i)
|
||||
for (size_t j = 0; j < 3; ++j)
|
||||
@@ -99,16 +102,16 @@ TEST(TestLegendre, BasisDerivatives)
|
||||
|
||||
Legendre::basisDerivatives(10, result2);
|
||||
|
||||
const double ref10[10][10] =
|
||||
{{-22.5, 30.43814502928187, -12.17794670742983, 6.943788485133953,
|
||||
-4.599354761103132, 3.294643033749184, -2.452884175442686, 1.829563931903247,
|
||||
const Real ref10[10][10] =
|
||||
{{-22.5, 30.43814502928187, -12.17794670742983, 6.943788485133953,
|
||||
-4.599354761103132, 3.294643033749184, -2.452884175442686, 1.829563931903247,
|
||||
-1.275954836092663, 0.5},
|
||||
{-5.074064702978062, 0.0, 7.185502869705838, -3.351663862746774,
|
||||
2.078207994036418, -1.444948448751456, 1.059154463645442, -0.783239293137909,
|
||||
0.5437537382357057, -0.2127027580091889},
|
||||
{ 1.203351992852207, -4.259297354965217, 0.0, 4.368674557010181,
|
||||
-2.104350179413155, 1.334915483878251, -0.9366032131394465,
|
||||
0.6767970871960859, -0.4642749589081571, 0.1807865854892499},
|
||||
0.6767970871960859, -0.4642749589081571, 0.1807865854892499},
|
||||
{-0.528369376820273, 1.529902638181603, -3.364125868297819, 0.0,
|
||||
3.387318101202445, -1.64649408398706, 1.046189365502494,
|
||||
-0.7212373127216044, 0.4834623263339481, -0.1866457893937361},
|
||||
@@ -125,7 +128,7 @@ TEST(TestLegendre, BasisDerivatives)
|
||||
0.9366032131394465, -1.334915483878251, 2.104350179413155,
|
||||
-4.368674557010181, 0.0, 4.259297354965217,
|
||||
-1.203351992852207},
|
||||
{0.2127027580091889, -0.5437537382357057, 0.783239293137909,
|
||||
{ 0.2127027580091889, -0.5437537382357057, 0.783239293137909,
|
||||
-1.059154463645442, 1.444948448751456, -2.078207994036418,
|
||||
3.351663862746774, -7.185502869705838, 0.0, 5.074064702978062},
|
||||
{-0.5, 1.275954836092663, -1.829563931903247,
|
||||
@@ -136,3 +139,4 @@ TEST(TestLegendre, BasisDerivatives)
|
||||
for (size_t j = 0; j < 10; ++j)
|
||||
EXPECT_FLOAT_EQ(result2(i+1, j+1), ref10[i][j]);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user