added: use BLAS in outer_product
This commit is contained in:
parent
85a7024501
commit
83073f06f2
@ -638,32 +638,7 @@ namespace utl //! General utility classes and functions.
|
||||
|
||||
//! \brief Outer product between two vectors.
|
||||
bool outer_product(const std::vector<T>& X, const std::vector<T>& Y,
|
||||
bool addTo = false, T alpha = T(1))
|
||||
{
|
||||
if (!addTo)
|
||||
this->resize(X.size(),Y.size());
|
||||
else if (X.size() != nrow || Y.size() != ncol)
|
||||
{
|
||||
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||
<< nrow <<','<< ncol <<"), X("
|
||||
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||
<<" when computing A += X*Y^t"
|
||||
<< std::endl;
|
||||
ABORT_ON_INDEX_CHECK;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (addTo)
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] += alpha*X[i]*Y[j];
|
||||
else
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] = alpha*X[i]*Y[j];
|
||||
|
||||
return true;
|
||||
}
|
||||
bool addTo = false, T alpha = T(1));
|
||||
|
||||
//! \brief Return the Euclidean norm of the matrix.
|
||||
T norm2() const { return elem.norm2(); }
|
||||
@ -1259,6 +1234,68 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<float>::outer_product(const std::vector<float>& X,
|
||||
const std::vector<float>& Y,
|
||||
bool addTo, float alpha)
|
||||
{
|
||||
if (!addTo)
|
||||
this->resize(X.size(),Y.size());
|
||||
else if (X.size() != nrow || Y.size() != ncol)
|
||||
{
|
||||
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||
<< nrow <<','<< ncol <<"), X("
|
||||
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||
<<" when computing A += X*Y^t"
|
||||
<< std::endl;
|
||||
ABORT_ON_INDEX_CHECK;
|
||||
return false;
|
||||
}
|
||||
size_t M = X.size();
|
||||
size_t N = Y.size();
|
||||
size_t K = 1;
|
||||
cblas_sgemm(CblasColMajor,
|
||||
CblasNoTrans, CblasTrans,
|
||||
M, N, K, alpha,
|
||||
X.data(), M,
|
||||
Y.data(), N,
|
||||
addTo ? 1.0f : 0.0f,
|
||||
this->ptr(), nrow);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix<double>::outer_product(const std::vector<double>& X,
|
||||
const std::vector<double>& Y,
|
||||
bool addTo, double alpha)
|
||||
{
|
||||
if (!addTo)
|
||||
this->resize(X.size(),Y.size());
|
||||
else if (X.size() != nrow || Y.size() != ncol)
|
||||
{
|
||||
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||
<< nrow <<','<< ncol <<"), X("
|
||||
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||
<<" when computing A += X*Y^t"
|
||||
<< std::endl;
|
||||
ABORT_ON_INDEX_CHECK;
|
||||
return false;
|
||||
}
|
||||
size_t M = X.size();
|
||||
size_t N = Y.size();
|
||||
size_t K = 1;
|
||||
cblas_dgemm(CblasColMajor,
|
||||
CblasNoTrans, CblasTrans,
|
||||
M, N, K, alpha,
|
||||
X.data(), M,
|
||||
Y.data(), N,
|
||||
addTo ? 1.0 : 0.0,
|
||||
this->ptr(), nrow);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<> inline
|
||||
bool matrix3d<double>::multiply(const matrix<double>& A,
|
||||
const matrix3d<double>& B,
|
||||
@ -1490,6 +1527,36 @@ namespace utl //! General utility classes and functions.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
bool matrix<T>::outer_product(const std::vector<T>& X,
|
||||
const std::vector<T>& Y,
|
||||
bool addTo = false, T alpha = T(1))
|
||||
{
|
||||
if (!addTo)
|
||||
this->resize(X.size(),Y.size());
|
||||
else if (X.size() != nrow || Y.size() != ncol)
|
||||
{
|
||||
std::cerr <<"matrix::outer_product: Incompatible matrix and vectors: A("
|
||||
<< nrow <<','<< ncol <<"), X("
|
||||
<< X.size() <<"), Y("<< Y.size() <<")\n"
|
||||
<<" when computing A += X*Y^t"
|
||||
<< std::endl;
|
||||
ABORT_ON_INDEX_CHECK;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (addTo)
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] += alpha*X[i]*Y[j];
|
||||
else
|
||||
for (size_t j = 0; j < ncol; j++)
|
||||
for (size_t i = 0; i < nrow; i++)
|
||||
elem[i+nrow*j] = alpha*X[i]*Y[j];
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class T> inline
|
||||
bool matrix3d<T>::multiply(const matrix<T>& A, const matrix3d<T>& B,
|
||||
bool transA, bool addTo)
|
||||
|
Loading…
Reference in New Issue
Block a user