added: utl:matrix::multiply() exposing the full gemv API
alpha, beta, strides and offsets
This commit is contained in:
parent
92b01e54a4
commit
9615f33c30
@ -626,6 +626,16 @@ namespace utl //! General utility classes and functions.
|
||||
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||
bool transA = false, char addTo = 0) const;
|
||||
|
||||
/*! \brief Matrix-vector multiplication.
|
||||
\details Performs the following operations (\b A = \a *this):
|
||||
-# \f$ {\bf Y} = {\alpha}{\bf A} {\bf X} + {\beta}{\bf Y}\f$
|
||||
-# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$
|
||||
*/
|
||||
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||
double alpha, double beta, bool transA = false,
|
||||
int stridex = 1, int stridey = 1,
|
||||
int ofsx = 0, int ofsy = 0) const;
|
||||
|
||||
//! \brief Outer product between two vectors.
|
||||
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
||||
bool addTo = false, T alpha = T(1))
|
||||
@ -1049,6 +1059,25 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<float>::multiply(const std::vector<float>& X,
|
||||
std::vector<float>& Y,
|
||||
double alpha, double beta, bool transA,
|
||||
int stridex, int stridey, int ofsx,
|
||||
int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
cblas_sgemv(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
nrow, ncol, alpha,
|
||||
this->ptr(), nrow,
|
||||
&X[ofsx], stridex, beta,
|
||||
&Y[ofsy], stridey);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<double>::multiply(const std::vector<double>& X,
|
||||
std::vector<double>& Y,
|
||||
@ -1067,6 +1096,25 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<double>::multiply(const std::vector<double>& X,
|
||||
std::vector<double>& Y,
|
||||
double alpha, double beta, bool transA,
|
||||
int stridex, int stridey,
|
||||
int ofsx, int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
cblas_dgemv(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
nrow, ncol, alpha,
|
||||
this->ptr(), nrow,
|
||||
&X[ofsx], stridex, beta,
|
||||
&Y[ofsy], stridey);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
|
||||
const matrix<float>& B,
|
||||
@ -1359,6 +1407,26 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
bool matrix<T>::multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||
double alpha, double beta, bool transA,
|
||||
int stridex, int stridey,
|
||||
int ofsx, int ofsy) const
|
||||
{
|
||||
if (!this->compatible(X,transA)) return false;
|
||||
|
||||
for (size_t i = ofsy; i < Y.size(); i += stridey) {
|
||||
Y[i] *= beta;
|
||||
for (size_t j = ofsx; j < X.size(); j += stridex)
|
||||
if (transA)
|
||||
Y[i] += alpha * THIS(j+1,i+1) * X[j];
|
||||
else
|
||||
Y[i] += alpha * THIS(i+1,j+1) * X[j];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
|
||||
const matrix<T>& B,
|
||||
|
Loading…
Reference in New Issue
Block a user