added: utl:matrix::multiply() exposing the full gemv API
alpha, beta, strides and offsets
This commit is contained in:
parent
92b01e54a4
commit
9615f33c30
@ -626,6 +626,16 @@ namespace utl //! General utility classes and functions.
|
|||||||
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||||
bool transA = false, char addTo = 0) const;
|
bool transA = false, char addTo = 0) const;
|
||||||
|
|
||||||
|
/*! \brief Matrix-vector multiplication.
|
||||||
|
\details Performs the following operations (\b A = \a *this):
|
||||||
|
-# \f$ {\bf Y} = {\alpha}{\bf A} {\bf X} + {\beta}{\bf Y}\f$
|
||||||
|
-# \f$ {\bf Y} = {\alpha}{\bf A}^T {\bf X} + {\beta}{\bf Y}\f$
|
||||||
|
*/
|
||||||
|
bool multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||||
|
double alpha, double beta, bool transA = false,
|
||||||
|
int stridex = 1, int stridey = 1,
|
||||||
|
int ofsx = 0, int ofsy = 0) const;
|
||||||
|
|
||||||
//! \brief Outer product between two vectors.
|
//! \brief Outer product between two vectors.
|
||||||
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
||||||
bool addTo = false, T alpha = T(1))
|
bool addTo = false, T alpha = T(1))
|
||||||
@ -1049,6 +1059,25 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix<float>::multiply(const std::vector<float>& X,
|
||||||
|
std::vector<float>& Y,
|
||||||
|
double alpha, double beta, bool transA,
|
||||||
|
int stridex, int stridey, int ofsx,
|
||||||
|
int ofsy) const
|
||||||
|
{
|
||||||
|
if (!this->compatible(X,transA)) return false;
|
||||||
|
|
||||||
|
cblas_sgemv(CblasColMajor,
|
||||||
|
transA ? CblasTrans : CblasNoTrans,
|
||||||
|
nrow, ncol, alpha,
|
||||||
|
this->ptr(), nrow,
|
||||||
|
&X[ofsx], stridex, beta,
|
||||||
|
&Y[ofsy], stridey);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<> inline
|
template<> inline
|
||||||
bool matrix<double>::multiply(const std::vector<double>& X,
|
bool matrix<double>::multiply(const std::vector<double>& X,
|
||||||
std::vector<double>& Y,
|
std::vector<double>& Y,
|
||||||
@ -1067,6 +1096,25 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix<double>::multiply(const std::vector<double>& X,
|
||||||
|
std::vector<double>& Y,
|
||||||
|
double alpha, double beta, bool transA,
|
||||||
|
int stridex, int stridey,
|
||||||
|
int ofsx, int ofsy) const
|
||||||
|
{
|
||||||
|
if (!this->compatible(X,transA)) return false;
|
||||||
|
|
||||||
|
cblas_dgemv(CblasColMajor,
|
||||||
|
transA ? CblasTrans : CblasNoTrans,
|
||||||
|
nrow, ncol, alpha,
|
||||||
|
this->ptr(), nrow,
|
||||||
|
&X[ofsx], stridex, beta,
|
||||||
|
&Y[ofsy], stridey);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<> inline
|
template<> inline
|
||||||
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
|
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
|
||||||
const matrix<float>& B,
|
const matrix<float>& B,
|
||||||
@ -1359,6 +1407,26 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T> inline
|
||||||
|
bool matrix<T>::multiply(const std::vector<T>& X, std::vector<T>& Y,
|
||||||
|
double alpha, double beta, bool transA,
|
||||||
|
int stridex, int stridey,
|
||||||
|
int ofsx, int ofsy) const
|
||||||
|
{
|
||||||
|
if (!this->compatible(X,transA)) return false;
|
||||||
|
|
||||||
|
for (size_t i = ofsy; i < Y.size(); i += stridey) {
|
||||||
|
Y[i] *= beta;
|
||||||
|
for (size_t j = ofsx; j < X.size(); j += stridex)
|
||||||
|
if (transA)
|
||||||
|
Y[i] += alpha * THIS(j+1,i+1) * X[j];
|
||||||
|
else
|
||||||
|
Y[i] += alpha * THIS(i+1,j+1) * X[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<class T> inline
|
template<class T> inline
|
||||||
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
|
matrix<T>& matrix<T>::multiply(const matrix<T>& A,
|
||||||
const matrix<T>& B,
|
const matrix<T>& B,
|
||||||
|
Loading…
Reference in New Issue
Block a user