From 9615f33c303e2b9216cba4e2e8d26034fb8439a4 Mon Sep 17 00:00:00 2001 From: Arne Morten Kvarving Date: Fri, 18 Nov 2016 14:34:27 +0100 Subject: [PATCH] added: utl:matrix::multiply() exposing the full gemv API alpha, beta, strides and offsets --- src/LinAlg/matrix.h | 68 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/LinAlg/matrix.h b/src/LinAlg/matrix.h index 22e8b746..94afc2f8 100644 --- a/src/LinAlg/matrix.h +++ b/src/LinAlg/matrix.h @@ -626,6 +626,16 @@ namespace utl //! General utility classes and functions. bool multiply(const std::vector& X, std::vector& Y, bool transA = false, char addTo = 0) const; + /*! \brief Matrix-vector multiplication. + \details Performs the following operations (\b A = \a *this): + -# \f$ {\bf Y} = {\alpha}{\bf A} {\bf X} + {\beta}{\bf Y}\f$ + -# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$ + */ + bool multiply(const std::vector& X, std::vector& Y, + double alpha, double beta, bool transA = false, + int stridex = 1, int stridey = 1, + int ofsx = 0, int ofsy = 0) const; + //! \brief Outer product between two vectors. bool outer_product(const std::vector& X, const std::vector& Y, bool addTo = false, T alpha = T(1)) @@ -1049,6 +1059,25 @@ namespace utl //! General utility classes and functions. return true; } + template<> inline + bool matrix::multiply(const std::vector& X, + std::vector& Y, + double alpha, double beta, bool transA, + int stridex, int stridey, int ofsx, + int ofsy) const + { + if (!this->compatible(X,transA)) return false; + + cblas_sgemv(CblasColMajor, + transA ? CblasTrans : CblasNoTrans, + nrow, ncol, alpha, + this->ptr(), nrow, + &X[ofsx], stridex, beta, + &Y[ofsy], stridey); + + return true; + } + template<> inline bool matrix::multiply(const std::vector& X, std::vector& Y, @@ -1067,6 +1096,25 @@ namespace utl //! General utility classes and functions. return true; } + template<> inline + bool matrix::multiply(const std::vector& X, + std::vector& Y, + double alpha, double beta, bool transA, + int stridex, int stridey, + int ofsx, int ofsy) const + { + if (!this->compatible(X,transA)) return false; + + cblas_dgemv(CblasColMajor, + transA ? CblasTrans : CblasNoTrans, + nrow, ncol, alpha, + this->ptr(), nrow, + &X[ofsx], stridex, beta, + &Y[ofsy], stridey); + + return true; + } + template<> inline matrix& matrix::multiply(const matrix& A, const matrix& B, @@ -1359,6 +1407,26 @@ namespace utl //! General utility classes and functions. return true; } + template inline + bool matrix::multiply(const std::vector& X, std::vector& Y, + double alpha, double beta, bool transA, + int stridex, int stridey, + int ofsx, int ofsy) const + { + if (!this->compatible(X,transA)) return false; + + for (size_t i = ofsy; i < Y.size(); i += stridey) { + Y[i] *= beta; + for (size_t j = ofsx; j < X.size(); j += stridex) + if (transA) + Y[i] += alpha * THIS(j+1,i+1) * X[j]; + else + Y[i] += alpha * THIS(i+1,j+1) * X[j]; + } + + return true; + } + template inline matrix& matrix::multiply(const matrix& A, const matrix& B,