changed: move variadic broadcast into EclMpiSerializer

This commit is contained in:
Arne Morten Kvarving 2022-09-07 13:43:23 +02:00
parent 9e6574115a
commit 6e83e349d6
5 changed files with 56 additions and 66 deletions

View File

@ -34,6 +34,8 @@
#include <variant>
#include <vector>
namespace Opm {
namespace detail
{
@ -66,8 +68,6 @@ using remove_cvr_t = std::remove_const_t<std::remove_reference_t<T>>;
} // namespace detail
namespace Opm {
/*! \brief Class for (de-)serializing and broadcasting data in parallel.
*! \details If the class has a serializeOp member this is used,
* if not it is passed on to the underlying primitive serializer.
@ -418,12 +418,49 @@ public:
if (m_packSize == std::numeric_limits<size_t>::max()) {
throw std::runtime_error("Error detected in parallel serialization");
}
m_buffer.resize(m_packSize);
m_comm.broadcast(m_buffer.data(), m_packSize, root);
unpack(data);
}
}
template<typename... Args>
void broadcast(int root, Args&&... args)
{
if (m_comm.size() == 1)
return;
if (m_comm.rank() == root) {
try {
m_op = Operation::PACKSIZE;
m_packSize = 0;
variadic_call(args...);
m_position = 0;
m_buffer.resize(m_packSize);
m_op = Operation::PACK;
variadic_call(args...);
m_packSize = m_position;
m_comm.broadcast(&m_packSize, 1, 0);
m_comm.broadcast(m_buffer.data(), m_position, 0);
} catch (...) {
m_packSize = std::numeric_limits<size_t>::max();
m_comm.broadcast(&m_packSize, 1, 0);
throw;
}
} else {
m_comm.broadcast(&m_packSize, 1, 0);
if (m_packSize == std::numeric_limits<size_t>::max()) {
throw std::runtime_error("Error detected in parallel serialization");
}
m_buffer.resize(m_packSize);
m_comm.broadcast(m_buffer.data(), m_packSize, 0);
m_position = 0;
m_op = Operation::UNPACK;
variadic_call(std::forward<Args>(args)...);
}
}
//! \brief Serialize and broadcast on root process, de-serialize and append on
//! others.
//!
@ -457,6 +494,15 @@ public:
}
protected:
template<typename T, typename... Args>
void variadic_call(T& first,
Args&&... args)
{
(*this)(first);
if constexpr (sizeof...(args) > 0)
variadic_call(std::forward<Args>(args)...);
}
//! \brief Enumeration of operations.
enum class Operation {
PACKSIZE, //!< Calculating serialization buffer size

View File

@ -237,59 +237,6 @@ void unpack(std::bitset<Size>& data, std::vector<char>& buffer, int& position,
ADD_PACK_PROTOTYPES(std::string)
ADD_PACK_PROTOTYPES(time_point)
template<typename T, typename... Args>
void variadic_packsize(size_t& size, Parallel::Communication comm, T& first, Args&&... args)
{
size += packSize(first, comm);
if constexpr (sizeof...(args) > 0)
variadic_packsize(size, comm, std::forward<Args>(args)...);
}
template<typename T, typename... Args>
void variadic_pack(int& pos, std::vector<char>& buffer, Parallel::Communication comm, T& first, Args&&... args)
{
pack(first, buffer, pos, comm);
if constexpr (sizeof...(args) > 0)
variadic_pack(pos, buffer, comm, std::forward<Args>(args)...);
}
template<typename T, typename... Args>
void variadic_unpack(int& pos, std::vector<char>& buffer, Parallel::Communication comm, T& first, Args&&... args)
{
unpack(first, buffer, pos, comm);
if constexpr (sizeof...(args) > 0)
variadic_unpack(pos, buffer, comm, std::forward<Args>(args)...);
}
#if HAVE_MPI
template<typename... Args>
void broadcast(Parallel::Communication comm, int root, Args&&... args)
{
if (comm.size() == 1)
return;
size_t size = 0;
if (comm.rank() == root)
variadic_packsize(size, comm, args...);
comm.broadcast(&size, 1, root);
std::vector<char> buffer(size);
if (comm.rank() == root) {
int pos = 0;
variadic_pack(pos, buffer, comm, args...);
}
comm.broadcast(buffer.data(), size, root);
if (comm.rank() != root) {
int pos = 0;
variadic_unpack(pos, buffer, comm, std::forward<Args>(args)...);
}
}
#else
template<typename... Args>
void broadcast(Parallel::Communication, int, Args&&...)
{}
#endif
} // end namespace Mpi
} // end namespace Opm

View File

@ -42,7 +42,6 @@
namespace Opm {
void eclStateBroadcast(Parallel::Communication comm, EclipseState& eclState, Schedule& schedule,
SummaryConfig& summaryConfig,
UDQState& udqState,
@ -50,12 +49,7 @@ void eclStateBroadcast(Parallel::Communication comm, EclipseState& eclState, Sch
WellTestState& wtestState)
{
Opm::EclMpiSerializer ser(comm);
ser.broadcast(eclState);
ser.broadcast(schedule);
ser.broadcast(summaryConfig);
ser.broadcast(udqState);
ser.broadcast(actionState);
ser.broadcast(wtestState);
ser.broadcast(0, eclState, schedule, summaryConfig, udqState, actionState, wtestState);
}
template <class T>
@ -65,7 +59,6 @@ void eclBroadcast(Parallel::Communication comm, T& data)
ser.broadcast(data);
}
template void eclBroadcast<TransMult>(Parallel::Communication, TransMult&);
template void eclBroadcast<Schedule>(Parallel::Communication, Schedule&);

View File

@ -27,6 +27,7 @@
#include <opm/simulators/wells/VFPProperties.hpp>
#include <opm/simulators/utils/MPIPacker.hpp>
#include <ebos/eclmpiserializer.hh>
#include <algorithm>
#include <utility>
@ -1058,8 +1059,9 @@ namespace Opm {
// data if they are going to check the group rates in stage1
// Another similar idea is to only communicate the rates to
// process j = i + 1
Mpi::broadcast(comm, i, group_indexes, group_oil_rates,
group_gas_rates, group_water_rates, group_alq_rates);
EclMpiSerializer ser(comm);
ser.broadcast(i, group_indexes, group_oil_rates,
group_gas_rates, group_water_rates, group_alq_rates);
if (comm.rank() != i) {
for (int j=0; j<num_rates_to_sync; j++) {
group_info.updateRate(group_indexes[j],

View File

@ -26,6 +26,7 @@
#include <boost/test/unit_test.hpp>
#include <opm/simulators/utils/MPIPacker.hpp>
#include <ebos/eclmpiserializer.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <numeric>
@ -70,7 +71,8 @@ BOOST_AUTO_TEST_CASE(BroadCast)
double d1 = cc.rank() == 0 ? 7.0 : 0.0;
size_t i1 = cc.rank() == 0 ? 8 : 0;
Opm::Mpi::broadcast(cc, 0, d, i, d1, i1);
Opm::EclMpiSerializer ser(cc);
ser.broadcast(0, d, i, d1, i1);
for (size_t c = 0; c < 3; ++c) {
BOOST_CHECK_EQUAL(d[c], 1.0+c);