Add sign member to UDQASTNode and implement operator*

This commit is contained in:
Joakim Hove
2020-10-25 20:45:09 +01:00
parent 524e211247
commit 0994a1a2fe
2 changed files with 36 additions and 14 deletions

View File

@@ -40,7 +40,7 @@ public:
UDQASTNode();
explicit UDQASTNode(UDQTokenType type_arg);
explicit UDQASTNode(double scalar_value);
UDQASTNode(UDQTokenType type_arg, const std::variant<std::string, double>& value_arg, const UDQASTNode& arg);
UDQASTNode(UDQTokenType type_arg, const std::variant<std::string, double>& value_arg, const UDQASTNode& left_arg);
UDQASTNode(UDQTokenType type_arg, const std::variant<std::string, double>& value_arg, const UDQASTNode& left, const UDQASTNode& right);
UDQASTNode(UDQTokenType type_arg, const std::variant<std::string, double>& value_arg);
UDQASTNode(UDQTokenType type_arg, const std::variant<std::string, double>& value_arg, const std::vector<std::string>& selector);
@@ -57,6 +57,7 @@ public:
void set_right(const UDQASTNode& arg);
UDQASTNode* get_left() const;
UDQASTNode* get_right() const;
void scale(double sign_factor);
bool operator==(const UDQASTNode& data) const;
void required_summary(std::unordered_set<std::string>& summary_keys) const;
@@ -67,6 +68,7 @@ public:
serializer(var_type);
serializer(type);
serializer(value);
serializer(sign);
serializer(selector);
serializer(left);
serializer(right);
@@ -77,11 +79,15 @@ private:
void func_tokens(std::set<UDQTokenType>& tokens) const;
std::variant<std::string, double> value;
double sign = 1.0;
std::vector<std::string> selector;
std::shared_ptr<UDQASTNode> left;
std::shared_ptr<UDQASTNode> right;
};
UDQASTNode operator*(const UDQASTNode&lhs, double rhs);
UDQASTNode operator*(double lhs, const UDQASTNode& rhs);
}
#endif

View File

@@ -111,6 +111,7 @@ UDQASTNode UDQASTNode::serializeObject()
result.type = UDQTokenType::error;
result.value = "test1";
result.selector = {"test2"};
result.sign = -1;
UDQASTNode left = result;
result.left = std::make_shared<UDQASTNode>(left);
@@ -149,7 +150,7 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
if (this->selector.size() > 0) {
const std::string& well_pattern = this->selector[0];
if (well_pattern.find("*") == std::string::npos)
return UDQSet::scalar(string_value, context.get_well_var(well_pattern, string_value));
return this->sign * UDQSet::scalar(string_value, context.get_well_var(well_pattern, string_value));
else {
auto res = UDQSet::wells(string_value, wells);
int fnmatch_flags = 0;
@@ -157,14 +158,14 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
if (fnmatch(well_pattern.c_str(), well.c_str(), fnmatch_flags) == 0)
res.assign(well, context.get_well_var(well, string_value));
}
return res;
return this->sign * res;
}
} else {
auto res = UDQSet::wells(string_value, wells);
for (const auto& well : wells)
res.assign(well, context.get_well_var(well, string_value));
return res;
return this->sign * res;
}
}
@@ -180,16 +181,16 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
auto res = UDQSet::groups(string_value, groups);
for (const auto& group : groups)
res.assign(group, context.get_group_var(group, string_value));
return res;
return this->sign * res;
}
}
if (data_type == UDQVarType::FIELD_VAR)
return UDQSet::scalar(string_value, context.get(string_value));
return this->sign * UDQSet::scalar(string_value, context.get(string_value));
auto scalar = context.get(string_value);
if (scalar.has_value())
return UDQSet::scalar(string_value, scalar.value());
return this->sign * UDQSet::scalar(string_value, scalar.value());
throw std::logic_error("Should not be here: var_type: " + UDQ::typeName(data_type) + " stringvalue:" + string_value);
}
@@ -199,7 +200,7 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
const auto& string_value = std::get<std::string>( this->value );
const auto& udqft = context.function_table();
const UDQScalarFunction& func = dynamic_cast<const UDQScalarFunction&>(udqft.get(string_value));
return func.eval( this->left->eval(target_type, context) );
return this->sign * func.eval( this->left->eval(target_type, context) );
}
@@ -209,7 +210,7 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
const auto& udqft = context.function_table();
const UDQUnaryElementalFunction& func = dynamic_cast<const UDQUnaryElementalFunction&>(udqft.get(string_value));
return func.eval(func_arg);
return this->sign * func.eval(func_arg);
}
if (UDQ::binaryFunc(this->type)) {
@@ -220,7 +221,7 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
const auto& udqft = context.function_table();
const UDQBinaryFunction& func = dynamic_cast<const UDQBinaryFunction&>(udqft.get(string_value));
auto res = func.eval(left_arg, right_arg);
return res;
return this->sign * res;
}
if (this->type == UDQTokenType::number) {
@@ -228,13 +229,13 @@ UDQSet UDQASTNode::eval(UDQVarType target_type, const UDQContext& context) const
double numeric_value = std::get<double>(this->value);
switch(target_type) {
case UDQVarType::WELL_VAR:
return UDQSet::wells(dummy_name, context.wells(), numeric_value);
return this->sign * UDQSet::wells(dummy_name, context.wells(), numeric_value);
case UDQVarType::GROUP_VAR:
return UDQSet::groups(dummy_name, context.groups(), numeric_value);
return this->sign * UDQSet::groups(dummy_name, context.groups(), numeric_value);
case UDQVarType::SCALAR:
return UDQSet::scalar(dummy_name, numeric_value);
return this->sign * UDQSet::scalar(dummy_name, numeric_value);
case UDQVarType::FIELD_VAR:
return UDQSet::field(dummy_name, numeric_value);
return this->sign * UDQSet::field(dummy_name, numeric_value);
default:
throw std::invalid_argument("Unsupported target_type: " + std::to_string(static_cast<int>(target_type)));
}
@@ -290,6 +291,10 @@ void UDQASTNode::set_right(const UDQASTNode& arg) {
this->update_type(arg);
}
void UDQASTNode::scale(double sign_factor) {
this->sign *= sign_factor;
}
bool UDQASTNode::operator==(const UDQASTNode& data) const {
if ((this->left && !data.left) ||
(!this->left && data.left))
@@ -341,4 +346,15 @@ void UDQASTNode::required_summary(std::unordered_set<std::string>& summary_keys)
this->right->required_summary(summary_keys);
}
UDQASTNode operator*(const UDQASTNode&lhs, double sign_factor) {
UDQASTNode prod = lhs;
prod.scale(sign_factor);
return prod;
}
UDQASTNode operator*(double lhs, const UDQASTNode& rhs) {
return rhs * lhs;
}
}