added: alpha in multiply and outer_product functions

This commit is contained in:
Arne Morten Kvarving 2016-07-13 10:38:45 +02:00
parent ad1a22c51e
commit 4bd4208a63

View File

@ -582,7 +582,7 @@ namespace utl //! General utility classes and functions.
*/
matrix<T>& multiply(const matrix<T>& A, const matrix<T>& B,
bool transA = false, bool transB = false,
bool addTo = false);
bool addTo = false, T alpha = T(1));
/*! \brief Matrix-matrix multiplication.
\details Performs one of the following operations (\b C = \a *this):
@ -628,7 +628,7 @@ namespace utl //! General utility classes and functions.
//! \brief Outer product between two vectors.
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
bool addTo = false)
bool addTo = false, T alpha = T(1))
{
if (!addTo)
this->resize(X.size(),Y.size());
@ -646,11 +646,11 @@ namespace utl //! General utility classes and functions.
if (addTo)
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] += X[i]*Y[j];
elem[i+nrow*j] += alpha*X[i]*Y[j];
else
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] = X[i]*Y[j];
elem[i+nrow*j] = alpha*X[i]*Y[j];
return true;
}
@ -1070,7 +1070,8 @@ namespace utl //! General utility classes and functions.
template<> inline
matrix<float>& matrix<float>::multiply(const matrix<float>& A,
const matrix<float>& B,
bool transA, bool transB, bool addTo)
bool transA, bool transB,
bool addTo, float alpha)
{
size_t M, N, K;
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
@ -1079,7 +1080,7 @@ namespace utl //! General utility classes and functions.
cblas_sgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, 1.0f,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0f : 0.0f,
@ -1091,7 +1092,8 @@ namespace utl //! General utility classes and functions.
template<> inline
matrix<double>& matrix<double>::multiply(const matrix<double>& A,
const matrix<double>& B,
bool transA, bool transB, bool addTo)
bool transA, bool transB,
bool addTo, double alpha)
{
size_t M, N, K;
if (!this->compatible(A,B,transA,transB,M,N,K)) return *this;
@ -1100,7 +1102,7 @@ namespace utl //! General utility classes and functions.
cblas_dgemm(CblasColMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M, N, K, 1.0,
M, N, K, alpha,
A.ptr(), A.nrow,
B.ptr(), B.nrow,
addTo ? 1.0 : 0.0,