Merge pull request #2126 from joakim-hove/udq-type-cast

Udq type cast
This commit is contained in:
Joakim Hove 2020-11-24 14:28:25 +01:00 committed by GitHub
commit 9b0b03d6d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 27 deletions

View File

@ -65,6 +65,7 @@ namespace Opm {
void eval(std::size_t report_step, const WellMatcher& wm, SummaryState& st, UDQState& udq_state) const; void eval(std::size_t report_step, const WellMatcher& wm, SummaryState& st, UDQState& udq_state) const;
const UDQDefine& define(const std::string& key) const; const UDQDefine& define(const std::string& key) const;
const UDQAssign& assign(const std::string& key) const;
std::vector<UDQDefine> definitions() const; std::vector<UDQDefine> definitions() const;
std::vector<UDQDefine> definitions(UDQVarType var_type) const; std::vector<UDQDefine> definitions(UDQVarType var_type) const;
std::vector<UDQInput> input() const; std::vector<UDQInput> input() const;

View File

@ -159,6 +159,10 @@ namespace Opm {
} }
} }
const UDQAssign& UDQConfig::assign(const std::string& key) const {
return this->m_assignments.at(key);
}
const UDQDefine& UDQConfig::define(const std::string& key) const { const UDQDefine& UDQConfig::define(const std::string& key) const {
return this->m_definitions.at(key); return this->m_definitions.at(key);
} }

View File

