LBPM/common/FunctionTable.hpp
2023-10-23 04:18:20 -04:00

291 lines
11 KiB
C++

/*
Copyright 2013--2018 James E. McClure, Virginia Polytechnic & State University
Copyright Equnior ASA
This file is part of the Open Porous Media project (OPM).
OPM is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
OPM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with OPM. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef included_FunctionTable_hpp
#define included_FunctionTable_hpp
#include "common/FunctionTable.h"
#include "common/UtilityMacros.h"
#include <algorithm>
#include <cstring>
#include <limits>
//#include <random>
/********************************************************
* Random number initialization *
********************************************************/
/*template<class TYPE> TYPE genRand();
template<class TYPE, class FUN>
inline void FunctionTable::rand( Array<TYPE, FUN> &x )
{
for ( size_t i = 0; i < x.length(); i++ )
x( i ) = genRand<TYPE>();
}
*/
/********************************************************
* Reduction *
********************************************************/
template <class TYPE, class FUN, typename LAMBDA>
inline TYPE FunctionTable::reduce(LAMBDA &op, const Array<TYPE, FUN> &A,
const TYPE &initialValue) {
if (A.length() == 0)
return TYPE();
const TYPE *x = A.data();
TYPE y = initialValue;
for (size_t i = 0; i < A.length(); i++)
y = op(x[i], y);
return y;
}
template <class TYPE, class FUN, typename LAMBDA>
inline TYPE FunctionTable::reduce(LAMBDA &op, const Array<TYPE, FUN> &A,
const Array<TYPE, FUN> &B,
const TYPE &initialValue) {
ARRAY_ASSERT(A.length() == B.length());
if (A.length() == 0)
return TYPE();
const TYPE *x = A.data();
const TYPE *y = B.data();
TYPE z = initialValue;
for (size_t i = 0; i < A.length(); i++)
z = op(x[i], y[i], z);
return z;
}
/********************************************************
* Unary transformation *
********************************************************/
template <class TYPE, class FUN, typename LAMBDA>
inline void FunctionTable::transform(LAMBDA &fun, const Array<TYPE, FUN> &x,
Array<TYPE, FUN> &y) {
y.resize(x.size());
const size_t N = x.length();
for (size_t i = 0; i < N; i++)
y(i) = fun(x(i));
}
template <class TYPE, class FUN, typename LAMBDA>
inline void FunctionTable::transform(LAMBDA &fun, const Array<TYPE, FUN> &x,
const Array<TYPE, FUN> &y,
Array<TYPE, FUN> &z) {
if (x.size() != y.size())
throw std::logic_error("Sizes of x and y do not match");
z.resize(x.size());
const size_t N = x.length();
for (size_t i = 0; i < N; i++)
z(i) = fun(x(i), y(i));
}
/********************************************************
* axpy *
********************************************************/
template <class TYPE>
void call_axpy(size_t N, const TYPE alpha, const TYPE *x, TYPE *y);
template <>
void call_axpy<float>(size_t N, const float alpha, const float *x, float *y);
template <>
void call_axpy<double>(size_t N, const double alpha, const double *x,
double *y);
template <class TYPE>
void call_axpy(size_t N, const TYPE alpha, const TYPE *x, TYPE *y) {
for (size_t i = 0; i < N; i++)
y[i] += alpha * x[i];
}
template <class TYPE, class FUN>
void FunctionTable::axpy(const TYPE alpha, const Array<TYPE, FUN> &x,
Array<TYPE, FUN> &y) {
if (x.size() != y.size())
throw std::logic_error("Array sizes do not match");
call_axpy(x.length(), alpha, x.data(), y.data());
}
/********************************************************
* Multiply two arrays *
********************************************************/
template <class TYPE>
void call_gemv(size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A,
const TYPE *x, TYPE *y);
template <>
void call_gemv<double>(size_t M, size_t N, double alpha, double beta,
const double *A, const double *x, double *y);
template <>
void call_gemv<float>(size_t M, size_t N, float alpha, float beta,
const float *A, const float *x, float *y);
template <class TYPE>
void call_gemv(size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A,
const TYPE *x, TYPE *y) {
for (size_t i = 0; i < M; i++)
y[i] = beta * y[i];
for (size_t j = 0; j < N; j++) {
for (size_t i = 0; i < M; i++)
y[i] += alpha * A[i + j * M] * x[j];
}
}
template <class TYPE>
void call_gemm(size_t M, size_t N, size_t K, TYPE alpha, TYPE beta,
const TYPE *A, const TYPE *B, TYPE *C);
template <>
void call_gemm<double>(size_t M, size_t N, size_t K, double alpha, double beta,
const double *A, const double *B, double *C);
template <>
void call_gemm<float>(size_t M, size_t N, size_t K, float alpha, float beta,
const float *A, const float *B, float *C);
template <class TYPE>
void call_gemm(size_t M, size_t N, size_t K, TYPE alpha, TYPE beta,
const TYPE *A, const TYPE *B, TYPE *C) {
for (size_t i = 0; i < K * M; i++)
C[i] = beta * C[i];
for (size_t k = 0; k < K; k++) {
for (size_t j = 0; j < N; j++) {
for (size_t i = 0; i < M; i++)
C[i + k * M] += alpha * A[i + j * M] * B[j + k * N];
}
}
}
template <class TYPE, class FUN>
void FunctionTable::gemm(const TYPE alpha, const Array<TYPE, FUN> &a,
const Array<TYPE, FUN> &b, const TYPE beta,
Array<TYPE, FUN> &c) {
if (a.size(1) != b.size(0))
throw std::logic_error("Inner dimensions must match");
if (a.ndim() == 2 && b.ndim() == 1) {
call_gemv<TYPE>(a.size(0), a.size(1), alpha, beta, a.data(), b.data(),
c.data());
} else if (a.ndim() <= 2 && b.ndim() <= 2) {
call_gemm<TYPE>(a.size(0), a.size(1), b.size(1), alpha, beta, a.data(),
b.data(), c.data());
} else {
throw std::logic_error("Not finished yet");
}
}
template <class TYPE, class FUN>
void FunctionTable::multiply(const Array<TYPE, FUN> &a,
const Array<TYPE, FUN> &b, Array<TYPE, FUN> &c) {
if (a.size(1) != b.size(0))
throw std::logic_error("Inner dimensions must match");
if (a.ndim() == 2 && b.ndim() == 1) {
c.resize(a.size(0));
call_gemv<TYPE>(a.size(0), a.size(1), 1, 0, a.data(), b.data(),
c.data());
} else if (a.ndim() <= 2 && b.ndim() <= 2) {
c.resize(a.size(0), b.size(1));
call_gemm<TYPE>(a.size(0), a.size(1), b.size(1), 1, 0, a.data(),
b.data(), c.data());
} else {
throw std::logic_error("Not finished yet");
}
}
/********************************************************
* Check if two arrays are equal *
********************************************************/
template <class TYPE, class FUN>
inline typename std::enable_if<std::is_integral<TYPE>::value, bool>::type
FunctionTableCompare(const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b,
TYPE) {
bool pass = true;
if (a.size() != b.size())
throw std::logic_error("Sizes of x and y do not match");
for (size_t i = 0; i < a.length(); i++)
pass = pass && a(i) == b(i);
return pass;
}
template <class TYPE, class FUN>
inline typename std::enable_if<std::is_floating_point<TYPE>::value, bool>::type
FunctionTableCompare(const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b,
TYPE tol) {
bool pass = true;
if (a.size() != b.size())
throw std::logic_error("Sizes of x and y do not match");
for (size_t i = 0; i < a.length(); i++)
pass = pass && (std::abs(a(i) - b(i)) < tol);
return pass;
}
template <class TYPE, class FUN>
bool FunctionTable::equals(const Array<TYPE, FUN> &a, const Array<TYPE, FUN> &b,
TYPE tol) {
return FunctionTableCompare(a, b, tol);
}
/********************************************************
* Specialized Functions *
********************************************************/
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformReLU(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
const auto &fun = [](const TYPE &a) {
return std::max(a, static_cast<TYPE>(0));
};
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformAbs(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
B.resize(A.size());
const auto &fun = [](const TYPE &a) { return std::abs(a); };
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformTanh(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
B.resize(A.size());
const auto &fun = [](const TYPE &a) { return tanh(a); };
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformHardTanh(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
B.resize(A.size());
const auto &fun = [](const TYPE &a) {
return std::max(-static_cast<TYPE>(1.0),
std::min(static_cast<TYPE>(1.0), a));
};
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformSigmoid(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
B.resize(A.size());
const auto &fun = [](const TYPE &a) { return 1.0 / (1.0 + exp(-a)); };
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
void FunctionTable::transformSoftPlus(const Array<TYPE, FUN, ALLOC> &A,
Array<TYPE, FUN, ALLOC> &B) {
B.resize(A.size());
const auto &fun = [](const TYPE &a) { return log1p(exp(a)); };
transform(fun, A, B);
}
template <class TYPE, class FUN, class ALLOC>
TYPE FunctionTable::sum(const Array<TYPE, FUN, ALLOC> &A) {
const auto &fun = [](const TYPE &a, const TYPE &b) { return a + b; };
return reduce(fun, A, (TYPE)0);
}
template <class TYPE>
inline void FunctionTable::gemmWrapper(char, char, int, int, int, TYPE,
const TYPE *, int, const TYPE *, int,
TYPE, TYPE *, int) {
ERROR("Not finished");
}
#endif