added: utl:matrix::multiply() exposing the full gemv API

alpha, beta, strides and offsets
This commit is contained in:
Arne Morten Kvarving 2016-11-18 14:34:27 +01:00
parent 92b01e54a4
commit 9615f33c30

View File

@ -626,6 +626,16 @@ namespace utl //! General utility classes and functions.
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
bool transA = false, char addTo = 0) const;
/*! \brief Matrix-vector multiplication.
\details Performs the following operations (\b A = \a *this):
-# \f$ {\bf Y} = {\alpha}{\bf A} {\bf X} + {\beta}{\bf Y}\f$
-# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$
*/
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
double alpha, double beta, bool transA = false,
int stridex = 1, int stridey = 1,
int ofsx = 0, int ofsy = 0) const;
//! \brief Outer product between two vectors.
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
bool addTo = false, T alpha = T(1))
@ -1049,6 +1059,25 @@ namespace utl //! General utility classes and functions.
return true;
}
template<> inline
bool matrix<float>::multiply(const std::vector<float>& X,
std::vector<float>& Y,
double alpha, double beta, bool transA,
int stridex, int stridey, int ofsx,
int ofsy) const
{
if (!this->compatible(X,transA)) return false;
cblas_sgemv(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
nrow, ncol, alpha,
this->ptr(), nrow,
&X[ofsx], stridex, beta,
&Y[ofsy], stridey);
return true;
}
template<> inline
bool matrix<double>::multiply(const std::vector<double>& X,
std::vector<double>& Y,
@ -1067,6 +1096,25 @@ namespace utl //! General utility classes and functions.
return true;
}
template<> inline
bool matrix<double>::multiply(const std::vector<double>& X,
std::vector<double>& Y,
double alpha, double beta, bool transA,
int stridex, int stridey,
int ofsx, int ofsy) const
{
if (!this->compatible(X,transA)) return false;
cblas_dgemv(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
nrow, ncol, alpha,
this->ptr(), nrow,
&X[ofsx], stridex, beta,
&Y[ofsy], stridey);
return true;
}
template<> inline
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
const matrix<float>& B,
@ -1359,6 +1407,26 @@ namespace utl //! General utility classes and functions.
return true;
}
template<class T> inline
bool matrix<T>::multiply(const std::vector<T>& X, std::vector<T>& Y,
double alpha, double beta, bool transA,
int stridex, int stridey,
int ofsx, int ofsy) const
{
if (!this->compatible(X,transA)) return false;
for (size_t i = ofsy; i < Y.size(); i += stridey) {
Y[i] *= beta;
for (size_t j = ofsx; j < X.size(); j += stridex)
if (transA)
Y[i] += alpha * THIS(j+1,i+1) * X[j];
else
Y[i] += alpha * THIS(i+1,j+1) * X[j];
}
return true;
}
template<class T> inline
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
const matrix<T>& B,