mirror of
https://github.com/OPM/opm-simulators.git
synced 2024-07-04 11:33:06 -05:00
Add MPI Support to Saturation Function Consistency Checks
This commit adds a new public member function SatfuncConsistencyChecks<>::collectFailures(root, comm) which aggregates consistency check violations from all ranks in the MPI communication object 'comm' onto rank 'root' of 'comm'. This amounts to summing the total number of violations from all ranks and potentially resampling the failure points for reporting purposes. To this end, extract the body of function processViolation() into a general helper which performs reservoir sampling and records point IDs and which uses a call-back function to populate the check values associated to a single failed check. Re-implement the original function in terms of this helper by wrapping exportCheckValues() in a lambda function. Extract similar helpers for numPoints() and anyFailedChecks(), and add a new helper function SatfuncConsistencyChecks<>::incorporateRankViolations() which brings sampled points from an MPI rank into the 'root's internal data structures. One caveat applies here. Our current approach to collecting check failures implies that calling member function reportFailures() is safe only on the 'root' process in a parallel run. On the other hand functions anyFailedChecks() and anyFailedCriticalChecks() are safe, and guaranteed to return the same answer, on all MPI ranks. On a final note, the internal helper functions are at present mostly implemented in terms of non-owning pointers. I intend to switch to using 'std::span<>' once we enable C++20 mode.
This commit is contained in:
parent
ce7d415e4d
commit
0c71d0701c
|
@ -542,6 +542,45 @@ opm_add_test(test_parallel_region_phase_pvaverage_np4
|
|||
4
|
||||
)
|
||||
|
||||
opm_add_test(test_parallel_satfunc_consistency_checks_np2
|
||||
EXE_NAME
|
||||
test_SatfuncConsistencyChecks_parallel
|
||||
CONDITION
|
||||
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
|
||||
DRIVER_ARGS
|
||||
-n 2
|
||||
-b ${PROJECT_BINARY_DIR}
|
||||
NO_COMPILE
|
||||
PROCESSORS
|
||||
2
|
||||
)
|
||||
|
||||
opm_add_test(test_parallel_satfunc_consistency_checks_np3
|
||||
EXE_NAME
|
||||
test_SatfuncConsistencyChecks_parallel
|
||||
CONDITION
|
||||
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
|
||||
DRIVER_ARGS
|
||||
-n 3
|
||||
-b ${PROJECT_BINARY_DIR}
|
||||
NO_COMPILE
|
||||
PROCESSORS
|
||||
3
|
||||
)
|
||||
|
||||
opm_add_test(test_parallel_satfunc_consistency_checks_np4
|
||||
EXE_NAME
|
||||
test_SatfuncConsistencyChecks_parallel
|
||||
CONDITION
|
||||
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
|
||||
DRIVER_ARGS
|
||||
-n 4
|
||||
-b ${PROJECT_BINARY_DIR}
|
||||
NO_COMPILE
|
||||
PROCESSORS
|
||||
4
|
||||
)
|
||||
|
||||
opm_add_test(test_broadcast
|
||||
DEPENDS "opmsimulators"
|
||||
LIBRARIES opmsimulators ${Boost_UNIT_TEST_FRAMEWORK_LIBRARY}
|
||||
|
|
|
@ -340,6 +340,7 @@ if (HAVE_ECL_INPUT)
|
|||
list(APPEND TEST_SOURCE_FILES
|
||||
tests/test_nonnc.cpp
|
||||
tests/test_SatfuncConsistencyChecks.cpp
|
||||
tests/test_SatfuncConsistencyChecks_parallel.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -21,6 +21,10 @@
|
|||
|
||||
#include <opm/simulators/utils/satfunc/SatfuncConsistencyChecks.hpp>
|
||||
|
||||
#include <opm/simulators/utils/ParallelCommunication.hpp>
|
||||
|
||||
#include <opm/grid/common/CommunicationUtils.hpp>
|
||||
|
||||
#include <opm/material/fluidmatrixinteractions/EclEpsScalingPoints.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -141,6 +145,22 @@ checkEndpoints(const std::size_t pointID,
|
|||
});
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
collectFailures(const int root,
|
||||
const Parallel::Communication& comm)
|
||||
{
|
||||
if (comm.size() == 1) {
|
||||
// Not a parallel run. Violation structure complete without
|
||||
// exchanging additional information, so nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& violation : this->violations_) {
|
||||
this->collectFailures(root, comm, violation);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
bool Opm::SatfuncConsistencyChecks<Scalar>::anyFailedChecks() const
|
||||
{
|
||||
|
@ -194,9 +214,77 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ViolationSample::clear()
|
|||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
bool anyFailedChecks(const std::vector<std::size_t>& count)
|
||||
{
|
||||
return std::any_of(count.begin(), count.end(),
|
||||
[](const std::size_t n) { return n > 0; });
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void
|
||||
Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
collectFailures(const int root,
|
||||
const Parallel::Communication& comm,
|
||||
ViolationSample& violation)
|
||||
{
|
||||
// Count total number of violations of each check across all ranks.
|
||||
// This should be the final number emitted in reportFailures() on the
|
||||
// root process.
|
||||
auto totalCount = violation.count;
|
||||
comm.sum(totalCount.data(), violation.count.size());
|
||||
|
||||
if (! ::anyFailedChecks(totalCount)) {
|
||||
// No failed checks on any rank for this severity level.
|
||||
//
|
||||
// No additional work needed, since every rank will have zero
|
||||
// failure counts for all checks.
|
||||
return;
|
||||
}
|
||||
|
||||
// CSR-like structures for the failure counts, sampled point IDs, and
|
||||
// sampled check values from all ranks. One set of all-to-one messages
|
||||
// for each quantity. If this stage becomes a bottleneck we must devise
|
||||
// a better communication structure that reduces the number of messages.
|
||||
const auto& [rankCount, startRankCount] =
|
||||
gatherv(violation.count, comm, root);
|
||||
|
||||
const auto& [rankPointID, startRankPointID] =
|
||||
gatherv(violation.pointID, comm, root);
|
||||
|
||||
const auto& [rankCheckValues, startRankCheckValues] =
|
||||
gatherv(violation.checkValues, comm, root);
|
||||
|
||||
if (comm.rank() == root) {
|
||||
// Re-initialise this violation sample to prepare for incorporating
|
||||
// contributions from all MPI ranks--including the current rank.
|
||||
violation.clear();
|
||||
this->buildStructure(violation);
|
||||
|
||||
const auto numRanks = comm.size();
|
||||
for (auto rank = 0*numRanks; rank < numRanks; ++rank) {
|
||||
this->incorporateRankViolations
|
||||
(rankCount.data() + startRankCount[rank],
|
||||
rankPointID.data() + startRankPointID[rank],
|
||||
rankCheckValues.data() + startRankCheckValues[rank],
|
||||
violation);
|
||||
}
|
||||
}
|
||||
|
||||
// The final violation counts for reporting purposes should be the sum
|
||||
// of the per-rank counts. This ensures that all ranks give the same
|
||||
// answer to the anyFailedChecks() predicate, although the particular
|
||||
// sample points will differ across the ranks.
|
||||
violation.count.swap(totalCount);
|
||||
|
||||
// Ensure that all ranks are synchronised here before proceeding. We
|
||||
// don't want to end up in a situation where the ranks have a different
|
||||
// notion of what to send/receive.
|
||||
comm.barrier();
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
buildStructure(ViolationSample& violation)
|
||||
{
|
||||
violation.count.assign(this->battery_.size(), 0);
|
||||
|
@ -212,13 +300,13 @@ buildStructure(ViolationSample& violation)
|
|||
}
|
||||
|
||||
template <typename Scalar>
|
||||
template <typename PopulateCheckValues>
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
processViolation(const ViolationLevel level,
|
||||
processViolation(ViolationSample& violation,
|
||||
const std::size_t checkIx,
|
||||
const std::size_t pointID)
|
||||
const std::size_t pointID,
|
||||
PopulateCheckValues&& populateCheckValues)
|
||||
{
|
||||
auto& violation = this->violations_[this->index(level)];
|
||||
|
||||
const auto nViol = ++violation.count[checkIx];
|
||||
|
||||
// Special case handling for number of violations not exceeding number
|
||||
|
@ -240,12 +328,59 @@ processViolation(const ViolationLevel level,
|
|||
// reported violations. Record the pointID and the corresponding check
|
||||
// values in their appropriate locations.
|
||||
|
||||
violation.pointID[checkIx*this->numSamplePoints_ + sampleIx] = pointID;
|
||||
violation.pointID[this->violationPointIDStart(checkIx) + sampleIx] = pointID;
|
||||
|
||||
auto* exportedCheckValues = violation.checkValues.data()
|
||||
auto* const checkValues = violation.checkValues.data()
|
||||
+ this->violationValueStart(checkIx, sampleIx);
|
||||
|
||||
populateCheckValues(checkValues);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
processViolation(const ViolationLevel level,
|
||||
const std::size_t checkIx,
|
||||
const std::size_t pointID)
|
||||
{
|
||||
this->processViolation(this->violations_[this->index(level)], checkIx, pointID,
|
||||
[this, checkIx](Scalar* const exportedCheckValues)
|
||||
{
|
||||
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
incorporateRankViolations(const std::size_t* const count,
|
||||
const std::size_t* const pointID,
|
||||
const Scalar* const checkValues,
|
||||
ViolationSample& violation)
|
||||
{
|
||||
this->checkLoop([this, count, pointID, checkValues, &violation]
|
||||
(const Check* currentCheck,
|
||||
const std::size_t checkIx)
|
||||
{
|
||||
if (count[checkIx] == 0) {
|
||||
// No violations of this check on this rank. Nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
const auto* const srcPointID = pointID
|
||||
+ this->violationPointIDStart(checkIx);
|
||||
|
||||
const auto numCheckValues = currentCheck->numExportedCheckValues();
|
||||
const auto numSrcSamples = this->numPoints(count[checkIx]);
|
||||
|
||||
for (auto srcSampleIx = 0*numSrcSamples; srcSampleIx < numSrcSamples; ++srcSampleIx) {
|
||||
this->processViolation(violation, checkIx, srcPointID[srcSampleIx],
|
||||
[numCheckValues,
|
||||
srcCheckValues = checkValues + this->violationValueStart(checkIx, srcSampleIx)]
|
||||
(Scalar* const destCheckValues)
|
||||
{
|
||||
std::copy_n(srcCheckValues, numCheckValues, destCheckValues);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -471,8 +606,15 @@ Opm::SatfuncConsistencyChecks<Scalar>::
|
|||
numPoints(const ViolationSample& violation,
|
||||
const std::size_t checkIx) const
|
||||
{
|
||||
return std::min(this->numSamplePoints_,
|
||||
violation.count[checkIx]);
|
||||
return this->numPoints(violation.count[checkIx]);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
std::size_t
|
||||
Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
numPoints(const std::size_t violationCount) const
|
||||
{
|
||||
return std::min(this->numSamplePoints_, violationCount);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
|
@ -507,6 +649,14 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ensureRandomBitGeneratorIsInitialise
|
|||
this->urbg_ = std::make_unique<RandomBitGenerator>(seeds);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
std::vector<std::size_t>::size_type
|
||||
Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
violationPointIDStart(const std::size_t checkIx) const
|
||||
{
|
||||
return checkIx * this->numSamplePoints_;
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
typename std::vector<Scalar>::size_type
|
||||
Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
|
@ -522,10 +672,7 @@ bool
|
|||
Opm::SatfuncConsistencyChecks<Scalar>::
|
||||
anyFailedChecks(const ViolationLevel level) const
|
||||
{
|
||||
const auto& violation = this->violations_[this->index(level)];
|
||||
|
||||
return std::any_of(violation.count.begin(), violation.count.end(),
|
||||
[](const std::size_t n) { return n > 0; });
|
||||
return ::anyFailedChecks(this->violations_[this->index(level)].count);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#ifndef OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
|
||||
#define OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
|
||||
|
||||
#include <opm/simulators/utils/ParallelCommunication.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
@ -198,6 +200,19 @@ namespace Opm {
|
|||
void checkEndpoints(const std::size_t pointID,
|
||||
const EclEpsScalingPointsInfo<Scalar>& endPoints);
|
||||
|
||||
/// Collect consistency violations from all ranks in MPI communicator.
|
||||
///
|
||||
/// Incorporates violation counts and sampled failure points into
|
||||
/// the internal structures on each rank. Aggregate results useful
|
||||
/// for subsequent call to reportFailures() on root process.
|
||||
///
|
||||
/// \param[in] root MPI root process. This is the process onto
|
||||
/// which the counts and samples will be collected. Typically
|
||||
/// the index of the IO rank.
|
||||
///
|
||||
/// \param[in] comm MPI communication object.
|
||||
void collectFailures(int root, const Parallel::Communication& comm);
|
||||
|
||||
/// Whether or not any checks failed at the \c Standard level.
|
||||
bool anyFailedChecks() const;
|
||||
|
||||
|
@ -210,6 +225,10 @@ namespace Opm {
|
|||
/// Reports only those conditions/checks for which there is at least
|
||||
/// one violation.
|
||||
///
|
||||
/// In a parallel run it is only safe to call this function on the
|
||||
/// MPI process to which the consistency check violations were
|
||||
/// collected in a previous call to collectFailures().
|
||||
///
|
||||
/// \param[in] level Report's severity level.
|
||||
///
|
||||
/// \param[in] emitReportRecord Call-back function for outputting a
|
||||
|
@ -299,6 +318,26 @@ namespace Opm {
|
|||
/// is a common case in production runs.
|
||||
std::unique_ptr<RandomBitGenerator> urbg_{};
|
||||
|
||||
/// Collect violations of single severity level from all ranks in
|
||||
/// MPI communicator.
|
||||
///
|
||||
/// Incorporates violation counts and sampled failure points into
|
||||
/// the internal structures on each rank. Aggregate results useful
|
||||
/// for subsequent call to reportFailures().
|
||||
///
|
||||
/// \param[in] root MPI root process. This is the process/rank onto
|
||||
/// which the counts and samples will be collected. Typically
|
||||
/// the index of the IO rank.
|
||||
///
|
||||
/// \param[in] comm MPI communication object.
|
||||
///
|
||||
/// \param[in, out] violation Current rank's violation structure for
|
||||
/// a single severity level. Holds aggregate values across all
|
||||
/// ranks, including updated sample points, on return.
|
||||
void collectFailures(int root,
|
||||
const Parallel::Communication& comm,
|
||||
ViolationSample& violation);
|
||||
|
||||
/// Allocate and initialise backing storage for a single set of
|
||||
/// sampled consistency check violations.
|
||||
///
|
||||
|
@ -306,6 +345,44 @@ namespace Opm {
|
|||
/// violation sample of proper size.
|
||||
void buildStructure(ViolationSample& violation);
|
||||
|
||||
/// Internalise a single violation into internal data structures.
|
||||
///
|
||||
/// Counts the violation and uses "reservoir sampling"
|
||||
/// (https://en.wikipedia.org/wiki/Reservoir_sampling) to determine
|
||||
/// whether or not to include the specific point into the reporting
|
||||
/// sample.
|
||||
///
|
||||
/// \tparam PopulateCheckValues Call-back function type
|
||||
/// encapsulating block of code populate sequence of check values
|
||||
/// for a single, failed consistency check. Expected to be a
|
||||
/// callable type with a function call operator of the form
|
||||
/// \code
|
||||
/// void operator()(Scalar* checkValues) const
|
||||
/// \endcode
|
||||
/// in which the \c checkValues points the start of a sequence of
|
||||
/// values associated to particular check. The call-back function
|
||||
/// is expected to know how many values are in a valid sequence and
|
||||
/// to fill in exactly this many values.
|
||||
///
|
||||
/// \param[in, out] violation Current rank's violation sample at
|
||||
/// particular severity level.
|
||||
///
|
||||
/// \param[in] checkIx Numerical check index in the range
|
||||
/// [0..battery_.size()).
|
||||
///
|
||||
/// \param[in] pointID Numeric identifier for this particular set of
|
||||
/// end-points. Typically a saturation region or a cell ID.
|
||||
///
|
||||
/// \param[in] populateCheckValues Call-back function to populate a
|
||||
/// sequence of values pertaining to specified check. Typically
|
||||
/// \code Check::exportCheckValues() \endcode or a copy routine
|
||||
/// to incorporate samples from multiple MPI ranks.
|
||||
template <typename PopulateCheckValues>
|
||||
void processViolation(ViolationSample& violation,
|
||||
const std::size_t checkIx,
|
||||
const std::size_t pointID,
|
||||
PopulateCheckValues&& populateCheckValues);
|
||||
|
||||
/// Internalise a single violation into internal data structures.
|
||||
///
|
||||
/// Counts the violation and uses "reservoir sampling"
|
||||
|
@ -324,6 +401,24 @@ namespace Opm {
|
|||
const std::size_t checkIx,
|
||||
const std::size_t pointID);
|
||||
|
||||
/// Incorporate single severity level's set of violations from
|
||||
/// single MPI rank into current rank's internal data structures.
|
||||
///
|
||||
/// \param[in] count Start of sequence of failure counts for all
|
||||
/// checks from single MPI rank.
|
||||
///
|
||||
/// \param[in] pointID Start of sequence of sampled point IDs for
|
||||
/// all checks from a single MPI rank.
|
||||
///
|
||||
/// \param[in] checkValues Start of sequence of sampled check values
|
||||
/// for all checks from a single MPI rank.
|
||||
///
|
||||
/// \param[in, out] violation
|
||||
void incorporateRankViolations(const std::size_t* count,
|
||||
const std::size_t* pointID,
|
||||
const Scalar* checkValues,
|
||||
ViolationSample& violation);
|
||||
|
||||
/// Generate random index in the sample size.
|
||||
///
|
||||
/// \param[in] sampleSize Total number of violations of a particular
|
||||
|
@ -337,6 +432,16 @@ namespace Opm {
|
|||
/// initialised.
|
||||
void ensureRandomBitGeneratorIsInitialised();
|
||||
|
||||
/// Compute start offset into ViolationSample::pointID for
|
||||
/// particular check.
|
||||
///
|
||||
/// \param[in] checkIx Numerical check index in the range
|
||||
/// [0..battery_.size()).
|
||||
///
|
||||
/// \return Start offset into ViolationSample::pointID.
|
||||
std::vector<std::size_t>::size_type
|
||||
violationPointIDStart(const std::size_t checkIx) const;
|
||||
|
||||
/// Compute start offset into ViolationSample::checkValues for
|
||||
/// particular check and sample index.
|
||||
///
|
||||
|
@ -442,6 +547,17 @@ namespace Opm {
|
|||
std::size_t numPoints(const ViolationSample& violation,
|
||||
const std::size_t checkIx) const;
|
||||
|
||||
/// Compute number of sample points for a single check's violations.
|
||||
///
|
||||
/// Effectively the minimum of the number of violations of that
|
||||
/// check and the maximum number of sample points (\code
|
||||
/// this->numSamplePoints_ \endcode).
|
||||
///
|
||||
/// \param[in] violationCount Total number of check violations.
|
||||
///
|
||||
/// \return Number of active sample points.
|
||||
std::size_t numPoints(const std::size_t violationCount) const;
|
||||
|
||||
/// Whether or not any checks failed at specified severity level.
|
||||
///
|
||||
/// \param[in] level Violation severity level.
|
||||
|
|
|
@ -259,6 +259,9 @@ BOOST_AUTO_TEST_CASE(Critical_Violation)
|
|||
|
||||
checker.checkEndpoints(42, makePoints());
|
||||
|
||||
BOOST_CHECK_MESSAGE(! checker.anyFailedChecks(),
|
||||
"There must be no failed standard level checks");
|
||||
|
||||
BOOST_CHECK_MESSAGE(checker.anyFailedCriticalChecks(),
|
||||
"There must be at least one failed Critical check");
|
||||
|
||||
|
|
2127
tests/test_SatfuncConsistencyChecks_parallel.cpp
Normal file
2127
tests/test_SatfuncConsistencyChecks_parallel.cpp
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user