expreval: make all classes a template over a scalar

instance for double and autodiff::Variable<double>
This commit is contained in:
Arne Morten Kvarving 2023-09-20 07:05:23 +02:00
parent c9c13d888c
commit 327a855d1b
15 changed files with 1327 additions and 1032 deletions

View File

@ -15,7 +15,7 @@
namespace ExprEval
{
// Forward declarations
class Expression;
template<class Value> class Expression;
// Exception class
//--------------------------------------------------------------------------

View File

@ -16,57 +16,66 @@
using namespace std;
using namespace ExprEval;
#include <autodiff/reverse/var.hpp>
// Expression object
//------------------------------------------------------------------------------
// Constructor
Expression::Expression() : m_vlist(0), m_flist(0), m_expr(0)
template<class Value>
Expression<Value>::Expression() : m_vlist(0), m_flist(0), m_expr(0)
{
m_abortcount = 200000;
m_abortreset = 200000;
}
// Destructor
Expression::~Expression()
template<class Value>
Expression<Value>::~Expression()
{
// Delete expression nodes
delete m_expr;
}
// Set value list
void Expression::SetValueList(ValueList *vlist)
template<class Value>
void Expression<Value>::SetValueList(ValueList<Value>* vlist)
{
m_vlist = vlist;
}
// Get value list
ValueList *Expression::GetValueList() const
template<class Value>
ValueList<Value>* Expression<Value>::GetValueList() const
{
return m_vlist;
}
// Set function list
void Expression::SetFunctionList(FunctionList *flist)
template<class Value>
void Expression<Value>::SetFunctionList(FunctionList<Value> *flist)
{
m_flist = flist;
}
// Get function list
FunctionList *Expression::GetFunctionList() const
template<class Value>
FunctionList<Value>* Expression<Value>::GetFunctionList() const
{
return m_flist;
}
// Test for an abort
bool Expression::DoTestAbort()
template<class Value>
bool Expression<Value>::DoTestAbort()
{
// Derive a class to test abort
return false;
}
// Test for an abort
void Expression::TestAbort(bool force)
template<class Value>
void Expression<Value>::TestAbort(bool force)
{
if(force)
{
@ -99,7 +108,8 @@ void Expression::TestAbort(bool force)
}
// Set test abort count
void Expression::SetTestAbortCount(unsigned long count)
template<class Value>
void Expression<Value>::SetTestAbortCount(unsigned long count)
{
m_abortreset = count;
if(m_abortcount > count)
@ -107,28 +117,31 @@ void Expression::SetTestAbortCount(unsigned long count)
}
// Parse expression
void Expression::Parse(const string &exstr)
template<class Value>
void Expression<Value>::Parse(const string &exstr)
{
// Clear the expression if needed
if(m_expr)
Clear();
// Create parser
aptr(Parser) p(new Parser(this));
aptr(Parser<Value>) p(new Parser<Value>(this));
// Parse the expression
m_expr = p->Parse(exstr);
}
// Clear the expression
void Expression::Clear()
template<class Value>
void Expression<Value>::Clear()
{
delete m_expr;
m_expr = 0;
}
// Evaluate an expression
double Expression::Evaluate()
template<class Value>
Value Expression<Value>::Evaluate()
{
if(m_expr)
{
@ -139,3 +152,8 @@ double Expression::Evaluate()
throw(EmptyExpressionException());
}
}
namespace ExprEval {
template class Expression<double>;
template class Expression<autodiff::var>;
}

View File

@ -14,12 +14,13 @@
namespace ExprEval
{
// Forward declarations
class ValueList;
class FunctionList;
class Node;
template<class Value> class ValueList;
template<class Value> class FunctionList;
template<class Value> class Node;
// Expression class
//--------------------------------------------------------------------------
template<class Value>
class Expression
{
public:
@ -27,12 +28,12 @@ namespace ExprEval
virtual ~Expression();
// Variable list
void SetValueList(ValueList *vlist);
ValueList *GetValueList() const;
void SetValueList(ValueList<Value> *vlist);
ValueList<Value> *GetValueList() const;
// Function list
void SetFunctionList(FunctionList *flist);
FunctionList *GetFunctionList() const;
void SetFunctionList(FunctionList<Value> *flist);
FunctionList<Value> *GetFunctionList() const;
// Abort control
virtual bool DoTestAbort();
@ -46,12 +47,12 @@ namespace ExprEval
void Clear();
// Evaluate expression
double Evaluate();
Value Evaluate();
protected:
ValueList *m_vlist;
FunctionList *m_flist;
Node *m_expr;
ValueList<Value> *m_vlist;
FunctionList<Value> *m_flist;
Node<Value> *m_expr;
unsigned long m_abortcount;
unsigned long m_abortreset;
};

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,8 @@
#include "except.h"
#include "node.h"
#include <autodiff/reverse/var.hpp>
using namespace std;
using namespace ExprEval;
@ -18,19 +20,22 @@ using namespace ExprEval;
//------------------------------------------------------------------------------
// Constructor
FunctionFactory::FunctionFactory()
template<class Value>
FunctionFactory<Value>::FunctionFactory()
{
}
// Destructor
FunctionFactory::~FunctionFactory()
template<class Value>
FunctionFactory<Value>::~FunctionFactory()
{
}
// Create
FunctionNode *FunctionFactory::Create(Expression *expr)
template<class Value>
FunctionNode<Value> *FunctionFactory<Value>::Create(Expression<Value> *expr)
{
FunctionNode *n = DoCreate(expr);
FunctionNode<Value> *n = DoCreate(expr);
if(n)
n->m_factory = this;
@ -42,19 +47,22 @@ FunctionNode *FunctionFactory::Create(Expression *expr)
//------------------------------------------------------------------------------
// Constructor
FunctionList::FunctionList()
template<class Value>
FunctionList<Value>::FunctionList()
{
}
// Destructor
FunctionList::~FunctionList()
template<class Value>
FunctionList<Value>::~FunctionList()
{
// Free function factories
Clear();
}
// Add factory to list
void FunctionList::Add(FunctionFactory *factory)
template<class Value>
void FunctionList<Value>::Add(FunctionFactory<Value> *factory)
{
// Check it
if(factory == 0)
@ -73,7 +81,8 @@ void FunctionList::Add(FunctionFactory *factory)
}
// Create a node for a function
FunctionNode *FunctionList::Create(const string &name, Expression *expr)
template<class Value>
FunctionNode<Value> *FunctionList<Value>::Create(const string &name, Expression<Value> *expr)
{
// Make sure pointer exists
if(expr == 0)
@ -98,7 +107,8 @@ FunctionNode *FunctionList::Create(const string &name, Expression *expr)
// along with the default function factories
// Free function list
void FunctionList::Clear()
template<class Value>
void FunctionList<Value>::Clear()
{
size_type pos;
@ -108,6 +118,11 @@ void FunctionList::Clear()
}
}
#define INSTANCE(...) \
template class FunctionFactory<__VA_ARGS__>; \
template class FunctionList<__VA_ARGS__>;
namespace ExprEval {
INSTANCE(double)
INSTANCE(autodiff::var)
}

View File

@ -15,11 +15,12 @@
namespace ExprEval
{
// Forward declarations
class FunctionNode;
class Expression;
template<class Value> class FunctionNode;
template<class Value> class Expression;
// Function factory
//--------------------------------------------------------------------------
template<class Value>
class FunctionFactory
{
public:
@ -27,26 +28,28 @@ namespace ExprEval
virtual ~FunctionFactory();
virtual ::std::string GetName() const = 0;
virtual FunctionNode *DoCreate(Expression *expr) = 0;
virtual FunctionNode<Value> *DoCreate(Expression<Value> *expr) = 0;
FunctionNode *Create(Expression *expr);
FunctionNode<Value> *Create(Expression<Value> *expr);
};
// Function list
//--------------------------------------------------------------------------
template<class Value>
class FunctionList
{
public:
typedef ::std::vector<FunctionFactory*>::size_type size_type;
using FactoryVec = std::vector<FunctionFactory<Value>*>;
using size_type = typename FactoryVec::size_type;
FunctionList();
~FunctionList();
// Add a function factory to the list
void Add(FunctionFactory *factory);
void Add(FunctionFactory<Value> *factory);
// Create a node for a function
FunctionNode *Create(const ::std::string &name, Expression *expr);
FunctionNode<Value> *Create(const ::std::string &name, Expression<Value> *expr);
// Initialize default functions
void AddDefaultFunctions();
@ -55,7 +58,7 @@ namespace ExprEval
void Clear();
private:
::std::vector<FunctionFactory*> m_functions;
::std::vector<FunctionFactory<Value>*> m_functions;
};
}

View File

@ -17,6 +17,8 @@
#include "funclist.h"
#include "except.h"
#include <autodiff/reverse/var.hpp>
using namespace std;
using namespace ExprEval;
@ -24,19 +26,22 @@ using namespace ExprEval;
//------------------------------------------------------------------------------
// Constructor
Node::Node(Expression *expr) : m_expr(expr)
template<class Value>
Node<Value>::Node(Expression<Value>* expr) : m_expr(expr)
{
if(expr == 0)
throw(NullPointerException("Node::Node"));
}
// Destructor
Node::~Node()
template<class Value>
Node<Value>::~Node()
{
}
// Evaluate
double Node::Evaluate()
template<class Value>
Value Node<Value>::Evaluate()
{
m_expr->TestAbort();
@ -47,31 +52,30 @@ double Node::Evaluate()
//------------------------------------------------------------------------------
// Constructor
FunctionNode::FunctionNode(Expression *expr) : Node(expr),
template<class Value>
FunctionNode<Value>::FunctionNode(Expression<Value> *expr) : Node<Value>(expr),
m_argMin(0), m_argMax(0), m_refMin(0), m_refMax(0)
{
}
// Destructor
FunctionNode::~FunctionNode()
template<class Value>
FunctionNode<Value>::~FunctionNode()
{
// Delete child nodes
vector<Node*>::size_type pos;
for(pos = 0; pos < m_nodes.size(); pos++)
{
delete m_nodes[pos];
}
for (auto& node : m_nodes)
delete node;
}
// Get name
string FunctionNode::GetName() const
template<class Value>
string FunctionNode<Value>::GetName() const
{
return m_factory->GetName();
}
// Set argument count
void FunctionNode::SetArgumentCount(long argMin, long argMax, long refMin, long refMax)
template<class Value>
void FunctionNode<Value>::SetArgumentCount(long argMin, long argMax, long refMin, long refMax)
{
m_argMin = argMin;
m_argMax = argMax;
@ -80,10 +84,13 @@ void FunctionNode::SetArgumentCount(long argMin, long argMax, long refMin, long
}
// Parse expression
void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void FunctionNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
Parser::size_type pos, last;
typename Parser<Value>::size_type pos, last;
int plevel = 0;
// Sanity check (start/end are function parenthesis)
@ -135,7 +142,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
if(last == pos - 2 && parser[last + 1].GetType() == Token::TypeIdentifier)
{
// Get value list
ValueList *vlist = m_expr->GetValueList();
ValueList<Value> *vlist = this->m_expr->GetValueList();
if(vlist == 0)
{
NoValueListException e;
@ -159,7 +166,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
}
// Get address
double *vaddr = vlist->GetAddress(ident);
Value *vaddr = vlist->GetAddress(ident);
if(vaddr == 0)
{
// Try to add it and get again
@ -192,7 +199,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
else
{
// Create node
aptr(Node) n(parser.ParseRegion(last, pos - 1));
aptr(Node<Value>) n(parser.ParseRegion(last, pos - 1));
m_nodes.push_back(n.get());
n.release();
}
@ -229,7 +236,7 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
// Check argument count
if(m_argMin != -1 && m_nodes.size() < (vector<Node*>::size_type)m_argMin)
if(m_argMin != -1 && m_nodes.size() < static_cast<size_t>(m_argMin))
{
InvalidArgumentCountException e(GetName());
@ -237,8 +244,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
e.SetEnd(parser[end].GetEnd());
throw(e);
}
if(m_argMax != -1 && m_nodes.size() > (vector<Node*>::size_type)m_argMax)
if(m_argMax != -1 && m_nodes.size() > static_cast<size_t>(m_argMax))
{
InvalidArgumentCountException e(GetName());
@ -246,8 +253,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
e.SetEnd(parser[end].GetEnd());
throw(e);
}
if(m_refMin != -1 && m_refs.size() < (vector<double*>::size_type)m_refMin)
if(m_refMin != -1 && m_refs.size() < static_cast<size_t>(m_refMin))
{
InvalidArgumentCountException e(GetName());
@ -255,8 +262,8 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
e.SetEnd(parser[end].GetEnd());
throw(e);
}
if(m_refMax != -1 && m_refs.size() > (vector<double*>::size_type)m_refMax)
if(m_refMax != -1 && m_refs.size() > static_cast<size_t>(m_refMax))
{
InvalidArgumentCountException e(GetName());
@ -270,41 +277,43 @@ void FunctionNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
//------------------------------------------------------------------------------
// Constructor
MultiNode::MultiNode(Expression *expr) : Node(expr)
template<class Value>
MultiNode<Value>::MultiNode(Expression<Value> *expr) : Node<Value>(expr)
{
}
// Destructor
MultiNode::~MultiNode()
template<class Value>
MultiNode<Value>::~MultiNode()
{
// Free child nodes
vector<Node*>::size_type pos;
for(pos = 0; pos < m_nodes.size(); pos++)
for(auto& node : m_nodes)
{
delete m_nodes[pos];
delete node;
}
}
// Evaluate
double MultiNode::DoEvaluate()
template<class Value>
Value MultiNode<Value>::DoEvaluate()
{
vector<Node*>::size_type pos;
double result = 0.0;
Value result = 0.0;
for(pos = 0; pos < m_nodes.size(); pos++)
for(auto& node : m_nodes)
{
result = m_nodes[pos]->Evaluate();
result = node->Evaluate();
}
return result;
}
// Parse
void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void MultiNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
Parser::size_type pos, last;
typename Parser<Value>::size_type pos, last;
int plevel = 0;
// Sanity check
@ -346,7 +355,7 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type
if(pos > last)
{
// Everything from last to pos - 1
aptr(Node) n(parser.ParseRegion(last, pos - 1));
aptr(Node<Value>) n(parser.ParseRegion(last, pos - 1));
m_nodes.push_back(n.get());
n.release();
}
@ -382,7 +391,7 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type
// If the end was not a semicolon, test it as well
if(last < end + 1)
{
aptr(Node) n(parser.ParseRegion(last, end));
aptr(Node<Value>) n(parser.ParseRegion(last, end));
m_nodes.push_back(n.get());
n.release();
}
@ -392,26 +401,32 @@ void MultiNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type
//------------------------------------------------------------------------------
// Constructor
AssignNode::AssignNode(Expression *expr) : Node(expr), m_var(0), m_rhs(0)
template<class Value>
AssignNode<Value>::AssignNode(Expression<Value> *expr) : Node<Value>(expr), m_var(0), m_rhs(0)
{
}
// Destructor
AssignNode::~AssignNode()
template<class Value>
AssignNode<Value>::~AssignNode()
{
// Free child node
delete m_rhs;
}
// Evaluate
double AssignNode::DoEvaluate()
template<class Value>
Value AssignNode<Value>::DoEvaluate()
{
return (*m_var = m_rhs->Evaluate());
}
// Parse
void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void AssignNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 != start + 1 || v1 >= end)
@ -424,7 +439,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
}
// Get the value list
ValueList *vlist = m_expr->GetValueList();
ValueList<Value> *vlist = this->m_expr->GetValueList();
if(vlist == 0)
{
NoValueListException e;
@ -448,7 +463,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
}
// Get address
double *vaddr = vlist->GetAddress(ident);
Value *vaddr = vlist->GetAddress(ident);
if(vaddr == 0)
{
@ -467,7 +482,7 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
}
// Parse the node (will throw if it can not parse)
aptr(Node) n(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) n(parser.ParseRegion(v1 + 1, end));
// Set data
m_var = vaddr;
@ -479,12 +494,14 @@ void AssignNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
//------------------------------------------------------------------------------
// Constructor
AddNode::AddNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0)
template<class Value>
AddNode<Value>::AddNode(Expression<Value> *expr) : Node<Value>(expr), m_lhs(0), m_rhs(0)
{
}
// Destructor
AddNode::~AddNode()
template<class Value>
AddNode<Value>::~AddNode()
{
// Free child nodes
delete m_lhs;
@ -492,14 +509,18 @@ AddNode::~AddNode()
}
// Evaluate
double AddNode::DoEvaluate()
template<class Value>
Value AddNode<Value>::DoEvaluate()
{
return m_lhs->Evaluate() + m_rhs->Evaluate();
}
// Parse
void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void AddNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 <= start || v1 >= end)
@ -512,8 +533,8 @@ void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type e
}
// Parse sides
aptr(Node) left(parser.ParseRegion(start, v1 - 1));
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) left(parser.ParseRegion(start, v1 - 1));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_lhs = left.release();
m_rhs = right.release();
@ -523,12 +544,14 @@ void AddNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type e
//------------------------------------------------------------------------------
// Constructor
SubtractNode::SubtractNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0)
template<class Value>
SubtractNode<Value>::SubtractNode(Expression<Value> *expr) : Node<Value>(expr), m_lhs(0), m_rhs(0)
{
}
// Destructor
SubtractNode::~SubtractNode()
template<class Value>
SubtractNode<Value>::~SubtractNode()
{
// Free child nodes
delete m_lhs;
@ -536,14 +559,18 @@ SubtractNode::~SubtractNode()
}
// Evaluate
double SubtractNode::DoEvaluate()
template<class Value>
Value SubtractNode<Value>::DoEvaluate()
{
return m_lhs->Evaluate() - m_rhs->Evaluate();
}
// Parse
void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void SubtractNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 <= start || v1 >= end)
@ -556,8 +583,8 @@ void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
}
// Parse sides
aptr(Node) left(parser.ParseRegion(start, v1 - 1));
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) left(parser.ParseRegion(start, v1 - 1));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_lhs = left.release();
m_rhs = right.release();
@ -567,12 +594,14 @@ void SubtractNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
//------------------------------------------------------------------------------
// Constructor
MultiplyNode::MultiplyNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0)
template<class Value>
MultiplyNode<Value>::MultiplyNode(Expression<Value> *expr) : Node<Value>(expr), m_lhs(0), m_rhs(0)
{
}
// Destructor
MultiplyNode::~MultiplyNode()
template<class Value>
MultiplyNode<Value>::~MultiplyNode()
{
// Free child nodes
delete m_lhs;
@ -580,14 +609,18 @@ MultiplyNode::~MultiplyNode()
}
// Evaluate
double MultiplyNode::DoEvaluate()
template<class Value>
Value MultiplyNode<Value>::DoEvaluate()
{
return m_lhs->Evaluate() * m_rhs->Evaluate();
}
// Parse
void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void MultiplyNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 <= start || v1 >= end)
@ -600,8 +633,8 @@ void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
}
// Parse sides
aptr(Node) left(parser.ParseRegion(start, v1 - 1));
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) left(parser.ParseRegion(start, v1 - 1));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_lhs = left.release();
m_rhs = right.release();
@ -611,12 +644,14 @@ void MultiplyNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
//------------------------------------------------------------------------------
// Constructor
DivideNode::DivideNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0)
template<class Value>
DivideNode<Value>::DivideNode(Expression<Value> *expr) : Node<Value>(expr), m_lhs(0), m_rhs(0)
{
}
// Destructor
DivideNode::~DivideNode()
template<class Value>
DivideNode<Value>::~DivideNode()
{
// Free child nodes
delete m_lhs;
@ -624,11 +659,12 @@ DivideNode::~DivideNode()
}
// Evaluate
double DivideNode::DoEvaluate()
template<class Value>
Value DivideNode<Value>::DoEvaluate()
{
double r2 = m_rhs->Evaluate();
Value r2 = m_rhs->Evaluate();
if(r2 != 0.0)
if(r2 != Value(0.0))
{
return m_lhs->Evaluate() / r2;
}
@ -639,8 +675,11 @@ double DivideNode::DoEvaluate()
}
// Parse
void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void DivideNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 <= start || v1 >= end)
@ -653,8 +692,8 @@ void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
}
// Parse sides
aptr(Node) left(parser.ParseRegion(start, v1 - 1));
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) left(parser.ParseRegion(start, v1 - 1));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_lhs = left.release();
m_rhs = right.release();
@ -664,26 +703,32 @@ void DivideNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
//------------------------------------------------------------------------------
// Constructor
NegateNode::NegateNode(Expression *expr) : Node(expr), m_rhs(0)
template<class Value>
NegateNode<Value>::NegateNode(Expression<Value> *expr) : Node<Value>(expr), m_rhs(0)
{
}
// Destructor
NegateNode::~NegateNode()
template<class Value>
NegateNode<Value>::~NegateNode()
{
// Free child nodes
delete m_rhs;
}
// Evaluate
double NegateNode::DoEvaluate()
template<class Value>
Value NegateNode<Value>::DoEvaluate()
{
return -(m_rhs->Evaluate());
}
// Parse
void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void NegateNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(start != v1 || v1 >= end)
@ -696,7 +741,7 @@ void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
}
// Parse sides
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_rhs = right.release();
}
@ -705,12 +750,14 @@ void NegateNode::Parse(Parser &parser, Parser::size_type start, Parser::size_typ
//------------------------------------------------------------------------------
// Constructor
ExponentNode::ExponentNode(Expression *expr) : Node(expr), m_lhs(0), m_rhs(0)
template<class Value>
ExponentNode<Value>::ExponentNode(Expression<Value> *expr) : Node<Value>(expr), m_lhs(0), m_rhs(0)
{
}
// Destructor
ExponentNode::~ExponentNode()
template<class Value>
ExponentNode<Value>::~ExponentNode()
{
// Free child nodes
delete m_lhs;
@ -718,21 +765,27 @@ ExponentNode::~ExponentNode()
}
// Evaluate
double ExponentNode::DoEvaluate()
template<class Value>
Value ExponentNode<Value>::DoEvaluate()
{
errno = 0;
double result = pow(m_lhs->Evaluate(), m_rhs->Evaluate());
if(errno)
throw(MathException("^"));
Value result = pow(m_lhs->Evaluate(), m_rhs->Evaluate());
if constexpr (std::is_same_v<Value, double>) {
if(errno)
throw(MathException("^"));
}
return result;
}
// Parse
void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void ExponentNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(v1 <= start || v1 >= end)
@ -745,8 +798,8 @@ void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
}
// Parse sides
aptr(Node) left(parser.ParseRegion(start, v1 - 1));
aptr(Node) right(parser.ParseRegion(v1 + 1, end));
aptr(Node<Value>) left(parser.ParseRegion(start, v1 - 1));
aptr(Node<Value>) right(parser.ParseRegion(v1 + 1, end));
m_lhs = left.release();
m_rhs = right.release();
@ -756,24 +809,30 @@ void ExponentNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
//------------------------------------------------------------------------------
// Constructor
VariableNode::VariableNode(Expression *expr) : Node(expr), m_var(0)
template<class Value>
VariableNode<Value>::VariableNode(Expression<Value> *expr) : Node<Value>(expr), m_var(0)
{
}
// Destructor
VariableNode::~VariableNode()
template<class Value>
VariableNode<Value>::~VariableNode()
{
}
// Evaluate
double VariableNode::DoEvaluate()
template<class Value>
Value VariableNode<Value>::DoEvaluate()
{
return *m_var;
}
// Parse
void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void VariableNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check some basic syntax
if(start != end)
@ -786,7 +845,7 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
}
// Get value list
ValueList *vlist = m_expr->GetValueList();
ValueList<Value> *vlist = this->m_expr->GetValueList();
if(vlist == 0)
{
NoValueListException e;
@ -800,7 +859,7 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
string ident = parser[start].GetIdentifier();
// Get address
double *vaddr = vlist->GetAddress(ident);
Value *vaddr = vlist->GetAddress(ident);
if(vaddr == 0)
{
@ -826,24 +885,30 @@ void VariableNode::Parse(Parser &parser, Parser::size_type start, Parser::size_t
//------------------------------------------------------------------------------
// Constructor
ValueNode::ValueNode(Expression *expr) : Node(expr), m_val(0)
template<class Value>
ValueNode<Value>::ValueNode(Expression<Value> *expr) : Node<Value>(expr), m_val(0)
{
}
// Destructor
ValueNode::~ValueNode()
template<class Value>
ValueNode<Value>::~ValueNode()
{
}
// Evaluate
double ValueNode::DoEvaluate()
template<class Value>
Value ValueNode<Value>::DoEvaluate()
{
return m_val;
}
// Parse
void ValueNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1)
template<class Value>
void ValueNode<Value>::Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1)
{
// Check basic syntax
if(start != end)
@ -857,7 +922,23 @@ void ValueNode::Parse(Parser &parser, Parser::size_type start, Parser::size_type
// Set info
m_val = parser[start].GetValue();
}
}
#define INSTANCE(T) \
template class T<double>; \
template class T<autodiff::var>;
namespace ExprEval {
INSTANCE(AddNode)
INSTANCE(AssignNode)
INSTANCE(DivideNode)
INSTANCE(ExponentNode)
INSTANCE(FunctionNode)
INSTANCE(MultiNode)
INSTANCE(MultiplyNode)
INSTANCE(NegateNode)
INSTANCE(Node)
INSTANCE(SubtractNode)
INSTANCE(ValueNode)
INSTANCE(VariableNode)
}

View File

@ -16,44 +16,49 @@
namespace ExprEval
{
// Forward declarations
class Expression;
class FunctionFactory;
template<class Value> class Expression;
template<class Value> class FunctionFactory;
// Node class
//--------------------------------------------------------------------------
template<class Value>
class Node
{
public:
explicit Node(Expression *expr);
explicit Node(Expression<Value>* expr);
virtual ~Node();
virtual double DoEvaluate() = 0;
virtual void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) = 0;
virtual Value DoEvaluate() = 0;
virtual void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) = 0;
double Evaluate(); // Calls Expression::TestAbort, then DoEvaluate
Value Evaluate(); // Calls Expression::TestAbort, then DoEvaluate
protected:
Expression *m_expr;
Expression<Value> *m_expr;
};
// General function node class
//--------------------------------------------------------------------------
class FunctionNode : public Node
template<class Value>
class FunctionNode : public Node<Value>
{
public:
explicit FunctionNode(Expression *expr);
explicit FunctionNode(Expression<Value> *expr);
~FunctionNode();
// Parse nodes and references
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
// Function factory
FunctionFactory *m_factory;
FunctionFactory<Value> *m_factory;
// Argument count
long m_argMin;
@ -70,176 +75,206 @@ namespace ExprEval
::std::string GetName() const;
// Normal, reference, and data parameters
::std::vector<Node*> m_nodes;
::std::vector<double*> m_refs;
::std::vector<Node<Value>*> m_nodes;
::std::vector<Value*> m_refs;
friend class FunctionFactory;
friend class FunctionFactory<Value>;
};
// Mulit-expression node
//--------------------------------------------------------------------------
class MultiNode : public Node
template<class Value>
class MultiNode : public Node<Value>
{
public:
explicit MultiNode(Expression *expr);
explicit MultiNode(Expression<Value> *expr);
~MultiNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value>& parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
::std::vector<Node*> m_nodes;
::std::vector<Node<Value>*> m_nodes;
};
// Assign node
//--------------------------------------------------------------------------
class AssignNode : public Node
template<class Value>
class AssignNode : public Node<Value>
{
public:
explicit AssignNode(Expression *expr);
explicit AssignNode(Expression<Value> *expr);
~AssignNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
double *m_var;
Node *m_rhs;
Value *m_var;
Node<Value> *m_rhs;
};
// Add node
//--------------------------------------------------------------------------
class AddNode : public Node
template<class Value>
class AddNode : public Node<Value>
{
public:
explicit AddNode(Expression *expr);
explicit AddNode(Expression<Value> *expr);
~AddNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_lhs;
Node *m_rhs;
Node<Value> *m_lhs;
Node<Value> *m_rhs;
};
// Subtract node
//--------------------------------------------------------------------------
class SubtractNode : public Node
template<class Value>
class SubtractNode : public Node<Value>
{
public:
explicit SubtractNode(Expression *expr);
explicit SubtractNode(Expression<Value> *expr);
~SubtractNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_lhs;
Node *m_rhs;
Node<Value> *m_lhs;
Node<Value> *m_rhs;
};
// Multiply node
//--------------------------------------------------------------------------
class MultiplyNode : public Node
template<class Value>
class MultiplyNode : public Node<Value>
{
public:
explicit MultiplyNode(Expression *expr);
explicit MultiplyNode(Expression<Value> *expr);
~MultiplyNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_lhs;
Node *m_rhs;
Node<Value> *m_lhs;
Node<Value> *m_rhs;
};
// Divide node
//--------------------------------------------------------------------------
class DivideNode : public Node
template<class Value>
class DivideNode : public Node<Value>
{
public:
explicit DivideNode(Expression *expr);
explicit DivideNode(Expression<Value> *expr);
~DivideNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_lhs;
Node *m_rhs;
Node<Value> *m_lhs;
Node<Value> *m_rhs;
};
// Negate node
//--------------------------------------------------------------------------
class NegateNode : public Node
template<class Value>
class NegateNode : public Node<Value>
{
public:
explicit NegateNode(Expression *expr);
explicit NegateNode(Expression<Value> *expr);
~NegateNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_rhs;
Node<Value> *m_rhs;
};
// Exponent node
//--------------------------------------------------------------------------
class ExponentNode : public Node
template<class Value>
class ExponentNode : public Node<Value>
{
public:
explicit ExponentNode(Expression *expr);
explicit ExponentNode(Expression<Value> *expr);
~ExponentNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
Node *m_lhs;
Node *m_rhs;
Node<Value> *m_lhs;
Node<Value> *m_rhs;
};
// Variable node (also used for constants)
//--------------------------------------------------------------------------
class VariableNode : public Node
template<class Value>
class VariableNode : public Node<Value>
{
public:
explicit VariableNode(Expression *expr);
explicit VariableNode(Expression<Value> *expr);
~VariableNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
double *m_var;
Value *m_var;
};
// Value node
//--------------------------------------------------------------------------
class ValueNode : public Node
template<class Value>
class ValueNode : public Node<Value>
{
public:
explicit ValueNode(Expression *expr);
explicit ValueNode(Expression<Value> *expr);
~ValueNode();
double DoEvaluate() override;
void Parse(Parser &parser, Parser::size_type start, Parser::size_type end,
Parser::size_type v1 = 0) override;
Value DoEvaluate() override;
void Parse(Parser<Value> &parser,
typename Parser<Value>::size_type start,
typename Parser<Value>::size_type end,
typename Parser<Value>::size_type v1 = 0) override;
private:
double m_val;
Value m_val;
};
} // namespace ExprEval

