added: use BLAS in outer_product

This commit is contained in:
Arne Morten Kvarving 2016-12-02 13:31:34 +01:00
parent 85a7024501
commit 83073f06f2

View File

@ -638,32 +638,7 @@ namespace utl //! General utility classes and functions.
//! \brief Outer product between two vectors. //! \brief Outer product between two vectors.
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y, bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
bool addTo = false, T alpha = T(1)) bool addTo = false, T alpha = T(1));
{
if (!addTo)
this->resize(X.size(),Y.size());
else if (X.size() != nrow || Y.size() != ncol)
{
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
<< nrow <<','<< ncol <<"), X("
<< X.size() <<"), Y("<< Y.size() <<")\n"
<<" when computing A += X*Y^t"
<< std::endl;
ABORT_ON_INDEX_CHECK;
return false;
}
if (addTo)
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] += alpha*X[i]*Y[j];
else
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] = alpha*X[i]*Y[j];
return true;
}
//! \brief Return the Euclidean norm of the matrix. //! \brief Return the Euclidean norm of the matrix.
T norm2() const { return elem.norm2(); } T norm2() const { return elem.norm2(); }
@ -1259,6 +1234,68 @@ namespace utl //! General utility classes and functions.
return true; return true;
} }
template<> inline
bool matrix<float>::outer_product(const std::vector<float>& X,
const std::vector<float>& Y,
bool addTo, float alpha)
{
if (!addTo)
this->resize(X.size(),Y.size());
else if (X.size() != nrow || Y.size() != ncol)
{
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
<< nrow <<','<< ncol <<"), X("
<< X.size() <<"), Y("<< Y.size() <<")\n"
<<" when computing A += X*Y^t"
<< std::endl;
ABORT_ON_INDEX_CHECK;
return false;
}
size_t M = X.size();
size_t N = Y.size();
size_t K = 1;
cblas_sgemm(CblasColMajor,
CblasNoTrans, CblasTrans,
M, N, K, alpha,
X.data(), M,
Y.data(), N,
addTo ? 1.0f : 0.0f,
this->ptr(), nrow);
return true;
}
template<> inline
bool matrix<double>::outer_product(const std::vector<double>& X,
const std::vector<double>& Y,
bool addTo, double alpha)
{
if (!addTo)
this->resize(X.size(),Y.size());
else if (X.size() != nrow || Y.size() != ncol)
{
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
<< nrow <<','<< ncol <<"), X("
<< X.size() <<"), Y("<< Y.size() <<")\n"
<<" when computing A += X*Y^t"
<< std::endl;
ABORT_ON_INDEX_CHECK;
return false;
}
size_t M = X.size();
size_t N = Y.size();
size_t K = 1;
cblas_dgemm(CblasColMajor,
CblasNoTrans, CblasTrans,
M, N, K, alpha,
X.data(), M,
Y.data(), N,
addTo ? 1.0 : 0.0,
this->ptr(), nrow);
return true;
}
template<> inline template<> inline
bool matrix3d<double>::multiply(const matrix<double>& A, bool matrix3d<double>::multiply(const matrix<double>& A,
const matrix3d<double>& B, const matrix3d<double>& B,
@ -1490,6 +1527,36 @@ namespace utl //! General utility classes and functions.
return true; return true;
} }
template<class T> inline
bool matrix<T>::outer_product(const std::vector<T>& X,
const std::vector<T>& Y,
bool addTo = false, T alpha = T(1))
{
if (!addTo)
this->resize(X.size(),Y.size());
else if (X.size() != nrow || Y.size() != ncol)
{
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
<< nrow <<','<< ncol <<"), X("
<< X.size() <<"), Y("<< Y.size() <<")\n"
<<" when computing A += X*Y^t"
<< std::endl;
ABORT_ON_INDEX_CHECK;
return false;
}
if (addTo)
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] += alpha*X[i]*Y[j];
else
for (size_t j = 0; j < ncol; j++)
for (size_t i = 0; i < nrow; i++)
elem[i+nrow*j] = alpha*X[i]*Y[j];
return true;
}
template<class T> inline template<class T> inline
bool matrix3d<T>::multiply(const matrix<T>& A, const matrix3d<T>& B, bool matrix3d<T>::multiply(const matrix<T>& A, const matrix3d<T>& B,
bool transA, bool addTo) bool transA, bool addTo)