238 lines
8.2 KiB
C++
238 lines
8.2 KiB
C++
#ifndef included_FunctionTable_hpp
|
|
#define included_FunctionTable_hpp
|
|
|
|
#include "common/FunctionTable.h"
|
|
#include "common/Utilities.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
#include <limits>
|
|
#include <random>
|
|
|
|
|
|
/********************************************************
|
|
* Random number initialization *
|
|
********************************************************/
|
|
template<class TYPE>
|
|
static inline typename std::enable_if<std::is_integral<TYPE>::value>::type genRand(
|
|
size_t N, TYPE* x )
|
|
{
|
|
std::random_device rd;
|
|
std::mt19937 gen( rd() );
|
|
std::uniform_int_distribution<TYPE> dis;
|
|
for ( size_t i = 0; i < N; i++ )
|
|
x[i] = dis( gen );
|
|
}
|
|
template<class TYPE>
|
|
static inline typename std::enable_if<std::is_floating_point<TYPE>::value>::type genRand(
|
|
size_t N, TYPE* x )
|
|
{
|
|
std::random_device rd;
|
|
std::mt19937 gen( rd() );
|
|
std::uniform_real_distribution<TYPE> dis( 0, 1 );
|
|
for ( size_t i = 0; i < N; i++ )
|
|
x[i] = dis( gen );
|
|
}
|
|
template<class TYPE, class FUN>
|
|
inline void FunctionTable::rand( Array<TYPE, FUN>& x )
|
|
{
|
|
genRand<TYPE>( x.length(), x.data() );
|
|
}
|
|
|
|
|
|
/********************************************************
|
|
* Reduction *
|
|
********************************************************/
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
|
inline TYPE FunctionTable::reduce( LAMBDA& op, const Array<TYPE, FUN>& A )
|
|
{
|
|
if ( A.length() == 0 )
|
|
return TYPE();
|
|
const TYPE* x = A.data();
|
|
TYPE y = x[0];
|
|
const size_t N = A.length();
|
|
for ( size_t i = 1; i < N; i++ )
|
|
y = op( x[i], y );
|
|
return y;
|
|
}
|
|
|
|
|
|
/********************************************************
|
|
* Unary transformation *
|
|
********************************************************/
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
|
inline void FunctionTable::transform( LAMBDA& fun, const Array<TYPE, FUN>& x, Array<TYPE, FUN>& y )
|
|
{
|
|
y.resize( x.size() );
|
|
const size_t N = x.length();
|
|
for ( size_t i = 0; i < N; i++ )
|
|
y( i ) = fun( x( i ) );
|
|
}
|
|
template<class TYPE, class FUN, typename LAMBDA>
|
|
inline void FunctionTable::transform(
|
|
LAMBDA& fun, const Array<TYPE, FUN>& x, const Array<TYPE, FUN>& y, Array<TYPE, FUN>& z )
|
|
{
|
|
if ( x.size() != y.size() )
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
z.resize( x.size() );
|
|
const size_t N = x.length();
|
|
for ( size_t i = 0; i < N; i++ )
|
|
z( i ) = fun( x( i ), y( i ) );
|
|
}
|
|
|
|
|
|
/********************************************************
|
|
* axpy *
|
|
********************************************************/
|
|
template<class TYPE>
|
|
inline void call_axpy( size_t N, const TYPE alpha, const TYPE* x, TYPE* y );
|
|
template<>
|
|
inline void call_axpy<float>( size_t, const float, const float*, float* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<>
|
|
inline void call_axpy<double>( size_t, const double, const double*, double* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<class TYPE>
|
|
inline void call_axpy( size_t N, const TYPE alpha, const TYPE* x, TYPE* y )
|
|
{
|
|
for ( size_t i = 0; i < N; i++ )
|
|
y[i] += alpha * x[i];
|
|
}
|
|
template<class TYPE, class FUN>
|
|
void FunctionTable::axpy( const TYPE alpha, const Array<TYPE, FUN>& x, Array<TYPE, FUN>& y )
|
|
{
|
|
if ( x.size() != y.size() )
|
|
throw std::logic_error( "Array sizes do not match" );
|
|
call_axpy( x.length(), alpha, x.data(), y.data() );
|
|
}
|
|
|
|
|
|
/********************************************************
|
|
* Multiply two arrays *
|
|
********************************************************/
|
|
template<class TYPE>
|
|
inline void call_gemv( size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* x, TYPE* y );
|
|
template<>
|
|
inline void call_gemv<double>(
|
|
size_t, size_t, double, double, const double*, const double*, double* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<>
|
|
inline void call_gemv<float>( size_t, size_t, float, float, const float*, const float*, float* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<class TYPE>
|
|
inline void call_gemv(
|
|
size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* x, TYPE* y )
|
|
{
|
|
for ( size_t i = 0; i < M; i++ )
|
|
y[i] = beta * y[i];
|
|
for ( size_t j = 0; j < N; j++ ) {
|
|
for ( size_t i = 0; i < M; i++ )
|
|
y[i] += alpha * A[i + j * M] * x[j];
|
|
}
|
|
}
|
|
template<class TYPE>
|
|
inline void call_gemm(
|
|
size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* B, TYPE* C );
|
|
template<>
|
|
inline void call_gemm<double>( size_t, size_t, size_t, double, double, const double*, const double*, double* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<>
|
|
inline void call_gemm<float>( size_t, size_t, size_t, float, float, const float*, const float*, float* )
|
|
{
|
|
throw std::logic_error( "LapackWrappers not configured" );
|
|
}
|
|
template<class TYPE>
|
|
inline void call_gemm(
|
|
size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE* A, const TYPE* B, TYPE* C )
|
|
{
|
|
for ( size_t i = 0; i < K * M; i++ )
|
|
C[i] = beta * C[i];
|
|
for ( size_t k = 0; k < K; k++ ) {
|
|
for ( size_t j = 0; j < N; j++ ) {
|
|
for ( size_t i = 0; i < M; i++ )
|
|
C[i + k * M] += alpha * A[i + j * M] * B[j + k * N];
|
|
}
|
|
}
|
|
}
|
|
template<class TYPE, class FUN>
|
|
void FunctionTable::gemm( const TYPE alpha, const Array<TYPE, FUN>& a, const Array<TYPE, FUN>& b,
|
|
const TYPE beta, Array<TYPE, FUN>& c )
|
|
{
|
|
if ( a.ndim() == 2 && b.ndim() == 1 ) {
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
throw std::logic_error( "Inner dimensions must match" );
|
|
call_gemv<TYPE>( a.size( 0 ), a.size( 1 ), alpha, beta, a.data(), b.data(), c.data() );
|
|
} else if ( a.ndim() <= 2 && b.ndim() <= 2 ) {
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
throw std::logic_error( "Inner dimensions must match" );
|
|
call_gemm<TYPE>(
|
|
a.size( 0 ), a.size( 1 ), b.size( 1 ), alpha, beta, a.data(), b.data(), c.data() );
|
|
} else {
|
|
throw std::logic_error( "Not finished yet" );
|
|
}
|
|
}
|
|
template<class TYPE, class FUN>
|
|
void FunctionTable::multiply(
|
|
const Array<TYPE, FUN>& a, const Array<TYPE, FUN>& b, Array<TYPE, FUN>& c )
|
|
{
|
|
if ( a.ndim() == 2 && b.ndim() == 1 ) {
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
throw std::logic_error( "Inner dimensions must match" );
|
|
c.resize( a.size( 0 ) );
|
|
call_gemv<TYPE>( a.size( 0 ), a.size( 1 ), 1, 0, a.data(), b.data(), c.data() );
|
|
} else if ( a.ndim() <= 2 && b.ndim() <= 2 ) {
|
|
if ( a.size( 1 ) != b.size( 0 ) )
|
|
throw std::logic_error( "Inner dimensions must match" );
|
|
c.resize( a.size( 0 ), b.size( 1 ) );
|
|
call_gemm<TYPE>(
|
|
a.size( 0 ), a.size( 1 ), b.size( 1 ), 1, 0, a.data(), b.data(), c.data() );
|
|
} else {
|
|
throw std::logic_error( "Not finished yet" );
|
|
}
|
|
}
|
|
|
|
|
|
/********************************************************
|
|
* Check if two arrays are equal *
|
|
********************************************************/
|
|
template<class TYPE, class FUN>
|
|
inline typename std::enable_if<!std::is_floating_point<TYPE>::value, bool>::type
|
|
FunctionTableCompare( const Array<TYPE, FUN>& a, const Array<TYPE, FUN>& b, TYPE )
|
|
{
|
|
bool pass = true;
|
|
if ( a.size() != b.size() )
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
for ( size_t i = 0; i < a.length(); i++ )
|
|
pass = pass && a( i ) == b( i );
|
|
return pass;
|
|
}
|
|
template<class TYPE, class FUN>
|
|
inline typename std::enable_if<std::is_floating_point<TYPE>::value, bool>::type
|
|
FunctionTableCompare( const Array<TYPE, FUN>& a, const Array<TYPE, FUN>& b, TYPE tol )
|
|
{
|
|
bool pass = true;
|
|
if ( a.size() != b.size() )
|
|
throw std::logic_error( "Sizes of x and y do not match" );
|
|
for ( size_t i = 0; i < a.length(); i++ )
|
|
pass = pass && ( std::abs( a( i ) - b( i ) ) < tol );
|
|
return pass;
|
|
}
|
|
template<class TYPE, class FUN>
|
|
bool FunctionTable::equals( const Array<TYPE, FUN>& a, const Array<TYPE, FUN>& b, TYPE tol )
|
|
{
|
|
return FunctionTableCompare( a, b, tol );
|
|
}
|
|
|
|
|
|
#endif
|