Use 'if constexpr' to simplify MultiRate specialization

This commit is contained in:
Ray Speth 2023-03-05 18:02:42 -05:00 committed by Ingmar Schoegl
parent 56c3a077f7
commit 184f4ddccf
3 changed files with 89 additions and 181 deletions

View File

@ -187,17 +187,13 @@ const U& getValue(const std::map<T, U>& m, const T& key, const U& default_val) {
}
//! A macro for generating member function detectors, which can then be used in
//! combination with `std::enable_if` to allow selection of a particular template
//! specialization based on the presence of that member function. See MultiRate for
//! examples of use.
#define CT_DEFINE_HAS_MEMBER(detector_name, func_name) \
template <typename T> \
struct detector_name { \
typedef char (& yes)[1], (& no)[2]; \
template <typename C> static yes check(decltype(&C::func_name)); \
template <typename> static no check(...); \
static bool const value = sizeof(check<T>(0)) == sizeof(yes); \
};
//! combination with `if constexpr` to condition behavior on the availability of that
//! member function. See MultiRate for examples of use.
#define CT_DEFINE_HAS_MEMBER(detector_name, func_name) \
template<class T, class=void> \
struct detector_name : std::false_type {}; \
template<class T> \
struct detector_name<T, std::void_t<decltype(&T::func_name)>> : std::true_type {};
}

View File

@ -402,7 +402,9 @@ public:
//! Update reaction rate parameters
//! @param shared_data data shared by all reactions of a given type
void updateFromStruct(const DataType& shared_data) {
_update(shared_data);
if constexpr (has_update<RateType>::value) {
RateType::updateFromStruct(shared_data);
}
InterfaceRateBase::updateFromStruct(shared_data);
}
@ -432,22 +434,6 @@ public:
virtual double activationEnergy() const override {
return RateType::activationEnergy() + m_ecov * GasConstant;
}
protected:
//! Helper function to process updates for rate types that implement the
//! `updateFromStruct` method.
template <typename T=RateType,
typename std::enable_if<has_update<T>::value, bool>::type = true>
void _update(const DataType& shared_data) {
T::updateFromStruct(shared_data);
}
//! Helper function for rate types that do not implement `updateFromStruct`.
//! Does nothing, but exists to allow generic implementations of update().
template <typename T=RateType,
typename std::enable_if<!has_update<T>::value, bool>::type = true>
void _update(const DataType& shared_data) {
}
};
using InterfaceArrheniusRate = InterfaceRate<ArrheniusRate, InterfaceData>;
@ -539,7 +525,9 @@ public:
//! Update reaction rate parameters
//! @param shared_data data shared by all reactions of a given type
void updateFromStruct(const DataType& shared_data) {
_update(shared_data);
if constexpr (has_update<RateType>::value) {
RateType::updateFromStruct(shared_data);
}
InterfaceRateBase::updateFromStruct(shared_data);
m_factor = pow(m_siteDensity, -m_surfaceOrder);
}
@ -575,22 +563,6 @@ public:
virtual double activationEnergy() const override {
return RateType::activationEnergy() + m_ecov * GasConstant;
}
protected:
//! Helper function to process updates for rate types that implement the
//! `updateFromStruct` method.
template <typename T=RateType,
typename std::enable_if<has_update<T>::value, bool>::type = true>
void _update(const DataType& shared_data) {
T::updateFromStruct(shared_data);
}
//! Helper function for rate types that do not implement `updateFromStruct`.
//! Does nothing, but exists to allow generic implementations of update().
template <typename T=RateType,
typename std::enable_if<!has_update<T>::value, bool>::type = true>
void _update(const DataType& shared_data) {
}
};
using StickingArrheniusRate = StickingRate<ArrheniusRate, InterfaceData>;

View File

