diff --git a/src/LinAlg/matrix.h b/src/LinAlg/matrix.h index 94afc2f8..1a1b1ebb 100644 --- a/src/LinAlg/matrix.h +++ b/src/LinAlg/matrix.h @@ -638,32 +638,7 @@ namespace utl //! General utility classes and functions. //! \brief Outer product between two vectors. bool outer_product(const std::vector& X, const std::vector& Y, - bool addTo = false, T alpha = T(1)) - { - if (!addTo) - this->resize(X.size(),Y.size()); - else if (X.size() != nrow || Y.size() != ncol) - { - std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A(" - << nrow <<','<< ncol <<"), X(" - << X.size() <<"), Y("<< Y.size() <<")\n" - <<" when computing A += X*Y^t" - << std::endl; - ABORT_ON_INDEX_CHECK; - return false; - } - - if (addTo) - for (size_t j = 0; j < ncol; j++) - for (size_t i = 0; i < nrow; i++) - elem[i+nrow*j] += alpha*X[i]*Y[j]; - else - for (size_t j = 0; j < ncol; j++) - for (size_t i = 0; i < nrow; i++) - elem[i+nrow*j] = alpha*X[i]*Y[j]; - - return true; - } + bool addTo = false, T alpha = T(1)); //! \brief Return the Euclidean norm of the matrix. T norm2() const { return elem.norm2(); } @@ -1259,6 +1234,68 @@ namespace utl //! General utility classes and functions. return true; } + template<> inline + bool matrix::outer_product(const std::vector& X, + const std::vector& Y, + bool addTo, float alpha) + { + if (!addTo) + this->resize(X.size(),Y.size()); + else if (X.size() != nrow || Y.size() != ncol) + { + std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A(" + << nrow <<','<< ncol <<"), X(" + << X.size() <<"), Y("<< Y.size() <<")\n" + <<" when computing A += X*Y^t" + << std::endl; + ABORT_ON_INDEX_CHECK; + return false; + } + size_t M = X.size(); + size_t N = Y.size(); + size_t K = 1; + cblas_sgemm(CblasColMajor, + CblasNoTrans, CblasTrans, + M, N, K, alpha, + X.data(), M, + Y.data(), N, + addTo ? 1.0f : 0.0f, + this->ptr(), nrow); + + return true; + } + + template<> inline + bool matrix::outer_product(const std::vector& X, + const std::vector& Y, + bool addTo, double alpha) + { + if (!addTo) + this->resize(X.size(),Y.size()); + else if (X.size() != nrow || Y.size() != ncol) + { + std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A(" + << nrow <<','<< ncol <<"), X(" + << X.size() <<"), Y("<< Y.size() <<")\n" + <<" when computing A += X*Y^t" + << std::endl; + ABORT_ON_INDEX_CHECK; + return false; + } + size_t M = X.size(); + size_t N = Y.size(); + size_t K = 1; + cblas_dgemm(CblasColMajor, + CblasNoTrans, CblasTrans, + M, N, K, alpha, + X.data(), M, + Y.data(), N, + addTo ? 1.0 : 0.0, + this->ptr(), nrow); + + return true; + } + template<> inline bool matrix3d::multiply(const matrix& A, const matrix3d& B, @@ -1490,6 +1527,36 @@ namespace utl //! General utility classes and functions. return true; } + template inline + bool matrix::outer_product(const std::vector& X, + const std::vector& Y, + bool addTo = false, T alpha = T(1)) + { + if (!addTo) + this->resize(X.size(),Y.size()); + else if (X.size() != nrow || Y.size() != ncol) + { + std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A(" + << nrow <<','<< ncol <<"), X(" + << X.size() <<"), Y("<< Y.size() <<")\n" + <<" when computing A += X*Y^t" + << std::endl; + ABORT_ON_INDEX_CHECK; + return false; + } + + if (addTo) + for (size_t j = 0; j < ncol; j++) + for (size_t i = 0; i < nrow; i++) + elem[i+nrow*j] += alpha*X[i]*Y[j]; + else + for (size_t j = 0; j < ncol; j++) + for (size_t i = 0; i < nrow; i++) + elem[i+nrow*j] = alpha*X[i]*Y[j]; + + return true; + } + template inline bool matrix3d::multiply(const matrix& A, const matrix3d& B, bool transA, bool addTo)