/* Copyright 2013--2018 James E. McClure, Virginia Polytechnic & State University Copyright Equnior ASA This file is part of the Open Porous Media project (OPM). OPM is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. OPM is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OPM. If not, see . */ #ifndef included_FunctionTable_hpp #define included_FunctionTable_hpp #include "common/FunctionTable.h" #include "common/UtilityMacros.h" #include #include #include //#include /******************************************************** * Random number initialization * ********************************************************/ /*template TYPE genRand(); template inline void FunctionTable::rand( Array &x ) { for ( size_t i = 0; i < x.length(); i++ ) x( i ) = genRand(); } */ /******************************************************** * Reduction * ********************************************************/ template inline TYPE FunctionTable::reduce(LAMBDA &op, const Array &A, const TYPE &initialValue) { if (A.length() == 0) return TYPE(); const TYPE *x = A.data(); TYPE y = initialValue; for (size_t i = 0; i < A.length(); i++) y = op(x[i], y); return y; } template inline TYPE FunctionTable::reduce(LAMBDA &op, const Array &A, const Array &B, const TYPE &initialValue) { ARRAY_ASSERT(A.length() == B.length()); if (A.length() == 0) return TYPE(); const TYPE *x = A.data(); const TYPE *y = B.data(); TYPE z = initialValue; for (size_t i = 0; i < A.length(); i++) z = op(x[i], y[i], z); return z; } /******************************************************** * Unary transformation * ********************************************************/ template inline void FunctionTable::transform(LAMBDA &fun, const Array &x, Array &y) { y.resize(x.size()); const size_t N = x.length(); for (size_t i = 0; i < N; i++) y(i) = fun(x(i)); } template inline void FunctionTable::transform(LAMBDA &fun, const Array &x, const Array &y, Array &z) { if (x.size() != y.size()) throw std::logic_error("Sizes of x and y do not match"); z.resize(x.size()); const size_t N = x.length(); for (size_t i = 0; i < N; i++) z(i) = fun(x(i), y(i)); } /******************************************************** * axpy * ********************************************************/ template void call_axpy(size_t N, const TYPE alpha, const TYPE *x, TYPE *y); template <> void call_axpy(size_t N, const float alpha, const float *x, float *y); template <> void call_axpy(size_t N, const double alpha, const double *x, double *y); template void call_axpy(size_t N, const TYPE alpha, const TYPE *x, TYPE *y) { for (size_t i = 0; i < N; i++) y[i] += alpha * x[i]; } template void FunctionTable::axpy(const TYPE alpha, const Array &x, Array &y) { if (x.size() != y.size()) throw std::logic_error("Array sizes do not match"); call_axpy(x.length(), alpha, x.data(), y.data()); } /******************************************************** * Multiply two arrays * ********************************************************/ template void call_gemv(size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *x, TYPE *y); template <> void call_gemv(size_t M, size_t N, double alpha, double beta, const double *A, const double *x, double *y); template <> void call_gemv(size_t M, size_t N, float alpha, float beta, const float *A, const float *x, float *y); template void call_gemv(size_t M, size_t N, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *x, TYPE *y) { for (size_t i = 0; i < M; i++) y[i] = beta * y[i]; for (size_t j = 0; j < N; j++) { for (size_t i = 0; i < M; i++) y[i] += alpha * A[i + j * M] * x[j]; } } template void call_gemm(size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *B, TYPE *C); template <> void call_gemm(size_t M, size_t N, size_t K, double alpha, double beta, const double *A, const double *B, double *C); template <> void call_gemm(size_t M, size_t N, size_t K, float alpha, float beta, const float *A, const float *B, float *C); template void call_gemm(size_t M, size_t N, size_t K, TYPE alpha, TYPE beta, const TYPE *A, const TYPE *B, TYPE *C) { for (size_t i = 0; i < K * M; i++) C[i] = beta * C[i]; for (size_t k = 0; k < K; k++) { for (size_t j = 0; j < N; j++) { for (size_t i = 0; i < M; i++) C[i + k * M] += alpha * A[i + j * M] * B[j + k * N]; } } } template void FunctionTable::gemm(const TYPE alpha, const Array &a, const Array &b, const TYPE beta, Array &c) { if (a.size(1) != b.size(0)) throw std::logic_error("Inner dimensions must match"); if (a.ndim() == 2 && b.ndim() == 1) { call_gemv(a.size(0), a.size(1), alpha, beta, a.data(), b.data(), c.data()); } else if (a.ndim() <= 2 && b.ndim() <= 2) { call_gemm(a.size(0), a.size(1), b.size(1), alpha, beta, a.data(), b.data(), c.data()); } else { throw std::logic_error("Not finished yet"); } } template void FunctionTable::multiply(const Array &a, const Array &b, Array &c) { if (a.size(1) != b.size(0)) throw std::logic_error("Inner dimensions must match"); if (a.ndim() == 2 && b.ndim() == 1) { c.resize(a.size(0)); call_gemv(a.size(0), a.size(1), 1, 0, a.data(), b.data(), c.data()); } else if (a.ndim() <= 2 && b.ndim() <= 2) { c.resize(a.size(0), b.size(1)); call_gemm(a.size(0), a.size(1), b.size(1), 1, 0, a.data(), b.data(), c.data()); } else { throw std::logic_error("Not finished yet"); } } /******************************************************** * Check if two arrays are equal * ********************************************************/ template inline typename std::enable_if::value, bool>::type FunctionTableCompare(const Array &a, const Array &b, TYPE) { bool pass = true; if (a.size() != b.size()) throw std::logic_error("Sizes of x and y do not match"); for (size_t i = 0; i < a.length(); i++) pass = pass && a(i) == b(i); return pass; } template inline typename std::enable_if::value, bool>::type FunctionTableCompare(const Array &a, const Array &b, TYPE tol) { bool pass = true; if (a.size() != b.size()) throw std::logic_error("Sizes of x and y do not match"); for (size_t i = 0; i < a.length(); i++) pass = pass && (std::abs(a(i) - b(i)) < tol); return pass; } template bool FunctionTable::equals(const Array &a, const Array &b, TYPE tol) { return FunctionTableCompare(a, b, tol); } /******************************************************** * Specialized Functions * ********************************************************/ template void FunctionTable::transformReLU(const Array &A, Array &B) { const auto &fun = [](const TYPE &a) { return std::max(a, static_cast(0)); }; transform(fun, A, B); } template void FunctionTable::transformAbs(const Array &A, Array &B) { B.resize(A.size()); const auto &fun = [](const TYPE &a) { return std::abs(a); }; transform(fun, A, B); } template void FunctionTable::transformTanh(const Array &A, Array &B) { B.resize(A.size()); const auto &fun = [](const TYPE &a) { return tanh(a); }; transform(fun, A, B); } template void FunctionTable::transformHardTanh(const Array &A, Array &B) { B.resize(A.size()); const auto &fun = [](const TYPE &a) { return std::max(-static_cast(1.0), std::min(static_cast(1.0), a)); }; transform(fun, A, B); } template void FunctionTable::transformSigmoid(const Array &A, Array &B) { B.resize(A.size()); const auto &fun = [](const TYPE &a) { return 1.0 / (1.0 + exp(-a)); }; transform(fun, A, B); } template void FunctionTable::transformSoftPlus(const Array &A, Array &B) { B.resize(A.size()); const auto &fun = [](const TYPE &a) { return log1p(exp(a)); }; transform(fun, A, B); } template TYPE FunctionTable::sum(const Array &A) { const auto &fun = [](const TYPE &a, const TYPE &b) { return a + b; }; return reduce(fun, A, (TYPE)0); } template inline void FunctionTable::gemmWrapper(char, char, int, int, int, TYPE, const TYPE *, int, const TYPE *, int, TYPE, TYPE *, int) { ERROR("Not finished"); } #endif