#ifndef included_FunctionTable_hpp #define included_FunctionTable_hpp #include "common/FunctionTable.h" #include "common/Utilities.h" #include #include #include #include /******************************************************** * Random number initialization * ********************************************************/ template static inline typename std::enable_if::value>::type genRand( size_t N, TYPE* x ) { std::random_device rd; std::mt19937 gen( rd() ); std::uniform_int_distribution dis; for ( size_t i = 0; i < N; i++ ) x[i] = dis( gen ); } template static inline typename std::enable_if::value>::type genRand( size_t N, TYPE* x ) { std::random_device rd; std::mt19937 gen( rd() ); std::uniform_real_distribution dis( 0, 1 ); for ( size_t i = 0; i < N; i++ ) x[i] = dis( gen ); } template inline void FunctionTable::rand( Array& x ) { genRand( x.length(), x.data() ); } /******************************************************** * Reduction * ********************************************************/ template inline TYPE FunctionTable::reduce( LAMBDA& op, const Array& A ) { if ( A.length() == 0 ) return TYPE(); const TYPE* x = A.data(); TYPE y = x[0]; const size_t N = A.length(); for ( size_t i = 1; i < N; i++ ) y = op( x[i], y ); return y; } /******************************************************** * Unary transformation * ********************************************************/ template inline void FunctionTable::transform( LAMBDA& fun, const Array& x, Array& y ) { y.resize( x.size() ); const size_t N = x.length(); for ( size_t i = 0; i < N; i++ ) y( i ) = fun( x( i ) ); } template inline void FunctionTable::transform( LAMBDA& fun, const Array& x, const Array& y, Array& z ) { if ( x.size() != y.size() ) throw std::logic_error( "Sizes of x and y do not match" ); z.resize( x.size() ); const size_t N = x.length(); for ( size_t i = 0; i < N; i++ ) z( i ) = fun( x( i ), y( i ) ); } /******************************************************** * axpy * ********************************************************/ template inline void call_axpy( size_t N, const TYPE alpha, const TYPE* x, TYPE* y ); template<> inline void call_axpy( size_t, const float, const float*, float* ) { throw std::logic_error( "LapackWrappers not configured" ); } template<> inline void call_axpy( size_t, const double, const double*, double* ) { throw std::logic_error( "LapackWrappers not configured" ); } template inline void call_axpy( size_t N, const TYPE alpha, const TYPE* x, TYPE* y ) { for ( size_t i = 0; i < N; i++ ) y[i] += alpha * x[i]; } template void FunctionTable::axpy( const TYPE alpha, const Array& x, Array& y ) { if ( x.size() != y.size() ) throw std::logic_error( "Array sizes do not match" ); call_axpy( x.length(), alpha, x.data(), y.data() ); } /******************************************************** * Multiply two arrays * ********************************************************/ template inline void call_gemv( size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* x, TYPE* y ); template<> inline void call_gemv( size_t, size_t, double, double, const double*, const double*, double* ) { throw std::logic_error( "LapackWrappers not configured" ); } template<> inline void call_gemv( size_t, size_t, float, float, const float*, const float*, float* ) { throw std::logic_error( "LapackWrappers not configured" ); } template inline void call_gemv( size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* x, TYPE* y ) { for ( size_t i = 0; i < M; i++ ) y[i] = beta * y[i]; for ( size_t j = 0; j < N; j++ ) { for ( size_t i = 0; i < M; i++ ) y[i] += alpha * A[i + j * M] * x[j]; } } template inline void call_gemm( size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* B, TYPE* C ); template<> inline void call_gemm( size_t, size_t, size_t, double, double, const double*, const double*, double* ) { throw std::logic_error( "LapackWrappers not configured" ); } template<> inline void call_gemm( size_t, size_t, size_t, float, float, const float*, const float*, float* ) { throw std::logic_error( "LapackWrappers not configured" ); } template inline void call_gemm( size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* B, TYPE* C ) { for ( size_t i = 0; i < K * M; i++ ) C[i] = beta * C[i]; for ( size_t k = 0; k < K; k++ ) { for ( size_t j = 0; j < N; j++ ) { for ( size_t i = 0; i < M; i++ ) C[i + k * M] += alpha * A[i + j * M] * B[j + k * N]; } } } template void FunctionTable::gemm( const TYPE alpha, const Array& a, const Array& b, const TYPE beta, Array& c ) { if ( a.ndim() == 2 && b.ndim() == 1 ) { if ( a.size( 1 ) != b.size( 0 ) ) throw std::logic_error( "Inner dimensions must match" ); call_gemv( a.size( 0 ), a.size( 1 ), alpha, beta, a.data(), b.data(), c.data() ); } else if ( a.ndim() <= 2 && b.ndim() <= 2 ) { if ( a.size( 1 ) != b.size( 0 ) ) throw std::logic_error( "Inner dimensions must match" ); call_gemm( a.size( 0 ), a.size( 1 ), b.size( 1 ), alpha, beta, a.data(), b.data(), c.data() ); } else { throw std::logic_error( "Not finished yet" ); } } template void FunctionTable::multiply( const Array& a, const Array& b, Array& c ) { if ( a.ndim() == 2 && b.ndim() == 1 ) { if ( a.size( 1 ) != b.size( 0 ) ) throw std::logic_error( "Inner dimensions must match" ); c.resize( a.size( 0 ) ); call_gemv( a.size( 0 ), a.size( 1 ), 1, 0, a.data(), b.data(), c.data() ); } else if ( a.ndim() <= 2 && b.ndim() <= 2 ) { if ( a.size( 1 ) != b.size( 0 ) ) throw std::logic_error( "Inner dimensions must match" ); c.resize( a.size( 0 ), b.size( 1 ) ); call_gemm( a.size( 0 ), a.size( 1 ), b.size( 1 ), 1, 0, a.data(), b.data(), c.data() ); } else { throw std::logic_error( "Not finished yet" ); } } /******************************************************** * Check if two arrays are equal * ********************************************************/ template inline typename std::enable_if::value, bool>::type FunctionTableCompare( const Array& a, const Array& b, TYPE ) { bool pass = true; if ( a.size() != b.size() ) throw std::logic_error( "Sizes of x and y do not match" ); for ( size_t i = 0; i < a.length(); i++ ) pass = pass && a( i ) == b( i ); return pass; } template inline typename std::enable_if::value, bool>::type FunctionTableCompare( const Array& a, const Array& b, TYPE tol ) { bool pass = true; if ( a.size() != b.size() ) throw std::logic_error( "Sizes of x and y do not match" ); for ( size_t i = 0; i < a.length(); i++ ) pass = pass && ( std::abs( a( i ) - b( i ) ) < tol ); return pass; } template bool FunctionTable::equals( const Array& a, const Array& b, TYPE tol ) { return FunctionTableCompare( a, b, tol ); } #endif