added: alpha in multiply and outer_product functions
This commit is contained in:
parent
ad1a22c51e
commit
4bd4208a63
@ -582,7 +582,7 @@ namespace utl //! General utility classes and functions.
|
||||
*/
|
||||
matrix<T>& multiply(const matrix<T>& A, const matrix<T>& B,
|
||||
bool transA = false, bool transB = false,
|
||||
bool addTo = false);
|
||||
bool addTo = false, T alpha = T(1));
|
||||
|
||||
/*! \brief Matrix-matrix multiplication.
|
||||
\details Performs one of the following operations (\b C = \a *this):
|
||||
@ -628,7 +628,7 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
//! \brief Outer product between two vectors.
|
||||
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
||||
bool addTo = false)
|
||||
bool addTo = false, T alpha = T(1))
|
||||
{
|
||||
if (!addTo)
|
||||
this->resize(X.size(),Y.size());
|
||||
@ -646,11 +646,11 @@ namespace utl //! General utility classes and functions.
|
||||
if (addTo)
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] += X[i]*Y[j];
|
||||
elem[i+nrow*j] += alpha*X[i]*Y[j];
|
||||
else
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] = X[i]*Y[j];
|
||||
elem[i+nrow*j] = alpha*X[i]*Y[j];
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -1070,7 +1070,8 @@ namespace utl //! General utility classes and functions.
|
||||
template<> inline
|
||||
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
|
||||
const matrix<float>& B,
|
||||
bool transA, bool transB, bool addTo)
|
||||
bool transA, bool transB,
|
||||
bool addTo, float alpha)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
|
||||
@ -1079,7 +1080,7 @@ namespace utl //! General utility classes and functions.
|
||||
cblas_sgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, 1.0f,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0f : 0.0f,
|
||||
@ -1091,7 +1092,8 @@ namespace utl //! General utility classes and functions.
|
||||
template<> inline
|
||||
matrix<double>& matrix<double>::multiply(const matrix<double>& A,
|
||||
const matrix<double>& B,
|
||||
bool transA, bool transB, bool addTo)
|
||||
bool transA, bool transB,
|
||||
bool addTo, double alpha)
|
||||
{
|
||||
size_t M, N, K;
|
||||
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
|
||||
@ -1100,7 +1102,7 @@ namespace utl //! General utility classes and functions.
|
||||
cblas_dgemm(CblasColMajor,
|
||||
transA ? CblasTrans : CblasNoTrans,
|
||||
transB ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, 1.0,
|
||||
M, N, K, alpha,
|
||||
A.ptr(), A.nrow,
|
||||
B.ptr(), B.nrow,
|
||||
addTo ? 1.0 : 0.0,
|
||||
|
Loading…
Reference in New Issue
Block a user