@ -74,18 +74,54 @@ public:
const double* kf,
double deltaT) override
{
// call helper function: implementation of derivative depends on whether
// ReactionRate::ddTFromStruct is defined
_process_ddT(rop, kf, deltaT);
if constexpr (has_ddT<RateType>::value) {
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] *= rate.ddTScaledFromStruct(m_shared);
}
} else {
// perturb conditions
double dTinv = 1. / (m_shared.temperature * deltaT);
m_shared.perturbTemperature(deltaT);
_update();
// apply numerical derivative
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dTinv * (k1 / kf[iRxn] - 1.);
} // else not needed: derivative is already zero
}
// revert changes
m_shared.restore();
_update();
}
}
virtual void processRateConstants_ddP(double* rop,
const double* kf,
double deltaP) override
{
// call helper function: implementation of derivative depends on whether
// ReactionData::perturbPressure is defined
_process_ddP(rop, kf, deltaP);
if constexpr (has_ddP<DataType>::value) {
double dPinv = 1. / (m_shared.pressure * deltaP);
m_shared.perturbPressure(deltaP);
_update();
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dPinv * (k1 / kf[iRxn] - 1.);
} // else not needed: derivative is already zero
}
// revert changes
m_shared.restore();
_update();
} else {
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] = 0.;
}
}
}
virtual void processRateConstants_ddM(double* rop,
@ -93,9 +129,33 @@ public:
double deltaM,
bool overwrite=true) override
{
// call helper function: implementation of derivative depends on whether
// ReactionRate::thirdBodyConcentration is defined
_process_ddM(rop, kf, deltaM, overwrite);
if constexpr (has_ddM<DataType>::value) {
double dMinv = 1. / deltaM;
m_shared.perturbThirdBodies(deltaM);
_update();
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0. && m_shared.conc_3b[iRxn] > 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dMinv * (k1 / kf[iRxn] - 1.);
rop[iRxn] /= m_shared.conc_3b[iRxn];
} else {
rop[iRxn] = 0.;
}
}
// revert changes
m_shared.restore();
_update();
} else {
if (!overwrite) {
// do not overwrite existing entries
return;
}
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] = 0.;
}
}
}
virtual void update(double T) override {
@ -125,7 +185,9 @@ public:
virtual double evalSingle(ReactionRate& rate) override {
RateType& R = static_cast<RateType&>(rate);
_updateRate(R);
if constexpr (has_update<RateType>::value) {
R.updateFromStruct(m_shared);
}
return R.evalFromStruct(m_shared);
}
@ -136,135 +198,13 @@ public:
}
protected:
//! Helper function to process updates for rate types that implement the
//! `updateFromStruct` method.
template <typename T=RateType,
typename std::enable_if<has_update<T>::value, bool>::type = true>
//! Helper function to process updates
void _update() {
for (auto& rxn : m_rxn_rates) {
rxn.second.updateFromStruct(m_shared);
}
}
//! Helper function for rate types that do not implement `updateFromStruct`.
//! Does nothing, but exists to allow generic implementations of update().
template <typename T=RateType,
typename std::enable_if<!has_update<T>::value, bool>::type = true>
void _update() {
}
//! Helper function to update a single rate that has an `updateFromStruct` method`.
template <typename T=RateType,
typename std::enable_if<has_update<T>::value, bool>::type = true>
void _updateRate(RateType& rate) {
rate.updateFromStruct(m_shared);
}
//! Helper function for single rate that does not implement `updateFromStruct`.
//! Exists to allow generic implementations of `evalSingle` and `ddTSingle`.
template <typename T=RateType,
typename std::enable_if<!has_update<T>::value, bool>::type = true>
void _updateRate(RateType& rate) {
}
//! Helper function to process temperature derivatives for rate types that
//! implement the `ddTScaledFromStruct` method.
template <typename T=RateType,
typename std::enable_if<has_ddT<T>::value, bool>::type = true>
void _process_ddT(double* rop, const double* kf, double deltaT) {
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] *= rate.ddTScaledFromStruct(m_shared);
}
}
//! Helper function for rate types that do not implement `ddTScaledFromStruct`
template <typename T=RateType,
typename std::enable_if<!has_ddT<T>::value, bool>::type = true>
void _process_ddT(double* rop, const double* kf, double deltaT) {
// perturb conditions
double dTinv = 1. / (m_shared.temperature * deltaT);
m_shared.perturbTemperature(deltaT);
_update();
// apply numerical derivative
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dTinv * (k1 / kf[iRxn] - 1.);
} // else not needed: derivative is already zero
}
// revert changes
m_shared.restore();
_update();
}
//! Helper function to process third-body derivatives for rate data that
//! implement the `perturbThirdBodies` method.
template <typename T=RateType, typename D=DataType,
typename std::enable_if<has_ddM<D>::value, bool>::type = true>
void _process_ddM(double* rop, const double* kf, double deltaM, bool overwrite) {
double dMinv = 1. / deltaM;
m_shared.perturbThirdBodies(deltaM);
_update();
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0. && m_shared.conc_3b[iRxn] > 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dMinv * (k1 / kf[iRxn] - 1.);
rop[iRxn] /= m_shared.conc_3b[iRxn];
} else {
rop[iRxn] = 0.;
if constexpr (has_update<RateType>::value) {
for (auto& [i, rxn] : m_rxn_rates) {
rxn.updateFromStruct(m_shared);
}
}
// revert changes
m_shared.restore();
_update();
}
//! Helper function for rate data that do not implement `perturbThirdBodies`
template <typename T=RateType, typename D=DataType,
typename std::enable_if<!has_ddM<D>::value, bool>::type = true>
void _process_ddM(double* rop, const double* kf, double deltaM, bool overwrite) {
if (!overwrite) {
// do not overwrite existing entries
return;
}
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] = 0.;
}
}
//! Helper function to process pressure derivatives for rate data that
//! implement the `perturbPressure` method.
template <typename T=RateType, typename D=DataType,
typename std::enable_if<has_ddP<D>::value, bool>::type = true>
void _process_ddP(double* rop, const double* kf, double deltaP) {
double dPinv = 1. / (m_shared.pressure * deltaP);
m_shared.perturbPressure(deltaP);
_update();
for (auto& [iRxn, rate] : m_rxn_rates) {
if (kf[iRxn] != 0.) {
double k1 = rate.evalFromStruct(m_shared);
rop[iRxn] *= dPinv * (k1 / kf[iRxn] - 1.);
} // else not needed: derivative is already zero
}
// revert changes
m_shared.restore();
_update();
}
//! Helper function for rate data that do not implement `perturbPressure`
template <typename T=RateType, typename D=DataType,
typename std::enable_if<!has_ddP<D>::value, bool>::type = true>
void _process_ddP(double* rop, const double* kf, double deltaP) {
for (const auto& [iRxn, rate] : m_rxn_rates) {
rop[iRxn] = 0.;
}
}
//! Vector of pairs of reaction rates indices and reaction rates