View File

@ -9,6 +9,7 @@
#include <memory>
#include <cstdlib>
#include "autodiff/reverse/var/var.hpp"
#include "defs.h"
#include "parser.h"
#include "node.h"
@ -16,7 +17,6 @@
#include "funclist.h"
#include "expr.h"
using namespace std;
using namespace ExprEval;
@ -158,19 +158,22 @@ string::size_type Token::GetEnd() const
//------------------------------------------------------------------------------
// Constructor
Parser::Parser(Expression *expr) : m_expr(expr)
template<class Value>
Parser<Value>::Parser(Expression<Value>* expr) : m_expr(expr)
{
if(expr == 0)
throw(NullPointerException("Parser::Parser"));
}
// Destructor
Parser::~Parser()
template<class Value>
Parser<Value>::~Parser()
{
}
// Parse an expression string
Node *Parser::Parse(const string &exstr)
template<class Value>
Node<Value> *Parser<Value>::Parse(const string &exstr)
{
BuildTokens(exstr);
@ -183,7 +186,8 @@ Node *Parser::Parse(const string &exstr)
}
// Parse a region of tokens
Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
template<class Value>
Node<Value> *Parser<Value>::ParseRegion(Parser::size_type start, Parser::size_type end)
{
size_type pos;
size_type fgopen = (size_type)-1;
@ -335,14 +339,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
// Multi-expression first
if(multiexpr)
{
aptr(Node) n(new MultiNode(m_expr));
aptr(Node<Value>) n(new MultiNode<Value>(m_expr));
n->Parse(*this, start, end);
return n.release();
}
else if(assignindex != (size_type)-1)
{
// Assignment next
aptr(Node) n(new AssignNode(m_expr));
aptr(Node<Value>) n(new AssignNode<Value>(m_expr));
n->Parse(*this, start, end, assignindex);
return n.release();
}
@ -352,14 +356,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
if(m_tokens[addsubindex].GetType() == Token::TypePlus)
{
// Addition
aptr(Node) n(new AddNode(m_expr));
aptr(Node<Value>) n(new AddNode<Value>(m_expr));
n->Parse(*this, start, end, addsubindex);
return n.release();
}
else
{
// Subtraction
aptr(Node) n(new SubtractNode(m_expr));
aptr(Node<Value>) n(new SubtractNode<Value>(m_expr));
n->Parse(*this, start, end, addsubindex);
return n.release();
}
@ -371,14 +375,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
if(m_tokens[muldivindex].GetType() == Token::TypeAsterisk)
{
// Multiplication
aptr(Node) n(new MultiplyNode(m_expr));
aptr(Node<Value>) n(new MultiplyNode<Value>(m_expr));
n->Parse(*this, start, end, muldivindex);
return n.release();
}
else
{
// Division
aptr(Node) n(new DivideNode(m_expr));
aptr(Node<Value>) n(new DivideNode<Value>(m_expr));
n->Parse(*this, start, end, muldivindex);
return n.release();
}
@ -393,7 +397,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
}
else
{
aptr(Node) n(new NegateNode(m_expr));
aptr(Node<Value>) n(new NegateNode<Value>(m_expr));
n->Parse(*this, start, end, posnegindex);
return n.release();
}
@ -401,7 +405,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
else if(expindex != (size_type)-1)
{
// Exponent
aptr(Node) n(new ExponentNode(m_expr));
aptr(Node<Value>) n(new ExponentNode<Value>(m_expr));
n->Parse(*this, start, end, expindex);
return n.release();
}
@ -441,7 +445,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
if(fgclose == end)
{
// Find function list
FunctionList *flist = m_expr->GetFunctionList();
FunctionList<Value> *flist = m_expr->GetFunctionList();
if(flist == 0)
{
@ -456,7 +460,7 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
string ident = m_tokens[start].GetIdentifier();
// Create function node
aptr(FunctionNode) n(flist->Create(ident, m_expr));
aptr(FunctionNode<Value>) n(flist->Create(ident, m_expr));
if(n.get())
{
@ -493,14 +497,14 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
if(m_tokens[start].GetType() == Token::TypeIdentifier)
{
// Variable/constant
aptr(Node) n(new VariableNode(m_expr));
aptr(Node<Value>) n(new VariableNode<Value>(m_expr));
n->Parse(*this, start, end);
return n.release();
}
else
{
// Value
aptr(Node) n(new ValueNode(m_expr));
aptr(Node<Value>) n(new ValueNode<Value>(m_expr));
n->Parse(*this, start, end);
return n.release();
}
@ -518,13 +522,15 @@ Node *Parser::ParseRegion(Parser::size_type start, Parser::size_type end)
}
// Get a token
const Token &Parser::operator[] (Parser::size_type pos) const
template<class Value>
const Token &Parser<Value>::operator[] (Parser::size_type pos) const
{
return m_tokens[pos];
}
// Build tokens
void Parser::BuildTokens(const string &exstr)
template<class Value>
void Parser<Value>::BuildTokens(const string &exstr)
{
m_tokens.clear();
@ -774,6 +780,10 @@ void Parser::BuildTokens(const string &exstr)
}
}
}
}
}
namespace ExprEval {
template class Parser<double>;
template class Parser<autodiff::var>;
}

