Addressing review comments (changed assert to throw, put chunked broadcast in private method)

This commit is contained in:
Vegard Kippe 2024-07-09 19:15:50 +02:00
parent 99d5a147b1
commit fae49636cc
3 changed files with 22 additions and 41 deletions

View File

@ -70,7 +70,7 @@ packSize(const std::string& data, Parallel::MPIComm comm)
MPI_Pack_size(1, Dune::MPITraits<std::size_t>::getType(), comm, &size);
int totalSize = size;
MPI_Pack_size(data.size(), MPI_CHAR, comm, &size);
return static_cast<std::size_t>(totalSize + size);
return totalSize + size;
}
void Packing<false,std::string>::

View File

@ -35,7 +35,9 @@ namespace Mpi {
namespace detail {
static std::size_t mpi_buffer_size(const std::size_t bufsize, const std::size_t position) {
assert (bufsize >= position);
if (bufsize < position)
throw std::invalid_argument("Buffer size should never be less than position!");
return static_cast<int>(std::min(bufsize-position,
static_cast<std::size_t>(std::numeric_limits<int>::max())));
}
@ -69,11 +71,10 @@ struct Packing<true,T>
{
// For now we do not handle the situation where a a single call to packSize/pack/unpack
// is likely to require an MPI_Pack_size value larger than intmax
assert ( n*sizeof(T) <= std::numeric_limits<int>::max() );
if (n*sizeof(T) > std::numeric_limits<int>::max())
throw std::invalid_argument("packSize will be larger than max integer - this is not supported.");
int size = 0;
MPI_Pack_size(n, Dune::MPITraits<T>::getType(), comm, &size);
assert (size >= 0);
assert (size < std::numeric_limits<int>::max() );
return static_cast<std::size_t>(size);
}

View File

@ -52,15 +52,7 @@ public:
try {
this->pack(data);
m_comm.broadcast(&m_packSize, 1, root);
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
broadcast_chunked(root);
} catch (...) {
m_packSize = std::numeric_limits<size_t>::max();
m_comm.broadcast(&m_packSize, 1, root);
@ -72,15 +64,7 @@ public:
throw std::runtime_error("Error detected in parallel serialization");
}
m_buffer.resize(m_packSize);
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
broadcast_chunked(root);
this->unpack(data);
}
}
@ -95,15 +79,7 @@ public:
try {
this->pack(std::forward<Args>(args)...);
m_comm.broadcast(&m_packSize, 1, root);
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
broadcast_chunked(root);
} catch (...) {
m_packSize = std::numeric_limits<size_t>::max();
m_comm.broadcast(&m_packSize, 1, root);
@ -115,15 +91,7 @@ public:
throw std::runtime_error("Error detected in parallel serialization");
}
m_buffer.resize(m_packSize);
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
broadcast_chunked(root);
this->unpack(std::forward<Args>(args)...);
}
}
@ -149,6 +117,18 @@ public:
}
private:
void broadcast_chunked(int root) {
const int maxChunkSize = std::numeric_limits<int>::max();
std::size_t remainingSize = m_packSize;
std::size_t pos = 0;
while (remainingSize > maxChunkSize) {
m_comm.broadcast(m_buffer.data()+pos, maxChunkSize, root);
pos += maxChunkSize;
remainingSize -= maxChunkSize;
}
m_comm.broadcast(m_buffer.data()+pos, static_cast<int>(remainingSize), root);
}
const Mpi::Packer m_packer; //!< Packer instance
Parallel::Communication m_comm; //!< Communicator to use
};