Fixed: stride > 1 and ofs > 0 was not correctly accounted for.

Also the non-BLAS version of this multiply method was faulty.
Changed: Use const reference only for those parameters that
are passed into BLAS subroutines (consistency).
Changed: Clear the matrix if dimension mismatch is detected in the multiply
methods returning the reference to the matrix itself (for error handling).
This commit is contained in:
Knut Morten Okstad
2018-09-03 21:10:15 +02:00
parent 7fc18849ff
commit 3a6ef1d033
3 changed files with 134 additions and 66 deletions

View File

@@ -63,6 +63,35 @@ TEST(TestMatrix, AddRows)
}
TEST(TestMatrix, Multiply)
{
utl::vector<double> u(14), v(9), x, y;
utl::matrix<double> A(3,5);
std::iota(u.begin(),u.end(),1.0);
std::iota(v.begin(),v.end(),1.0);
std::iota(A.begin(),A.end(),1.0);
ASSERT_TRUE(A.multiply(u,x,1.0,0.0,false,3,4,1,2));
ASSERT_TRUE(A.multiply(v,y,1.0,0.0,true,4,2));
EXPECT_FLOAT_EQ(x(3),370.0);
EXPECT_FLOAT_EQ(x(7),410.0);
EXPECT_FLOAT_EQ(x(11),450.0);
EXPECT_FLOAT_EQ(y(1),38.0);
EXPECT_FLOAT_EQ(y(3),83.0);
EXPECT_FLOAT_EQ(y(5),128.0);
EXPECT_FLOAT_EQ(y(7),173.0);
EXPECT_FLOAT_EQ(y(9),218.0);
ASSERT_TRUE(A.multiply(u,x,1.0,-1.0,false,3,4,1,2));
ASSERT_TRUE(A.multiply(v,y,1.0,-1.0,true,4,2));
EXPECT_FLOAT_EQ(x.sum(),0.0);
EXPECT_FLOAT_EQ(y.sum(),0.0);
}
TEST(TestMatrix, Norm)
{
utl::matrix<double> a(4,5);

View File

@@ -92,7 +92,7 @@ namespace utl //! General utility classes and functions.
}
//! \brief Fill the vector with a scalar value.
void fill(const T& s) { std::fill(this->begin(),this->end(),s); }
void fill(T s) { std::fill(this->begin(),this->end(),s); }
//! \brief Fill the vector with data from an array.
void fill(const T* values, size_t n = 0)
{
@@ -125,7 +125,7 @@ namespace utl //! General utility classes and functions.
//! \brief Subtract the given vector \b X from \a *this.
vector<T>& operator-=(const vector<T>& X) { return this->add(X,T(-1)); }
//! \brief Add the given vector \b X scaled by \a alfa to \a *this.
vector<T>& add(const std::vector<T>& X, T alfa = T(1));
vector<T>& add(const std::vector<T>& X, const T& alfa = T(1));
//! \brief Perform \f${\bf Y} = \alpha{\bf Y} + (1-\alpha){\bf X} \f$
//! where \b Y = \a *this.
@@ -317,7 +317,7 @@ namespace utl //! General utility classes and functions.
void clear() { n[0] = n[1] = n[2] = 0; elem.clear(); }
//! \brief Fill the matrix with a scalar value.
void fill(const T& s) { std::fill(elem.begin(),elem.end(),s); }
void fill(T s) { std::fill(elem.begin(),elem.end(),s); }
//! \brief Fill the matrix with data from an array.
void fill(const T* values, size_t n = 0) { elem.fill(values,n); }
@@ -519,7 +519,7 @@ namespace utl //! General utility classes and functions.
}
//! \brief Add a scalar multiple of another matrix to a block of the matrix.
void addBlock(const matrix<T>& block, const T& s, size_t r, size_t c,
void addBlock(const matrix<T>& block, T s, size_t r, size_t c,
bool transposed = false)
{
size_t nr = transposed ? block.cols() : block.rows();
@@ -533,7 +533,7 @@ namespace utl //! General utility classes and functions.
}
//! \brief Create a diagonal matrix.
matrix<T>& diag(const T& d, size_t dim = 0)
matrix<T>& diag(T d, size_t dim = 0)
{
if (dim > 0)
this->resize(dim,dim,true);
@@ -586,7 +586,7 @@ namespace utl //! General utility classes and functions.
//! \brief Compute the inverse of a square matrix.
//! \param[in] tol Division by zero tolerance
//! \return Determinant of the matrix
T inverse(const T& tol = T(0))
T inverse(T tol = T(0))
{
T det = this->det();
if (det == T(-999))
@@ -626,7 +626,7 @@ namespace utl //! General utility classes and functions.
//! \brief Check for symmetry.
//! \param[in] tol Comparison tolerance
bool isSymmetric(const T& tol = T(0)) const
bool isSymmetric(T tol = T(0)) const
{
if (nrow != ncol) return false;
@@ -673,7 +673,7 @@ namespace utl //! General utility classes and functions.
*/
matrix<T>& multiply(const matrix<T>& A, const matrix<T>& B,
bool transA = false, bool transB = false,
bool addTo = false, T alpha = T(1));
bool addTo = false, const T& alpha = T(1));
/*! \brief Matrix-matrix multiplication.
\details Performs one of the following operations (\b C = \a *this):
@@ -723,8 +723,8 @@ namespace utl //! General utility classes and functions.
-# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$
*/
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
T alpha, T beta, bool transA = false,
int stridex = 1, int stridey = 1,
const T& alpha, const T& beta = T(0),
bool transA = false, int stridex = 1, int stridey = 1,
int ofsx = 0, int ofsy = 0) const;
//! \brief Outer product between two vectors.
@@ -1111,7 +1111,8 @@ namespace utl //! General utility classes and functions.
}
template<> inline
vector<float>& vector<float>::add(const std::vector<float>& X, float alfa)
vector<float>& vector<float>::add(const std::vector<float>& X,
const float& alfa)
{
size_t n = this->size() < X.size() ? this->size() : X.size();
cblas_saxpy(n,alfa,&X.front(),1,this->ptr(),1);
@@ -1119,7 +1120,8 @@ namespace utl //! General utility classes and functions.
}
template<> inline
vector<double>& vector<double>::add(const std::vector<double>& X, double alfa)
vector<double>& vector<double>::add(const std::vector<double>& X,
const double& alfa)
{
size_t n = this->size() < X.size() ? this->size() : X.size();
cblas_daxpy(n,alfa,&X.front(),1,this->ptr(),1);
@@ -1197,11 +1199,18 @@ namespace utl //! General utility classes and functions.
template<> inline
bool matrix<float>::multiply(const std::vector<float>& X,
std::vector<float>& Y,
float alpha, float beta, bool transA,
int stridex, int stridey, int ofsx,
int ofsy) const
const float& alpha, const float& beta,
bool transA, int stridex, int stridey,
int ofsx, int ofsy) const
{
if (!this->compatible(X,transA)) return false;
if (ofsx == 0 && stridex == 1)
if (!this->compatible(X,transA)) return false;
if (beta == 0.0f)
{
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
if (stridey > 0) std::fill(Y.begin(),Y.end(),0.0f);
}
cblas_sgemv(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
@@ -1216,11 +1225,18 @@ namespace utl //! General utility classes and functions.
template<> inline
bool matrix<double>::multiply(const std::vector<double>& X,
std::vector<double>& Y,
double alpha, double beta, bool transA,
int stridex, int stridey,
const double& alpha, const double& beta,
bool transA, int stridex, int stridey,
int ofsx, int ofsy) const
{
if (!this->compatible(X,transA)) return false;
if (ofsx == 0 && stridex == 1)
if (!this->compatible(X,transA)) return false;
if (beta == 0.0)
{
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
if (stridey > 0) std::fill(Y.begin(),Y.end(),0.0);
}
cblas_dgemv(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
@@ -1236,20 +1252,23 @@ namespace utl //! General utility classes and functions.
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
const matrix<float>& B,
bool transA, bool transB,
bool addTo, float alpha)
bool addTo, const float& alpha)
{
size_t M, N, K;
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
if (!addTo) this->resize(M,N);
if (!this->compatible(A,B,transA,transB,M,N,K))
this->clear();
else if (!addTo)
this->resize(M,N);
cblas_sgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0f : 0.0f,
this->ptr(), nrow);
if (!this->empty())
cblas_sgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0f : 0.0f,
this->ptr(), nrow);
return *this;
}
@@ -1258,20 +1277,23 @@ namespace utl //! General utility classes and functions.
matrix<double>& matrix<double>::multiply(const matrix<double>& A,
const matrix<double>& B,
bool transA, bool transB,
bool addTo, double alpha)
bool addTo, const double& alpha)
{
size_t M, N, K;
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
if (!addTo) this->resize(M,N);
if (!this->compatible(A,B,transA,transB,M,N,K))
this->clear();
else if (!addTo)
this->resize(M,N);
cblas_dgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0 : 0.0,
this->ptr(), nrow);
if (!this->empty())
cblas_dgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0 : 0.0,
this->ptr(), nrow);
return *this;
}
@@ -1416,7 +1438,7 @@ namespace utl //! General utility classes and functions.
template<> inline
bool matrix3d<double>::multiplyMat(const std::vector<double>& A,
const matrix3d<double>& B, bool addTo)
const matrix3d<double>& B, bool addTo)
{
size_t M, N, K;
if (!this->compatible(A,B,M,N,K)) return false;
@@ -1546,7 +1568,7 @@ namespace utl //! General utility classes and functions.
}
template<class T> inline
vector<T>& vector<T>::add(const std::vector<T>& X, T alfa)
vector<T>& vector<T>::add(const std::vector<T>& X, const T& alfa)
{
T* p = this->ptr();
const T* q = &X.front();
@@ -1556,7 +1578,7 @@ namespace utl //! General utility classes and functions.
}
template<class T> inline
matrixBase<T>& matrixBase<T>::add(const matrixBase<T>& A, T alfa)
matrixBase<T>& matrixBase<T>::add(const matrixBase<T>& A, const T& alfa)
{
T* p = this->ptr();
const T* q = A.ptr();
@@ -1596,20 +1618,28 @@ namespace utl //! General utility classes and functions.
template<class T> inline
bool matrix<T>::multiply(const std::vector<T>& X, std::vector<T>& Y,
T alpha, T beta, bool transA,
int stridex, int stridey,
const T& alpha, const T& beta,
bool transA, int stridex, int stridey,
int ofsx, int ofsy) const
{
if (!this->compatible(X,transA)) return false;
if (ofsx == 0 && stridex == 1)
if (!this->compatible(X,transA)) return false;
for (size_t i = ofsy; i < Y.size(); i += stridey) {
Y[i] *= beta;
for (size_t j = ofsx; j < X.size(); j += stridex)
if (transA)
Y[i] += alpha * THIS(j+1,i+1) * X[j];
else
Y[i] += alpha * THIS(i+1,j+1) * X[j];
if (beta == T(0))
{
Y.resize(ofsy + 1 + ((transA ? ncol : nrow)-1)*stridey);
std::fill(Y.begin(),Y.end(),T(0));
}
else for (size_t i = ofsy; i < Y.size(); i += stridey)
Y[i] *= beta;
size_t a, b, i, j;
for (a = 1, i = ofsy; i < Y.size(); a++, i += stridey)
for (b = 1, j = ofsx; j < X.size(); b++, j += stridex)
if (transA)
Y[i] += alpha * THIS(b,a) * X[j];
else
Y[i] += alpha * THIS(a,b) * X[j];
return true;
}
@@ -1618,11 +1648,16 @@ namespace utl //! General utility classes and functions.
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
const matrix<T>& B,
bool transA, bool transB, bool addTo,
T alpha)
const T& alpha)
{
size_t M, N, K;
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
if (!addTo) this->resize(M,N,true);
if (!this->compatible(A,B,transA,transB,M,N,K))
{
this->clear();
M = 0;
}
else if (!addTo)
this->resize(M,N,true);
for (size_t i = 1; i <= M; i++)
for (size_t j = 1; j <= N; j++)
@@ -1749,7 +1784,7 @@ namespace utl //! General utility classes and functions.
//! matrices when they contain terms that are numerically zero, except for
//! some round-off noise. The value of the global variable \a zero_print_tol
//! is used as a tolerance in this method.
template<class T> inline T trunc(const T& v)
template<class T> inline T trunc(T v)
{
return v > T(zero_print_tol) || v < T(-zero_print_tol) || std::isnan(v) ?
v : T(0);

View File

@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#ifdef HAS_BLAS
TEST(TestLegendre, GL)
{
std::vector<Real> w, p;
@@ -60,6 +61,7 @@ TEST(TestLegendre, GLL)
ASSERT_FLOAT_EQ(p[2], 0.447213595499958);
ASSERT_FLOAT_EQ(p[3], 1.0);
}
#endif
TEST(TestLegendre, Eval)
@@ -84,14 +86,15 @@ TEST(TestLegendre, Derivative)
}
#ifdef HAS_BLAS
TEST(TestLegendre, BasisDerivatives)
{
utl::matrix<Real> result1, result2;
Legendre::basisDerivatives(3, result1);
const double ref3[3][3] = {{-1.5, 2.0, -0.5},
{-0.5, 0.0, 0.5},
{ 0.5, -2.0, 1.5}};
const Real ref3[3][3] = {{-1.5, 2.0, -0.5},
{-0.5, 0.0, 0.5},
{ 0.5, -2.0, 1.5}};
for (size_t i = 0; i < 3; ++i)
for (size_t j = 0; j < 3; ++j)
@@ -99,16 +102,16 @@ TEST(TestLegendre, BasisDerivatives)
Legendre::basisDerivatives(10, result2);
const double ref10[10][10] =
{{-22.5, 30.43814502928187, -12.17794670742983, 6.943788485133953,
-4.599354761103132, 3.294643033749184, -2.452884175442686, 1.829563931903247,
const Real ref10[10][10] =
{{-22.5, 30.43814502928187, -12.17794670742983, 6.943788485133953,
-4.599354761103132, 3.294643033749184, -2.452884175442686, 1.829563931903247,
-1.275954836092663, 0.5},
{-5.074064702978062, 0.0, 7.185502869705838, -3.351663862746774,
2.078207994036418, -1.444948448751456, 1.059154463645442, -0.783239293137909,
0.5437537382357057, -0.2127027580091889},
{ 1.203351992852207, -4.259297354965217, 0.0, 4.368674557010181,
-2.104350179413155, 1.334915483878251, -0.9366032131394465,
0.6767970871960859, -0.4642749589081571, 0.1807865854892499},
0.6767970871960859, -0.4642749589081571, 0.1807865854892499},
{-0.528369376820273, 1.529902638181603, -3.364125868297819, 0.0,
3.387318101202445, -1.64649408398706, 1.046189365502494,
-0.7212373127216044, 0.4834623263339481, -0.1866457893937361},
@@ -125,7 +128,7 @@ TEST(TestLegendre, BasisDerivatives)
0.9366032131394465, -1.334915483878251, 2.104350179413155,
-4.368674557010181, 0.0, 4.259297354965217,
-1.203351992852207},
{0.2127027580091889, -0.5437537382357057, 0.783239293137909,
{ 0.2127027580091889, -0.5437537382357057, 0.783239293137909,
-1.059154463645442, 1.444948448751456, -2.078207994036418,
3.351663862746774, -7.185502869705838, 0.0, 5.074064702978062},
{-0.5, 1.275954836092663, -1.829563931903247,
@@ -136,3 +139,4 @@ TEST(TestLegendre, BasisDerivatives)
for (size_t j = 0; j < 10; ++j)
EXPECT_FLOAT_EQ(result2(i+1, j+1), ref10[i][j]);
}
#endif