Merge pull request #5449 from bska/sfunc-consistency-checks-parallel

Add MPI Support to Saturation Function Consistency Checks
This commit is contained in:
Bård Skaflestad 2024-06-28 14:16:31 +02:00 committed by GitHub
commit 1e831bab80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 2454 additions and 97 deletions

View File

@ -418,44 +418,20 @@ opm_add_test(test_parallelwellinfo_mpi
4
)
opm_add_test(test_parallel_wbp_sourcevalues_np2
EXE_NAME
test_parallel_wbp_sourcevalues
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 2
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
2
)
opm_add_test(test_parallel_wbp_sourcevalues_np3
EXE_NAME
test_parallel_wbp_sourcevalues
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 3
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
3
)
opm_add_test(test_parallel_wbp_sourcevalues_np4
EXE_NAME
test_parallel_wbp_sourcevalues
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 4
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
4
)
foreach(NPROC 2 3 4)
opm_add_test(test_parallel_wbp_sourcevalues_np${NPROC}
EXE_NAME
test_parallel_wbp_sourcevalues
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n ${NPROC}
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
${NPROC}
)
endforeach()
opm_add_test(test_parallel_wbp_calculation
SOURCES
@ -497,50 +473,37 @@ opm_add_test(test_parallel_wbp_calculation_well_openconns
2
)
opm_add_test(test_parallel_region_phase_pvaverage_np2
EXE_NAME
test_region_phase_pvaverage
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 2
-b ${PROJECT_BINARY_DIR}
TEST_ARGS
--run_test=Parallel/*
NO_COMPILE
PROCESSORS
2
)
foreach(NPROC 2 3 4)
opm_add_test(test_parallel_region_phase_pvaverage_np${NPROC}
EXE_NAME
test_region_phase_pvaverage
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n ${NPROC}
-b ${PROJECT_BINARY_DIR}
TEST_ARGS
--run_test=Parallel/*
NO_COMPILE
PROCESSORS
${NPROC}
)
endforeach()
opm_add_test(test_parallel_region_phase_pvaverage_np3
EXE_NAME
test_region_phase_pvaverage
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 3
-b ${PROJECT_BINARY_DIR}
TEST_ARGS
--run_test=Parallel/*
NO_COMPILE
PROCESSORS
3
)
opm_add_test(test_parallel_region_phase_pvaverage_np4
EXE_NAME
test_region_phase_pvaverage
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n 4
-b ${PROJECT_BINARY_DIR}
TEST_ARGS
--run_test=Parallel/*
NO_COMPILE
PROCESSORS
4
)
foreach(NPROC 2 3 4)
opm_add_test(test_parallel_satfunc_consistency_checks_np${NPROC}
EXE_NAME
test_SatfuncConsistencyChecks_parallel
CONDITION
MPI_FOUND AND Boost_UNIT_TEST_FRAMEWORK_FOUND
DRIVER_ARGS
-n ${NPROC}
-b ${PROJECT_BINARY_DIR}
NO_COMPILE
PROCESSORS
${NPROC}
)
endforeach()
opm_add_test(test_broadcast
DEPENDS "opmsimulators"

View File

@ -340,6 +340,7 @@ if (HAVE_ECL_INPUT)
list(APPEND TEST_SOURCE_FILES
tests/test_nonnc.cpp
tests/test_SatfuncConsistencyChecks.cpp
tests/test_SatfuncConsistencyChecks_parallel.cpp
)
endif()

View File

@ -21,6 +21,10 @@
#include <opm/simulators/utils/satfunc/SatfuncConsistencyChecks.hpp>
#include <opm/simulators/utils/ParallelCommunication.hpp>
#include <opm/grid/common/CommunicationUtils.hpp>
#include <opm/material/fluidmatrixinteractions/EclEpsScalingPoints.hpp>
#include <algorithm>
@ -141,6 +145,22 @@ checkEndpoints(const std::size_t pointID,
});
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
collectFailures(const int root,
const Parallel::Communication& comm)
{
if (comm.size() == 1) {
// Not a parallel run. Violation structure complete without
// exchanging additional information, so nothing to do.
return;
}
for (auto& violation : this->violations_) {
this->collectFailures(root, comm, violation);
}
}
template <typename Scalar>
bool Opm::SatfuncConsistencyChecks<Scalar>::anyFailedChecks() const
{
@ -194,9 +214,77 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ViolationSample::clear()
// ---------------------------------------------------------------------------
namespace {
bool anyFailedChecks(const std::vector<std::size_t>& count)
{
return std::any_of(count.begin(), count.end(),
[](const std::size_t n) { return n > 0; });
}
}
template <typename Scalar>
void
Opm::SatfuncConsistencyChecks<Scalar>::
void Opm::SatfuncConsistencyChecks<Scalar>::
collectFailures(const int root,
const Parallel::Communication& comm,
ViolationSample& violation)
{
// Count total number of violations of each check across all ranks.
// This should be the final number emitted in reportFailures() on the
// root process.
auto totalCount = violation.count;
comm.sum(totalCount.data(), violation.count.size());
if (! ::anyFailedChecks(totalCount)) {
// No failed checks on any rank for this severity level.
//
// No additional work needed, since every rank will have zero
// failure counts for all checks.
return;
}
// CSR-like structures for the failure counts, sampled point IDs, and
// sampled check values from all ranks. One set of all-to-one messages
// for each quantity. If this stage becomes a bottleneck we must devise
// a better communication structure that reduces the number of messages.
const auto& [rankCount, startRankCount] =
gatherv(violation.count, comm, root);
const auto& [rankPointID, startRankPointID] =
gatherv(violation.pointID, comm, root);
const auto& [rankCheckValues, startRankCheckValues] =
gatherv(violation.checkValues, comm, root);
if (comm.rank() == root) {
// Re-initialise this violation sample to prepare for incorporating
// contributions from all MPI ranks--including the current rank.
violation.clear();
this->buildStructure(violation);
const auto numRanks = comm.size();
for (auto rank = 0*numRanks; rank < numRanks; ++rank) {
this->incorporateRankViolations
(rankCount.data() + startRankCount[rank],
rankPointID.data() + startRankPointID[rank],
rankCheckValues.data() + startRankCheckValues[rank],
violation);
}
}
// The final violation counts for reporting purposes should be the sum
// of the per-rank counts. This ensures that all ranks give the same
// answer to the anyFailedChecks() predicate, although the particular
// sample points will differ across the ranks.
violation.count.swap(totalCount);
// Ensure that all ranks are synchronised here before proceeding. We
// don't want to end up in a situation where the ranks have a different
// notion of what to send/receive.
comm.barrier();
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
buildStructure(ViolationSample& violation)
{
violation.count.assign(this->battery_.size(), 0);
@ -212,13 +300,13 @@ buildStructure(ViolationSample& violation)
}
template <typename Scalar>
template <typename PopulateCheckValues>
void Opm::SatfuncConsistencyChecks<Scalar>::
processViolation(const ViolationLevel level,
const std::size_t checkIx,
const std::size_t pointID)
processViolation(ViolationSample& violation,
const std::size_t checkIx,
const std::size_t pointID,
PopulateCheckValues&& populateCheckValues)
{
auto& violation = this->violations_[this->index(level)];
const auto nViol = ++violation.count[checkIx];
// Special case handling for number of violations not exceeding number
@ -240,12 +328,59 @@ processViolation(const ViolationLevel level,
// reported violations. Record the pointID and the corresponding check
// values in their appropriate locations.
violation.pointID[checkIx*this->numSamplePoints_ + sampleIx] = pointID;
violation.pointID[this->violationPointIDStart(checkIx) + sampleIx] = pointID;
auto* exportedCheckValues = violation.checkValues.data()
auto* const checkValues = violation.checkValues.data()
+ this->violationValueStart(checkIx, sampleIx);
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
populateCheckValues(checkValues);
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
processViolation(const ViolationLevel level,
const std::size_t checkIx,
const std::size_t pointID)
{
this->processViolation(this->violations_[this->index(level)], checkIx, pointID,
[this, checkIx](Scalar* const exportedCheckValues)
{
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
});
}
template <typename Scalar>
void Opm::SatfuncConsistencyChecks<Scalar>::
incorporateRankViolations(const std::size_t* const count,
const std::size_t* const pointID,
const Scalar* const checkValues,
ViolationSample& violation)
{
this->checkLoop([this, count, pointID, checkValues, &violation]
(const Check* currentCheck,
const std::size_t checkIx)
{
if (count[checkIx] == 0) {
// No violations of this check on this rank. Nothing to do.
return;
}
const auto* const srcPointID = pointID
+ this->violationPointIDStart(checkIx);
const auto numCheckValues = currentCheck->numExportedCheckValues();
const auto numSrcSamples = this->numPoints(count[checkIx]);
for (auto srcSampleIx = 0*numSrcSamples; srcSampleIx < numSrcSamples; ++srcSampleIx) {
this->processViolation(violation, checkIx, srcPointID[srcSampleIx],
[numCheckValues,
srcCheckValues = checkValues + this->violationValueStart(checkIx, srcSampleIx)]
(Scalar* const destCheckValues)
{
std::copy_n(srcCheckValues, numCheckValues, destCheckValues);
});
}
});
}
namespace {
@ -471,8 +606,15 @@ Opm::SatfuncConsistencyChecks<Scalar>::
numPoints(const ViolationSample& violation,
const std::size_t checkIx) const
{
return std::min(this->numSamplePoints_,
violation.count[checkIx]);
return this->numPoints(violation.count[checkIx]);
}
template <typename Scalar>
std::size_t
Opm::SatfuncConsistencyChecks<Scalar>::
numPoints(const std::size_t violationCount) const
{
return std::min(this->numSamplePoints_, violationCount);
}
template <typename Scalar>
@ -507,6 +649,14 @@ void Opm::SatfuncConsistencyChecks<Scalar>::ensureRandomBitGeneratorIsInitialise
this->urbg_ = std::make_unique<RandomBitGenerator>(seeds);
}
template <typename Scalar>
std::vector<std::size_t>::size_type
Opm::SatfuncConsistencyChecks<Scalar>::
violationPointIDStart(const std::size_t checkIx) const
{
return checkIx * this->numSamplePoints_;
}
template <typename Scalar>
typename std::vector<Scalar>::size_type
Opm::SatfuncConsistencyChecks<Scalar>::
@ -522,10 +672,7 @@ bool
Opm::SatfuncConsistencyChecks<Scalar>::
anyFailedChecks(const ViolationLevel level) const
{
const auto& violation = this->violations_[this->index(level)];
return std::any_of(violation.count.begin(), violation.count.end(),
[](const std::size_t n) { return n > 0; });
return ::anyFailedChecks(this->violations_[this->index(level)].count);
}
template <typename Scalar>

View File

@ -20,6 +20,8 @@
#ifndef OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
#define OPM_SATFUNC_CONSISTENCY_CHECK_MODULE_HPP
#include <opm/simulators/utils/ParallelCommunication.hpp>
#include <cstddef>
#include <functional>
#include <memory>
@ -198,6 +200,19 @@ namespace Opm {
void checkEndpoints(const std::size_t pointID,
const EclEpsScalingPointsInfo<Scalar>& endPoints);
/// Collect consistency violations from all ranks in MPI communicator.
///
/// Incorporates violation counts and sampled failure points into
/// the internal structures on each rank. Aggregate results useful
/// for subsequent call to reportFailures() on root process.
///
/// \param[in] root MPI root process. This is the process onto
/// which the counts and samples will be collected. Typically
/// the index of the IO rank.
///
/// \param[in] comm MPI communication object.
void collectFailures(int root, const Parallel::Communication& comm);
/// Whether or not any checks failed at the \c Standard level.
bool anyFailedChecks() const;
@ -210,6 +225,10 @@ namespace Opm {
/// Reports only those conditions/checks for which there is at least
/// one violation.
///
/// In a parallel run it is only safe to call this function on the
/// MPI process to which the consistency check violations were
/// collected in a previous call to collectFailures().
///
/// \param[in] level Report's severity level.
///
/// \param[in] emitReportRecord Call-back function for outputting a
@ -299,6 +318,26 @@ namespace Opm {
/// is a common case in production runs.
std::unique_ptr<RandomBitGenerator> urbg_{};
/// Collect violations of single severity level from all ranks in
/// MPI communicator.
///
/// Incorporates violation counts and sampled failure points into
/// the internal structures on each rank. Aggregate results useful
/// for subsequent call to reportFailures().
///
/// \param[in] root MPI root process. This is the process/rank onto
/// which the counts and samples will be collected. Typically
/// the index of the IO rank.
///
/// \param[in] comm MPI communication object.
///
/// \param[in, out] violation Current rank's violation structure for
/// a single severity level. Holds aggregate values across all
/// ranks, including updated sample points, on return.
void collectFailures(int root,
const Parallel::Communication& comm,
ViolationSample& violation);
/// Allocate and initialise backing storage for a single set of
/// sampled consistency check violations.
///
@ -306,6 +345,44 @@ namespace Opm {
/// violation sample of proper size.
void buildStructure(ViolationSample& violation);
/// Internalise a single violation into internal data structures.
///
/// Counts the violation and uses "reservoir sampling"
/// (https://en.wikipedia.org/wiki/Reservoir_sampling) to determine
/// whether or not to include the specific point into the reporting
/// sample.
///
/// \tparam PopulateCheckValues Call-back function type
/// encapsulating block of code populate sequence of check values
/// for a single, failed consistency check. Expected to be a
/// callable type with a function call operator of the form
/// \code
/// void operator()(Scalar* checkValues) const
/// \endcode
/// in which the \c checkValues points the start of a sequence of
/// values associated to particular check. The call-back function
/// is expected to know how many values are in a valid sequence and
/// to fill in exactly this many values.
///
/// \param[in, out] violation Current rank's violation sample at
/// particular severity level.
///
/// \param[in] checkIx Numerical check index in the range
/// [0..battery_.size()).
///
/// \param[in] pointID Numeric identifier for this particular set of
/// end-points. Typically a saturation region or a cell ID.
///
/// \param[in] populateCheckValues Call-back function to populate a
/// sequence of values pertaining to specified check. Typically
/// \code Check::exportCheckValues() \endcode or a copy routine
/// to incorporate samples from multiple MPI ranks.
template <typename PopulateCheckValues>
void processViolation(ViolationSample& violation,
const std::size_t checkIx,
const std::size_t pointID,
PopulateCheckValues&& populateCheckValues);
/// Internalise a single violation into internal data structures.
///
/// Counts the violation and uses "reservoir sampling"
@ -324,6 +401,24 @@ namespace Opm {
const std::size_t checkIx,
const std::size_t pointID);
/// Incorporate single severity level's set of violations from
/// single MPI rank into current rank's internal data structures.
///
/// \param[in] count Start of sequence of failure counts for all
/// checks from single MPI rank.
///
/// \param[in] pointID Start of sequence of sampled point IDs for
/// all checks from a single MPI rank.
///
/// \param[in] checkValues Start of sequence of sampled check values
/// for all checks from a single MPI rank.
///
/// \param[in, out] violation
void incorporateRankViolations(const std::size_t* count,
const std::size_t* pointID,
const Scalar* checkValues,
ViolationSample& violation);
/// Generate random index in the sample size.
///
/// \param[in] sampleSize Total number of violations of a particular
@ -337,6 +432,16 @@ namespace Opm {
/// initialised.
void ensureRandomBitGeneratorIsInitialised();
/// Compute start offset into ViolationSample::pointID for
/// particular check.
///
/// \param[in] checkIx Numerical check index in the range
/// [0..battery_.size()).
///
/// \return Start offset into ViolationSample::pointID.
std::vector<std::size_t>::size_type
violationPointIDStart(const std::size_t checkIx) const;
/// Compute start offset into ViolationSample::checkValues for
/// particular check and sample index.
///
@ -442,6 +547,17 @@ namespace Opm {
std::size_t numPoints(const ViolationSample& violation,
const std::size_t checkIx) const;
/// Compute number of sample points for a single check's violations.
///
/// Effectively the minimum of the number of violations of that
/// check and the maximum number of sample points (\code
/// this->numSamplePoints_ \endcode).
///
/// \param[in] violationCount Total number of check violations.
///
/// \return Number of active sample points.
std::size_t numPoints(const std::size_t violationCount) const;
/// Whether or not any checks failed at specified severity level.
///
/// \param[in] level Violation severity level.

View File

@ -259,6 +259,9 @@ BOOST_AUTO_TEST_CASE(Critical_Violation)
checker.checkEndpoints(42, makePoints());
BOOST_CHECK_MESSAGE(! checker.anyFailedChecks(),
"There must be no failed standard level checks");
BOOST_CHECK_MESSAGE(checker.anyFailedCriticalChecks(),
"There must be at least one failed Critical check");

File diff suppressed because it is too large Load Diff