View File

@ -15,8 +15,8 @@
namespace ExprEval
{
// Forward declarations
class Expression;
class Node;
template<class Value> class Expression;
template<class Value> class Node;
// Token class
//--------------------------------------------------------------------------
@ -70,24 +70,25 @@ namespace ExprEval
// Parser class
//--------------------------------------------------------------------------
template<class Value>
class Parser
{
public:
typedef ::std::vector<Token*>::size_type size_type;
explicit Parser(Expression *expr);
explicit Parser(Expression<Value> *expr);
~Parser();
Node *Parse(const ::std::string &exstr);
Node *ParseRegion(size_type start, size_type end);
Node<Value> *Parse(const ::std::string &exstr);
Node<Value> *ParseRegion(size_type start, size_type end);
const Token &operator[] (size_type pos) const;
private:
void BuildTokens(const std::string &exstr);
Expression *m_expr;
Expression<Value>* m_expr;
::std::vector<Token> m_tokens; // Store token and not pointer to token
};
}

View File

@ -8,6 +8,7 @@
#include <new>
#include <memory>
#include "autodiff/reverse/var/var.hpp"
#include "defs.h"
#include "vallist.h"
#include "except.h"
@ -19,7 +20,8 @@ using namespace ExprEval;
//------------------------------------------------------------------------------
// Constructor for internal value
ValueListItem::ValueListItem(const string &name, double def, bool constant)
template<class Value>
ValueListItem<Value>::ValueListItem(const string &name, Value def, bool constant)
{
m_name = name;
m_constant = constant;
@ -27,8 +29,9 @@ ValueListItem::ValueListItem(const string &name, double def, bool constant)
m_ptr = 0;
}
// Constructor for external value
ValueListItem::ValueListItem(const string &name, double *ptr, double def, bool constant)
// Constructor for external value
template<class Value>
ValueListItem<Value>::ValueListItem(const string &name, Value *ptr, Value def, bool constant)
{
m_name = name;
m_constant = constant;
@ -42,25 +45,29 @@ ValueListItem::ValueListItem(const string &name, double *ptr, double def, bool c
}
// Get the name
const string &ValueListItem::GetName() const
template<class Value>
const string& ValueListItem<Value>::GetName() const
{
return m_name;
}
// Return if it is constant
bool ValueListItem::IsConstant() const
template<class Value>
bool ValueListItem<Value>::IsConstant() const
{
return m_constant;
}
// Get value address
double *ValueListItem::GetAddress()
template<class Value>
Value* ValueListItem<Value>::GetAddress()
{
return m_ptr ? m_ptr : &m_value;
}
// Reset to default value
void ValueListItem::Reset()
template<class Value>
void ValueListItem<Value>::Reset()
{
if(m_ptr)
*m_ptr = m_def;
@ -73,18 +80,21 @@ void ValueListItem::Reset()
//------------------------------------------------------------------------------
// Constructor
ValueList::ValueList()
template<class Value>
ValueList<Value>::ValueList()
{
}
// Destructor
ValueList::~ValueList()
template<class Value>
ValueList<Value>::~ValueList()
{
Clear();
}
// Add value to list
void ValueList::Add(const string &name, double def, bool constant)
template<class Value>
void ValueList<Value>::Add(const string &name, Value def, bool constant)
{
// Ensure value does not already exist
if(GetAddress(name))
@ -94,7 +104,7 @@ void ValueList::Add(const string &name, double def, bool constant)
else
{
// Create value
aptr(ValueListItem) i(new ValueListItem(name, def ,constant));
aptr(ValueListItem<Value>) i(new ValueListItem<Value>(name, def, constant));
// Add value to list
m_values.push_back(i.get());
@ -103,7 +113,8 @@ void ValueList::Add(const string &name, double def, bool constant)
}
// Add an external value to the list
void ValueList::AddAddress(const string &name, double *ptr, double def, bool constant)
template<class Value>
void ValueList<Value>::AddAddress(const string &name, Value* ptr, Value def, bool constant)
{
if(GetAddress(name))
{
@ -116,7 +127,7 @@ void ValueList::AddAddress(const string &name, double *ptr, double def, bool con
else
{
// Create value
aptr(ValueListItem) i(new ValueListItem(name, ptr, def, constant));
aptr(ValueListItem<Value>) i(new ValueListItem<Value>(name, ptr, def, constant));
// Add value to list
m_values.push_back(i.get());
@ -125,7 +136,8 @@ void ValueList::AddAddress(const string &name, double *ptr, double def, bool con
}
// Get the address of the value, internal or external
double *ValueList::GetAddress(const string &name) const
template<class Value>
Value* ValueList<Value>::GetAddress(const string &name) const
{
size_type pos;
@ -143,7 +155,8 @@ double *ValueList::GetAddress(const string &name) const
}
// Is the value a constant
bool ValueList::IsConstant(const string &name) const
template<class Value>
bool ValueList<Value>::IsConstant(const string &name) const
{
size_type pos;
@ -159,13 +172,15 @@ bool ValueList::IsConstant(const string &name) const
}
// Number of values in the list
ValueList::size_type ValueList::Count() const
template<class Value>
typename ValueList<Value>::size_type ValueList<Value>::Count() const
{
return m_values.size();
}
// Get an item
void ValueList::Item(size_type pos, string *name, double *value) const
template<class Value>
void ValueList<Value>::Item(size_type pos, string *name, Value* value) const
{
if(name)
*name = m_values[pos]->GetName();
@ -175,7 +190,8 @@ void ValueList::Item(size_type pos, string *name, double *value) const
}
// Add some default values
void ValueList::AddDefaultValues()
template<class Value>
void ValueList<Value>::AddDefaultValues()
{
// Math constant 'e'
Add("E", EXPREVAL_E, true);
@ -185,7 +201,8 @@ void ValueList::AddDefaultValues()
}
// Reset values
void ValueList::Reset()
template<class Value>
void ValueList<Value>::Reset()
{
size_type pos;
@ -196,7 +213,8 @@ void ValueList::Reset()
}
// Free values
void ValueList::Clear()
template<class Value>
void ValueList<Value>::Clear()
{
size_type pos;
@ -204,5 +222,9 @@ void ValueList::Clear()
{
delete m_values[pos];
}
}
}
namespace ExprEval {
template class ValueList<double>;
template class ValueList<autodiff::var>;
}

View File

@ -16,55 +16,58 @@ namespace ExprEval
{
// Value list item
//--------------------------------------------------------------------------
template<class Value>
class ValueListItem
{
public:
ValueListItem(const ::std::string &name, double def = 0.0, bool constant = false);
ValueListItem(const ::std::string &name, double *ptr, double def = 0.0, bool constant = false);
ValueListItem(const ::std::string &name, Value def = 0.0, bool constant = false);
ValueListItem(const ::std::string &name, Value *ptr, Value def = 0.0, bool constant = false);
const ::std::string &GetName() const;
bool IsConstant() const;
double *GetAddress();
Value *GetAddress();
void Reset();
private:
::std::string m_name; // Name of value
bool m_constant; // Value is constant
double m_value; // Internal value (if ptr == 0)
double *m_ptr; // Pointer to extern value if not 0
Value m_value; // Internal value (if ptr == 0)
Value *m_ptr; // Pointer to extern value if not 0
double m_def; // Default value when reset
Value m_def; // Default value when reset
};
// Value list
//--------------------------------------------------------------------------
template<class Value>
class ValueList
{
public:
typedef ::std::vector<ValueListItem*>::size_type size_type;
using ValueVector = std::vector<ValueListItem<Value>*>;
using size_type = typename ValueVector::size_type;
ValueList();
~ValueList();
// Add variable or constant to the list
void Add(const ::std::string &name, double def = 0.0, bool constant = false);
void Add(const ::std::string &name, Value def = 0.0, bool constant = false);
// Add an external variable or constant to the list
void AddAddress(const ::std::string &name, double *ptr, double def = 0.0, bool constant = false);
void AddAddress(const ::std::string &name, Value *ptr, Value def = 0.0, bool constant = false);
// Get the address of a variable or constant, internal or external
double *GetAddress(const ::std::string &name) const;
Value *GetAddress(const ::std::string &name) const;
// Is the value constant
bool IsConstant(const ::std::string &name) const;
// Enumerate values
size_type Count() const;
void Item(size_type pos, ::std::string *name = 0, double *value = 0) const;
void Item(size_type pos, ::std::string *name = 0, Value *value = 0) const;
// Initialize some default values (math constants)
void AddDefaultValues();
@ -76,7 +79,7 @@ namespace ExprEval
void Clear();
private:
::std::vector<ValueListItem*> m_values;
::std::vector<ValueListItem<Value>*> m_values;
};
};

View File

@ -61,6 +61,7 @@ set(IFEM_INCLUDES
${PROJECT_SOURCE_DIR}/src/Utility
${PROJECT_SOURCE_DIR}/3rdparty
${PROJECT_SOURCE_DIR}/3rdparty/expreval
${PROJECT_SOURCE_DIR}/3rdparty/autodiff
${PROJECT_BINARY_DIR}
)

View File

@ -105,9 +105,9 @@ EvalFunc::EvalFunc (const char* function, const char* x, Real eps)
expr.resize(nalloc);
arg.resize(nalloc);
for (size_t i = 0; i < nalloc; ++i) {
expr[i] = new ExprEval::Expression;
f[i] = new ExprEval::FunctionList;
v[i] = new ExprEval::ValueList;
expr[i] = new Expression;
f[i] = new FunctionList;
v[i] = new ValueList;
f[i]->AddDefaultFunctions();
v[i]->AddDefaultValues();
v[i]->Add(x,0.0,false);
@ -132,11 +132,11 @@ EvalFunc::~EvalFunc ()
void EvalFunc::cleanup ()
{
for (ExprEval::Expression* it : expr)
for (Expression* it : expr)
delete it;
for (ExprEval::FunctionList* it : f)
for (FunctionList* it : f)
delete it;
for (ExprEval::ValueList* it : v)
for (ValueList* it : v)
delete it;
delete gradient;
expr.clear();
@ -199,9 +199,9 @@ EvalFunction::EvalFunction (const char* function, Real epsX, Real epsT)
expr.resize(nalloc);
arg.resize(nalloc);
for (size_t i = 0; i < nalloc; ++i) {
expr[i] = new ExprEval::Expression;
f[i] = new ExprEval::FunctionList;
v[i] = new ExprEval::ValueList;
expr[i] = new Expression;
f[i] = new FunctionList;
v[i] = new ValueList;
f[i]->AddDefaultFunctions();
v[i]->AddDefaultValues();
v[i]->Add("x",0.0,false);
@ -242,11 +242,11 @@ EvalFunction::~EvalFunction ()
void EvalFunction::cleanup ()
{
for (ExprEval::Expression* it : expr)
for (Expression* it : expr)
delete it;
for (ExprEval::FunctionList* it : f)
for (FunctionList* it : f)
delete it;
for (ExprEval::ValueList* it : v)
for (ValueList* it : v)
delete it;
for (EvalFunction* it : gradient)
delete it;
@ -378,7 +378,7 @@ Real EvalFunction::dderiv (const Vec3& X, int d1, int d2) const
void EvalFunction::setParam (const std::string& name, double value)
{
for (ExprEval::ValueList* v1 : v) {
for (ValueList* v1 : v) {
double* address = v1->GetAddress(name);
if (!address)
v1->Add(name,value,false);

View File

@ -21,9 +21,9 @@
#include <array>
namespace ExprEval {
class Expression;
class FunctionList;
class ValueList;
template<class ArgType> class Expression;
template<class ArgType> class FunctionList;
template<class ArgType> class ValueList;
}
@ -33,9 +33,12 @@ namespace ExprEval {
class EvalFunc : public ScalarFunc
{
std::vector<ExprEval::Expression*> expr; //!< Roots of the expression tree
std::vector<ExprEval::FunctionList*> f; //!< Lists of functions
std::vector<ExprEval::ValueList*> v; //!< Lists of variables and constants
using Expression = ExprEval::Expression<Real>; //!< Type alias for expression tree
using FunctionList = ExprEval::FunctionList<Real>; //!< Type alias for function list
using ValueList = ExprEval::ValueList<Real>; //!< Type alias for value list
std::vector<Expression*> expr; //!< Roots of the expression tree
std::vector<FunctionList*> f; //!< Lists of functions
std::vector<ValueList*> v; //!< Lists of variables and constants
std::vector<Real*> arg; //!< Function argument values
@ -80,9 +83,12 @@ protected:
class EvalFunction : public RealFunc
{
std::vector<ExprEval::Expression*> expr; //!< Roots of the expression tree
std::vector<ExprEval::FunctionList*> f; //!< Lists of functions
std::vector<ExprEval::ValueList*> v; //!< Lists of variables and constants
using Expression = ExprEval::Expression<Real>; //!< Type alias for expression tree
using FunctionList = ExprEval::FunctionList<Real>; //!< Type alias for function list
using ValueList = ExprEval::ValueList<Real>; //!< Type alias for value list
std::vector<Expression*> expr; //!< Roots of the expression tree
std::vector<FunctionList*> f; //!< Lists of functions
std::vector<ValueList*> v; //!< Lists of variables and constants
//! \brief A struct representing a spatial function argument.
struct Arg