2018-02-06 10:50:43 -05:00
|
|
|
#ifndef included_FunctionTable_hpp
|
|
|
|
|
#define included_FunctionTable_hpp
|
|
|
|
|
|
|
|
|
|
#include "common/FunctionTable.h"
|
2021-06-02 15:36:44 -04:00
|
|
|
#include "common/UtilityMacros.h"
|
2018-02-06 10:50:43 -05:00
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cstring>
|
|
|
|
|
#include <limits>
|
2021-09-02 08:19:54 -04:00
|
|
|
//#include <random>
|
2018-02-06 10:50:43 -05:00
|
|
|
|
2021-06-02 15:36:44 -04:00
|
|
|
|
2018-02-06 10:50:43 -05:00
|
|
|
/********************************************************
|
|
|
|
|
* Random number initialization *
|
|
|
|
|
********************************************************/
|
2021-09-02 08:19:54 -04:00
|
|
|
/*template<class TYPE> TYPE genRand();
|
2018-09-17 13:03:00 -04:00
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
inline void FunctionTable::rand( Array<TYPE, FUN> &x )
|
2018-02-06 10:50:43 -05:00
|
|
|
{
|
2021-06-02 15:36:44 -04:00
|
|
|
for ( size_t i = 0; i < x.length(); i++ )
|
|
|
|
|
x( i ) = genRand<TYPE>();
|
2018-02-06 10:50:43 -05:00
|
|
|
}
|
2021-09-02 08:19:54 -04:00
|
|
|
*/
|
2018-02-06 10:50:43 -05:00
|
|
|
|
|
|
|
|
/********************************************************
|
|
|
|
|
* Reduction *
|
|
|
|
|
********************************************************/
|
|
|
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
2021-06-02 15:36:44 -04:00
|
|
|
inline TYPE FunctionTable::reduce( LAMBDA &op, const Array<TYPE, FUN> &A, const TYPE &initialValue )
|
2018-02-06 10:50:43 -05:00
|
|
|
{
|
|
|
|
|
if ( A.length() == 0 )
|
|
|
|
|
return TYPE();
|
2021-06-02 15:36:44 -04:00
|
|
|
const TYPE *x = A.data();
|
|
|
|
|
TYPE y = initialValue;
|
|
|
|
|
for ( size_t i = 0; i < A.length(); i++ )
|
2018-02-06 10:50:43 -05:00
|
|
|
y = op( x[i], y );
|
|
|
|
|
return y;
|
|
|
|
|
}
|
2021-06-02 15:36:44 -04:00
|
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
|
|
|
|
inline TYPE FunctionTable::reduce( LAMBDA &op,
|
|
|
|
|
const Array<TYPE, FUN> &A,
|
|
|
|
|
const Array<TYPE, FUN> &B,
|
|
|
|
|
const TYPE &initialValue )
|
|
|
|
|
{
|
|
|
|
|
ARRAY_ASSERT( A.length() == B.length() );
|
|
|
|
|
if ( A.length() == 0 )
|
|
|
|
|
return TYPE();
|
|
|
|
|
const TYPE *x = A.data();
|
|
|
|
|
const TYPE *y = B.data();
|
|
|
|
|
TYPE z = initialValue;
|
|
|
|
|
for ( size_t i = 0; i < A.length(); i++ )
|
|
|
|
|
z = op( x[i], y[i], z );
|
|
|
|
|
return z;
|
|
|
|
|
}
|
2018-02-06 10:50:43 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
/********************************************************
|
|
|
|
|
* Unary transformation *
|
|
|
|
|
********************************************************/
|
|
|
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
2021-06-02 15:36:44 -04:00
|
|
|
inline void FunctionTable::transform( LAMBDA &fun, const Array<TYPE, FUN> &x, Array<TYPE, FUN> &y )
|
2018-02-06 10:50:43 -05:00
|
|
|
{
|
|
|
|
|
y.resize( x.size() );
|
|
|
|
|
const size_t N = x.length();
|
|
|
|
|
for ( size_t i = 0; i < N; i++ )
|
|
|
|
|
y( i ) = fun( x( i ) );
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
2021-06-02 15:36:44 -04:00
|
|
|
inline void FunctionTable::transform( LAMBDA &fun,
|
|
|
|
|
const Array<TYPE, FUN> &x,
|
|
|
|
|
const Array<TYPE, FUN> &y,
|
|
|
|
|
Array<TYPE, FUN> &z )
|
2018-02-06 10:50:43 -05:00
|
|
|
{
|
2018-09-17 13:03:00 -04:00
|
|
|
if ( x.size() != y.size() )
|
2018-02-06 10:50:43 -05:00
|
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
|
|
|
z.resize( x.size() );
|
|
|
|
|
const size_t N = x.length();
|
|
|
|
|
for ( size_t i = 0; i < N; i++ )
|
|
|
|
|
z( i ) = fun( x( i ), y( i ) );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2018-09-17 13:03:00 -04:00
|
|
|
/********************************************************
|
|
|
|
|
* axpy *
|
|
|
|
|
********************************************************/
|
|
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_axpy( size_t N, const TYPE alpha, const TYPE *x, TYPE *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_axpy<float>( size_t N, const float alpha, const float *x, float *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_axpy<double>( size_t N, const double alpha, const double *x, double *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_axpy( size_t N, const TYPE alpha, const TYPE *x, TYPE *y )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
for ( size_t i = 0; i < N; i++ )
|
|
|
|
|
y[i] += alpha * x[i];
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
void FunctionTable::axpy( const TYPE alpha, const Array<TYPE, FUN> &x, Array<TYPE, FUN> &y )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
if ( x.size() != y.size() )
|
|
|
|
|
throw std::logic_error( "Array sizes do not match" );
|
|
|
|
|
call_axpy( x.length(), alpha, x.data(), y.data() );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2018-02-06 10:50:43 -05:00
|
|
|
/********************************************************
|
|
|
|
|
* Multiply two arrays *
|
|
|
|
|
********************************************************/
|
2018-09-17 13:03:00 -04:00
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemv( size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *x, TYPE *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemv<double>(
|
|
|
|
|
size_t M, size_t N, double alpha, double beta, const double *A, const double *x, double *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemv<float>(
|
|
|
|
|
size_t M, size_t N, float alpha, float beta, const float *A, const float *x, float *y );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemv( size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *x, TYPE *y )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
for ( size_t i = 0; i < M; i++ )
|
|
|
|
|
y[i] = beta * y[i];
|
|
|
|
|
for ( size_t j = 0; j < N; j++ ) {
|
|
|
|
|
for ( size_t i = 0; i < M; i++ )
|
|
|
|
|
y[i] += alpha * A[i + j * M] * x[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemm(
|
|
|
|
|
size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *B, TYPE *C );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemm<double>( size_t M,
|
|
|
|
|
size_t N,
|
|
|
|
|
size_t K,
|
|
|
|
|
double alpha,
|
|
|
|
|
double beta,
|
|
|
|
|
const double *A,
|
|
|
|
|
const double *B,
|
|
|
|
|
double *C );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemm<float>( size_t M,
|
|
|
|
|
size_t N,
|
|
|
|
|
size_t K,
|
|
|
|
|
float alpha,
|
|
|
|
|
float beta,
|
|
|
|
|
const float *A,
|
|
|
|
|
const float *B,
|
|
|
|
|
float *C );
|
2018-09-17 13:03:00 -04:00
|
|
|
template<class TYPE>
|
2021-06-02 15:36:44 -04:00
|
|
|
void call_gemm(
|
|
|
|
|
size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *B, TYPE *C )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
for ( size_t i = 0; i < K * M; i++ )
|
|
|
|
|
C[i] = beta * C[i];
|
|
|
|
|
for ( size_t k = 0; k < K; k++ ) {
|
|
|
|
|
for ( size_t j = 0; j < N; j++ ) {
|
|
|
|
|
for ( size_t i = 0; i < M; i++ )
|
|
|
|
|
C[i + k * M] += alpha * A[i + j * M] * B[j + k * N];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
void FunctionTable::gemm( const TYPE alpha,
|
|
|
|
|
const Array<TYPE, FUN> &a,
|
|
|
|
|
const Array<TYPE, FUN> &b,
|
|
|
|
|
const TYPE beta,
|
|
|
|
|
Array<TYPE, FUN> &c )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
2021-06-02 15:36:44 -04:00
|
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
|
|
|
throw std::logic_error( "Inner dimensions must match" );
|
2018-09-17 13:03:00 -04:00
|
|
|
if ( a.ndim() == 2 && b.ndim() == 1 ) {
|
|
|
|
|
call_gemv<TYPE>( a.size( 0 ), a.size( 1 ), alpha, beta, a.data(), b.data(), c.data() );
|
|
|
|
|
} else if ( a.ndim() <= 2 && b.ndim() <= 2 ) {
|
|
|
|
|
call_gemm<TYPE>(
|
|
|
|
|
a.size( 0 ), a.size( 1 ), b.size( 1 ), alpha, beta, a.data(), b.data(), c.data() );
|
|
|
|
|
} else {
|
|
|
|
|
throw std::logic_error( "Not finished yet" );
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-02-06 10:50:43 -05:00
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
void FunctionTable::multiply( const Array<TYPE, FUN> &a,
|
|
|
|
|
const Array<TYPE, FUN> &b,
|
|
|
|
|
Array<TYPE, FUN> &c )
|
2018-02-06 10:50:43 -05:00
|
|
|
{
|
2021-06-02 15:36:44 -04:00
|
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
|
|
|
throw std::logic_error( "Inner dimensions must match" );
|
2018-09-17 13:03:00 -04:00
|
|
|
if ( a.ndim() == 2 && b.ndim() == 1 ) {
|
|
|
|
|
c.resize( a.size( 0 ) );
|
|
|
|
|
call_gemv<TYPE>( a.size( 0 ), a.size( 1 ), 1, 0, a.data(), b.data(), c.data() );
|
|
|
|
|
} else if ( a.ndim() <= 2 && b.ndim() <= 2 ) {
|
2018-02-06 10:50:43 -05:00
|
|
|
c.resize( a.size( 0 ), b.size( 1 ) );
|
2018-09-17 13:03:00 -04:00
|
|
|
call_gemm<TYPE>(
|
|
|
|
|
a.size( 0 ), a.size( 1 ), b.size( 1 ), 1, 0, a.data(), b.data(), c.data() );
|
2018-02-06 10:50:43 -05:00
|
|
|
} else {
|
|
|
|
|
throw std::logic_error( "Not finished yet" );
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2018-09-17 13:03:00 -04:00
|
|
|
/********************************************************
|
|
|
|
|
* Check if two arrays are equal *
|
|
|
|
|
********************************************************/
|
|
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
inline typename std::enable_if<std::is_integral<TYPE>::value, bool>::type
|
|
|
|
|
FunctionTableCompare( const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b, TYPE )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
bool pass = true;
|
|
|
|
|
if ( a.size() != b.size() )
|
|
|
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
|
|
|
for ( size_t i = 0; i < a.length(); i++ )
|
|
|
|
|
pass = pass && a( i ) == b( i );
|
|
|
|
|
return pass;
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN>
|
|
|
|
|
inline typename std::enable_if<std::is_floating_point<TYPE>::value, bool>::type
|
2021-06-02 15:36:44 -04:00
|
|
|
FunctionTableCompare( const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b, TYPE tol )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
bool pass = true;
|
|
|
|
|
if ( a.size() != b.size() )
|
|
|
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
|
|
|
for ( size_t i = 0; i < a.length(); i++ )
|
|
|
|
|
pass = pass && ( std::abs( a( i ) - b( i ) ) < tol );
|
|
|
|
|
return pass;
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN>
|
2021-06-02 15:36:44 -04:00
|
|
|
bool FunctionTable::equals( const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b, TYPE tol )
|
2018-09-17 13:03:00 -04:00
|
|
|
{
|
|
|
|
|
return FunctionTableCompare( a, b, tol );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2021-06-02 15:36:44 -04:00
|
|
|
/********************************************************
|
|
|
|
|
* Specialized Functions *
|
|
|
|
|
********************************************************/
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformReLU( const Array<TYPE, FUN, ALLOC> &A, Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
const auto &fun = []( const TYPE &a ) { return std::max( a, static_cast<TYPE>( 0 ) ); };
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformAbs( const Array<TYPE, FUN, ALLOC> &A, Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
B.resize( A.size() );
|
|
|
|
|
const auto &fun = []( const TYPE &a ) { return std::abs( a ); };
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformTanh( const Array<TYPE, FUN, ALLOC> &A, Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
B.resize( A.size() );
|
|
|
|
|
const auto &fun = []( const TYPE &a ) { return tanh( a ); };
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformHardTanh( const Array<TYPE, FUN, ALLOC> &A,
|
|
|
|
|
Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
B.resize( A.size() );
|
|
|
|
|
const auto &fun = []( const TYPE &a ) {
|
|
|
|
|
return std::max( -static_cast<TYPE>( 1.0 ), std::min( static_cast<TYPE>( 1.0 ), a ) );
|
|
|
|
|
};
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformSigmoid( const Array<TYPE, FUN, ALLOC> &A, Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
B.resize( A.size() );
|
|
|
|
|
const auto &fun = []( const TYPE &a ) { return 1.0 / ( 1.0 + exp( -a ) ); };
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
void FunctionTable::transformSoftPlus( const Array<TYPE, FUN, ALLOC> &A,
|
|
|
|
|
Array<TYPE, FUN, ALLOC> &B )
|
|
|
|
|
{
|
|
|
|
|
B.resize( A.size() );
|
|
|
|
|
const auto &fun = []( const TYPE &a ) { return log1p( exp( a ) ); };
|
|
|
|
|
transform( fun, A, B );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE, class FUN, class ALLOC>
|
|
|
|
|
TYPE FunctionTable::sum( const Array<TYPE, FUN, ALLOC> &A )
|
|
|
|
|
{
|
|
|
|
|
const auto &fun = []( const TYPE &a, const TYPE &b ) { return a + b; };
|
|
|
|
|
return reduce( fun, A, (TYPE) 0 );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<class TYPE>
|
|
|
|
|
inline void FunctionTable::gemmWrapper( char TRANSA,
|
|
|
|
|
char TRANSB,
|
|
|
|
|
int M,
|
|
|
|
|
int N,
|
|
|
|
|
int K,
|
|
|
|
|
TYPE alpha,
|
|
|
|
|
const TYPE *A,
|
|
|
|
|
int LDA,
|
|
|
|
|
const TYPE *B,
|
|
|
|
|
int LDB,
|
|
|
|
|
TYPE beta,
|
|
|
|
|
TYPE *C,
|
|
|
|
|
int LDC )
|
|
|
|
|
{
|
|
|
|
|
ERROR("Not finished");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2018-02-06 10:50:43 -05:00
|
|
|
#endif
|