Addressing review comments (changed assert to throw, put chunked broadcast in private method)

This commit is contained in:
Vegard Kippe 2024-07-09 19:15:50 +02:00
parent 99d5a147b1
commit fae49636cc
3 changed files with 22 additions and 41 deletions

View File

@ -70,7 +70,7 @@ packSize(const std::string& data, Parallel::MPIComm comm)
MPI_Pack_size(1, Dune::MPITraits<std::size_t>::getType(), comm, &size); MPI_Pack_size(1, Dune::MPITraits<std::size_t>::getType(), comm, &size);
int totalSize = size; int totalSize = size;
MPI_Pack_size(data.size(), MPI_CHAR, comm, &size); MPI_Pack_size(data.size(), MPI_CHAR, comm, &size);
return static_cast<std::size_t>(totalSize + size); return totalSize + size;
} }
void Packing<false,std::string>:: void Packing<false,std::string>::

View File

@ -35,7 +35,9 @@ namespace Mpi {
namespace detail { namespace detail {
static std::size_t mpi_buffer_size(const std::size_t bufsize, const std::size_t position) { static std::size_t mpi_buffer_size(const std::size_t bufsize, const std::size_t position) {
assert (bufsize >= position); if (bufsize < position)
throw std::invalid_argument("Buffer size should never be less than position!");
return static_cast<int>(std::min(bufsize-position, return static_cast<int>(std::min(bufsize-position,
static_cast<std::size_t>(std::numeric_limits<int>::max()))); static_cast<std::size_t>(std::numeric_limits<int>::max())));
} }
@ -69,11 +71,10 @@ struct Packing<true,T>
{ {
// For now we do not handle the situation where a a single call to packSize/pack/unpack // For now we do not handle the situation where a a single call to packSize/pack/unpack
// is likely to require an MPI_Pack_size value larger than intmax // is likely to require an MPI_Pack_size value larger than intmax
assert ( n*sizeof(T) <= std::numeric_limits<int>::max() ); if (n*sizeof(T) > std::numeric_limits<int>::max())
throw std::invalid_argument("packSize will be larger than max integer - this is not supported.");
int size = 0; int size = 0;
MPI_Pack_size(n, Dune::MPITraits<T>::getType(), comm, &size); MPI_Pack_size(n, Dune::MPITraits<T>::getType(), comm, &size);
assert (size >= 0);
assert (size < std::numeric_limits<int>::max() );
return static_cast<std::size_t>(size); return static_cast<std::size_t>(size);
} }

View File

@ -52,15 +52,7 @@ public:
try { try {
this->pack(data); this->pack(data);
m_comm.broadcast(&m_packSize, 1, root); m_comm.broadcast(&m_packSize, 1, root);
const int maxChunkSize = std::numeric_limits<int>::max(); broadcast_chunked(root);
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
} catch (...) { } catch (...) {
m_packSize = std::numeric_limits<size_t>::max(); m_packSize = std::numeric_limits<size_t>::max();
m_comm.broadcast(&m_packSize, 1, root); m_comm.broadcast(&m_packSize, 1, root);
@ -72,15 +64,7 @@ public:
throw std::runtime_error("Error detected in parallel serialization"); throw std::runtime_error("Error detected in parallel serialization");
} }
m_buffer.resize(m_packSize); m_buffer.resize(m_packSize);
const int maxChunkSize = std::numeric_limits<int>::max(); broadcast_chunked(root);
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
this->unpack(data); this->unpack(data);
} }
} }
@ -95,15 +79,7 @@ public:
try { try {
this->pack(std::forward<Args>(args)...); this->pack(std::forward<Args>(args)...);
m_comm.broadcast(&m_packSize, 1, root); m_comm.broadcast(&m_packSize, 1, root);
const int maxChunkSize = std::numeric_limits<int>::max(); broadcast_chunked(root);
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
} catch (...) { } catch (...) {
m_packSize = std::numeric_limits<size_t>::max(); m_packSize = std::numeric_limits<size_t>::max();
m_comm.broadcast(&m_packSize, 1, root); m_comm.broadcast(&m_packSize, 1, root);
@ -115,15 +91,7 @@ public:
throw std::runtime_error("Error detected in parallel serialization"); throw std::runtime_error("Error detected in parallel serialization");
} }
m_buffer.resize(m_packSize); m_buffer.resize(m_packSize);
const int maxChunkSize = std::numeric_limits<int>::max(); broadcast_chunked(root);
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
this->unpack(std::forward<Args>(args)...); this->unpack(std::forward<Args>(args)...);
} }
} }
@ -149,6 +117,18 @@ public:
} }
private: private:
void broadcast_chunked(int root) {
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
}
const Mpi::Packer m_packer; //!< Packer instance const Mpi::Packer m_packer; //!< Packer instance
Parallel::Communication m_comm; //!< Communicator to use Parallel::Communication m_comm; //!< Communicator to use
}; };