added: use BLAS in outer_product
This commit is contained in:
parent
85a7024501
commit
83073f06f2
@ -638,32 +638,7 @@ namespace utl //! General utility classes and functions.
|
|||||||
|
|
||||||
//! \brief Outer product between two vectors.
|
//! \brief Outer product between two vectors.
|
||||||
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
||||||
bool addTo = false, T alpha = T(1))
|
bool addTo = false, T alpha = T(1));
|
||||||
{
|
|
||||||
if (!addTo)
|
|
||||||
this->resize(X.size(),Y.size());
|
|
||||||
else if (X.size() != nrow || Y.size() != ncol)
|
|
||||||
{
|
|
||||||
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
|
||||||
<< nrow <<','<< ncol <<"), X("
|
|
||||||
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
|
||||||
<<" when computing A += X*Y^t"
|
|
||||||
<< std::endl;
|
|
||||||
ABORT_ON_INDEX_CHECK;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (addTo)
|
|
||||||
for (size_t j = 0; j < ncol; j++)
|
|
||||||
for (size_t i = 0; i < nrow; i++)
|
|
||||||
elem[i+nrow*j] += alpha*X[i]*Y[j];
|
|
||||||
else
|
|
||||||
for (size_t j = 0; j < ncol; j++)
|
|
||||||
for (size_t i = 0; i < nrow; i++)
|
|
||||||
elem[i+nrow*j] = alpha*X[i]*Y[j];
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//! \brief Return the Euclidean norm of the matrix.
|
//! \brief Return the Euclidean norm of the matrix.
|
||||||
T norm2() const { return elem.norm2(); }
|
T norm2() const { return elem.norm2(); }
|
||||||
@ -1259,6 +1234,68 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix<float>::outer_product(const std::vector<float>& X,
|
||||||
|
const std::vector<float>& Y,
|
||||||
|
bool addTo, float alpha)
|
||||||
|
{
|
||||||
|
if (!addTo)
|
||||||
|
this->resize(X.size(),Y.size());
|
||||||
|
else if (X.size() != nrow || Y.size() != ncol)
|
||||||
|
{
|
||||||
|
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||||
|
<< nrow <<','<< ncol <<"), X("
|
||||||
|
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||||
|
<<" when computing A += X*Y^t"
|
||||||
|
<< std::endl;
|
||||||
|
ABORT_ON_INDEX_CHECK;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t M = X.size();
|
||||||
|
size_t N = Y.size();
|
||||||
|
size_t K = 1;
|
||||||
|
cblas_sgemm(CblasColMajor,
|
||||||
|
CblasNoTrans, CblasTrans,
|
||||||
|
M, N, K, alpha,
|
||||||
|
X.data(), M,
|
||||||
|
Y.data(), N,
|
||||||
|
addTo ? 1.0f : 0.0f,
|
||||||
|
this->ptr(), nrow);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> inline
|
||||||
|
bool matrix<double>::outer_product(const std::vector<double>& X,
|
||||||
|
const std::vector<double>& Y,
|
||||||
|
bool addTo, double alpha)
|
||||||
|
{
|
||||||
|
if (!addTo)
|
||||||
|
this->resize(X.size(),Y.size());
|
||||||
|
else if (X.size() != nrow || Y.size() != ncol)
|
||||||
|
{
|
||||||
|
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||||
|
<< nrow <<','<< ncol <<"), X("
|
||||||
|
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||||
|
<<" when computing A += X*Y^t"
|
||||||
|
<< std::endl;
|
||||||
|
ABORT_ON_INDEX_CHECK;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t M = X.size();
|
||||||
|
size_t N = Y.size();
|
||||||
|
size_t K = 1;
|
||||||
|
cblas_dgemm(CblasColMajor,
|
||||||
|
CblasNoTrans, CblasTrans,
|
||||||
|
M, N, K, alpha,
|
||||||
|
X.data(), M,
|
||||||
|
Y.data(), N,
|
||||||
|
addTo ? 1.0 : 0.0,
|
||||||
|
this->ptr(), nrow);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<> inline
|
template<> inline
|
||||||
bool matrix3d<double>::multiply(const matrix<double>& A,
|
bool matrix3d<double>::multiply(const matrix<double>& A,
|
||||||
const matrix3d<double>& B,
|
const matrix3d<double>& B,
|
||||||
@ -1490,6 +1527,36 @@ namespace utl //! General utility classes and functions.
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T> inline
|
||||||
|
bool matrix<T>::outer_product(const std::vector<T>& X,
|
||||||
|
const std::vector<T>& Y,
|
||||||
|
bool addTo = false, T alpha = T(1))
|
||||||
|
{
|
||||||
|
if (!addTo)
|
||||||
|
this->resize(X.size(),Y.size());
|
||||||
|
else if (X.size() != nrow || Y.size() != ncol)
|
||||||
|
{
|
||||||
|
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||||
|
<< nrow <<','<< ncol <<"), X("
|
||||||
|
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||||
|
<<" when computing A += X*Y^t"
|
||||||
|
<< std::endl;
|
||||||
|
ABORT_ON_INDEX_CHECK;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (addTo)
|
||||||
|
for (size_t j = 0; j < ncol; j++)
|
||||||
|
for (size_t i = 0; i < nrow; i++)
|
||||||
|
elem[i+nrow*j] += alpha*X[i]*Y[j];
|
||||||
|
else
|
||||||
|
for (size_t j = 0; j < ncol; j++)
|
||||||
|
for (size_t i = 0; i < nrow; i++)
|
||||||
|
elem[i+nrow*j] = alpha*X[i]*Y[j];
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<class T> inline
|
template<class T> inline
|
||||||
bool matrix3d<T>::multiply(const matrix<T>& A, const matrix3d<T>& B,
|
bool matrix3d<T>::multiply(const matrix<T>& A, const matrix3d<T>& B,
|
||||||
bool transA, bool addTo)
|
bool transA, bool addTo)
|
||||||
|
Loading…
Reference in New Issue
Block a user