From 42cdfa27a04057a7f9e2aa21f9373dd72f1b46c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Atgeirr=20Fl=C3=B8=20Rasmussen?= Date: Tue, 30 Apr 2013 11:05:59 +0200 Subject: [PATCH] Use named constructor pattern. Made AutoDiff::Forward constructor private. Added static methods constant(), variable(x) and function(x, dx). --- AutoDiff.hpp | 48 ++++++++++++++++++++++++++---------------------- find_zero.cpp | 5 +++-- test_ad.cpp | 10 ++++++---- test_syntax.cpp | 21 +++++++++++++-------- 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/AutoDiff.hpp b/AutoDiff.hpp index e1bb75fc6..c0b33eece 100644 --- a/AutoDiff.hpp +++ b/AutoDiff.hpp @@ -41,14 +41,22 @@ namespace AutoDiff { template class Forward { + private: public: - explicit Forward(const Scalar& x) - : val_(x), der_(Scalar(1)) - {} + static Forward constant(const Scalar x) + { + return Forward(x, Scalar(0)); + } - Forward(const Scalar x, const Scalar dx) - : val_(x), der_(dx) - {} + static Forward variable(const Scalar x) + { + return Forward(x, Scalar(1)); + } + + static Forward function(const Scalar x, const Scalar dx) + { + return Forward(x, dx); + } Forward& operator +=(const Scalar& rhs) @@ -133,11 +141,14 @@ namespace AutoDiff { const Scalar der() const { return der_; } private: + Forward(const Scalar x, const Scalar dx) + : val_(x), der_(dx) + {} Scalar val_ ; Scalar der_; - }; + template Ostream& operator<<(Ostream& os, const Forward& fw) @@ -194,9 +205,7 @@ namespace AutoDiff { operator -(const T lhs, const Forward& rhs) { - Forward ret(Scalar(lhs) - rhs.val(), -rhs.der()); - - return ret; + return Forward::function(Scalar(lhs) - rhs.val(), -rhs.der()); } template @@ -262,9 +271,7 @@ namespace AutoDiff { Scalar a = Scalar(lhs) / rhs.val(); Scalar b = -Scalar(lhs) / (rhs.val() * rhs.val()); - Forward ret(a, b); - - return ret; + return Forward::function(a, b); } template @@ -275,19 +282,17 @@ namespace AutoDiff { Scalar a = lhs.val() / Scalar(rhs); Scalar b = lhs.der() / Scalar(rhs); - Forward ret(a, b); - - return ret; + return Forward::function(a, b); } template Forward cos(const Forward& x) { - Forward ret( std::cos(x.val()), - -std::sin(x.val()) * x.der()); + Scalar a = std::cos(x.val()); + Scalar b = -std::sin(x.val()) * x.der(); - return ret; + return Forward::function(a, b); } template @@ -295,10 +300,9 @@ namespace AutoDiff { sqrt(const Forward& x) { Scalar a = std::sqrt(x.val()); - Scalar b = Scalar(1.0) / (Scalar(2.0) * a); - Forward ret(a, b * x.der()); + Scalar b = (Scalar(1.0) / (Scalar(2.0) * a)) * x.der(); - return ret; + return Forward::function(a, b); } } // namespace AutoDiff diff --git a/find_zero.cpp b/find_zero.cpp index fea87f60b..e428c5ebf 100644 --- a/find_zero.cpp +++ b/find_zero.cpp @@ -78,9 +78,10 @@ public: { double x = initial_guess; iterations_used = 0; + typedef AutoDiff::Forward AD; while (std::abs(f(x)) > tolerance && ++iterations_used < max_iter) { - AutoDiff::Forward xfad(x); - AutoDiff::Forward rfad = f(xfad); + AD xfad = AD::variable(x); + AD rfad = f(xfad); x = x - rfad.val()/rfad.der(); } return x; diff --git a/test_ad.cpp b/test_ad.cpp index 4b7a6f575..06f97505e 100644 --- a/test_ad.cpp +++ b/test_ad.cpp @@ -40,8 +40,10 @@ int main() { - AutoDiff::Forward a(0.0); - AutoDiff::Forward b(1.0); + typedef AutoDiff::Forward AdFW; + + AdFW a = AdFW::variable(0.0); + AdFW b = AdFW::variable(1.0); std::cout << "a: " << a << '\n'; std::cout << "b: " << b << '\n'; @@ -52,7 +54,7 @@ main() std::cout << "b: " << b << '\n'; std::cout << "a + b: " << a + b << '\n'; - a = AutoDiff::Forward(0.0); + a = AdFW::variable(0.0); std::cout << "a: " << a << '\n'; a += 1; @@ -60,7 +62,7 @@ main() std::cout << "a + 1: " << (a + 1) << '\n'; std::cout << "1 + a: " << (1.0f + a) << '\n'; - a = AutoDiff::Forward(1); + a = AdFW::variable(1); std::cout << "a: " << a << '\n'; std::cout << "a - 1: " << (a - 1) << '\n'; std::cout << "a - b: " << (a - b) << '\n'; diff --git a/test_syntax.cpp b/test_syntax.cpp index 48f9a946d..7886677f9 100644 --- a/test_syntax.cpp +++ b/test_syntax.cpp @@ -15,7 +15,8 @@ BOOST_AUTO_TEST_CASE(Initialisation) const double atol = 1.0e-14; - AdFW a(0.0), b(1.0); + AdFW a = AdFW::variable(0.0); + AdFW b = AdFW::variable(1.0); BOOST_CHECK_CLOSE(a.val(), 0.0, atol); BOOST_CHECK_CLOSE(b.val(), 1.0, atol); @@ -34,7 +35,8 @@ BOOST_AUTO_TEST_CASE(Addition) const double atol = 1.0e-14; - AdFW a(0.0), b(1.0); + AdFW a = AdFW::variable(0.0); + AdFW b = AdFW::variable(1.0); AdFW two_a = a + a; BOOST_CHECK_CLOSE(two_a.val(), 2*a.val(), atol); @@ -68,7 +70,8 @@ BOOST_AUTO_TEST_CASE(Subtraction) const double atol = 1.0e-14; - AdFW a(0.0), b(1.0); + AdFW a = AdFW::variable(0.0); + AdFW b = AdFW::variable(1.0); AdFW no_a = a - a; BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol); @@ -106,7 +109,8 @@ BOOST_AUTO_TEST_CASE(Multiplication) const double atol = 1.0e-14; - AdFW a(1.0), b(1.0); + AdFW a = AdFW::variable(0.0); + AdFW b = AdFW::variable(1.0); AdFW no_a = a * 0; BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol); @@ -144,7 +148,8 @@ BOOST_AUTO_TEST_CASE(Division) const double atol = 1.0e-14; - AdFW a(10.0), b(1.0); + AdFW a = AdFW::variable(10.0); + AdFW b = AdFW::variable(1.0); AdFW aob = a / b; BOOST_CHECK_CLOSE(aob.val(), a.val() * b.val(), atol); @@ -180,7 +185,7 @@ BOOST_AUTO_TEST_CASE(Polynomial) const double atol = 1.0e-14; - const AdFW x(1.234e-1); + const AdFW x = AdFW::variable(1.234e-1); const AdFW p0 = x * x; BOOST_CHECK_CLOSE(p0.val(), x.val() * x.val(), atol); @@ -198,7 +203,7 @@ BOOST_AUTO_TEST_CASE(Cosine) const double atol = 1.0e-14; - const AdFW x(3.14159265358979323846264338327950288); + const AdFW x = AdFW::variable(3.14159265358979323846264338327950288); const AdFW f = std::cos(x); BOOST_CHECK_CLOSE(f.val(), std::cos(x.val()), atol); @@ -217,7 +222,7 @@ BOOST_AUTO_TEST_CASE(SquareRoot) const double atol = 1.0e-14; - const AdFW x(1.234e-5); + const AdFW x = AdFW::variable(1.234e-5); const AdFW x2 = x * x; const AdFW g = std::cos(x2) + x;