Rewrite communicateGroupRates() to do a single sum().

This commit is contained in:
Atgeirr Flø Rasmussen
2020-09-17 17:39:57 +02:00
parent 6a592a8c55
commit 8dffd39b0a

View File

@@ -1007,29 +1007,73 @@ namespace Opm
}
template<class Comm>
void communicateGroupRates(const Comm& comm) {
// sum over all nodes
for (auto& x : injection_group_rein_rates) {
comm.sum(x.second.data(), x.second.size());
void communicateGroupRates(const Comm& comm)
{
// Note that injection_group_vrep_rates is handled separate from
// the forAllGroupData() function, since it contains single doubles,
// not vectors.
// Create a function that calls some function
// for all the individual data items to simplify
// the further code.
auto forAllGroupData = [&](auto& func) {
for (auto& x : injection_group_rein_rates) {
func(x.second);
}
for (auto& x : production_group_reduction_rates) {
func(x.second);
}
for (auto& x : injection_group_reduction_rates) {
func(x.second);
}
for (auto& x : injection_group_reservoir_rates) {
func(x.second);
}
for (auto& x : production_group_rates) {
func(x.second);
}
for (auto& x : well_rates) {
func(x.second);
}
};
// Compute the size of the data.
std::size_t sz = 0;
auto computeSize = [&sz](const auto& v) {
sz += v.size();
};
forAllGroupData(computeSize);
sz += injection_group_vrep_rates.size();
// Make a vector and collect all data into it.
std::vector<double> data(sz);
std::size_t pos = 0;
auto collect = [&data, &pos](const auto& v) {
for (const auto& x : v) {
data[pos++] = x;
}
};
forAllGroupData(collect);
for (const auto& x : injection_group_vrep_rates) {
data[pos++] = x.second;
}
assert(pos == sz);
// Communicate it with a single sum() call.
comm.sum(data.data(), data.size());
// Distribute the summed vector to the data items.
pos = 0;
auto distribute = [&data, &pos](auto& v) {
for (auto& x : v) {
x = data[pos++];
}
};
forAllGroupData(distribute);
for (auto& x : injection_group_vrep_rates) {
x.second = comm.sum(x.second);
}
for (auto& x : production_group_reduction_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : injection_group_reduction_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : injection_group_reservoir_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : production_group_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : well_rates) {
comm.sum(x.second.data(), x.second.size());
x.second = data[pos++];
}
assert(pos == sz);
}
template<class Comm>