Merge pull request #2792 from atgeirr/fewer-mpi-sums

Rewrite communicateGroupRates() to do a single sum().
This commit is contained in:
Markus Blatt 2020-09-22 07:51:34 +02:00 committed by GitHub
commit a4ea6e9658
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1007,29 +1007,73 @@ namespace Opm
}
template<class Comm>
void communicateGroupRates(const Comm& comm) {
// sum over all nodes
for (auto& x : injection_group_rein_rates) {
comm.sum(x.second.data(), x.second.size());
void communicateGroupRates(const Comm& comm)
{
// Note that injection_group_vrep_rates is handled separate from
// the forAllGroupData() function, since it contains single doubles,
// not vectors.
// Create a function that calls some function
// for all the individual data items to simplify
// the further code.
auto forAllGroupData = [&](auto& func) {
for (auto& x : injection_group_rein_rates) {
func(x.second);
}
for (auto& x : production_group_reduction_rates) {
func(x.second);
}
for (auto& x : injection_group_reduction_rates) {
func(x.second);
}
for (auto& x : injection_group_reservoir_rates) {
func(x.second);
}
for (auto& x : production_group_rates) {
func(x.second);
}
for (auto& x : well_rates) {
func(x.second);
}
};
// Compute the size of the data.
std::size_t sz = 0;
auto computeSize = [&sz](const auto& v) {
sz += v.size();
};
forAllGroupData(computeSize);
sz += injection_group_vrep_rates.size();
// Make a vector and collect all data into it.
std::vector<double> data(sz);
std::size_t pos = 0;
auto collect = [&data, &pos](const auto& v) {
for (const auto& x : v) {
data[pos++] = x;
}
};
forAllGroupData(collect);
for (const auto& x : injection_group_vrep_rates) {
data[pos++] = x.second;
}
assert(pos == sz);
// Communicate it with a single sum() call.
comm.sum(data.data(), data.size());
// Distribute the summed vector to the data items.
pos = 0;
auto distribute = [&data, &pos](auto& v) {
for (auto& x : v) {
x = data[pos++];
}
};
forAllGroupData(distribute);
for (auto& x : injection_group_vrep_rates) {
x.second = comm.sum(x.second);
}
for (auto& x : production_group_reduction_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : injection_group_reduction_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : injection_group_reservoir_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : production_group_rates) {
comm.sum(x.second.data(), x.second.size());
}
for (auto& x : well_rates) {
comm.sum(x.second.data(), x.second.size());
x.second = data[pos++];
}
assert(pos == sz);
}
template<class Comm>