Add MPI Support to Saturation Function Consistency Checks

This commit adds a new public member function

    SatfuncConsistencyChecks<>::collectFailures(root, comm)

which aggregates consistency check violations from all ranks in the
MPI communication object 'comm' onto rank 'root' of 'comm'.  This
amounts to summing the total number of violations from all ranks and
potentially resampling the failure points for reporting purposes.

To this end, extract the body of function processViolation() into a
general helper which performs reservoir sampling and records point
IDs and which uses a call-back function to populate the check values
associated to a single failed check.  Re-implement the original
function in terms of this helper by wrapping exportCheckValues() in
a lambda function.  Extract similar helpers for numPoints() and
anyFailedChecks(), and add a new helper function

    SatfuncConsistencyChecks<>::incorporateRankViolations()

which brings sampled points from an MPI rank into the 'root's
internal data structures.

One caveat applies here.  Our current approach to collecting check
failures implies that calling member function reportFailures() is
safe only on the 'root' process in a parallel run.  On the other
hand functions anyFailedChecks() and anyFailedCriticalChecks() are
safe, and guaranteed to return the same answer, on all MPI ranks.

On a final note, the internal helper functions are at present mostly
implemented in terms of non-owning pointers.  I intend to switch to
using 'std::span<>' once we enable C++20 mode.
This commit is contained in:
Bård Skaflestad 2024-06-25 18:57:09 +02:00
parent ce7d415e4d
commit 0c71d0701c
6 changed files with 2449 additions and 16 deletions

View File

