/*
Copyright 2024 Equinor AS
This file is part of the Open Porous Media project (OPM).
OPM is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
OPM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with OPM. If not, see .
*/
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
// ===========================================================================
// Public member functions for SatfuncConsistencyChecks Template
// ===========================================================================
template
Opm::SatfuncConsistencyChecks::
SatfuncConsistencyChecks(std::string_view pointName,
const std::size_t numSamplePoints)
: pointName_ { pointName }
, numSamplePoints_ { numSamplePoints }
, formatPointID_ { [](const std::size_t i) { return fmt::format("{}", i); } }
{}
template
Opm::SatfuncConsistencyChecks::
SatfuncConsistencyChecks(SatfuncConsistencyChecks&& rhs)
: pointName_ { std::move(rhs.pointName_) }
, numSamplePoints_ { rhs.numSamplePoints_ }
, formatPointID_ { std::move(rhs.formatPointID_) }
, startCheckValues_ { std::move(rhs.startCheckValues_) }
, violations_ { std::move(rhs.violations_) }
, battery_ { std::move(rhs.battery_) }
{}
template
Opm::SatfuncConsistencyChecks&
Opm::SatfuncConsistencyChecks::operator=(SatfuncConsistencyChecks&& rhs)
{
this->pointName_ = std::move(rhs.pointName_);
this->numSamplePoints_ = rhs.numSamplePoints_;
this->formatPointID_ = std::move(rhs.formatPointID_);
this->startCheckValues_ = std::move(rhs.startCheckValues_);
this->violations_ = std::move(rhs.violations_);
this->battery_ = std::move(rhs.battery_);
this->urbg_.reset();
return *this;
}
template
void Opm::SatfuncConsistencyChecks::resetCheckSet()
{
this->startCheckValues_.clear();
this->startCheckValues_.push_back(0);
for (auto& violation : this->violations_) {
violation.clear();
}
this->battery_.clear();
this->urbg_.reset();
}
template
void Opm::SatfuncConsistencyChecks::addCheck(std::unique_ptr check)
{
this->battery_.push_back(std::move(check));
const auto numCheckValues = this->battery_.back()->numExportedCheckValues();
this->startCheckValues_.push_back(this->numSamplePoints_ * numCheckValues);
}
template
void Opm::SatfuncConsistencyChecks::finaliseCheckSet()
{
std::partial_sum(this->startCheckValues_.begin(),
this->startCheckValues_.end(),
this->startCheckValues_.begin());
for (auto& violation : this->violations_) {
this->buildStructure(violation);
}
}
template
void Opm::SatfuncConsistencyChecks::
checkEndpoints(const std::size_t pointID,
const EclEpsScalingPointsInfo& endPoints)
{
this->checkLoop([pointID, &endPoints, this]
(Check* currentCheck, const std::size_t checkIx)
{
currentCheck->test(endPoints);
if (! currentCheck->isViolated()) {
// Check holds for this set of end-points. Nothing to do.
return;
}
// If we get here then the check does not hold for this set of
// end-points. Process the violation at the prescribed level of
// attention. Critical violations typically end the run whereas
// a standard level violation typically generates warnings only.
const auto level = currentCheck->isCritical()
? ViolationLevel::Critical
: ViolationLevel::Standard;
this->processViolation(level, checkIx, pointID);
});
}
template
void Opm::SatfuncConsistencyChecks::
collectFailures(const int root,
const Parallel::Communication& comm)
{
if (comm.size() == 1) {
// Not a parallel run. Violation structure complete without
// exchanging additional information, so nothing to do.
return;
}
for (auto& violation : this->violations_) {
this->collectFailures(root, comm, violation);
}
}
template
bool Opm::SatfuncConsistencyChecks::anyFailedChecks() const
{
return this->anyFailedChecks(ViolationLevel::Standard);
}
template
bool Opm::SatfuncConsistencyChecks::anyFailedCriticalChecks() const
{
return this->anyFailedChecks(ViolationLevel::Critical);
}
template
void Opm::SatfuncConsistencyChecks::
reportFailures(const ViolationLevel level,
const ReportRecordOutput& emitReportRecord) const
{
this->checkLoop([this,
&emitReportRecord,
nValueChar = fmt::formatted_size("{:> 8.6e}", 1.0),
&violation = this->violations_[this->index(level)]]
(const Check* currentCheck, const std::size_t checkIx)
{
if (violation.count[checkIx] == 0) {
return;
}
this->writeReportHeader(currentCheck,
violation.count[checkIx],
emitReportRecord);
this->writeTabulatedReportSample(nValueChar,
currentCheck,
violation,
checkIx,
emitReportRecord);
});
}
// ===========================================================================
// Private member functions for SatfuncConsistencyChecks Template
// ===========================================================================
template
void Opm::SatfuncConsistencyChecks::ViolationSample::clear()
{
this->count.clear();
this->pointID.clear();
this->checkValues.clear();
}
// ---------------------------------------------------------------------------
namespace {
bool anyFailedChecks(const std::vector& count)
{
return std::any_of(count.begin(), count.end(),
[](const std::size_t n) { return n > 0; });
}
}
template
void Opm::SatfuncConsistencyChecks::
collectFailures(const int root,
const Parallel::Communication& comm,
ViolationSample& violation)
{
// Count total number of violations of each check across all ranks.
// This should be the final number emitted in reportFailures() on the
// root process.
auto totalCount = violation.count;
comm.sum(totalCount.data(), violation.count.size());
if (! ::anyFailedChecks(totalCount)) {
// No failed checks on any rank for this severity level.
//
// No additional work needed, since every rank will have zero
// failure counts for all checks.
return;
}
// CSR-like structures for the failure counts, sampled point IDs, and
// sampled check values from all ranks. One set of all-to-one messages
// for each quantity. If this stage becomes a bottleneck we must devise
// a better communication structure that reduces the number of messages.
const auto& [rankCount, startRankCount] =
gatherv(violation.count, comm, root);
const auto& [rankPointID, startRankPointID] =
gatherv(violation.pointID, comm, root);
const auto& [rankCheckValues, startRankCheckValues] =
gatherv(violation.checkValues, comm, root);
if (comm.rank() == root) {
// Re-initialise this violation sample to prepare for incorporating
// contributions from all MPI ranks--including the current rank.
violation.clear();
this->buildStructure(violation);
const auto numRanks = comm.size();
for (auto rank = 0*numRanks; rank < numRanks; ++rank) {
this->incorporateRankViolations
(rankCount.data() + startRankCount[rank],
rankPointID.data() + startRankPointID[rank],
rankCheckValues.data() + startRankCheckValues[rank],
violation);
}
}
// The final violation counts for reporting purposes should be the sum
// of the per-rank counts. This ensures that all ranks give the same
// answer to the anyFailedChecks() predicate, although the particular
// sample points will differ across the ranks.
violation.count.swap(totalCount);
// Ensure that all ranks are synchronised here before proceeding. We
// don't want to end up in a situation where the ranks have a different
// notion of what to send/receive.
comm.barrier();
}
template
void Opm::SatfuncConsistencyChecks::
buildStructure(ViolationSample& violation)
{
violation.count.assign(this->battery_.size(), 0);
violation.pointID.resize(this->battery_.size() * this->numSamplePoints_,
static_cast(0xdeadc0deUL));
violation.checkValues.resize(this->startCheckValues_.back());
if constexpr (std::numeric_limits::has_quiet_NaN) {
std::fill(violation.checkValues.begin(),
violation.checkValues.end(),
std::numeric_limits::quiet_NaN());
}
}
template
template
void Opm::SatfuncConsistencyChecks::
processViolation(ViolationSample& violation,
const std::size_t checkIx,
const std::size_t pointID,
PopulateCheckValues&& populateCheckValues)
{
const auto nViol = ++violation.count[checkIx];
// Special case handling for number of violations not exceeding number
// of sample points. Needed in order to guarantee that the full table
// is populated before starting the random replacement stage.
const auto sampleIx = (nViol <= this->numSamplePoints_)
? (nViol - 1)
: this->getSampleIndex(nViol);
if (sampleIx >= this->numSamplePoints_) {
// Reservoir sampling algorithm
// (https://en.wikipedia.org/wiki/Reservoir_sampling) says that this
// particular set of end-points should *not* be included in the
// reported violations. No more work needed in this case.
return;
}
// If we get here, then this set of end-points should be included in the
// reported violations. Record the pointID and the corresponding check
// values in their appropriate locations.
violation.pointID[this->violationPointIDStart(checkIx) + sampleIx] = pointID;
auto* const checkValues = violation.checkValues.data()
+ this->violationValueStart(checkIx, sampleIx);
populateCheckValues(checkValues);
}
template
void Opm::SatfuncConsistencyChecks::
processViolation(const ViolationLevel level,
const std::size_t checkIx,
const std::size_t pointID)
{
this->processViolation(this->violations_[this->index(level)], checkIx, pointID,
[this, checkIx](Scalar* const exportedCheckValues)
{
this->battery_[checkIx]->exportCheckValues(exportedCheckValues);
});
}
template
void Opm::SatfuncConsistencyChecks::
incorporateRankViolations(const std::size_t* const count,
const std::size_t* const pointID,
const Scalar* const checkValues,
ViolationSample& violation)
{
this->checkLoop([this, count, pointID, checkValues, &violation]
(const Check* currentCheck,
const std::size_t checkIx)
{
if (count[checkIx] == 0) {
// No violations of this check on this rank. Nothing to do.
return;
}
const auto* const srcPointID = pointID
+ this->violationPointIDStart(checkIx);
const auto numCheckValues = currentCheck->numExportedCheckValues();
const auto numSrcSamples = this->numPoints(count[checkIx]);
for (auto srcSampleIx = 0*numSrcSamples; srcSampleIx < numSrcSamples; ++srcSampleIx) {
this->processViolation(violation, checkIx, srcPointID[srcSampleIx],
[numCheckValues,
srcCheckValues = checkValues + this->violationValueStart(checkIx, srcSampleIx)]
(Scalar* const destCheckValues)
{
std::copy_n(srcCheckValues, numCheckValues, destCheckValues);
});
}
});
}
namespace {
std::vector
computeFieldWidths(const std::vector& columnHeaders,
const std::string::size_type minColWidth)
{
auto fieldWidths = std::vector(columnHeaders.size());
std::transform(columnHeaders.begin(), columnHeaders.end(),
fieldWidths.begin(),
[minColWidth](const std::string& header)
{ return std::max(minColWidth, header.size()); });
return fieldWidths;
}
std::string
createTableSeparator(const std::string::size_type fwPointID,
const std::vector& fieldWidths)
{
using namespace fmt::literals;
// Note: "+2" for one blank space on each side of the string value.
auto separator = fmt::format("+{name:-<{width}}",
"name"_a = "",
"width"_a = fwPointID + 2);
for (const auto& fieldWidth : fieldWidths) {
separator += fmt::format("+{name:-<{width}}",
"name"_a = "",
"width"_a = fieldWidth + 2);
}
separator += '+';
return separator;
}
template
void writeTableHeader(const std::string_view::size_type fwPointID,
std::string_view pointName,
const std::vector& fieldWidths,
const std::vector& columnHeaders,
EmitRecord&& emitRecord)
{
using namespace fmt::literals;
auto tableHeader = fmt::format("| {name:<{width}} ",
"name"_a = pointName,
"width"_a = fwPointID);
for (auto colIx = 0*columnHeaders.size(); colIx < columnHeaders.size(); ++colIx) {
tableHeader += fmt::format("| {name:<{width}} ",
"name"_a = columnHeaders[colIx],
"width"_a = fieldWidths[colIx]);
}
emitRecord(tableHeader + '|');
}
template
void writeTableRecord(const std::string_view::size_type fwPointID,
std::string_view pointID,
const std::vector& fieldWidths,
const Scalar* checkValues,
EmitRecord&& emitRecord)
{
using namespace fmt::literals;
auto record = fmt::format("| {pointID:<{width}} ",
"width"_a = fwPointID,
"pointID"_a = pointID);
for (auto colIx = 0*fieldWidths.size(); colIx < fieldWidths.size(); ++colIx) {
record += fmt::format("| {checkValue:>{width}.6e} ",
"width"_a = fieldWidths[colIx],
"checkValue"_a = checkValues[colIx]);
}
emitRecord(record + '|');
}
} // Anonymous namespace
template
void Opm::SatfuncConsistencyChecks::
writeReportHeader(const Check* currentCheck,
const std::size_t violationCount,
const ReportRecordOutput& emitReportRecord) const
{
const auto* sampleMsg = (violationCount > this->numSamplePoints_)
? "Sample Violations"
: "List of Violations";
emitReportRecord(fmt::format("Consistency Problem:\n"
" {}\n"
" {}\n"
" Total Violations: {}\n\n"
"{}",
currentCheck->description(),
currentCheck->condition(),
violationCount, sampleMsg));
}
template
void Opm::SatfuncConsistencyChecks::
writeTabulatedReportSample(const std::size_t nValueChar,
const Check* currentCheck,
const ViolationSample& violation,
const std::size_t checkIx,
const ReportRecordOutput& emitReportRecord) const
{
const auto formattedPointIDs = this->formatPointIDs(violation, checkIx);
const auto fieldWidthPointID =
std::max(formattedPointIDs.second, this->pointName_.size());
const auto columnHeaders = this->collectColumnHeaders(currentCheck);
const auto fieldWidths = computeFieldWidths(columnHeaders, nValueChar);
const auto separator = createTableSeparator(fieldWidthPointID, fieldWidths);
// Output separator to start table output.
emitReportRecord(separator);
// Output column headers.
writeTableHeader(fieldWidthPointID, this->pointName_,
fieldWidths, columnHeaders,
emitReportRecord);
// Output separator to start table value output.
emitReportRecord(separator);
// Emit sampled check violations in order sorted on the pointID.
for (const auto& i : this->sortedPointIndices(violation, checkIx)) {
const auto* checkValues = violation.checkValues.data()
+ this->violationValueStart(checkIx, i);
writeTableRecord(fieldWidthPointID, formattedPointIDs.first[i],
fieldWidths, checkValues,
emitReportRecord);
}
// Output separator to end table output.
//
// Note: We emit two blank lines after final separator in order to
// generate some vertical space for the case of multiple failing checks.
emitReportRecord(fmt::format("{}\n\n", separator));
}
template
std::pair, std::string::size_type>
Opm::SatfuncConsistencyChecks::
formatPointIDs(const ViolationSample& violation,
const std::size_t checkIx) const
{
auto formattedPointIDs = std::pair
,
std::string::size_type>
{
std::piecewise_construct,
std::forward_as_tuple(),
std::forward_as_tuple(std::string::size_type{0})
};
const auto nPoints = this->numPoints(violation, checkIx);
formattedPointIDs.first.reserve(nPoints);
const auto* pointIDs = violation.pointID.data()
+ (checkIx * this->numSamplePoints_);
for (auto point = 0*nPoints; point < nPoints; ++point) {
formattedPointIDs.first.push_back
(this->formatPointID_(pointIDs[point]));
formattedPointIDs.second =
std::max(formattedPointIDs.second,
formattedPointIDs.first.back().size());
}
return formattedPointIDs;
}
template
std::vector
Opm::SatfuncConsistencyChecks::
collectColumnHeaders(const Check* currentCheck) const
{
auto headers = std::vector
(currentCheck->numExportedCheckValues());
currentCheck->columnNames(headers.data());
return headers;
}
template
std::vector
Opm::SatfuncConsistencyChecks::
sortedPointIndices(const ViolationSample& violation,
const std::size_t checkIx) const
{
auto sortedIdxs = std::vector
(this->numPoints(violation, checkIx));
std::iota(sortedIdxs.begin(), sortedIdxs.end(), std::size_t{0});
std::sort(sortedIdxs.begin(), sortedIdxs.end(),
[pointIDs = violation.pointID.data() + (checkIx * this->numSamplePoints_)]
(const std::size_t i1, const std::size_t i2)
{
return pointIDs[i1] < pointIDs[i2];
});
return sortedIdxs;
}
template
std::size_t
Opm::SatfuncConsistencyChecks::
numPoints(const ViolationSample& violation,
const std::size_t checkIx) const
{
return this->numPoints(violation.count[checkIx]);
}
template
std::size_t
Opm::SatfuncConsistencyChecks::
numPoints(const std::size_t violationCount) const
{
return std::min(this->numSamplePoints_, violationCount);
}
template
std::size_t
Opm::SatfuncConsistencyChecks::
getSampleIndex(const std::size_t sampleSize)
{
assert (sampleSize > 0);
this->ensureRandomBitGeneratorIsInitialised();
return std::uniform_int_distribution
{ 0, sampleSize - 1 }(*this->urbg_);
}
template
void Opm::SatfuncConsistencyChecks::ensureRandomBitGeneratorIsInitialised()
{
if (this->urbg_ != nullptr) {
return;
}
const auto k = static_cast
(std::log2(RandomBitGenerator::modulus) / 32) + 1;
auto state = std::vector(k + 3);
std::random_device rd{};
std::generate(state.begin(), state.end(), std::ref(rd));
std::seed_seq seeds(state.begin(), state.end());
this->urbg_ = std::make_unique(seeds);
}
template
std::vector::size_type
Opm::SatfuncConsistencyChecks::
violationPointIDStart(const std::size_t checkIx) const
{
return checkIx * this->numSamplePoints_;
}
template
typename std::vector::size_type
Opm::SatfuncConsistencyChecks::
violationValueStart(const std::size_t checkIx,
const std::size_t sampleIx) const
{
return this->startCheckValues_[checkIx]
+ (sampleIx * this->battery_[checkIx]->numExportedCheckValues());
}
template
bool
Opm::SatfuncConsistencyChecks::
anyFailedChecks(const ViolationLevel level) const
{
return ::anyFailedChecks(this->violations_[this->index(level)].count);
}
template
template
void Opm::SatfuncConsistencyChecks::checkLoop(Body&& body)
{
const auto numChecks = this->battery_.size();
for (auto checkIx = 0*numChecks; checkIx < numChecks; ++checkIx) {
body(this->battery_[checkIx].get(), checkIx);
}
}
template
template
void Opm::SatfuncConsistencyChecks::checkLoop(Body&& body) const
{
const auto numChecks = this->battery_.size();
for (auto checkIx = 0*numChecks; checkIx < numChecks; ++checkIx) {
body(this->battery_[checkIx].get(), checkIx);
}
}
// ===========================================================================
// Explicit Specialisations of SatfuncConsistencyChecks Template
//
// No other code below this separator
// ===========================================================================
template class Opm::SatfuncConsistencyChecks;
template class Opm::SatfuncConsistencyChecks;