mirror of
https://github.com/OPM/opm-simulators.git
synced 2025-02-25 18:55:30 -06:00
Use named constructor pattern.
Made AutoDiff::Forward constructor private. Added static methods constant(), variable(x) and function(x, dx).
This commit is contained in:
parent
0ff3b721ee
commit
42cdfa27a0
48
AutoDiff.hpp
48
AutoDiff.hpp
@ -41,14 +41,22 @@
|
||||
namespace AutoDiff {
|
||||
template <typename Scalar>
|
||||
class Forward {
|
||||
private:
|
||||
public:
|
||||
explicit Forward(const Scalar& x)
|
||||
: val_(x), der_(Scalar(1))
|
||||
{}
|
||||
static Forward constant(const Scalar x)
|
||||
{
|
||||
return Forward(x, Scalar(0));
|
||||
}
|
||||
|
||||
Forward(const Scalar x, const Scalar dx)
|
||||
: val_(x), der_(dx)
|
||||
{}
|
||||
static Forward variable(const Scalar x)
|
||||
{
|
||||
return Forward(x, Scalar(1));
|
||||
}
|
||||
|
||||
static Forward function(const Scalar x, const Scalar dx)
|
||||
{
|
||||
return Forward(x, dx);
|
||||
}
|
||||
|
||||
Forward&
|
||||
operator +=(const Scalar& rhs)
|
||||
@ -133,11 +141,14 @@ namespace AutoDiff {
|
||||
const Scalar der() const { return der_; }
|
||||
|
||||
private:
|
||||
Forward(const Scalar x, const Scalar dx)
|
||||
: val_(x), der_(dx)
|
||||
{}
|
||||
Scalar val_ ;
|
||||
Scalar der_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <class Ostream, typename Scalar>
|
||||
Ostream&
|
||||
operator<<(Ostream& os, const Forward<Scalar>& fw)
|
||||
@ -194,9 +205,7 @@ namespace AutoDiff {
|
||||
operator -(const T lhs,
|
||||
const Forward<Scalar>& rhs)
|
||||
{
|
||||
Forward<Scalar> ret(Scalar(lhs) - rhs.val(), -rhs.der());
|
||||
|
||||
return ret;
|
||||
return Forward<Scalar>::function(Scalar(lhs) - rhs.val(), -rhs.der());
|
||||
}
|
||||
|
||||
template <typename Scalar, typename T>
|
||||
@ -262,9 +271,7 @@ namespace AutoDiff {
|
||||
Scalar a = Scalar(lhs) / rhs.val();
|
||||
Scalar b = -Scalar(lhs) / (rhs.val() * rhs.val());
|
||||
|
||||
Forward<Scalar> ret(a, b);
|
||||
|
||||
return ret;
|
||||
return Forward<Scalar>::function(a, b);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename T>
|
||||
@ -275,19 +282,17 @@ namespace AutoDiff {
|
||||
Scalar a = lhs.val() / Scalar(rhs);
|
||||
Scalar b = lhs.der() / Scalar(rhs);
|
||||
|
||||
Forward<Scalar> ret(a, b);
|
||||
|
||||
return ret;
|
||||
return Forward<Scalar>::function(a, b);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
Forward<Scalar>
|
||||
cos(const Forward<Scalar>& x)
|
||||
{
|
||||
Forward<Scalar> ret( std::cos(x.val()),
|
||||
-std::sin(x.val()) * x.der());
|
||||
Scalar a = std::cos(x.val());
|
||||
Scalar b = -std::sin(x.val()) * x.der();
|
||||
|
||||
return ret;
|
||||
return Forward<Scalar>::function(a, b);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
@ -295,10 +300,9 @@ namespace AutoDiff {
|
||||
sqrt(const Forward<Scalar>& x)
|
||||
{
|
||||
Scalar a = std::sqrt(x.val());
|
||||
Scalar b = Scalar(1.0) / (Scalar(2.0) * a);
|
||||
Forward<Scalar> ret(a, b * x.der());
|
||||
Scalar b = (Scalar(1.0) / (Scalar(2.0) * a)) * x.der();
|
||||
|
||||
return ret;
|
||||
return Forward<Scalar>::function(a, b);
|
||||
}
|
||||
} // namespace AutoDiff
|
||||
|
||||
|
@ -78,9 +78,10 @@ public:
|
||||
{
|
||||
double x = initial_guess;
|
||||
iterations_used = 0;
|
||||
typedef AutoDiff::Forward<double> AD;
|
||||
while (std::abs(f(x)) > tolerance && ++iterations_used < max_iter) {
|
||||
AutoDiff::Forward<double> xfad(x);
|
||||
AutoDiff::Forward<double> rfad = f(xfad);
|
||||
AD xfad = AD::variable(x);
|
||||
AD rfad = f(xfad);
|
||||
x = x - rfad.val()/rfad.der();
|
||||
}
|
||||
return x;
|
||||
|
10
test_ad.cpp
10
test_ad.cpp
@ -40,8 +40,10 @@
|
||||
int
|
||||
main()
|
||||
{
|
||||
AutoDiff::Forward<double> a(0.0);
|
||||
AutoDiff::Forward<double> b(1.0);
|
||||
typedef AutoDiff::Forward<double> AdFW;
|
||||
|
||||
AdFW a = AdFW::variable(0.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
std::cout << "a: " << a << '\n';
|
||||
std::cout << "b: " << b << '\n';
|
||||
@ -52,7 +54,7 @@ main()
|
||||
std::cout << "b: " << b << '\n';
|
||||
std::cout << "a + b: " << a + b << '\n';
|
||||
|
||||
a = AutoDiff::Forward<double>(0.0);
|
||||
a = AdFW::variable(0.0);
|
||||
std::cout << "a: " << a << '\n';
|
||||
|
||||
a += 1;
|
||||
@ -60,7 +62,7 @@ main()
|
||||
std::cout << "a + 1: " << (a + 1) << '\n';
|
||||
std::cout << "1 + a: " << (1.0f + a) << '\n';
|
||||
|
||||
a = AutoDiff::Forward<double>(1);
|
||||
a = AdFW::variable(1);
|
||||
std::cout << "a: " << a << '\n';
|
||||
std::cout << "a - 1: " << (a - 1) << '\n';
|
||||
std::cout << "a - b: " << (a - b) << '\n';
|
||||
|
@ -15,7 +15,8 @@ BOOST_AUTO_TEST_CASE(Initialisation)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
AdFW a(0.0), b(1.0);
|
||||
AdFW a = AdFW::variable(0.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
BOOST_CHECK_CLOSE(a.val(), 0.0, atol);
|
||||
BOOST_CHECK_CLOSE(b.val(), 1.0, atol);
|
||||
@ -34,7 +35,8 @@ BOOST_AUTO_TEST_CASE(Addition)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
AdFW a(0.0), b(1.0);
|
||||
AdFW a = AdFW::variable(0.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
AdFW two_a = a + a;
|
||||
BOOST_CHECK_CLOSE(two_a.val(), 2*a.val(), atol);
|
||||
@ -68,7 +70,8 @@ BOOST_AUTO_TEST_CASE(Subtraction)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
AdFW a(0.0), b(1.0);
|
||||
AdFW a = AdFW::variable(0.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
AdFW no_a = a - a;
|
||||
BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol);
|
||||
@ -106,7 +109,8 @@ BOOST_AUTO_TEST_CASE(Multiplication)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
AdFW a(1.0), b(1.0);
|
||||
AdFW a = AdFW::variable(0.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
AdFW no_a = a * 0;
|
||||
BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol);
|
||||
@ -144,7 +148,8 @@ BOOST_AUTO_TEST_CASE(Division)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
AdFW a(10.0), b(1.0);
|
||||
AdFW a = AdFW::variable(10.0);
|
||||
AdFW b = AdFW::variable(1.0);
|
||||
|
||||
AdFW aob = a / b;
|
||||
BOOST_CHECK_CLOSE(aob.val(), a.val() * b.val(), atol);
|
||||
@ -180,7 +185,7 @@ BOOST_AUTO_TEST_CASE(Polynomial)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
const AdFW x(1.234e-1);
|
||||
const AdFW x = AdFW::variable(1.234e-1);
|
||||
|
||||
const AdFW p0 = x * x;
|
||||
BOOST_CHECK_CLOSE(p0.val(), x.val() * x.val(), atol);
|
||||
@ -198,7 +203,7 @@ BOOST_AUTO_TEST_CASE(Cosine)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
const AdFW x(3.14159265358979323846264338327950288);
|
||||
const AdFW x = AdFW::variable(3.14159265358979323846264338327950288);
|
||||
|
||||
const AdFW f = std::cos(x);
|
||||
BOOST_CHECK_CLOSE(f.val(), std::cos(x.val()), atol);
|
||||
@ -217,7 +222,7 @@ BOOST_AUTO_TEST_CASE(SquareRoot)
|
||||
|
||||
const double atol = 1.0e-14;
|
||||
|
||||
const AdFW x(1.234e-5);
|
||||
const AdFW x = AdFW::variable(1.234e-5);
|
||||
|
||||
const AdFW x2 = x * x;
|
||||
const AdFW g = std::cos(x2) + x;
|
||||
|
Loading…
Reference in New Issue
Block a user