ALQState: template Scalar type

This commit is contained in:
Arne Morten Kvarving 2024-02-17 18:13:46 +01:00
parent ab0e696709
commit 58f334b264
6 changed files with 66 additions and 44 deletions

View File

@ -28,7 +28,8 @@
namespace Opm { namespace Opm {
ALQState ALQState::serializationTestObject() template<class Scalar>
ALQState<Scalar> ALQState<Scalar>::serializationTestObject()
{ {
ALQState result; ALQState result;
result.current_alq_ = {{"test1", 1.0}}; result.current_alq_ = {{"test1", 1.0}};
@ -40,7 +41,9 @@ ALQState ALQState::serializationTestObject()
return result; return result;
} }
double ALQState::get(const std::string& wname) const { template<class Scalar>
Scalar ALQState<Scalar>::get(const std::string& wname) const
{
auto iter = this->current_alq_.find(wname); auto iter = this->current_alq_.find(wname);
if (iter != this->current_alq_.end()) if (iter != this->current_alq_.end())
return iter->second; return iter->second;
@ -52,7 +55,9 @@ double ALQState::get(const std::string& wname) const {
throw std::logic_error("No ALQ value registered for well: " + wname); throw std::logic_error("No ALQ value registered for well: " + wname);
} }
void ALQState::update_default(const std::string& wname, double value) { template<class Scalar>
void ALQState<Scalar>::update_default(const std::string& wname, Scalar value)
{
auto default_iter = this->default_alq_.find(wname); auto default_iter = this->default_alq_.find(wname);
if (default_iter == this->default_alq_.end() || default_iter->second != value) { if (default_iter == this->default_alq_.end() || default_iter->second != value) {
this->default_alq_.insert_or_assign(wname, value); this->default_alq_.insert_or_assign(wname, value);
@ -60,20 +65,28 @@ void ALQState::update_default(const std::string& wname, double value) {
} }
} }
void ALQState::set(const std::string& wname, double value) { template<class Scalar>
void ALQState<Scalar>::set(const std::string& wname, Scalar value)
{
this->current_alq_[wname] = value; this->current_alq_[wname] = value;
} }
int ALQState::get_debug_counter() { template<class Scalar>
int ALQState<Scalar>::get_debug_counter()
{
return this->debug_counter_; return this->debug_counter_;
} }
int ALQState::update_debug_counter() { template<class Scalar>
int ALQState<Scalar>::update_debug_counter()
{
this->debug_counter_++; this->debug_counter_++;
return this->debug_counter_; return this->debug_counter_;
} }
void ALQState::set_debug_counter(int value) { template<class Scalar>
void ALQState<Scalar>::set_debug_counter(int value)
{
this->debug_counter_ = value; this->debug_counter_ = value;
} }
@ -88,7 +101,9 @@ int get_counter(const std::map<std::string, int>& count_map, const std::string&
} }
bool ALQState::oscillation(const std::string& wname) const { template<class Scalar>
bool ALQState<Scalar>::oscillation(const std::string& wname) const
{
auto inc_count = get_counter(this->alq_increase_count_, wname); auto inc_count = get_counter(this->alq_increase_count_, wname);
if (inc_count == 0) if (inc_count == 0)
return false; return false;
@ -97,35 +112,43 @@ bool ALQState::oscillation(const std::string& wname) const {
return dec_count >= 1; return dec_count >= 1;
} }
template<class Scalar>
void ALQState::update_count(const std::string& wname, bool increase) { void ALQState<Scalar>::update_count(const std::string& wname, bool increase)
{
if (increase) if (increase)
this->alq_increase_count_[wname] += 1; this->alq_increase_count_[wname] += 1;
else else
this->alq_decrease_count_[wname] += 1; this->alq_decrease_count_[wname] += 1;
} }
template<class Scalar>
void ALQState::reset_count() { void ALQState<Scalar>::reset_count()
{
this->alq_decrease_count_.clear(); this->alq_decrease_count_.clear();
this->alq_increase_count_.clear(); this->alq_increase_count_.clear();
} }
template<class Scalar>
int ALQState::get_increment_count(const std::string& wname) const { int ALQState<Scalar>::get_increment_count(const std::string& wname) const
{
return get_counter(this->alq_increase_count_, wname); return get_counter(this->alq_increase_count_, wname);
} }
int ALQState::get_decrement_count(const std::string& wname) const { template<class Scalar>
int ALQState<Scalar>::get_decrement_count(const std::string& wname) const
{
return get_counter(this->alq_decrease_count_, wname); return get_counter(this->alq_decrease_count_, wname);
} }
std::size_t ALQState::pack_size() const { template<class Scalar>
std::size_t ALQState<Scalar>::pack_size() const
{
return this->current_alq_.size(); return this->current_alq_.size();
} }
std::size_t ALQState::pack_data(double * data) const { template<class Scalar>
std::size_t ALQState<Scalar>::pack_data(Scalar* data) const
{
std::size_t index = 0; std::size_t index = 0;
for (const auto& [_, value] : this->current_alq_) { for (const auto& [_, value] : this->current_alq_) {
(void)_; (void)_;
@ -134,7 +157,9 @@ std::size_t ALQState::pack_data(double * data) const {
return index; return index;
} }
std::size_t ALQState::unpack_data(const double * data) { template<class Scalar>
std::size_t ALQState<Scalar>::unpack_data(const Scalar* data)
{
std::size_t index = 0; std::size_t index = 0;
for (auto& [_, value] : this->current_alq_) { for (auto& [_, value] : this->current_alq_) {
(void)_; (void)_;
@ -143,7 +168,8 @@ std::size_t ALQState::unpack_data(const double * data) {
return index; return index;
} }
bool ALQState::operator==(const ALQState& rhs) const template<class Scalar>
bool ALQState<Scalar>::operator==(const ALQState& rhs) const
{ {
return this->current_alq_ == rhs.current_alq_ && return this->current_alq_ == rhs.current_alq_ &&
this->default_alq_ == rhs.default_alq_ && this->default_alq_ == rhs.default_alq_ &&
@ -152,7 +178,6 @@ bool ALQState::operator==(const ALQState& rhs) const
this->debug_counter_ == rhs.debug_counter_; this->debug_counter_ == rhs.debug_counter_;
} }
template class ALQState<double>;
} }

View File

@ -22,22 +22,22 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector>
namespace Opm { namespace Opm {
class ALQState { template<class Scalar>
class ALQState
{
public: public:
static ALQState serializationTestObject(); static ALQState serializationTestObject();
std::size_t pack_size() const; std::size_t pack_size() const;
std::size_t unpack_data(const double * data); std::size_t unpack_data(const Scalar* data);
std::size_t pack_data(double * data) const; std::size_t pack_data(Scalar* data) const;
double get(const std::string& wname) const; Scalar get(const std::string& wname) const;
void update_default(const std::string& wname, double value); void update_default(const std::string& wname, Scalar value);
void set(const std::string& wname, double value); void set(const std::string& wname, Scalar value);
bool oscillation(const std::string& wname) const; bool oscillation(const std::string& wname) const;
void update_count(const std::string& wname, bool increase); void update_count(const std::string& wname, bool increase);
void reset_count(); void reset_count();
@ -60,14 +60,13 @@ public:
bool operator==(const ALQState&) const; bool operator==(const ALQState&) const;
private: private:
std::map<std::string, double> current_alq_; std::map<std::string, Scalar> current_alq_;
std::map<std::string, double> default_alq_; std::map<std::string, Scalar> default_alq_;
std::map<std::string, int> alq_increase_count_; std::map<std::string, int> alq_increase_count_;
std::map<std::string, int> alq_decrease_count_; std::map<std::string, int> alq_decrease_count_;
int debug_counter_ = 0; int debug_counter_ = 0;
}; };
} }
#endif #endif

View File

@ -136,7 +136,7 @@ WellState::WellState(const ParallelWellInfo& pinfo)
WellState WellState::serializationTestObject(const ParallelWellInfo& pinfo) WellState WellState::serializationTestObject(const ParallelWellInfo& pinfo)
{ {
WellState result(PhaseUsage{}); WellState result(PhaseUsage{});
result.alq_state = ALQState::serializationTestObject(); result.alq_state = ALQState<double>::serializationTestObject();
result.well_rates = {{"test2", {true, {1.0}}}, {"test3", {false, {2.0}}}}; result.well_rates = {{"test2", {true, {1.0}}}, {"test3", {false, {2.0}}}};
result.wells_.add("test4", SingleWellState<double>::serializationTestObject(pinfo)); result.wells_.add("test4", SingleWellState<double>::serializationTestObject(pinfo));

View File

@ -331,7 +331,7 @@ private:
// WellStateFullyImplicitBlackoil class should be default constructible, // WellStateFullyImplicitBlackoil class should be default constructible,
// whereas the GlobalWellInfo is not. // whereas the GlobalWellInfo is not.
std::optional<GlobalWellInfo> global_well_info; std::optional<GlobalWellInfo> global_well_info;
ALQState alq_state; ALQState<double> alq_state;
// The well_rates variable is defined for all wells on all processors. The // The well_rates variable is defined for all wells on all processors. The
// bool in the value pair is whether the current process owns the well or // bool in the value pair is whether the current process owns the well or

View File

@ -21,20 +21,17 @@
#include "config.h" #include "config.h"
#endif // HAVE_CONFIG_H #endif // HAVE_CONFIG_H
#include <stdexcept> #include <exception>
#include <opm/simulators/wells/ALQState.hpp> #include <opm/simulators/wells/ALQState.hpp>
#define BOOST_TEST_MODULE GroupStateTest #define BOOST_TEST_MODULE GroupStateTest
#include <boost/test/unit_test.hpp> #include <boost/test/unit_test.hpp>
using namespace Opm; using namespace Opm;
BOOST_AUTO_TEST_CASE(ALQStateCreate)
{
BOOST_AUTO_TEST_CASE(ALQStateCreate) { ALQState<double> alq_state;
ALQState alq_state;
alq_state.update_default("W1", 100); alq_state.update_default("W1", 100);
alq_state.update_default("W2", 200); alq_state.update_default("W2", 200);

View File

@ -101,7 +101,8 @@ BOOST_AUTO_TEST_CASE(NAME) \
#define TEST_FOR_TYPE(TYPE) \ #define TEST_FOR_TYPE(TYPE) \
TEST_FOR_TYPE_NAMED(TYPE, TYPE) TEST_FOR_TYPE_NAMED(TYPE, TYPE)
TEST_FOR_TYPE(ALQState) namespace Opm { using ALQS = ALQState<double>; }
TEST_FOR_TYPE_NAMED(ALQS, ALQState)
TEST_FOR_TYPE(GroupState) TEST_FOR_TYPE(GroupState)
TEST_FOR_TYPE(HardcodedTimeStepControl) TEST_FOR_TYPE(HardcodedTimeStepControl)
TEST_FOR_TYPE(Inplace) TEST_FOR_TYPE(Inplace)