@ -542,6 +542,45 @@ opm_add_test(test_parallel_region_phase_pvaverage_np4
4
)
opm_add_test(test_parallel_satfunc_consistency_checks_np2
EXE_NAME
test_SatfuncConsistencyChecks_parallel
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 2
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
2
)
opm_add_test(test_parallel_satfunc_consistency_checks_np3
EXE_NAME
test_SatfuncConsistencyChecks_parallel
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 3
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
3
)
opm_add_test(test_parallel_satfunc_consistency_checks_np4
EXE_NAME
test_SatfuncConsistencyChecks_parallel
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 4
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
4
)
opm_add_test(test_broadcast
DEPENDS "opmsimulators"
LIBRARIES opmsimulators ${Boost_UNIT_TEST_FRAMEWORK_LIBRARY}

View File

@ -340,6 +340,7 @@ if (HAVE_ECL_INPUT)
list(APPEND TEST_SOURCE_FILES
tests/test_nonnc.cpp
tests/test_SatfuncConsistencyChecks.cpp
tests/test_SatfuncConsistencyChecks_parallel.cpp
)
endif()

View File

@ -21,6 +21,10 @@
#include <opm/simulators/utils/satfunc/SatfuncConsistencyChecks.hpp>
#include <opm/simulators/utils/ParallelCommunication.hpp>
#include <opm/grid/common/CommunicationUtils.hpp>
#include <opm/material/fluidmatrixinteractions/EclEpsScalingPoints.hpp>
#include <algorithm>
@ -141,6 +145,22 @@ checkEndpoints(const std::size_t pointID,
});
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
collectFailures(const int root,
const Parallel::Communication& comm)
{
if (comm.size() == 1) {
// Not a parallel run. Violation structure complete without
// exchanging additional information, so nothing to do.
return;
}
for (auto& violation : this->violations_) {
this->collectFailures(root, comm, violation);
}
}
template <typename Scalar>
bool Opm::SatfuncConsistencyChecks<Scalar>::anyFailedChecks() const
{
@ -194,9 +214,77 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ViolationSample::clear()
// ---------------------------------------------------------------------------
namespace {
bool anyFailedChecks(const std::vector<std::size_t>& count)
{
return std::any_of(count.begin(), count.end(),
[](const std::size_t n) { return n > 0; });
}
}
template <typename Scalar>
void
Opm::SatfuncConsistencyChecks<Scalar>::
void Opm::SatfuncConsistencyChecks<Scalar>::
collectFailures(const int root,
const Parallel::Communication& comm,
ViolationSample& violation)
{
// Count total number of violations of each check across all ranks.
// This should be the final number emitted in reportFailures() on the
// root process.
auto totalCount = violation.count;
comm.sum(totalCount.data(), violation.count.size());
if (! ::anyFailedChecks(totalCount)) {
// No failed checks on any rank for this severity level.
//
// No additional work needed, since every rank will have zero
// failure counts for all checks.
return;
}
// CSR-like structures for the failure counts, sampled point IDs, and
// sampled check values from all ranks. One set of all-to-one messages
// for each quantity. If this stage becomes a bottleneck we must devise
// a better communication structure that reduces the number of messages.
const auto& [rankCount, startRankCount] =
gatherv(violation.count, comm, root);
const auto& [rankPointID, startRankPointID] =
gatherv(violation.pointID, comm, root);
const auto& [rankCheckValues, startRankCheckValues] =
gatherv(violation.checkValues, comm, root);
if (comm.rank() == root) {
// Re-initialise this violation sample to prepare for incorporating
// contributions from all MPI ranks--including the current rank.
violation.clear();
this->buildStructure(violation);
const auto numRanks = comm.size();
for (auto rank = 0*numRanks; rank < numRanks; ++rank) {
this->incorporateRankViolations
(rankCount.data() + startRankCount[rank],
rankPointID.data() + startRankPointID[rank],
rankCheckValues.data() + startRankCheckValues[rank],
violation);
}
}
// The final violation counts for reporting purposes should be the sum
// of the per-rank counts. This ensures that all ranks give the same
// answer to the anyFailedChecks() predicate, although the particular
// sample points will differ across the ranks.
violation.count.swap(totalCount);
// Ensure that all ranks are synchronised here before proceeding. We
// don't want to end up in a situation where the ranks have a different
// notion of what to send/receive.
comm.barrier();
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
buildStructure(ViolationSample& violation)
{
violation.count.assign(this->battery_.size(), 0);
@ -212,13 +300,13 @@ buildStructure(ViolationSample& violation)
}
template <typename Scalar>
template <typename PopulateCheckValues>
void Opm::SatfuncConsistencyChecks<Scalar>::
processViolation(const ViolationLevel level,
const std::size_t checkIx,
const std::size_t pointID)
processViolation(ViolationSample& violation,
const std::size_t checkIx,
const std::size_t pointID,
PopulateCheckValues&& populateCheckValues)
{
auto& violation = this->violations_[this->index(level)];
const auto nViol = ++violation.count[checkIx];
// Special case handling for number of violations not exceeding number
@ -240,12 +328,59 @@ processViolation(const ViolationLevel level,
// reported violations. Record the pointID and the corresponding check
// values in their appropriate locations.
violation.pointID[checkIx*this->numSamplePoints_ + sampleIx] = pointID;
violation.pointID[this->violationPointIDStart(checkIx) + sampleIx] = pointID;
auto* exportedCheckValues = violation.checkValues.data()
auto* const checkValues = violation.checkValues.data()
+ this->violationValueStart(checkIx, sampleIx);
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
populateCheckValues(checkValues);
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
processViolation(const ViolationLevel level,
const std::size_t checkIx,
const std::size_t pointID)
{
this->processViolation(this->violations_[this->index(level)], checkIx, pointID,
[this, checkIx](Scalar* const exportedCheckValues)
{
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
});
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
incorporateRankViolations(const std::size_t* const count,
const std::size_t* const pointID,
const Scalar* const checkValues,
ViolationSample& violation)
{
this->checkLoop([this, count, pointID, checkValues, &violation]
(const Check* currentCheck,
const std::size_t checkIx)
{
if (count[checkIx] == 0) {
// No violations of this check on this rank. Nothing to do.
return;
}
const auto* const srcPointID = pointID
+ this->violationPointIDStart(checkIx);
const auto numCheckValues = currentCheck->numExportedCheckValues();
const auto numSrcSamples = this->numPoints(count[checkIx]);
for (auto srcSampleIx = 0*numSrcSamples; srcSampleIx < numSrcSamples; ++srcSampleIx) {
this->processViolation(violation, checkIx, srcPointID[srcSampleIx],
[numCheckValues,
srcCheckValues = checkValues + this->violationValueStart(checkIx, srcSampleIx)]
(Scalar* const destCheckValues)
{
std::copy_n(srcCheckValues, numCheckValues, destCheckValues);
});
}
});
}
namespace {
@ -471,8 +606,15 @@ Opm::SatfuncConsistencyChecks<Scalar>::
numPoints(const ViolationSample& violation,
const std::size_t checkIx) const
{
return std::min(this->numSamplePoints_,
violation.count[checkIx]);
return this->numPoints(violation.count[checkIx]);
}
template <typename Scalar>
std::size_t
Opm::SatfuncConsistencyChecks<Scalar>::
numPoints(const std::size_t violationCount) const
{
return std::min(this->numSamplePoints_, violationCount);
}
template <typename Scalar>
@ -507,6 +649,14 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ensureRandomBitGeneratorIsInitialise
this->urbg_ = std::make_unique<RandomBitGenerator>(seeds);
}
template <typename Scalar>
std::vector<std::size_t>::size_type
Opm::SatfuncConsistencyChecks<Scalar>::
violationPointIDStart(const std::size_t checkIx) const
{
return checkIx * this->numSamplePoints_;
}
template <typename Scalar>
typename std::vector<Scalar>::size_type
Opm::SatfuncConsistencyChecks<Scalar>::
@ -522,10 +672,7 @@ bool
Opm::SatfuncConsistencyChecks<Scalar>::
anyFailedChecks(const ViolationLevel level) const
{
const auto& violation = this->violations_[this->index(level)];
return std::any_of(violation.count.begin(), violation.count.end(),
[](const std::size_t n) { return n > 0; });
return ::anyFailedChecks(this->violations_[this->index(level)].count);
}
template <typename Scalar>

View File

@ -20,6 +20,8 @@
#ifndef OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
#define OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
#include <opm/simulators/utils/ParallelCommunication.hpp>
#include <cstddef>
#include <functional>
#include <memory>
@ -198,6 +200,19 @@ namespace Opm {
void checkEndpoints(const std::size_t pointID,
const EclEpsScalingPointsInfo<Scalar>& endPoints);
/// Collect consistency violations from all ranks in MPI communicator.
///
/// Incorporates violation counts and sampled failure points into
/// the internal structures on each rank. Aggregate results useful
/// for subsequent call to reportFailures() on root process.
///
/// \param[in] root MPI root process. This is the process onto
/// which the counts and samples will be collected. Typically
/// the index of the IO rank.
///
/// \param[in] comm MPI communication object.
void collectFailures(int root, const Parallel::Communication& comm);
/// Whether or not any checks failed at the \c Standard level.
bool anyFailedChecks() const;
@ -210,6 +225,10 @@ namespace Opm {
/// Reports only those conditions/checks for which there is at least
/// one violation.
///
/// In a parallel run it is only safe to call this function on the
/// MPI process to which the consistency check violations were
/// collected in a previous call to collectFailures().
///
/// \param[in] level Report's severity level.
///
/// \param[in] emitReportRecord Call-back function for outputting a
@ -299,6 +318,26 @@ namespace Opm {
/// is a common case in production runs.
std::unique_ptr<RandomBitGenerator> urbg_{};
/// Collect violations of single severity level from all ranks in
/// MPI communicator.
///
/// Incorporates violation counts and sampled failure points into
/// the internal structures on each rank. Aggregate results useful
/// for subsequent call to reportFailures().
///
/// \param[in] root MPI root process. This is the process/rank onto
/// which the counts and samples will be collected. Typically
/// the index of the IO rank.
///
/// \param[in] comm MPI communication object.
///
/// \param[in, out] violation Current rank's violation structure for
/// a single severity level. Holds aggregate values across all
/// ranks, including updated sample points, on return.
void collectFailures(int root,
const Parallel::Communication& comm,
ViolationSample& violation);
/// Allocate and initialise backing storage for a single set of
/// sampled consistency check violations.
///
@ -306,6 +345,44 @@ namespace Opm {
/// violation sample of proper size.
void buildStructure(ViolationSample& violation);
/// Internalise a single violation into internal data structures.
///
/// Counts the violation and uses "reservoir sampling"
/// (https://en.wikipedia.org/wiki/Reservoir_sampling) to determine
/// whether or not to include the specific point into the reporting
/// sample.
///
/// \tparam PopulateCheckValues Call-back function type
/// encapsulating block of code populate sequence of check values
/// for a single, failed consistency check. Expected to be a
/// callable type with a function call operator of the form
/// \code
/// void operator()(Scalar* checkValues) const
/// \endcode
/// in which the \c checkValues points the start of a sequence of
/// values associated to particular check. The call-back function
/// is expected to know how many values are in a valid sequence and
/// to fill in exactly this many values.
///
/// \param[in, out] violation Current rank's violation sample at
/// particular severity level.
///
/// \param[in] checkIx Numerical check index in the range
/// [0..battery_.size()).
///
/// \param[in] pointID Numeric identifier for this particular set of
/// end-points. Typically a saturation region or a cell ID.
///
/// \param[in] populateCheckValues Call-back function to populate a
/// sequence of values pertaining to specified check. Typically
/// \code Check::exportCheckValues() \endcode or a copy routine
/// to incorporate samples from multiple MPI ranks.
template <typename PopulateCheckValues>
void processViolation(ViolationSample& violation,
const std::size_t checkIx,
const std::size_t pointID,
PopulateCheckValues&& populateCheckValues);
/// Internalise a single violation into internal data structures.
///
/// Counts the violation and uses "reservoir sampling"
@ -324,6 +401,24 @@ namespace Opm {
const std::size_t checkIx,
const std::size_t pointID);
/// Incorporate single severity level's set of violations from
/// single MPI rank into current rank's internal data structures.
///
/// \param[in] count Start of sequence of failure counts for all
/// checks from single MPI rank.
///
/// \param[in] pointID Start of sequence of sampled point IDs for
/// all checks from a single MPI rank.
///
/// \param[in] checkValues Start of sequence of sampled check values
/// for all checks from a single MPI rank.
///
/// \param[in, out] violation
void incorporateRankViolations(const std::size_t* count,
const std::size_t* pointID,
const Scalar* checkValues,
ViolationSample& violation);
/// Generate random index in the sample size.
///
/// \param[in] sampleSize Total number of violations of a particular
@ -337,6 +432,16 @@ namespace Opm {
/// initialised.
void ensureRandomBitGeneratorIsInitialised();
/// Compute start offset into ViolationSample::pointID for
/// particular check.
///
/// \param[in] checkIx Numerical check index in the range
/// [0..battery_.size()).
///
/// \return Start offset into ViolationSample::pointID.
std::vector<std::size_t>::size_type
violationPointIDStart(const std::size_t checkIx) const;
/// Compute start offset into ViolationSample::checkValues for
/// particular check and sample index.
///
@ -442,6 +547,17 @@ namespace Opm {
std::size_t numPoints(const ViolationSample& violation,
const std::size_t checkIx) const;
/// Compute number of sample points for a single check's violations.
///
/// Effectively the minimum of the number of violations of that
/// check and the maximum number of sample points (\code
/// this->numSamplePoints_ \endcode).
///
/// \param[in] violationCount Total number of check violations.
///
/// \return Number of active sample points.
std::size_t numPoints(const std::size_t violationCount) const;
/// Whether or not any checks failed at specified severity level.
///
/// \param[in] level Violation severity level.

View File

@ -259,6 +259,9 @@ BOOST_AUTO_TEST_CASE(Critical_Violation)
checker.checkEndpoints(42, makePoints());
BOOST_CHECK_MESSAGE(! checker.anyFailedChecks(),
"There must be no failed standard level checks");
BOOST_CHECK_MESSAGE(checker.anyFailedCriticalChecks(),
"There must be at least one failed Critical check");

File diff suppressed because it is too large Load Diff