@ -478,37 +478,59 @@ UDQScalar operator/(double lhs, const UDQScalar& rhs) {
namespace { namespace {
bool is_scalar(const UDQSet& udq_set) {
if (udq_set.var_type() == UDQVarType::SCALAR)
return true;
if (udq_set.var_type() == UDQVarType::FIELD_VAR)
return true;
return false;
}
/* /*
If one result set is scalar and the other represents a set of wells/groups, If one result set is scalar and the other represents a set of wells/groups,
the scalar result is promoted to a set of the right type. the scalar result is promoted to a set of the right type.
*/
UDQSet udq_cast(const UDQSet& lhs, const UDQSet& rhs)
{
if (lhs.size() != rhs.size()) {
if (lhs.var_type() != UDQVarType::SCALAR) {
auto msg = fmt::format("Type/size mismatch when combining UDQs {}(size={}, type={}) and {}(size={}, type={})",
lhs.name(), lhs.size(), lhs.var_type(),
rhs.name(), rhs.size(), rhs.var_type());
throw std::logic_error(msg);
}
This function is quite subconcious about FIELD / SCALAR.
*/
std::pair<UDQSet, UDQSet> udq_cast(const UDQSet& lhs, const UDQSet& rhs)
{
if (lhs.var_type() == rhs.var_type())
return std::make_pair(lhs,rhs);
if (lhs.size() == rhs.size())
return std::make_pair(lhs,rhs);
if (is_scalar(lhs)) {
if (rhs.var_type() == UDQVarType::WELL_VAR) if (rhs.var_type() == UDQVarType::WELL_VAR)
return UDQSet::wells(lhs.name(), rhs.wgnames(), lhs[0].get()); return std::make_pair(UDQSet::wells(lhs.name(), rhs.wgnames(), lhs[0].get()), rhs);
if (rhs.var_type() == UDQVarType::GROUP_VAR) if (rhs.var_type() == UDQVarType::GROUP_VAR)
return UDQSet::groups(lhs.name(), rhs.wgnames(), lhs[0].get()); return std::make_pair(UDQSet::groups(lhs.name(), rhs.wgnames(), lhs[0].get()), rhs);
}
throw std::logic_error("Don't have a clue"); if (is_scalar(rhs)) {
} else if (lhs.var_type() == UDQVarType::WELL_VAR)
return lhs; return std::make_pair(lhs, UDQSet::wells(rhs.name(), lhs.wgnames(), rhs[0].get()));
if (lhs.var_type() == UDQVarType::GROUP_VAR)
return std::make_pair(lhs, UDQSet::groups(rhs.name(), lhs.wgnames(), rhs[0].get()));
}
auto msg = fmt::format("Type/size mismatch when combining UDQs {}(size={}, type={}) and {}(size={}, type={})",
lhs.name(), lhs.size(), lhs.var_type(),
rhs.name(), rhs.size(), rhs.var_type());
throw std::logic_error(msg);
} }
} }
UDQSet operator+(const UDQSet&lhs, const UDQSet& rhs) { UDQSet operator+(const UDQSet&lhs, const UDQSet& rhs) {
UDQSet sum = udq_cast(lhs, rhs); auto [left,right] = udq_cast(lhs, rhs);
sum += rhs; left += right;
return sum; return left;
} }
UDQSet operator+(const UDQSet&lhs, double rhs) { UDQSet operator+(const UDQSet&lhs, double rhs) {
@ -524,9 +546,9 @@ UDQSet operator+(double lhs, const UDQSet& rhs) {
} }
UDQSet operator-(const UDQSet&lhs, const UDQSet& rhs) { UDQSet operator-(const UDQSet&lhs, const UDQSet& rhs) {
UDQSet diff = udq_cast(lhs, rhs); auto [left,right] = udq_cast(lhs, rhs);
diff -= rhs; left -= right;
return diff; return left;
} }
UDQSet operator-(const UDQSet&lhs, double rhs) { UDQSet operator-(const UDQSet&lhs, double rhs) {
@ -542,9 +564,9 @@ UDQSet operator-(double lhs, const UDQSet& rhs) {
} }
UDQSet operator*(const UDQSet&lhs, const UDQSet& rhs) { UDQSet operator*(const UDQSet&lhs, const UDQSet& rhs) {
UDQSet prod = udq_cast(lhs, rhs); auto [left,right] = udq_cast(lhs, rhs);
prod *= rhs; left *= right;
return prod; return left;
} }
UDQSet operator*(const UDQSet&lhs, double rhs) { UDQSet operator*(const UDQSet&lhs, double rhs) {
@ -560,9 +582,9 @@ UDQSet operator*(double lhs, const UDQSet& rhs) {
} }
UDQSet operator/(const UDQSet&lhs, const UDQSet& rhs) { UDQSet operator/(const UDQSet&lhs, const UDQSet& rhs) {
UDQSet frac = udq_cast(lhs, rhs); auto [left,right] = udq_cast(lhs, rhs);
frac /= rhs; left /= right;
return frac; return left;
} }
UDQSet operator/(const UDQSet&lhs, double rhs) { UDQSet operator/(const UDQSet&lhs, double rhs) {

View File

@ -2484,3 +2484,48 @@ TSTEP
BOOST_CHECK( !udq_state.define(def.keyword(), def.status())); BOOST_CHECK( !udq_state.define(def.keyword(), def.status()));
} }
} }
BOOST_AUTO_TEST_CASE(UDQ_TYPE_CAST) {
std::string valid = R"(
SCHEDULE
UDQ
ASSIGN FUBHPP1 100 /
/
TSTEP
10 /
UDQ
DEFINE FU_TIME TIME /
DEFINE WUDELTA WBHP '*' - FUBHPP1 /
DEFINE WU_TEST WUBHPINI '*' - (WGPR '*')/2000.0 /
/
)";
auto schedule = make_schedule(valid);
UDQState udq_state(0);
SummaryState st(std::chrono::system_clock::now());
UDQFunctionTable udqft;
UDQContext context(udqft, WellMatcher({"W1", "W2", "W3"}), st, udq_state);
st.update_well_var("W1", "WBHP", 400);
st.update_well_var("W2", "WBHP", 300);
st.update_well_var("W3", "WBHP", 200);
const auto& udq = schedule.getUDQConfig(1);
{
const auto& ass = udq.assign("FUBHPP1");
context.update_assign(1, "FUBHPP1", ass.eval());
}
const auto& def = udq.define("WUDELTA");
auto res = def.eval(context);
BOOST_CHECK_EQUAL(res["W1"].get(), 300);
BOOST_CHECK_EQUAL(res["W2"].get(), 200);
BOOST_CHECK_EQUAL(res["W3"].get(), 100);
}