dense AD: introduce convenience functions
this PR was inspired by one of @atgeirr's recent comments. it allows
to get rid if the `MathToolbox<Eval>` detour for the dense AD code in
most cases: e.g. instead of writing
```c++
template <class Eval>
Eval f(const Eval& val)
{ return Opm::MathToolbox<Eval>::sqrt(val); }
```
one can simply write
```c++
template <class Eval>
Eval f(const Eval& val)
{ return Opm::sqrt(val); }
```
and it will work transparently with DenseAD `Evaluation` objects and
primitive floating point values.
one complication of this is that the type of the `Evaluation` object
does not need to be explicitly defined. for functions which take more
than one argument (like e.g. `pow()`, `min()` and `max()`), it is thus
assumed that the type of the result is the same as the type of the
first argument.
another drawback is that when both, the contents of the `Opm::` and
the `std::` namespaces are implicitly available, it is not clear to me
what's used for primitive floating point values. (it seems to compile
but IMO it could lead to surprising behaviour. thus, please only merge
if you consider the benefits of these convenience functions to be
greater than their drawbacks.)
This commit is contained in:
@@ -203,6 +203,72 @@ public:
|
||||
{ return std::pow(base, exp); }
|
||||
};
|
||||
|
||||
// these are convenience functions for not having to type MathToolbox<Scalar>::foo()
|
||||
template <class Evaluation, class Scalar>
|
||||
Evaluation constant(const Scalar& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::createConstant(value); }
|
||||
|
||||
template <class Evaluation, class Scalar>
|
||||
Evaluation variable(const Scalar& value, int idx)
|
||||
{ return Opm::MathToolbox<Evaluation>::createVariable(value, idx); }
|
||||
|
||||
template <class Evaluation1, class Evaluation2>
|
||||
Evaluation1 max(const Evaluation1& arg1, const Evaluation2& arg2)
|
||||
{ return Opm::MathToolbox<Evaluation1>::max(arg1, arg2); }
|
||||
|
||||
template <class Evaluation1, class Evaluation2>
|
||||
Evaluation1 min(const Evaluation1& arg1, const Evaluation2& arg2)
|
||||
{ return Opm::MathToolbox<Evaluation1>::min(arg1, arg2); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation abs(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::abs(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation tan(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::tan(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation atan(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::atan(value); }
|
||||
|
||||
template <class Evaluation1, class Evaluation2>
|
||||
Evaluation1 atan2(const Evaluation1& value1, const Evaluation2& value2)
|
||||
{ return Opm::MathToolbox<Evaluation1>::atan2(value1, value2); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation sin(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::sin(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation asin(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::asin(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation cos(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::cos(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation acos(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::acos(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation sqrt(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::sqrt(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation exp(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::exp(value); }
|
||||
|
||||
template <class Evaluation>
|
||||
Evaluation log(const Evaluation& value)
|
||||
{ return Opm::MathToolbox<Evaluation>::log(value); }
|
||||
|
||||
template <class Evaluation1, class Evaluation2>
|
||||
Evaluation1 pow(const Evaluation1& base, const Evaluation2& exp)
|
||||
{ return Opm::MathToolbox<Evaluation1>::pow(base, exp); }
|
||||
|
||||
} // namespace Opm
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
#ifndef OPM_LOCAL_AD_EVALUATION_HPP
|
||||
#define OPM_LOCAL_AD_EVALUATION_HPP
|
||||
|
||||
#include "Math.hpp"
|
||||
|
||||
#include <opm/material/common/Valgrind.hpp>
|
||||
|
||||
#include <dune/common/version.hh>
|
||||
|
||||
@@ -38,6 +38,10 @@
|
||||
|
||||
namespace Opm {
|
||||
namespace DenseAd {
|
||||
// forward declaration of the Evaluation template class
|
||||
template <class ValueT, int numVars>
|
||||
class Evaluation;
|
||||
|
||||
// provide some algebraic functions
|
||||
template <class ValueType, int numVars>
|
||||
Evaluation<ValueType, numVars> abs(const Evaluation<ValueType, numVars>& x)
|
||||
|
||||
@@ -566,6 +566,49 @@ inline void testAll()
|
||||
test1DFunction<Scalar>(Opm::DenseAd::log<Scalar, numVars>,
|
||||
static_cast<Scalar (*)(Scalar)>(std::log),
|
||||
1e-6, 1e9);
|
||||
|
||||
while (false) {
|
||||
// make sure that the convenince functions work (i.e., that everything can be
|
||||
// accessed without the MathToolbox<Scalar> detour.)
|
||||
Scalar val1(0.0), val2(1.0), resultVal;
|
||||
resultVal = Opm::constant<Scalar>(val1);
|
||||
resultVal = Opm::variable<Scalar>(val1, /*idx=*/0);
|
||||
resultVal = Opm::min(val1, val2);
|
||||
resultVal = Opm::max(val1, val2);
|
||||
resultVal = Opm::atan2(val1, val2);
|
||||
resultVal = Opm::pow(val1, val2);
|
||||
resultVal = Opm::abs(val1);
|
||||
resultVal = Opm::atan(val1);
|
||||
resultVal = Opm::sin(val1);
|
||||
resultVal = Opm::asin(val1);
|
||||
resultVal = Opm::cos(val1);
|
||||
resultVal = Opm::acos(val1);
|
||||
resultVal = Opm::sqrt(val1);
|
||||
resultVal = Opm::exp(val1);
|
||||
resultVal = Opm::log(val1);
|
||||
|
||||
typedef Opm::DenseAd::Evaluation<Scalar, numVars> TmpEval;
|
||||
TmpEval eval1, eval2, resultEval;
|
||||
resultEval = Opm::constant<TmpEval>(val1);
|
||||
resultEval = Opm::variable<TmpEval>(val1, /*idx=*/0);
|
||||
resultEval = Opm::min(eval1, eval2);
|
||||
resultEval = Opm::min(eval1, val2);
|
||||
resultEval = Opm::max(eval1, eval2);
|
||||
resultEval = Opm::max(eval1, val2);
|
||||
resultEval = Opm::atan2(eval1, eval2);
|
||||
resultEval = Opm::atan2(eval1, val2);
|
||||
resultEval = Opm::pow(eval1, eval2);
|
||||
resultEval = Opm::pow(eval1, val2);
|
||||
resultEval = Opm::abs(eval1);
|
||||
resultEval = Opm::atan(eval1);
|
||||
resultEval = Opm::sin(eval1);
|
||||
resultEval = Opm::asin(eval1);
|
||||
resultEval = Opm::cos(eval1);
|
||||
resultEval = Opm::acos(eval1);
|
||||
resultEval = Opm::sqrt(eval1);
|
||||
resultEval = Opm::exp(eval1);
|
||||
resultEval = Opm::log(eval1);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
|
||||
Reference in New Issue
Block a user