added: template EvalFunc over a Scalar type
rename to EvalFuncImpl and make EvalFunc a type alias for EvalFuncImpl<Real> to avoid breaking existing code. add deriv autodiff::var, thus adding automatic differentation
This commit is contained in:
parent
943044538d
commit
38f7265dc1
@ -19,8 +19,10 @@
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include <autodiff/reverse/var.hpp>
|
||||
|
||||
int EvalFunc::numError = 0;
|
||||
|
||||
template<> int EvalFunc::numError = 0;
|
||||
|
||||
|
||||
namespace {
|
||||
@ -198,7 +200,8 @@ int voigtIdx (int d1, int d2)
|
||||
}
|
||||
|
||||
|
||||
EvalFunc::EvalFunc (const char* function, const char* x, Real eps)
|
||||
template<class Scalar>
|
||||
EvalFuncImpl<Scalar>::EvalFuncImpl (const char* function, const char* x, Real eps)
|
||||
: dx(eps)
|
||||
{
|
||||
try {
|
||||
@ -231,17 +234,20 @@ EvalFunc::EvalFunc (const char* function, const char* x, Real eps)
|
||||
}
|
||||
|
||||
|
||||
EvalFunc::~EvalFunc () = default;
|
||||
template<class Scalar>
|
||||
EvalFuncImpl<Scalar>::~EvalFuncImpl () = default;
|
||||
|
||||
|
||||
void EvalFunc::addDerivative (const std::string& function, const char* x)
|
||||
template<class Scalar>
|
||||
void EvalFuncImpl<Scalar>::addDerivative (const std::string& function, const char* x)
|
||||
{
|
||||
if (!gradient)
|
||||
gradient = std::make_unique<EvalFunc>(function.c_str(),x);
|
||||
gradient = std::make_unique<FuncType>(function.c_str(),x);
|
||||
}
|
||||
|
||||
|
||||
Real EvalFunc::evaluate (const Real& x) const
|
||||
template<class Scalar>
|
||||
Real EvalFuncImpl<Scalar>::evaluate (const Real& x) const
|
||||
{
|
||||
Real result = Real(0);
|
||||
size_t i = 0;
|
||||
@ -252,7 +258,10 @@ Real EvalFunc::evaluate (const Real& x) const
|
||||
return result;
|
||||
try {
|
||||
*arg[i] = x;
|
||||
result = expr[i]->Evaluate();
|
||||
if constexpr (std::is_same_v<Scalar,Real>)
|
||||
result = expr[i]->Evaluate();
|
||||
else
|
||||
result = expr[i]->Evaluate().expr->val;
|
||||
}
|
||||
catch (ExprEval::Exception& e) {
|
||||
ExprException(e,"evaluating expression");
|
||||
@ -262,6 +271,7 @@ Real EvalFunc::evaluate (const Real& x) const
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
Real EvalFunc::deriv (Real x) const
|
||||
{
|
||||
if (gradient)
|
||||
@ -272,6 +282,29 @@ Real EvalFunc::deriv (Real x) const
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
Real EvalFuncImpl<autodiff::var>::deriv (Real x) const
|
||||
{
|
||||
if (gradient)
|
||||
return gradient->evaluate(x);
|
||||
|
||||
size_t i = 0;
|
||||
#ifdef USE_OPENMP
|
||||
i = omp_get_thread_num();
|
||||
#endif
|
||||
try {
|
||||
*arg[i] = x;
|
||||
return derivativesx(expr[i]->Evaluate(),
|
||||
autodiff::wrt(*this->arg[i]))[0].expr->val;
|
||||
}
|
||||
catch (ExprEval::Exception& e) {
|
||||
ExprException(e,"evaluating expression");
|
||||
}
|
||||
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
|
||||
EvalFunction::EvalFunction (const char* function, Real epsX, Real epsT)
|
||||
: dx(epsX), dt(epsT)
|
||||
{
|
||||
@ -533,6 +566,8 @@ EvalMultiFunction<ParentFunc, Ret>::evalTimeDerivative (const Vec3& X) const
|
||||
}
|
||||
|
||||
|
||||
template class EvalFuncImpl<Real>;
|
||||
template class EvalFuncImpl<autodiff::var>;
|
||||
template class EvalMultiFunction<VecFunc,Vec3>;
|
||||
template class EvalMultiFunction<TensorFunc,Tensor>;
|
||||
template class EvalMultiFunction<STensorFunc,SymmTensor>;
|
||||
|
@ -33,18 +33,20 @@ namespace ExprEval {
|
||||
\brief A scalar-valued function, general expression.
|
||||
*/
|
||||
|
||||
class EvalFunc : public ScalarFunc
|
||||
template<class Scalar>
|
||||
class EvalFuncImpl : public ScalarFunc
|
||||
{
|
||||
using Expression = ExprEval::Expression<Real>; //!< Type alias for expression tree
|
||||
using FunctionList = ExprEval::FunctionList<Real>; //!< Type alias for function list
|
||||
using ValueList = ExprEval::ValueList<Real>; //!< Type alias for value list
|
||||
using Expression = ExprEval::Expression<Scalar>; //!< Type alias for expression tree
|
||||
using FunctionList = ExprEval::FunctionList<Scalar>; //!< Type alias for function list
|
||||
using ValueList = ExprEval::ValueList<Scalar>; //!< Type alias for value list
|
||||
using FuncType = EvalFuncImpl<Scalar>; //!< Type alias for function
|
||||
std::vector<std::unique_ptr<Expression>> expr; //!< Roots of the expression tree
|
||||
std::vector<std::unique_ptr<FunctionList>> f; //!< Lists of functions
|
||||
std::vector<std::unique_ptr<ValueList>> v; //!< Lists of variables and constants
|
||||
|
||||
std::vector<Real*> arg; //!< Function argument values
|
||||
std::vector<Scalar*> arg; //!< Function argument values
|
||||
|
||||
std::unique_ptr<EvalFunc> gradient; //!< First derivative expression
|
||||
std::unique_ptr<FuncType> gradient; //!< First derivative expression
|
||||
|
||||
Real dx; //!< Domain increment for calculation of numerical derivative
|
||||
|
||||
@ -52,12 +54,12 @@ public:
|
||||
static int numError; //!< Error counter - set by the exception handler
|
||||
|
||||
//! \brief The constructor parses the expression string.
|
||||
explicit EvalFunc(const char* function, const char* x = "x",
|
||||
Real eps = Real(1.0e-8));
|
||||
explicit EvalFuncImpl(const char* function, const char* x = "x",
|
||||
Real eps = Real(1.0e-8));
|
||||
//! \brief Defaulted destructor.
|
||||
//! \details The implementation needs to be in compile unit so we have the
|
||||
//! definition for the types of the unique_ptr's.
|
||||
virtual ~EvalFunc();
|
||||
virtual ~EvalFuncImpl();
|
||||
|
||||
//! \brief Adds an expression function for a first derivative.
|
||||
void addDerivative(const std::string& function, const char* x = "x");
|
||||
@ -69,10 +71,6 @@ public:
|
||||
Real deriv(Real x) const override;
|
||||
|
||||
protected:
|
||||
//! \brief Non-implemented copy constructor to disallow copying.
|
||||
EvalFunc(const EvalFunc&) = delete;
|
||||
//! \brief Non-implemented assignment operator to disallow copying.
|
||||
EvalFunc& operator=(const EvalFunc&) = delete;
|
||||
//! \brief Evaluates the function expression.
|
||||
Real evaluate(const Real& x) const override;
|
||||
};
|
||||
@ -234,6 +232,8 @@ protected:
|
||||
std::vector<Real> evalTimeDerivative(const Vec3& X) const override;
|
||||
};
|
||||
|
||||
//! Scalar-valued function expression
|
||||
using EvalFunc = EvalFuncImpl<Real>;
|
||||
//! Vector-valued function expression
|
||||
using VecFuncExpr = EvalMultiFunction<VecFunc,Vec3>;
|
||||
//! Tensor-valued function expression
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <autodiff/reverse/var.hpp>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
|
||||
@ -29,6 +31,8 @@ TEST(TestScalarFunc, ParseDerivative)
|
||||
ScalarFunc* f1 = utl::parseTimeFunc(func1,"expression");
|
||||
ScalarFunc* f2 = utl::parseTimeFunc(func2,"expression");
|
||||
|
||||
EvalFuncImpl<autodiff::var> f3(func1,"t");
|
||||
|
||||
ASSERT_TRUE(f1 != nullptr);
|
||||
ASSERT_TRUE(f2 != nullptr);
|
||||
|
||||
@ -45,6 +49,7 @@ TEST(TestScalarFunc, ParseDerivative)
|
||||
EXPECT_FLOAT_EQ((*f2)(t),sin(1.5*t)*t);
|
||||
EXPECT_FLOAT_EQ(f1->deriv(t),1.5*cos(1.5*t)*t+sin(1.5*t));
|
||||
EXPECT_FLOAT_EQ(f2->deriv(t),1.5*cos(1.5*t)*t+sin(1.5*t));
|
||||
EXPECT_FLOAT_EQ(f3.deriv(t),1.5*cos(1.5*t)*t+sin(1.5*t));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user