diff --git a/3rdparty/expreval/except.h b/3rdparty/expreval/except.h index d3d3a025..b9818cec 100644 --- a/3rdparty/expreval/except.h +++ b/3rdparty/expreval/except.h @@ -15,7 +15,7 @@ namespace ExprEval { // Forward declarations - class Expression; + template class Expression; // Exception class //-------------------------------------------------------------------------- diff --git a/3rdparty/expreval/expr.cpp b/3rdparty/expreval/expr.cpp index b96eaca4..aa0bcb06 100644 --- a/3rdparty/expreval/expr.cpp +++ b/3rdparty/expreval/expr.cpp @@ -16,57 +16,66 @@ using namespace std; using namespace ExprEval; +#include // Expression object //------------------------------------------------------------------------------ // Constructor -Expression::Expression() : m_vlist(0), m_flist(0), m_expr(0) +template +Expression::Expression() : m_vlist(0), m_flist(0), m_expr(0) { m_abortcount = 200000; m_abortreset = 200000; } // Destructor -Expression::~Expression() +template +Expression::~Expression() { // Delete expression nodes delete m_expr; } // Set value list -void Expression::SetValueList(ValueList *vlist) +template +void Expression::SetValueList(ValueList* vlist) { m_vlist = vlist; } // Get value list -ValueList *Expression::GetValueList() const +template +ValueList* Expression::GetValueList() const { return m_vlist; } // Set function list -void Expression::SetFunctionList(FunctionList *flist) +template +void Expression::SetFunctionList(FunctionList *flist) { m_flist = flist; } // Get function list -FunctionList *Expression::GetFunctionList() const +template +FunctionList* Expression::GetFunctionList() const { return m_flist; } // Test for an abort -bool Expression::DoTestAbort() +template +bool Expression::DoTestAbort() { // Derive a class to test abort return false; } // Test for an abort -void Expression::TestAbort(bool force) +template +void Expression::TestAbort(bool force) { if(force) { @@ -99,7 +108,8 @@ void Expression::TestAbort(bool force) } // Set test abort count -void Expression::SetTestAbortCount(unsigned long count) +template +void Expression::SetTestAbortCount(unsigned long count) { m_abortreset = count; if(m_abortcount > count) @@ -107,28 +117,31 @@ void Expression::SetTestAbortCount(unsigned long count) } // Parse expression -void Expression::Parse(const string &exstr) +template +void Expression::Parse(const string &exstr) { // Clear the expression if needed if(m_expr) Clear(); // Create parser - aptr(Parser) p(new Parser(this)); + aptr(Parser) p(new Parser(this)); // Parse the expression m_expr = p->Parse(exstr); } // Clear the expression -void Expression::Clear() +template +void Expression::Clear() { delete m_expr; m_expr = 0; } // Evaluate an expression -double Expression::Evaluate() +template +Value Expression::Evaluate() { if(m_expr) { @@ -139,3 +152,8 @@ double Expression::Evaluate() throw(EmptyExpressionException()); } } + +namespace ExprEval { +template class Expression; +template class Expression; +} diff --git a/3rdparty/expreval/expr.h b/3rdparty/expreval/expr.h index b272d540..743355ff 100644 --- a/3rdparty/expreval/expr.h +++ b/3rdparty/expreval/expr.h @@ -14,12 +14,13 @@ namespace ExprEval { // Forward declarations - class ValueList; - class FunctionList; - class Node; + template class ValueList; + template class FunctionList; + template class Node; // Expression class //-------------------------------------------------------------------------- + template class Expression { public: @@ -27,12 +28,12 @@ namespace ExprEval virtual ~Expression(); // Variable list - void SetValueList(ValueList *vlist); - ValueList *GetValueList() const; + void SetValueList(ValueList *vlist); + ValueList *GetValueList() const; // Function list - void SetFunctionList(FunctionList *flist); - FunctionList *GetFunctionList() const; + void SetFunctionList(FunctionList *flist); + FunctionList *GetFunctionList() const; // Abort control virtual bool DoTestAbort(); @@ -46,12 +47,12 @@ namespace ExprEval void Clear(); // Evaluate expression - double Evaluate(); + Value Evaluate(); protected: - ValueList *m_vlist; - FunctionList *m_flist; - Node *m_expr; + ValueList *m_vlist; + FunctionList *m_flist; + Node *m_expr; unsigned long m_abortcount; unsigned long m_abortreset; }; diff --git a/3rdparty/expreval/func.cpp b/3rdparty/expreval/func.cpp index 69e0f11a..aba34ca1 100644 --- a/3rdparty/expreval/func.cpp +++ b/3rdparty/expreval/func.cpp @@ -11,6 +11,7 @@ #include #include +#include "autodiff/reverse/var/var.hpp" #include "defs.h" #include "funclist.h" #include "node.h" @@ -23,58 +24,66 @@ namespace { // Absolute value //-------------------------------------------------------------------------- - class abs_FunctionNode : public FunctionNode + template + class abs_FunctionNode : public FunctionNode { public: - explicit abs_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit abs_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return fabs(m_nodes[0]->Evaluate()); + if constexpr (std::is_same_v) + return fabs(this->m_nodes[0]->Evaluate()); + else + return abs(this->m_nodes[0]->Evaluate()); } }; - - class abs_FunctionFactory : public FunctionFactory + + template + class abs_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "abs"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new abs_FunctionNode(expr); + return new abs_FunctionNode(expr); } }; - + // Modulus //-------------------------------------------------------------------------- - class mod_FunctionNode : public FunctionNode + template + class mod_FunctionNode : public FunctionNode { public: - explicit mod_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit mod_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = fmod(m_nodes[0]->Evaluate(), m_nodes[1]->Evaluate()); - + + Value result = fmod(this->m_nodes[0]->Evaluate(), + this->m_nodes[1]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class mod_FunctionFactory : public FunctionFactory + + template + class mod_FunctionFactory : public FunctionFactory { public: std::string GetName() const override @@ -82,744 +91,783 @@ namespace return "mod"; } - FunctionNode *DoCreate(Expression *expr) override + FunctionNode *DoCreate(Expression *expr) override { - return new mod_FunctionNode(expr); + return new mod_FunctionNode(expr); } }; - + // Integer part //-------------------------------------------------------------------------- - class ipart_FunctionNode : public FunctionNode + template + class ipart_FunctionNode : public FunctionNode { public: - explicit ipart_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit ipart_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double result; - - modf(m_nodes[0]->Evaluate(), &result); - + Value result; + + modf(this->m_nodes[0]->Evaluate(), &result); + return result; } }; - - class ipart_FunctionFactory : public FunctionFactory + + template + class ipart_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "ipart"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new ipart_FunctionNode(expr); + return new ipart_FunctionNode(expr); } }; - + // Fraction part //-------------------------------------------------------------------------- - class fpart_FunctionNode : public FunctionNode + template + class fpart_FunctionNode : public FunctionNode { public: - explicit fpart_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit fpart_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double dummy; - - return modf(m_nodes[0]->Evaluate(), &dummy); + Value dummy; + + return modf(this->m_nodes[0]->Evaluate(), &dummy); } }; - - class fpart_FunctionFactory : public FunctionFactory + + template + class fpart_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "fpart"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new fpart_FunctionNode(expr); + return new fpart_FunctionNode(expr); } - }; - + }; + // Minimum //-------------------------------------------------------------------------- - class min_FunctionNode : public FunctionNode + template + class min_FunctionNode : public FunctionNode { public: - explicit min_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit min_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, -1, 0, 0); + this->SetArgumentCount(2, -1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - std::vector::size_type pos; - - double result = m_nodes[0]->Evaluate(); - - for(pos = 1; pos < m_nodes.size(); pos++) - { - double tmp = m_nodes[pos]->Evaluate(); - if(tmp < result) - result = tmp; + Value result = 0; + size_t pos = 0; + for (const auto& node : this->m_nodes) { + Value tmp = node->Evaluate(); + if (pos == 0 || tmp < result) + result = tmp; + ++pos; } - + return result; } }; - - class min_FunctionFactory : public FunctionFactory + + template + class min_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "min"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new min_FunctionNode(expr); + return new min_FunctionNode(expr); } - }; - + }; + // Maximum //-------------------------------------------------------------------------- - class max_FunctionNode : public FunctionNode + template + class max_FunctionNode : public FunctionNode { public: - explicit max_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit max_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, -1, 0, 0); + this->SetArgumentCount(2, -1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - std::vector::size_type pos; - - double result = m_nodes[0]->Evaluate(); - - for(pos = 1; pos < m_nodes.size(); pos++) + Value result = 0; + size_t pos = 0; + for (const auto& node : this->m_nodes) { - double tmp = m_nodes[pos]->Evaluate(); - if(tmp > result) - result = tmp; + Value tmp = node->Evaluate(); + if (pos == 0 || tmp > result) + result = tmp; + ++pos; } - + return result; } }; - - class max_FunctionFactory : public FunctionFactory + + template + class max_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "max"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new max_FunctionNode(expr); + return new max_FunctionNode(expr); } - }; - + }; + // Square root //-------------------------------------------------------------------------- - class sqrt_FunctionNode : public FunctionNode + template + class sqrt_FunctionNode : public FunctionNode { public: - explicit sqrt_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit sqrt_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = sqrt(m_nodes[0]->Evaluate()); - + + Value result = sqrt(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class sqrt_FunctionFactory : public FunctionFactory + + template + class sqrt_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "sqrt"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new sqrt_FunctionNode(expr); + return new sqrt_FunctionNode(expr); } }; - + // Sine //-------------------------------------------------------------------------- - class sin_FunctionNode : public FunctionNode + template + class sin_FunctionNode : public FunctionNode { public: - explicit sin_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit sin_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = sin(m_nodes[0]->Evaluate()); - + + Value result = sin(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class sin_FunctionFactory : public FunctionFactory + + template + class sin_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "sin"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new sin_FunctionNode(expr); + return new sin_FunctionNode(expr); } }; - + // Cosine //-------------------------------------------------------------------------- - class cos_FunctionNode : public FunctionNode + template + class cos_FunctionNode : public FunctionNode { public: - explicit cos_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit cos_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = cos(m_nodes[0]->Evaluate()); - + + Value result = cos(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class cos_FunctionFactory : public FunctionFactory + + template + class cos_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "cos"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new cos_FunctionNode(expr); + return new cos_FunctionNode(expr); } }; - + // Tangent //-------------------------------------------------------------------------- - class tan_FunctionNode : public FunctionNode + template + class tan_FunctionNode : public FunctionNode { public: - explicit tan_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit tan_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = tan(m_nodes[0]->Evaluate()); - + + Value result = tan(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class tan_FunctionFactory : public FunctionFactory + + template + class tan_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "tan"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new tan_FunctionNode(expr); + return new tan_FunctionNode(expr); } }; - + // Hyperbolic Sine //-------------------------------------------------------------------------- - class sinh_FunctionNode : public FunctionNode + template + class sinh_FunctionNode : public FunctionNode { public: - explicit sinh_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit sinh_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = sinh(m_nodes[0]->Evaluate()); - + + auto result = sinh(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class sinh_FunctionFactory : public FunctionFactory + + template + class sinh_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "sinh"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new sinh_FunctionNode(expr); + return new sinh_FunctionNode(expr); } }; - + // Hyperbolic Cosine //-------------------------------------------------------------------------- - class cosh_FunctionNode : public FunctionNode + template + class cosh_FunctionNode : public FunctionNode { public: - explicit cosh_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit cosh_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = cosh(m_nodes[0]->Evaluate()); - + + Value result = cosh(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class cosh_FunctionFactory : public FunctionFactory + + template + class cosh_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "cosh"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new cosh_FunctionNode(expr); + return new cosh_FunctionNode(expr); } }; - + // Hyperbolic Tangent //-------------------------------------------------------------------------- - class tanh_FunctionNode : public FunctionNode + template + class tanh_FunctionNode : public FunctionNode { public: - explicit tanh_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit tanh_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = tanh(m_nodes[0]->Evaluate()); - + + Value result = tanh(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class tanh_FunctionFactory : public FunctionFactory + + template + class tanh_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "tanh"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new tanh_FunctionNode(expr); + return new tanh_FunctionNode(expr); } }; - + // Arc Sine //-------------------------------------------------------------------------- - class asin_FunctionNode : public FunctionNode + template + class asin_FunctionNode : public FunctionNode { public: - explicit asin_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit asin_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = asin(m_nodes[0]->Evaluate()); - + + Value result = asin(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class asin_FunctionFactory : public FunctionFactory + + template + class asin_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "asin"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new asin_FunctionNode(expr); + return new asin_FunctionNode(expr); } }; - + // Arc Cosine //-------------------------------------------------------------------------- - class acos_FunctionNode : public FunctionNode + template + class acos_FunctionNode : public FunctionNode { public: - explicit acos_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit acos_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = acos(m_nodes[0]->Evaluate()); - + + Value result = acos(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class acos_FunctionFactory : public FunctionFactory + + template + class acos_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "acos"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new acos_FunctionNode(expr); + return new acos_FunctionNode(expr); } }; - + // Arc Tangent //-------------------------------------------------------------------------- - class atan_FunctionNode : public FunctionNode + template + class atan_FunctionNode : public FunctionNode { public: - explicit atan_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit atan_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = atan(m_nodes[0]->Evaluate()); - + + Value result = atan(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class atan_FunctionFactory : public FunctionFactory + + template + class atan_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "atan"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new atan_FunctionNode(expr); + return new atan_FunctionNode(expr); } }; - + // Arc Tangent 2: atan2(y, x) //-------------------------------------------------------------------------- - class atan2_FunctionNode : public FunctionNode + template + class atan2_FunctionNode : public FunctionNode { public: - explicit atan2_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit atan2_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = atan2(m_nodes[0]->Evaluate(), m_nodes[1]->Evaluate()); - + + Value result = atan2(this->m_nodes[0]->Evaluate(), this->m_nodes[1]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class atan2_FunctionFactory : public FunctionFactory + + template + class atan2_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "atan2"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new atan2_FunctionNode(expr); + return new atan2_FunctionNode(expr); } }; - + // Log //-------------------------------------------------------------------------- - class log_FunctionNode : public FunctionNode + template + class log_FunctionNode : public FunctionNode { public: - explicit log_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit log_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = log10(m_nodes[0]->Evaluate()); - + + Value result = log10(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class log_FunctionFactory : public FunctionFactory + + template + class log_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "log"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new log_FunctionNode(expr); + return new log_FunctionNode(expr); } }; - + // Ln //-------------------------------------------------------------------------- - class ln_FunctionNode : public FunctionNode + template + class ln_FunctionNode : public FunctionNode { public: - explicit ln_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit ln_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double result = log(m_nodes[0]->Evaluate()); - + + Value result = log(this->m_nodes[0]->Evaluate()); + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class ln_FunctionFactory : public FunctionFactory + + template + class ln_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "ln"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new ln_FunctionNode(expr); + return new ln_FunctionNode(expr); } }; - + // Exp //-------------------------------------------------------------------------- - class exp_FunctionNode : public FunctionNode + template + class exp_FunctionNode : public FunctionNode { public: - explicit exp_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit exp_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - double x = m_nodes[0]->Evaluate(); - double result = exp(x); + Value x = this->m_nodes[0]->Evaluate(); + Value result = exp(x); - if (errno && x > 0.0) // Fixed: No exception on underflow - throw(MathException(GetName())); + if (errno && x > Value(0.0)) // Fixed: No exception on underflow + throw(MathException(this->GetName())); return result; } }; - - class exp_FunctionFactory : public FunctionFactory + + template + class exp_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "exp"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new exp_FunctionNode(expr); + return new exp_FunctionNode(expr); } }; - + // Logn //-------------------------------------------------------------------------- - class logn_FunctionNode : public FunctionNode + template + class logn_FunctionNode : public FunctionNode { public: - explicit logn_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit logn_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - + // Check for division by zero - double tmp = log(m_nodes[1]->Evaluate()); - + Value tmp = log(this->m_nodes[1]->Evaluate()); + if(tmp == 0.0) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + // Calculate result - double result = log(m_nodes[0]->Evaluate()) / tmp; - + Value result = log(this->m_nodes[0]->Evaluate()) / tmp; + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return result; } }; - - class logn_FunctionFactory : public FunctionFactory + + template + class logn_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "logn"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new logn_FunctionNode(expr); + return new logn_FunctionNode(expr); } }; - class pow_FunctionNode : public FunctionNode + template + class pow_FunctionNode : public FunctionNode { public: - explicit pow_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit pow_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - double DoEvaluate() override + Value DoEvaluate() override { errno = 0; - double result = pow(m_nodes[0]->Evaluate(), m_nodes[1]->Evaluate()); + Value result = pow(this->m_nodes[0]->Evaluate(), this->m_nodes[1]->Evaluate()); - if(errno) - throw(MathException(GetName())); + if constexpr (std::is_same_v) { // disable exception with autodiff + if(errno) + throw(MathException(this->GetName())); + } return result; } }; - class pow_FunctionFactory : public FunctionFactory + template + class pow_FunctionFactory : public FunctionFactory { public: std::string GetName() const override @@ -827,523 +875,552 @@ namespace return "pow"; } - FunctionNode *DoCreate(Expression *expr) override + FunctionNode *DoCreate(Expression *expr) override { - return new pow_FunctionNode(expr); + return new pow_FunctionNode(expr); } }; // Ceil //-------------------------------------------------------------------------- - class ceil_FunctionNode : public FunctionNode + template + class ceil_FunctionNode : public FunctionNode { public: - explicit ceil_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit ceil_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return ceil(m_nodes[0]->Evaluate()); + return ceil(this->m_nodes[0]->Evaluate()); } }; - - class ceil_FunctionFactory : public FunctionFactory + + template + class ceil_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "ceil"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new ceil_FunctionNode(expr); + return new ceil_FunctionNode(expr); } }; - + // Floor //-------------------------------------------------------------------------- - class floor_FunctionNode : public FunctionNode + template + class floor_FunctionNode : public FunctionNode { public: - explicit floor_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit floor_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return floor(m_nodes[0]->Evaluate()); + return floor(this->m_nodes[0]->Evaluate()); } }; - - class floor_FunctionFactory : public FunctionFactory + + template + class floor_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "floor"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new floor_FunctionNode(expr); + return new floor_FunctionNode(expr); } }; - + // Rand //-------------------------------------------------------------------------- - + // Get next random (0,1) inline double NextRandom(double *seed) { long a = (long)(*seed) * 214013L + 2531011L; *seed = (double)a; - + return (double)((a >> 16) & 0x7FFF) / 32767.0; } - - class rand_FunctionNode : public FunctionNode + + template + class rand_FunctionNode : public FunctionNode { public: - explicit rand_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit rand_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(0, 0, 1, 1); + this->SetArgumentCount(0, 0, 1, 1); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return NextRandom(m_refs[0]); + return NextRandom(this->m_refs[0]); } }; - - class rand_FunctionFactory : public FunctionFactory + + template + class rand_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "rand"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new rand_FunctionNode(expr); + return new rand_FunctionNode(expr); } }; - + // Random //-------------------------------------------------------------------------- - class random_FunctionNode : public FunctionNode + template + class random_FunctionNode : public FunctionNode { public: - explicit random_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit random_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 1, 1); + this->SetArgumentCount(2, 2, 1, 1); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double a = m_nodes[0]->Evaluate(); - double b = m_nodes[1]->Evaluate(); - - return NextRandom(m_refs[0]) * (b - a) + a; + Value a = this->m_nodes[0]->Evaluate(); + Value b = this->m_nodes[1]->Evaluate(); + + return NextRandom(this->m_refs[0]) * (b - a) + a; } }; - - class random_FunctionFactory : public FunctionFactory + + template + class random_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "random"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new random_FunctionNode(expr); + return new random_FunctionNode(expr); } - }; - + }; + // Randomize //-------------------------------------------------------------------------- - class randomize_FunctionNode : public FunctionNode + template + class randomize_FunctionNode : public FunctionNode { public: - explicit randomize_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit randomize_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(0, 0, 1, 1); + this->SetArgumentCount(0, 0, 1, 1); } - - double DoEvaluate() override + + Value DoEvaluate() override { static long curcall = 1; - - *m_refs[0] = (double)(((clock() + 1024u + curcall) * time(NULL)) % 2176971487u); + + *this->m_refs[0] = (Value)(((clock() + 1024u + curcall) * time(NULL)) % 2176971487u); curcall++; - + return 0.0; } }; - - class randomize_FunctionFactory : public FunctionFactory + + template + class randomize_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "randomize"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new randomize_FunctionNode(expr); + return new randomize_FunctionNode(expr); } }; - + // Radians to degrees //-------------------------------------------------------------------------- - class deg_FunctionNode : public FunctionNode + template + class deg_FunctionNode : public FunctionNode { public: - explicit deg_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit deg_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return (m_nodes[0]->Evaluate() * 180.0) / M_PI; + return (this->m_nodes[0]->Evaluate() * 180.0) / M_PI; } }; - - class deg_FunctionFactory : public FunctionFactory + + template + class deg_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "deg"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new deg_FunctionNode(expr); + return new deg_FunctionNode(expr); } }; - + // Degrees to radians //-------------------------------------------------------------------------- - class rad_FunctionNode : public FunctionNode + template + class rad_FunctionNode : public FunctionNode { public: - explicit rad_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit rad_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - return (m_nodes[0]->Evaluate() * M_PI) / 180.0; + return (this->m_nodes[0]->Evaluate() * M_PI) / 180.0; } }; - - class rad_FunctionFactory : public FunctionFactory + + template + class rad_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "rad"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new rad_FunctionNode(expr); + return new rad_FunctionNode(expr); } }; - + // Rectangular to polar: rect2pol(x, y, &distance, &angle) //-------------------------------------------------------------------------- - class rect2pol_FunctionNode : public FunctionNode + template + class rect2pol_FunctionNode : public FunctionNode { public: - explicit rect2pol_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit rect2pol_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 2, 2); + this->SetArgumentCount(2, 2, 2, 2); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double x = m_nodes[0]->Evaluate(); - double y = m_nodes[1]->Evaluate(); - - double d = hypot(x, y); - double a = atan2(y, x); - + + Value x = this->m_nodes[0]->Evaluate(); + Value y = this->m_nodes[1]->Evaluate(); + + Value d = hypot(x, y); + Value a = atan2(y, x); + if(errno) - throw(MathException(GetName())); - - *m_refs[0] = d; + throw(MathException(this->GetName())); + + *this->m_refs[0] = d; if(a < 0.0) - *m_refs[1] = a + (2.0 * M_PI); + *this->m_refs[1] = a + (2.0 * M_PI); else - *m_refs[1] = a; - + *this->m_refs[1] = a; + return d; } }; - - class rect2pol_FunctionFactory : public FunctionFactory + + template + class rect2pol_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "rect2pol"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new rect2pol_FunctionNode(expr); + return new rect2pol_FunctionNode(expr); } }; - + // Polar to rectangular: pol2rect(distance, angle, &x, &y) //-------------------------------------------------------------------------- - class pol2rect_FunctionNode : public FunctionNode + template + class pol2rect_FunctionNode : public FunctionNode { public: - explicit pol2rect_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit pol2rect_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 2, 2); + this->SetArgumentCount(2, 2, 2, 2); } - - double DoEvaluate() override + + Value DoEvaluate() override { errno = 0; - - double d = m_nodes[0]->Evaluate(); - double a = m_nodes[1]->Evaluate(); - - double x = d * cos(a); - double y = d * sin(a); - + + Value d = this->m_nodes[0]->Evaluate(); + Value a = this->m_nodes[1]->Evaluate(); + + Value x = d * cos(a); + Value y = d * sin(a); + if(errno) - throw(MathException(GetName())); - - *m_refs[0] = x; - *m_refs[1] = y; - + throw(MathException(this->GetName())); + + *this->m_refs[0] = x; + *this->m_refs[1] = y; + return x; } }; - - class pol2rect_FunctionFactory : public FunctionFactory + + template + class pol2rect_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "pol2rect"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new pol2rect_FunctionNode(expr); + return new pol2rect_FunctionNode(expr); } }; - + // If //-------------------------------------------------------------------------- - class if_FunctionNode : public FunctionNode + template + class if_FunctionNode : public FunctionNode { public: - explicit if_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit if_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(3, 3, 0, 0); + this->SetArgumentCount(3, 3, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double c = m_nodes[0]->Evaluate(); - + Value c = this->m_nodes[0]->Evaluate(); + if(c == 0.0) - return m_nodes[2]->Evaluate(); + return this->m_nodes[2]->Evaluate(); else - return m_nodes[1]->Evaluate(); + return this->m_nodes[1]->Evaluate(); } }; - - class if_FunctionFactory : public FunctionFactory + + template + class if_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "if"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new if_FunctionNode(expr); + return new if_FunctionNode(expr); } }; - + // Select //-------------------------------------------------------------------------- - class select_FunctionNode : public FunctionNode + template + class select_FunctionNode : public FunctionNode { public: - explicit select_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit select_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(3, 4, 0, 0); + this->SetArgumentCount(3, 4, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double c = m_nodes[0]->Evaluate(); - + Value c = this->m_nodes[0]->Evaluate(); + if(c < 0.0) - return m_nodes[1]->Evaluate(); + return this->m_nodes[1]->Evaluate(); else if(c == 0.0) - return m_nodes[2]->Evaluate(); + return this->m_nodes[2]->Evaluate(); else { - if(m_nodes.size() == 3) - return m_nodes[2]->Evaluate(); + if(this->m_nodes.size() == 3) + return this->m_nodes[2]->Evaluate(); else - return m_nodes[3]->Evaluate(); + return this->m_nodes[3]->Evaluate(); } } }; - - class select_FunctionFactory : public FunctionFactory + + template + class select_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "select"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new select_FunctionNode(expr); + return new select_FunctionNode(expr); } }; - + // Equal //-------------------------------------------------------------------------- - class equal_FunctionNode : public FunctionNode + template + class equal_FunctionNode : public FunctionNode { public: - explicit equal_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit equal_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() == m_nodes[1]->Evaluate()) + if(this->m_nodes[0]->Evaluate() == this->m_nodes[1]->Evaluate()) return 1.0; else return 0.0; } }; - - class equal_FunctionFactory : public FunctionFactory + + template + class equal_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "equal"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new equal_FunctionNode(expr); + return new equal_FunctionNode(expr); } }; - + // Above //-------------------------------------------------------------------------- - class above_FunctionNode : public FunctionNode + template + class above_FunctionNode : public FunctionNode { public: - explicit above_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit above_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() > m_nodes[1]->Evaluate()) + if(this->m_nodes[0]->Evaluate() > this->m_nodes[1]->Evaluate()) return 1.0; else return 0.0; } }; - - class above_FunctionFactory : public FunctionFactory + + template + class above_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "above"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new above_FunctionNode(expr); + return new above_FunctionNode(expr); } }; - + // Below //-------------------------------------------------------------------------- - class below_FunctionNode : public FunctionNode + template + class below_FunctionNode : public FunctionNode { public: - explicit below_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit below_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() < m_nodes[1]->Evaluate()) + if(this->m_nodes[0]->Evaluate() < this->m_nodes[1]->Evaluate()) return 1.0; else return 0.0; } }; - - class below_FunctionFactory : public FunctionFactory + + template + class below_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "below"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new below_FunctionNode(expr); + return new below_FunctionNode(expr); } }; - + // Clip //-------------------------------------------------------------------------- - class clip_FunctionNode : public FunctionNode + template + class clip_FunctionNode : public FunctionNode { public: - explicit clip_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit clip_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(3, 3, 0, 0); + this->SetArgumentCount(3, 3, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double v = m_nodes[0]->Evaluate(); - double a = m_nodes[1]->Evaluate(); - double b = m_nodes[2]->Evaluate(); - + Value v = this->m_nodes[0]->Evaluate(); + Value a = this->m_nodes[1]->Evaluate(); + Value b = this->m_nodes[2]->Evaluate(); + if(v < a) return a; else if(v > b) @@ -1352,43 +1429,45 @@ namespace return v; } }; - - class clip_FunctionFactory : public FunctionFactory + + template + class clip_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "clip"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new clip_FunctionNode(expr); + return new clip_FunctionNode(expr); } }; - + // Clamp //-------------------------------------------------------------------------- - class clamp_FunctionNode : public FunctionNode + template + class clamp_FunctionNode : public FunctionNode { public: - explicit clamp_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit clamp_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(3, 3, 0, 0); + this->SetArgumentCount(3, 3, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double v = m_nodes[0]->Evaluate(); - double a = m_nodes[1]->Evaluate(); - double b = m_nodes[2]->Evaluate(); - + Value v = this->m_nodes[0]->Evaluate(); + Value a = this->m_nodes[1]->Evaluate(); + Value b = this->m_nodes[2]->Evaluate(); + if(a == b) return a; else { - double tmp = fmod(v - a, b - a); - + Value tmp = fmod(v - a, b - a); + if(tmp < 0) return tmp + b; else @@ -1396,40 +1475,42 @@ namespace } } }; - - class clamp_FunctionFactory : public FunctionFactory + + template + class clamp_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "clamp"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new clamp_FunctionNode(expr); + return new clamp_FunctionNode(expr); } }; - + // Rescale //-------------------------------------------------------------------------- - class rescale_FunctionNode : public FunctionNode + template + class rescale_FunctionNode : public FunctionNode { public: - explicit rescale_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit rescale_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(5, 5, 0, 0); + this->SetArgumentCount(5, 5, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double pnt = m_nodes[0]->Evaluate(); - double o1 = m_nodes[1]->Evaluate(); - double o2 = m_nodes[2]->Evaluate(); - double n1 = m_nodes[3]->Evaluate(); - double n2 = m_nodes[4]->Evaluate(); - - double odiff = o2 - o1; + Value pnt = this->m_nodes[0]->Evaluate(); + Value o1 = this->m_nodes[1]->Evaluate(); + Value o2 = this->m_nodes[2]->Evaluate(); + Value n1 = this->m_nodes[3]->Evaluate(); + Value n2 = this->m_nodes[4]->Evaluate(); + + Value odiff = o2 - o1; if(odiff == 0.0) return n1; else @@ -1438,239 +1519,257 @@ namespace } } }; - - class rescale_FunctionFactory : public FunctionFactory + + template + class rescale_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "rescale"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new rescale_FunctionNode(expr); + return new rescale_FunctionNode(expr); } }; - + // Poly: poly(x, c3, c2, c1, c0): c3*x^3 + c2*x^2 + c1*x + c0 //-------------------------------------------------------------------------- - class poly_FunctionNode : public FunctionNode + template + class poly_FunctionNode : public FunctionNode { public: - explicit poly_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit poly_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, -1, 0, 0); + this->SetArgumentCount(2, -1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - double total = 0.0; + Value total = 0.0; double curpow; - - std::vector::size_type pos, count = m_nodes.size(); - - curpow = (double)(count - 2); - + + curpow = (double)(this->m_nodes.size() - 2); + // Value of x - double x = m_nodes[0]->Evaluate(); - + Value x = this->m_nodes[0]->Evaluate(); + errno = 0; - - for(pos = 1; pos < count; pos++) + + for (const auto& node : this->m_nodes) { - total += (m_nodes[pos]->Evaluate() * pow(x, curpow)); + total += node->Evaluate() * pow(x, curpow); curpow -= 1.0; } - + if(errno) - throw(MathException(GetName())); - + throw(MathException(this->GetName())); + return total; } }; - - class poly_FunctionFactory : public FunctionFactory + + template + class poly_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "poly"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new poly_FunctionNode(expr); + return new poly_FunctionNode(expr); } }; - + // And //-------------------------------------------------------------------------- - class and_FunctionNode : public FunctionNode + template + class and_FunctionNode : public FunctionNode { public: - explicit and_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit and_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() == 0.0 || m_nodes[1]->Evaluate() == 0.0) + if(this->m_nodes[0]->Evaluate() == 0.0 || this->m_nodes[1]->Evaluate() == 0.0) return 0.0; else return 1.0; } }; - - class and_FunctionFactory : public FunctionFactory + + template + class and_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "and"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new and_FunctionNode(expr); + return new and_FunctionNode(expr); } }; - + // Or //-------------------------------------------------------------------------- - class or_FunctionNode : public FunctionNode + template + class or_FunctionNode : public FunctionNode { public: - explicit or_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit or_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(2, 2, 0, 0); + this->SetArgumentCount(2, 2, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() == 0.0 && m_nodes[1]->Evaluate() == 0.0) + if(this->m_nodes[0]->Evaluate() == 0.0 && this->m_nodes[1]->Evaluate() == 0.0) return 0.0; else return 1.0; } }; - - class or_FunctionFactory : public FunctionFactory + + template + class or_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "or"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new or_FunctionNode(expr); + return new or_FunctionNode(expr); } }; - + // Not //-------------------------------------------------------------------------- - class not_FunctionNode : public FunctionNode + template + class not_FunctionNode : public FunctionNode { public: - explicit not_FunctionNode(Expression *expr) : FunctionNode(expr) + explicit not_FunctionNode(Expression *expr) : FunctionNode(expr) { - SetArgumentCount(1, 1, 0, 0); + this->SetArgumentCount(1, 1, 0, 0); } - - double DoEvaluate() override + + Value DoEvaluate() override { - if(m_nodes[0]->Evaluate() == 0.0) + if(this->m_nodes[0]->Evaluate() == 0.0) return 1.0; else return 0.0; } }; - - class not_FunctionFactory : public FunctionFactory + + template + class not_FunctionFactory : public FunctionFactory { public: std::string GetName() const override { return "not"; } - - FunctionNode *DoCreate(Expression *expr) override + + FunctionNode *DoCreate(Expression *expr) override { - return new not_FunctionNode(expr); + return new not_FunctionNode(expr); } }; - + } // namespace // Initialize default functions -void FunctionList::AddDefaultFunctions() +template +void FunctionList::AddDefaultFunctions() { #define ADDFUNCTION(name) \ - aptr(FunctionFactory) name ## _func(new name ## _FunctionFactory()); \ + aptr(FunctionFactory) name ## _func(new name ## _FunctionFactory()); \ m_functions.push_back(name ## _func.get()); \ name ## _func.release(); - + ADDFUNCTION(abs); - ADDFUNCTION(mod); - - ADDFUNCTION(ipart); - ADDFUNCTION(fpart); - - ADDFUNCTION(min); - ADDFUNCTION(max); + ADDFUNCTION(sqrt); - + ADDFUNCTION(sin); ADDFUNCTION(cos); ADDFUNCTION(tan); - + ADDFUNCTION(sinh); ADDFUNCTION(cosh); ADDFUNCTION(tanh); - + ADDFUNCTION(asin); ADDFUNCTION(acos); ADDFUNCTION(atan); ADDFUNCTION(atan2); - + ADDFUNCTION(log); ADDFUNCTION(ln); ADDFUNCTION(exp); ADDFUNCTION(logn); ADDFUNCTION(pow); - - ADDFUNCTION(ceil); - ADDFUNCTION(floor); - - ADDFUNCTION(rand); - ADDFUNCTION(random); - ADDFUNCTION(randomize); - + ADDFUNCTION(deg); ADDFUNCTION(rad); - + ADDFUNCTION(rect2pol); ADDFUNCTION(pol2rect); - + ADDFUNCTION(if); // Preprocess will take care of this beforehand ADDFUNCTION(select); - - ADDFUNCTION(equal) + + ADDFUNCTION(equal); ADDFUNCTION(above); ADDFUNCTION(below); - + ADDFUNCTION(clip); - ADDFUNCTION(clamp); ADDFUNCTION(rescale); - + ADDFUNCTION(poly); - + ADDFUNCTION(and); ADDFUNCTION(or); ADDFUNCTION(not); + + if constexpr (std::is_same_v) { + ADDFUNCTION(mod); + + ADDFUNCTION(ipart); + ADDFUNCTION(fpart); + + ADDFUNCTION(min); + ADDFUNCTION(max); + + ADDFUNCTION(ceil); + ADDFUNCTION(floor); + + ADDFUNCTION(rand); + ADDFUNCTION(random); + ADDFUNCTION(randomize); + + ADDFUNCTION(clamp); + } +} + +namespace ExprEval { +template void FunctionList::AddDefaultFunctions(); +template class FunctionList; } diff --git a/3rdparty/expreval/funclist.cpp b/3rdparty/expreval/funclist.cpp index f9127d89..2c5207b6 100644 --- a/3rdparty/expreval/funclist.cpp +++ b/3rdparty/expreval/funclist.cpp @@ -11,6 +11,8 @@ #include "except.h" #include "node.h" +#include + using namespace std; using namespace ExprEval; @@ -18,19 +20,22 @@ using namespace ExprEval; //------------------------------------------------------------------------------ // Constructor -FunctionFactory::FunctionFactory() +template +FunctionFactory::FunctionFactory() { } // Destructor -FunctionFactory::~FunctionFactory() +template +FunctionFactory::~FunctionFactory() { } // Create -FunctionNode *FunctionFactory::Create(Expression *expr) +template +FunctionNode *FunctionFactory::Create(Expression *expr) { - FunctionNode *n = DoCreate(expr); + FunctionNode *n = DoCreate(expr); if(n) n->m_factory = this; @@ -42,19 +47,22 @@ FunctionNode *FunctionFactory::Create(Expression *expr) //------------------------------------------------------------------------------ // Constructor -FunctionList::FunctionList() +template +FunctionList::FunctionList() { } // Destructor -FunctionList::~FunctionList() +template +FunctionList::~FunctionList() { // Free function factories Clear(); } // Add factory to list -void FunctionList::Add(FunctionFactory *factory) +template +void FunctionList::Add(FunctionFactory *factory) { // Check it if(factory == 0) @@ -73,7 +81,8 @@ void FunctionList::Add(FunctionFactory *factory) } // Create a node for a function -FunctionNode *FunctionList::Create(const string &name, Expression *expr) +template +FunctionNode *FunctionList::Create(const string &name, Expression *expr) { // Make sure pointer exists if(expr == 0) @@ -98,7 +107,8 @@ FunctionNode *FunctionList::Create(const string &name, Expression *expr) // along with the default function factories // Free function list -void FunctionList::Clear() +template +void FunctionList::Clear() { size_type pos; @@ -108,6 +118,11 @@ void FunctionList::Clear() } } +#define INSTANCE(...) \ +template class FunctionFactory<__VA_ARGS__>; \ +template class FunctionList<__VA_ARGS__>; - - +namespace ExprEval { +INSTANCE(double) +INSTANCE(autodiff::var) +} diff --git a/3rdparty/expreval/funclist.h b/3rdparty/expreval/funclist.h index 31d39e3b..89a293f1 100644 --- a/3rdparty/expreval/funclist.h +++ b/3rdparty/expreval/funclist.h @@ -15,11 +15,12 @@ namespace ExprEval { // Forward declarations - class FunctionNode; - class Expression; + template class FunctionNode; + template class Expression; // Function factory //-------------------------------------------------------------------------- + template class FunctionFactory { public: @@ -27,26 +28,28 @@ namespace ExprEval virtual ~FunctionFactory(); virtual ::std::string GetName() const = 0; - virtual FunctionNode *DoCreate(Expression *expr) = 0; + virtual FunctionNode *DoCreate(Expression *expr) = 0; - FunctionNode *Create(Expression *expr); + FunctionNode *Create(Expression *expr); }; // Function list //-------------------------------------------------------------------------- + template class FunctionList { public: - typedef ::std::vector::size_type size_type; + using FactoryVec = std::vector*>; + using size_type = typename FactoryVec::size_type; FunctionList(); ~FunctionList(); // Add a function factory to the list - void Add(FunctionFactory *factory); + void Add(FunctionFactory *factory); // Create a node for a function - FunctionNode *Create(const ::std::string &name, Expression *expr); + FunctionNode *Create(const ::std::string &name, Expression *expr); // Initialize default functions void AddDefaultFunctions(); @@ -55,7 +58,7 @@ namespace ExprEval void Clear(); private: - ::std::vector m_functions; + ::std::vector*> m_functions; }; } diff --git a/3rdparty/expreval/node.cpp b/3rdparty/expreval/node.cpp index 3394e118..07e9efcf 100644 --- a/3rdparty/expreval/node.cpp +++ b/3rdparty/expreval/node.cpp @@ -17,6 +17,8 @@ #include "funclist.h" #include "except.h" +#include + using namespace std; using namespace ExprEval; @@ -24,19 +26,22 @@ using namespace ExprEval; //------------------------------------------------------------------------------ // Constructor -Node::Node(Expression *expr) : m_expr(expr) +template +Node::Node(Expression* expr) : m_expr(expr) { if(expr == 0) throw(NullPointerException("Node::Node")); } // Destructor -Node::~Node() +template +Node::~Node() { } // Evaluate -double Node::Evaluate() +template +Value Node::Evaluate() { m_expr->TestAbort(); @@ -47,31 +52,30 @@ double Node::Evaluate() //------------------------------------------------------------------------------ // Constructor -FunctionNode::FunctionNode(Expression *expr) : Node(expr), +template +FunctionNode::FunctionNode(Expression *expr) : Node(expr), m_argMin(0), m_argMax(0), m_refMin(0), m_refMax(0) { } // Destructor -FunctionNode::~FunctionNode() +template +FunctionNode::~FunctionNode() { - // Delete child nodes - vector::size_type pos; - - for(pos = 0; pos < m_nodes.size(); pos++) - { - delete m_nodes[pos]; - } + for (auto& node : m_nodes) + delete node; } // Get name -string FunctionNode::GetName() const +template +string FunctionNode::GetName() const { return m_factory->GetName(); } // Set argument count -void FunctionNode::SetArgumentCount(long argMin, long argMax, long refMin, long refMax) +template +void FunctionNode::SetArgumentCount(long argMin, long argMax, long refMin, long refMax) { m_argMin = argMin; m_argMax = argMax; @@ -80,10 +84,13 @@ void FunctionNode::SetArgumentCount(long argMin, long argMax, long refMin, long } // Parse expression -void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void FunctionNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { - Parser::size_type pos, last; + typename Parser::size_type pos, last; int plevel = 0; // Sanity check (start/end are function parenthesis) @@ -135,7 +142,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t if(last == pos - 2 && parser[last + 1].GetType() == Token::TypeIdentifier) { // Get value list - ValueList *vlist = m_expr->GetValueList(); + ValueList *vlist = this->m_expr->GetValueList(); if(vlist == 0) { NoValueListException e; @@ -159,7 +166,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t } // Get address - double *vaddr = vlist->GetAddress(ident); + Value *vaddr = vlist->GetAddress(ident); if(vaddr == 0) { // Try to add it and get again @@ -192,7 +199,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t else { // Create node - aptr(Node) n(parser.ParseRegion(last, pos - 1)); + aptr(Node) n(parser.ParseRegion(last, pos - 1)); m_nodes.push_back(n.get()); n.release(); } @@ -229,7 +236,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t // Check argument count - if(m_argMin != -1 && m_nodes.size() < (vector::size_type)m_argMin) + if(m_argMin != -1 && m_nodes.size() < static_cast(m_argMin)) { InvalidArgumentCountException e(GetName()); @@ -237,8 +244,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t e.SetEnd(parser[end].GetEnd()); throw(e); } - - if(m_argMax != -1 && m_nodes.size() > (vector::size_type)m_argMax) + + if(m_argMax != -1 && m_nodes.size() > static_cast(m_argMax)) { InvalidArgumentCountException e(GetName()); @@ -246,8 +253,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t e.SetEnd(parser[end].GetEnd()); throw(e); } - - if(m_refMin != -1 && m_refs.size() < (vector::size_type)m_refMin) + + if(m_refMin != -1 && m_refs.size() < static_cast(m_refMin)) { InvalidArgumentCountException e(GetName()); @@ -255,8 +262,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t e.SetEnd(parser[end].GetEnd()); throw(e); } - - if(m_refMax != -1 && m_refs.size() > (vector::size_type)m_refMax) + + if(m_refMax != -1 && m_refs.size() > static_cast(m_refMax)) { InvalidArgumentCountException e(GetName()); @@ -270,41 +277,43 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t //------------------------------------------------------------------------------ // Constructor -MultiNode::MultiNode(Expression *expr) : Node(expr) +template +MultiNode::MultiNode(Expression *expr) : Node(expr) { } // Destructor -MultiNode::~MultiNode() +template +MultiNode::~MultiNode() { - // Free child nodes - vector::size_type pos; - - for(pos = 0; pos < m_nodes.size(); pos++) + for(auto& node : m_nodes) { - delete m_nodes[pos]; + delete node; } } // Evaluate -double MultiNode::DoEvaluate() +template +Value MultiNode::DoEvaluate() { - vector::size_type pos; - double result = 0.0; + Value result = 0.0; - for(pos = 0; pos < m_nodes.size(); pos++) + for(auto& node : m_nodes) { - result = m_nodes[pos]->Evaluate(); + result = node->Evaluate(); } return result; } // Parse -void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void MultiNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { - Parser::size_type pos, last; + typename Parser::size_type pos, last; int plevel = 0; // Sanity check @@ -346,7 +355,7 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type if(pos > last) { // Everything from last to pos - 1 - aptr(Node) n(parser.ParseRegion(last, pos - 1)); + aptr(Node) n(parser.ParseRegion(last, pos - 1)); m_nodes.push_back(n.get()); n.release(); } @@ -382,7 +391,7 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type // If the end was not a semicolon, test it as well if(last < end + 1) { - aptr(Node) n(parser.ParseRegion(last, end)); + aptr(Node) n(parser.ParseRegion(last, end)); m_nodes.push_back(n.get()); n.release(); } @@ -392,26 +401,32 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type //------------------------------------------------------------------------------ // Constructor -AssignNode::AssignNode(Expression *expr) : Node(expr), m_var(0), m_rhs(0) +template +AssignNode::AssignNode(Expression *expr) : Node(expr), m_var(0), m_rhs(0) { } // Destructor -AssignNode::~AssignNode() +template +AssignNode::~AssignNode() { // Free child node delete m_rhs; } // Evaluate -double AssignNode::DoEvaluate() +template +Value AssignNode::DoEvaluate() { return (*m_var = m_rhs->Evaluate()); } // Parse -void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void AssignNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 != start + 1 || v1 >= end) @@ -424,7 +439,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ } // Get the value list - ValueList *vlist = m_expr->GetValueList(); + ValueList *vlist = this->m_expr->GetValueList(); if(vlist == 0) { NoValueListException e; @@ -448,7 +463,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ } // Get address - double *vaddr = vlist->GetAddress(ident); + Value *vaddr = vlist->GetAddress(ident); if(vaddr == 0) { @@ -467,7 +482,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ } // Parse the node (will throw if it can not parse) - aptr(Node) n(parser.ParseRegion(v1 + 1, end)); + aptr(Node) n(parser.ParseRegion(v1 + 1, end)); // Set data m_var = vaddr; @@ -479,12 +494,14 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ //------------------------------------------------------------------------------ // Constructor -AddNode::AddNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) +template +AddNode::AddNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) { } // Destructor -AddNode::~AddNode() +template +AddNode::~AddNode() { // Free child nodes delete m_lhs; @@ -492,14 +509,18 @@ AddNode::~AddNode() } // Evaluate -double AddNode::DoEvaluate() +template +Value AddNode::DoEvaluate() { return m_lhs->Evaluate() + m_rhs->Evaluate(); } // Parse -void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void AddNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 <= start || v1 >= end) @@ -512,8 +533,8 @@ void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type e } // Parse sides - aptr(Node) left(parser.ParseRegion(start, v1 - 1)); - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) left(parser.ParseRegion(start, v1 - 1)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_lhs = left.release(); m_rhs = right.release(); @@ -523,12 +544,14 @@ void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type e //------------------------------------------------------------------------------ // Constructor -SubtractNode::SubtractNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) +template +SubtractNode::SubtractNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) { } // Destructor -SubtractNode::~SubtractNode() +template +SubtractNode::~SubtractNode() { // Free child nodes delete m_lhs; @@ -536,14 +559,18 @@ SubtractNode::~SubtractNode() } // Evaluate -double SubtractNode::DoEvaluate() +template +Value SubtractNode::DoEvaluate() { return m_lhs->Evaluate() - m_rhs->Evaluate(); } // Parse -void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void SubtractNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 <= start || v1 >= end) @@ -556,8 +583,8 @@ void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t } // Parse sides - aptr(Node) left(parser.ParseRegion(start, v1 - 1)); - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) left(parser.ParseRegion(start, v1 - 1)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_lhs = left.release(); m_rhs = right.release(); @@ -567,12 +594,14 @@ void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t //------------------------------------------------------------------------------ // Constructor -MultiplyNode::MultiplyNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) +template +MultiplyNode::MultiplyNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) { } // Destructor -MultiplyNode::~MultiplyNode() +template +MultiplyNode::~MultiplyNode() { // Free child nodes delete m_lhs; @@ -580,14 +609,18 @@ MultiplyNode::~MultiplyNode() } // Evaluate -double MultiplyNode::DoEvaluate() +template +Value MultiplyNode::DoEvaluate() { return m_lhs->Evaluate() * m_rhs->Evaluate(); } // Parse -void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void MultiplyNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 <= start || v1 >= end) @@ -600,8 +633,8 @@ void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t } // Parse sides - aptr(Node) left(parser.ParseRegion(start, v1 - 1)); - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) left(parser.ParseRegion(start, v1 - 1)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_lhs = left.release(); m_rhs = right.release(); @@ -611,12 +644,14 @@ void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t //------------------------------------------------------------------------------ // Constructor -DivideNode::DivideNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) +template +DivideNode::DivideNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) { } // Destructor -DivideNode::~DivideNode() +template +DivideNode::~DivideNode() { // Free child nodes delete m_lhs; @@ -624,11 +659,12 @@ DivideNode::~DivideNode() } // Evaluate -double DivideNode::DoEvaluate() +template +Value DivideNode::DoEvaluate() { - double r2 = m_rhs->Evaluate(); + Value r2 = m_rhs->Evaluate(); - if(r2 != 0.0) + if(r2 != Value(0.0)) { return m_lhs->Evaluate() / r2; } @@ -639,8 +675,11 @@ double DivideNode::DoEvaluate() } // Parse -void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void DivideNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 <= start || v1 >= end) @@ -653,8 +692,8 @@ void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ } // Parse sides - aptr(Node) left(parser.ParseRegion(start, v1 - 1)); - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) left(parser.ParseRegion(start, v1 - 1)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_lhs = left.release(); m_rhs = right.release(); @@ -664,26 +703,32 @@ void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ //------------------------------------------------------------------------------ // Constructor -NegateNode::NegateNode(Expression *expr) : Node(expr), m_rhs(0) +template +NegateNode::NegateNode(Expression *expr) : Node(expr), m_rhs(0) { } // Destructor -NegateNode::~NegateNode() +template +NegateNode::~NegateNode() { // Free child nodes delete m_rhs; } // Evaluate -double NegateNode::DoEvaluate() +template +Value NegateNode::DoEvaluate() { return -(m_rhs->Evaluate()); } // Parse -void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void NegateNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(start != v1 || v1 >= end) @@ -696,7 +741,7 @@ void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ } // Parse sides - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_rhs = right.release(); } @@ -705,12 +750,14 @@ void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ //------------------------------------------------------------------------------ // Constructor -ExponentNode::ExponentNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) +template +ExponentNode::ExponentNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0) { } // Destructor -ExponentNode::~ExponentNode() +template +ExponentNode::~ExponentNode() { // Free child nodes delete m_lhs; @@ -718,21 +765,27 @@ ExponentNode::~ExponentNode() } // Evaluate -double ExponentNode::DoEvaluate() +template +Value ExponentNode::DoEvaluate() { errno = 0; - - double result = pow(m_lhs->Evaluate(), m_rhs->Evaluate()); - - if(errno) - throw(MathException("^")); + + Value result = pow(m_lhs->Evaluate(), m_rhs->Evaluate()); + + if constexpr (std::is_same_v) { + if(errno) + throw(MathException("^")); + } return result; } // Parse -void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void ExponentNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(v1 <= start || v1 >= end) @@ -745,8 +798,8 @@ void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t } // Parse sides - aptr(Node) left(parser.ParseRegion(start, v1 - 1)); - aptr(Node) right(parser.ParseRegion(v1 + 1, end)); + aptr(Node) left(parser.ParseRegion(start, v1 - 1)); + aptr(Node) right(parser.ParseRegion(v1 + 1, end)); m_lhs = left.release(); m_rhs = right.release(); @@ -756,24 +809,30 @@ void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t //------------------------------------------------------------------------------ // Constructor -VariableNode::VariableNode(Expression *expr) : Node(expr), m_var(0) +template +VariableNode::VariableNode(Expression *expr) : Node(expr), m_var(0) { } // Destructor -VariableNode::~VariableNode() +template +VariableNode::~VariableNode() { } // Evaluate -double VariableNode::DoEvaluate() +template +Value VariableNode::DoEvaluate() { return *m_var; } // Parse -void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void VariableNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check some basic syntax if(start != end) @@ -786,7 +845,7 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t } // Get value list - ValueList *vlist = m_expr->GetValueList(); + ValueList *vlist = this->m_expr->GetValueList(); if(vlist == 0) { NoValueListException e; @@ -800,7 +859,7 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t string ident = parser[start].GetIdentifier(); // Get address - double *vaddr = vlist->GetAddress(ident); + Value *vaddr = vlist->GetAddress(ident); if(vaddr == 0) { @@ -826,24 +885,30 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t //------------------------------------------------------------------------------ // Constructor -ValueNode::ValueNode(Expression *expr) : Node(expr), m_val(0) +template +ValueNode::ValueNode(Expression *expr) : Node(expr), m_val(0) { } // Destructor -ValueNode::~ValueNode() +template +ValueNode::~ValueNode() { } // Evaluate -double ValueNode::DoEvaluate() +template +Value ValueNode::DoEvaluate() { return m_val; } // Parse -void ValueNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1) +template +void ValueNode::Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1) { // Check basic syntax if(start != end) @@ -857,7 +922,23 @@ void ValueNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type // Set info m_val = parser[start].GetValue(); -} - - - +} + +#define INSTANCE(T) \ + template class T; \ + template class T; + +namespace ExprEval { +INSTANCE(AddNode) +INSTANCE(AssignNode) +INSTANCE(DivideNode) +INSTANCE(ExponentNode) +INSTANCE(FunctionNode) +INSTANCE(MultiNode) +INSTANCE(MultiplyNode) +INSTANCE(NegateNode) +INSTANCE(Node) +INSTANCE(SubtractNode) +INSTANCE(ValueNode) +INSTANCE(VariableNode) +} diff --git a/3rdparty/expreval/node.h b/3rdparty/expreval/node.h index c67b7c2e..b58bf568 100644 --- a/3rdparty/expreval/node.h +++ b/3rdparty/expreval/node.h @@ -16,44 +16,49 @@ namespace ExprEval { // Forward declarations - class Expression; - class FunctionFactory; - + template class Expression; + template class FunctionFactory; // Node class //-------------------------------------------------------------------------- + template class Node { public: - explicit Node(Expression *expr); + explicit Node(Expression* expr); virtual ~Node(); - virtual double DoEvaluate() = 0; - virtual void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) = 0; + virtual Value DoEvaluate() = 0; + virtual void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) = 0; - double Evaluate(); // Calls Expression::TestAbort, then DoEvaluate + Value Evaluate(); // Calls Expression::TestAbort, then DoEvaluate protected: - Expression *m_expr; + Expression *m_expr; }; // General function node class //-------------------------------------------------------------------------- - class FunctionNode : public Node + template + class FunctionNode : public Node { public: - explicit FunctionNode(Expression *expr); + explicit FunctionNode(Expression *expr); ~FunctionNode(); // Parse nodes and references - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: // Function factory - FunctionFactory *m_factory; + FunctionFactory *m_factory; // Argument count long m_argMin; @@ -70,176 +75,206 @@ namespace ExprEval ::std::string GetName() const; // Normal, reference, and data parameters - ::std::vector m_nodes; - ::std::vector m_refs; + ::std::vector*> m_nodes; + ::std::vector m_refs; - friend class FunctionFactory; + friend class FunctionFactory; }; // Mulit-expression node //-------------------------------------------------------------------------- - class MultiNode : public Node + template + class MultiNode : public Node { public: - explicit MultiNode(Expression *expr); + explicit MultiNode(Expression *expr); ~MultiNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser& parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - ::std::vector m_nodes; + ::std::vector*> m_nodes; }; // Assign node //-------------------------------------------------------------------------- - class AssignNode : public Node + template + class AssignNode : public Node { public: - explicit AssignNode(Expression *expr); + explicit AssignNode(Expression *expr); ~AssignNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - double *m_var; - Node *m_rhs; + Value *m_var; + Node *m_rhs; }; // Add node //-------------------------------------------------------------------------- - class AddNode : public Node + template + class AddNode : public Node { public: - explicit AddNode(Expression *expr); + explicit AddNode(Expression *expr); ~AddNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_lhs; - Node *m_rhs; + Node *m_lhs; + Node *m_rhs; }; // Subtract node //-------------------------------------------------------------------------- - class SubtractNode : public Node + template + class SubtractNode : public Node { public: - explicit SubtractNode(Expression *expr); + explicit SubtractNode(Expression *expr); ~SubtractNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_lhs; - Node *m_rhs; + Node *m_lhs; + Node *m_rhs; }; // Multiply node //-------------------------------------------------------------------------- - class MultiplyNode : public Node + template + class MultiplyNode : public Node { public: - explicit MultiplyNode(Expression *expr); + explicit MultiplyNode(Expression *expr); ~MultiplyNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_lhs; - Node *m_rhs; + Node *m_lhs; + Node *m_rhs; }; // Divide node //-------------------------------------------------------------------------- - class DivideNode : public Node + template + class DivideNode : public Node { public: - explicit DivideNode(Expression *expr); + explicit DivideNode(Expression *expr); ~DivideNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_lhs; - Node *m_rhs; + Node *m_lhs; + Node *m_rhs; }; // Negate node //-------------------------------------------------------------------------- - class NegateNode : public Node + template + class NegateNode : public Node { public: - explicit NegateNode(Expression *expr); + explicit NegateNode(Expression *expr); ~NegateNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_rhs; + Node *m_rhs; }; // Exponent node //-------------------------------------------------------------------------- - class ExponentNode : public Node + template + class ExponentNode : public Node { public: - explicit ExponentNode(Expression *expr); + explicit ExponentNode(Expression *expr); ~ExponentNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - Node *m_lhs; - Node *m_rhs; + Node *m_lhs; + Node *m_rhs; }; // Variable node (also used for constants) //-------------------------------------------------------------------------- - class VariableNode : public Node + template + class VariableNode : public Node { public: - explicit VariableNode(Expression *expr); + explicit VariableNode(Expression *expr); ~VariableNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - double *m_var; + Value *m_var; }; // Value node //-------------------------------------------------------------------------- - class ValueNode : public Node + template + class ValueNode : public Node { public: - explicit ValueNode(Expression *expr); + explicit ValueNode(Expression *expr); ~ValueNode(); - double DoEvaluate() override; - void Parse(Parser &parser, Parser::size_type start, Parser::size_type end, - Parser::size_type v1 = 0) override; + Value DoEvaluate() override; + void Parse(Parser &parser, + typename Parser::size_type start, + typename Parser::size_type end, + typename Parser::size_type v1 = 0) override; private: - double m_val; + Value m_val; }; } // namespace ExprEval diff --git a/3rdparty/expreval/parser.cpp b/3rdparty/expreval/parser.cpp index 49023cdd..fa28a48d 100644 --- a/3rdparty/expreval/parser.cpp +++ b/3rdparty/expreval/parser.cpp @@ -9,6 +9,7 @@ #include #include +#include "autodiff/reverse/var/var.hpp" #include "defs.h" #include "parser.h" #include "node.h" @@ -16,7 +17,6 @@ #include "funclist.h" #include "expr.h" - using namespace std; using namespace ExprEval; @@ -158,19 +158,22 @@ string::size_type Token::GetEnd() const //------------------------------------------------------------------------------ // Constructor -Parser::Parser(Expression *expr) : m_expr(expr) +template +Parser::Parser(Expression* expr) : m_expr(expr) { if(expr == 0) throw(NullPointerException("Parser::Parser")); } // Destructor -Parser::~Parser() +template +Parser::~Parser() { } // Parse an expression string -Node *Parser::Parse(const string &exstr) +template +Node *Parser::Parse(const string &exstr) { BuildTokens(exstr); @@ -183,7 +186,8 @@ Node *Parser::Parse(const string &exstr) } // Parse a region of tokens -Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) +template +Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) { size_type pos; size_type fgopen = (size_type)-1; @@ -335,14 +339,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) // Multi-expression first if(multiexpr) { - aptr(Node) n(new MultiNode(m_expr)); + aptr(Node) n(new MultiNode(m_expr)); n->Parse(*this, start, end); return n.release(); } else if(assignindex != (size_type)-1) { // Assignment next - aptr(Node) n(new AssignNode(m_expr)); + aptr(Node) n(new AssignNode(m_expr)); n->Parse(*this, start, end, assignindex); return n.release(); } @@ -352,14 +356,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) if(m_tokens[addsubindex].GetType() == Token::TypePlus) { // Addition - aptr(Node) n(new AddNode(m_expr)); + aptr(Node) n(new AddNode(m_expr)); n->Parse(*this, start, end, addsubindex); return n.release(); } else { // Subtraction - aptr(Node) n(new SubtractNode(m_expr)); + aptr(Node) n(new SubtractNode(m_expr)); n->Parse(*this, start, end, addsubindex); return n.release(); } @@ -371,14 +375,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) if(m_tokens[muldivindex].GetType() == Token::TypeAsterisk) { // Multiplication - aptr(Node) n(new MultiplyNode(m_expr)); + aptr(Node) n(new MultiplyNode(m_expr)); n->Parse(*this, start, end, muldivindex); return n.release(); } else { // Division - aptr(Node) n(new DivideNode(m_expr)); + aptr(Node) n(new DivideNode(m_expr)); n->Parse(*this, start, end, muldivindex); return n.release(); } @@ -393,7 +397,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) } else { - aptr(Node) n(new NegateNode(m_expr)); + aptr(Node) n(new NegateNode(m_expr)); n->Parse(*this, start, end, posnegindex); return n.release(); } @@ -401,7 +405,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) else if(expindex != (size_type)-1) { // Exponent - aptr(Node) n(new ExponentNode(m_expr)); + aptr(Node) n(new ExponentNode(m_expr)); n->Parse(*this, start, end, expindex); return n.release(); } @@ -441,7 +445,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) if(fgclose == end) { // Find function list - FunctionList *flist = m_expr->GetFunctionList(); + FunctionList *flist = m_expr->GetFunctionList(); if(flist == 0) { @@ -456,7 +460,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) string ident = m_tokens[start].GetIdentifier(); // Create function node - aptr(FunctionNode) n(flist->Create(ident, m_expr)); + aptr(FunctionNode) n(flist->Create(ident, m_expr)); if(n.get()) { @@ -493,14 +497,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) if(m_tokens[start].GetType() == Token::TypeIdentifier) { // Variable/constant - aptr(Node) n(new VariableNode(m_expr)); + aptr(Node) n(new VariableNode(m_expr)); n->Parse(*this, start, end); return n.release(); } else { // Value - aptr(Node) n(new ValueNode(m_expr)); + aptr(Node) n(new ValueNode(m_expr)); n->Parse(*this, start, end); return n.release(); } @@ -518,13 +522,15 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end) } // Get a token -const Token &Parser::operator[] (Parser::size_type pos) const +template +const Token &Parser::operator[] (Parser::size_type pos) const { return m_tokens[pos]; } // Build tokens -void Parser::BuildTokens(const string &exstr) +template +void Parser::BuildTokens(const string &exstr) { m_tokens.clear(); @@ -774,6 +780,10 @@ void Parser::BuildTokens(const string &exstr) } } } -} - - +} + + +namespace ExprEval { +template class Parser; +template class Parser; +} diff --git a/3rdparty/expreval/parser.h b/3rdparty/expreval/parser.h index e7253821..606faf69 100644 --- a/3rdparty/expreval/parser.h +++ b/3rdparty/expreval/parser.h @@ -15,8 +15,8 @@ namespace ExprEval { // Forward declarations - class Expression; - class Node; + template class Expression; + template class Node; // Token class //-------------------------------------------------------------------------- @@ -70,24 +70,25 @@ namespace ExprEval // Parser class //-------------------------------------------------------------------------- + template class Parser { public: typedef ::std::vector::size_type size_type; - - explicit Parser(Expression *expr); + + explicit Parser(Expression *expr); ~Parser(); - - Node *Parse(const ::std::string &exstr); - Node *ParseRegion(size_type start, size_type end); + + Node *Parse(const ::std::string &exstr); + Node *ParseRegion(size_type start, size_type end); const Token &operator[] (size_type pos) const; private: void BuildTokens(const std::string &exstr); - - Expression *m_expr; + + Expression* m_expr; ::std::vector m_tokens; // Store token and not pointer to token }; } diff --git a/3rdparty/expreval/vallist.cpp b/3rdparty/expreval/vallist.cpp index 28794231..6919c384 100644 --- a/3rdparty/expreval/vallist.cpp +++ b/3rdparty/expreval/vallist.cpp @@ -8,6 +8,7 @@ #include #include +#include "autodiff/reverse/var/var.hpp" #include "defs.h" #include "vallist.h" #include "except.h" @@ -19,7 +20,8 @@ using namespace ExprEval; //------------------------------------------------------------------------------ // Constructor for internal value -ValueListItem::ValueListItem(const string &name, double def, bool constant) +template +ValueListItem::ValueListItem(const string &name, Value def, bool constant) { m_name = name; m_constant = constant; @@ -27,8 +29,9 @@ ValueListItem::ValueListItem(const string &name, double def, bool constant) m_ptr = 0; } -// Constructor for external value -ValueListItem::ValueListItem(const string &name, double *ptr, double def, bool constant) +// Constructor for external value +template +ValueListItem::ValueListItem(const string &name, Value *ptr, Value def, bool constant) { m_name = name; m_constant = constant; @@ -42,25 +45,29 @@ ValueListItem::ValueListItem(const string &name, double *ptr, double def, bool c } // Get the name -const string &ValueListItem::GetName() const +template +const string& ValueListItem::GetName() const { return m_name; } // Return if it is constant -bool ValueListItem::IsConstant() const +template +bool ValueListItem::IsConstant() const { return m_constant; } // Get value address -double *ValueListItem::GetAddress() +template +Value* ValueListItem::GetAddress() { return m_ptr ? m_ptr : &m_value; } // Reset to default value -void ValueListItem::Reset() +template +void ValueListItem::Reset() { if(m_ptr) *m_ptr = m_def; @@ -73,18 +80,21 @@ void ValueListItem::Reset() //------------------------------------------------------------------------------ // Constructor -ValueList::ValueList() +template +ValueList::ValueList() { } // Destructor -ValueList::~ValueList() +template +ValueList::~ValueList() { Clear(); } // Add value to list -void ValueList::Add(const string &name, double def, bool constant) +template +void ValueList::Add(const string &name, Value def, bool constant) { // Ensure value does not already exist if(GetAddress(name)) @@ -94,7 +104,7 @@ void ValueList::Add(const string &name, double def, bool constant) else { // Create value - aptr(ValueListItem) i(new ValueListItem(name, def ,constant)); + aptr(ValueListItem) i(new ValueListItem(name, def, constant)); // Add value to list m_values.push_back(i.get()); @@ -103,7 +113,8 @@ void ValueList::Add(const string &name, double def, bool constant) } // Add an external value to the list -void ValueList::AddAddress(const string &name, double *ptr, double def, bool constant) +template +void ValueList::AddAddress(const string &name, Value* ptr, Value def, bool constant) { if(GetAddress(name)) { @@ -116,7 +127,7 @@ void ValueList::AddAddress(const string &name, double *ptr, double def, bool con else { // Create value - aptr(ValueListItem) i(new ValueListItem(name, ptr, def, constant)); + aptr(ValueListItem) i(new ValueListItem(name, ptr, def, constant)); // Add value to list m_values.push_back(i.get()); @@ -125,7 +136,8 @@ void ValueList::AddAddress(const string &name, double *ptr, double def, bool con } // Get the address of the value, internal or external -double *ValueList::GetAddress(const string &name) const +template +Value* ValueList::GetAddress(const string &name) const { size_type pos; @@ -143,7 +155,8 @@ double *ValueList::GetAddress(const string &name) const } // Is the value a constant -bool ValueList::IsConstant(const string &name) const +template +bool ValueList::IsConstant(const string &name) const { size_type pos; @@ -159,13 +172,15 @@ bool ValueList::IsConstant(const string &name) const } // Number of values in the list -ValueList::size_type ValueList::Count() const +template +typename ValueList::size_type ValueList::Count() const { return m_values.size(); } // Get an item -void ValueList::Item(size_type pos, string *name, double *value) const +template +void ValueList::Item(size_type pos, string *name, Value* value) const { if(name) *name = m_values[pos]->GetName(); @@ -175,7 +190,8 @@ void ValueList::Item(size_type pos, string *name, double *value) const } // Add some default values -void ValueList::AddDefaultValues() +template +void ValueList::AddDefaultValues() { // Math constant 'e' Add("E", EXPREVAL_E, true); @@ -185,7 +201,8 @@ void ValueList::AddDefaultValues() } // Reset values -void ValueList::Reset() +template +void ValueList::Reset() { size_type pos; @@ -196,7 +213,8 @@ void ValueList::Reset() } // Free values -void ValueList::Clear() +template +void ValueList::Clear() { size_type pos; @@ -204,5 +222,9 @@ void ValueList::Clear() { delete m_values[pos]; } -} - +} + +namespace ExprEval { +template class ValueList; +template class ValueList; +} diff --git a/3rdparty/expreval/vallist.h b/3rdparty/expreval/vallist.h index 98d4de7e..ba6d8e62 100644 --- a/3rdparty/expreval/vallist.h +++ b/3rdparty/expreval/vallist.h @@ -16,55 +16,58 @@ namespace ExprEval { // Value list item //-------------------------------------------------------------------------- + template class ValueListItem { public: - ValueListItem(const ::std::string &name, double def = 0.0, bool constant = false); - ValueListItem(const ::std::string &name, double *ptr, double def = 0.0, bool constant = false); + ValueListItem(const ::std::string &name, Value def = 0.0, bool constant = false); + ValueListItem(const ::std::string &name, Value *ptr, Value def = 0.0, bool constant = false); const ::std::string &GetName() const; bool IsConstant() const; - double *GetAddress(); + Value *GetAddress(); void Reset(); private: ::std::string m_name; // Name of value bool m_constant; // Value is constant - double m_value; // Internal value (if ptr == 0) - double *m_ptr; // Pointer to extern value if not 0 + Value m_value; // Internal value (if ptr == 0) + Value *m_ptr; // Pointer to extern value if not 0 - double m_def; // Default value when reset + Value m_def; // Default value when reset }; // Value list //-------------------------------------------------------------------------- + template class ValueList { public: - typedef ::std::vector::size_type size_type; + using ValueVector = std::vector*>; + using size_type = typename ValueVector::size_type; ValueList(); ~ValueList(); // Add variable or constant to the list - void Add(const ::std::string &name, double def = 0.0, bool constant = false); + void Add(const ::std::string &name, Value def = 0.0, bool constant = false); // Add an external variable or constant to the list - void AddAddress(const ::std::string &name, double *ptr, double def = 0.0, bool constant = false); + void AddAddress(const ::std::string &name, Value *ptr, Value def = 0.0, bool constant = false); // Get the address of a variable or constant, internal or external - double *GetAddress(const ::std::string &name) const; + Value *GetAddress(const ::std::string &name) const; // Is the value constant bool IsConstant(const ::std::string &name) const; // Enumerate values size_type Count() const; - void Item(size_type pos, ::std::string *name = 0, double *value = 0) const; + void Item(size_type pos, ::std::string *name = 0, Value *value = 0) const; // Initialize some default values (math constants) void AddDefaultValues(); @@ -76,7 +79,7 @@ namespace ExprEval void Clear(); private: - ::std::vector m_values; + ::std::vector*> m_values; }; }; diff --git a/CMakeLists.txt b/CMakeLists.txt index 71151acd..dab64a27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ set(IFEM_INCLUDES ${PROJECT_SOURCE_DIR}/src/Utility ${PROJECT_SOURCE_DIR}/3rdparty ${PROJECT_SOURCE_DIR}/3rdparty/expreval + ${PROJECT_SOURCE_DIR}/3rdparty/autodiff ${PROJECT_BINARY_DIR} ) diff --git a/src/Utility/ExprFunctions.C b/src/Utility/ExprFunctions.C index 459dde37..87fb5e1f 100644 --- a/src/Utility/ExprFunctions.C +++ b/src/Utility/ExprFunctions.C @@ -105,9 +105,9 @@ EvalFunc::EvalFunc (const char* function, const char* x, Real eps) expr.resize(nalloc); arg.resize(nalloc); for (size_t i = 0; i < nalloc; ++i) { - expr[i] = new ExprEval::Expression; - f[i] = new ExprEval::FunctionList; - v[i] = new ExprEval::ValueList; + expr[i] = new Expression; + f[i] = new FunctionList; + v[i] = new ValueList; f[i]->AddDefaultFunctions(); v[i]->AddDefaultValues(); v[i]->Add(x,0.0,false); @@ -132,11 +132,11 @@ EvalFunc::~EvalFunc () void EvalFunc::cleanup () { - for (ExprEval::Expression* it : expr) + for (Expression* it : expr) delete it; - for (ExprEval::FunctionList* it : f) + for (FunctionList* it : f) delete it; - for (ExprEval::ValueList* it : v) + for (ValueList* it : v) delete it; delete gradient; expr.clear(); @@ -199,9 +199,9 @@ EvalFunction::EvalFunction (const char* function, Real epsX, Real epsT) expr.resize(nalloc); arg.resize(nalloc); for (size_t i = 0; i < nalloc; ++i) { - expr[i] = new ExprEval::Expression; - f[i] = new ExprEval::FunctionList; - v[i] = new ExprEval::ValueList; + expr[i] = new Expression; + f[i] = new FunctionList; + v[i] = new ValueList; f[i]->AddDefaultFunctions(); v[i]->AddDefaultValues(); v[i]->Add("x",0.0,false); @@ -242,11 +242,11 @@ EvalFunction::~EvalFunction () void EvalFunction::cleanup () { - for (ExprEval::Expression* it : expr) + for (Expression* it : expr) delete it; - for (ExprEval::FunctionList* it : f) + for (FunctionList* it : f) delete it; - for (ExprEval::ValueList* it : v) + for (ValueList* it : v) delete it; for (EvalFunction* it : gradient) delete it; @@ -378,7 +378,7 @@ Real EvalFunction::dderiv (const Vec3& X, int d1, int d2) const void EvalFunction::setParam (const std::string& name, double value) { - for (ExprEval::ValueList* v1 : v) { + for (ValueList* v1 : v) { double* address = v1->GetAddress(name); if (!address) v1->Add(name,value,false); diff --git a/src/Utility/ExprFunctions.h b/src/Utility/ExprFunctions.h index 15abbebb..8b5362c2 100644 --- a/src/Utility/ExprFunctions.h +++ b/src/Utility/ExprFunctions.h @@ -21,9 +21,9 @@ #include namespace ExprEval { - class Expression; - class FunctionList; - class ValueList; + template class Expression; + template class FunctionList; + template class ValueList; } @@ -33,9 +33,12 @@ namespace ExprEval { class EvalFunc : public ScalarFunc { - std::vector expr; //!< Roots of the expression tree - std::vector f; //!< Lists of functions - std::vector v; //!< Lists of variables and constants + using Expression = ExprEval::Expression; //!< Type alias for expression tree + using FunctionList = ExprEval::FunctionList; //!< Type alias for function list + using ValueList = ExprEval::ValueList; //!< Type alias for value list + std::vector expr; //!< Roots of the expression tree + std::vector f; //!< Lists of functions + std::vector v; //!< Lists of variables and constants std::vector arg; //!< Function argument values @@ -80,9 +83,12 @@ protected: class EvalFunction : public RealFunc { - std::vector expr; //!< Roots of the expression tree - std::vector f; //!< Lists of functions - std::vector v; //!< Lists of variables and constants + using Expression = ExprEval::Expression; //!< Type alias for expression tree + using FunctionList = ExprEval::FunctionList; //!< Type alias for function list + using ValueList = ExprEval::ValueList; //!< Type alias for value list + std::vector expr; //!< Roots of the expression tree + std::vector f; //!< Lists of functions + std::vector v; //!< Lists of variables and constants //! \brief A struct representing a spatial function argument. struct Arg