Use named constructor pattern.

Made AutoDiff::Forward constructor private.
Added static methods constant(), variable(x) and function(x, dx).
This commit is contained in:
Atgeirr Flø Rasmussen 2013-04-30 11:05:59 +02:00
parent 0ff3b721ee
commit 42cdfa27a0
4 changed files with 48 additions and 36 deletions

View File

@ -41,14 +41,22 @@
namespace AutoDiff {
template <typename Scalar>
class Forward {
private:
public:
explicit Forward(const Scalar& x)
: val_(x), der_(Scalar(1))
{}
static Forward constant(const Scalar x)
{
return Forward(x, Scalar(0));
}
Forward(const Scalar x, const Scalar dx)
: val_(x), der_(dx)
{}
static Forward variable(const Scalar x)
{
return Forward(x, Scalar(1));
}
static Forward function(const Scalar x, const Scalar dx)
{
return Forward(x, dx);
}
Forward&
operator +=(const Scalar& rhs)
@ -133,11 +141,14 @@ namespace AutoDiff {
const Scalar der() const { return der_; }
private:
Forward(const Scalar x, const Scalar dx)
: val_(x), der_(dx)
{}
Scalar val_ ;
Scalar der_;
};
template <class Ostream, typename Scalar>
Ostream&
operator<<(Ostream& os, const Forward<Scalar>& fw)
@ -194,9 +205,7 @@ namespace AutoDiff {
operator -(const T lhs,
const Forward<Scalar>& rhs)
{
Forward<Scalar> ret(Scalar(lhs) - rhs.val(), -rhs.der());
return ret;
return Forward<Scalar>::function(Scalar(lhs) - rhs.val(), -rhs.der());
}
template <typename Scalar, typename T>
@ -262,9 +271,7 @@ namespace AutoDiff {
Scalar a = Scalar(lhs) / rhs.val();
Scalar b = -Scalar(lhs) / (rhs.val() * rhs.val());
Forward<Scalar> ret(a, b);
return ret;
return Forward<Scalar>::function(a, b);
}
template <typename Scalar, typename T>
@ -275,19 +282,17 @@ namespace AutoDiff {
Scalar a = lhs.val() / Scalar(rhs);
Scalar b = lhs.der() / Scalar(rhs);
Forward<Scalar> ret(a, b);
return ret;
return Forward<Scalar>::function(a, b);
}
template <typename Scalar>
Forward<Scalar>
cos(const Forward<Scalar>& x)
{
Forward<Scalar> ret( std::cos(x.val()),
-std::sin(x.val()) * x.der());
Scalar a = std::cos(x.val());
Scalar b = -std::sin(x.val()) * x.der();
return ret;
return Forward<Scalar>::function(a, b);
}
template <typename Scalar>
@ -295,10 +300,9 @@ namespace AutoDiff {
sqrt(const Forward<Scalar>& x)
{
Scalar a = std::sqrt(x.val());
Scalar b = Scalar(1.0) / (Scalar(2.0) * a);
Forward<Scalar> ret(a, b * x.der());
Scalar b = (Scalar(1.0) / (Scalar(2.0) * a)) * x.der();
return ret;
return Forward<Scalar>::function(a, b);
}
} // namespace AutoDiff

View File

@ -78,9 +78,10 @@ public:
{
double x = initial_guess;
iterations_used = 0;
typedef AutoDiff::Forward<double> AD;
while (std::abs(f(x)) > tolerance && ++iterations_used < max_iter) {
AutoDiff::Forward<double> xfad(x);
AutoDiff::Forward<double> rfad = f(xfad);
AD xfad = AD::variable(x);
AD rfad = f(xfad);
x = x - rfad.val()/rfad.der();
}
return x;

View File

@ -40,8 +40,10 @@
int
main()
{
AutoDiff::Forward<double> a(0.0);
AutoDiff::Forward<double> b(1.0);
typedef AutoDiff::Forward<double> AdFW;
AdFW a = AdFW::variable(0.0);
AdFW b = AdFW::variable(1.0);
std::cout << "a: " << a << '\n';
std::cout << "b: " << b << '\n';
@ -52,7 +54,7 @@ main()
std::cout << "b: " << b << '\n';
std::cout << "a + b: " << a + b << '\n';
a = AutoDiff::Forward<double>(0.0);
a = AdFW::variable(0.0);
std::cout << "a: " << a << '\n';
a += 1;
@ -60,7 +62,7 @@ main()
std::cout << "a + 1: " << (a + 1) << '\n';
std::cout << "1 + a: " << (1.0f + a) << '\n';
a = AutoDiff::Forward<double>(1);
a = AdFW::variable(1);
std::cout << "a: " << a << '\n';
std::cout << "a - 1: " << (a - 1) << '\n';
std::cout << "a - b: " << (a - b) << '\n';

View File

@ -15,7 +15,8 @@ BOOST_AUTO_TEST_CASE(Initialisation)
const double atol = 1.0e-14;
AdFW a(0.0), b(1.0);
AdFW a = AdFW::variable(0.0);
AdFW b = AdFW::variable(1.0);
BOOST_CHECK_CLOSE(a.val(), 0.0, atol);
BOOST_CHECK_CLOSE(b.val(), 1.0, atol);
@ -34,7 +35,8 @@ BOOST_AUTO_TEST_CASE(Addition)
const double atol = 1.0e-14;
AdFW a(0.0), b(1.0);
AdFW a = AdFW::variable(0.0);
AdFW b = AdFW::variable(1.0);
AdFW two_a = a + a;
BOOST_CHECK_CLOSE(two_a.val(), 2*a.val(), atol);
@ -68,7 +70,8 @@ BOOST_AUTO_TEST_CASE(Subtraction)
const double atol = 1.0e-14;
AdFW a(0.0), b(1.0);
AdFW a = AdFW::variable(0.0);
AdFW b = AdFW::variable(1.0);
AdFW no_a = a - a;
BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol);
@ -106,7 +109,8 @@ BOOST_AUTO_TEST_CASE(Multiplication)
const double atol = 1.0e-14;
AdFW a(1.0), b(1.0);
AdFW a = AdFW::variable(0.0);
AdFW b = AdFW::variable(1.0);
AdFW no_a = a * 0;
BOOST_CHECK_CLOSE(no_a.val(), 0.0, atol);
@ -144,7 +148,8 @@ BOOST_AUTO_TEST_CASE(Division)
const double atol = 1.0e-14;
AdFW a(10.0), b(1.0);
AdFW a = AdFW::variable(10.0);
AdFW b = AdFW::variable(1.0);
AdFW aob = a / b;
BOOST_CHECK_CLOSE(aob.val(), a.val() * b.val(), atol);
@ -180,7 +185,7 @@ BOOST_AUTO_TEST_CASE(Polynomial)
const double atol = 1.0e-14;
const AdFW x(1.234e-1);
const AdFW x = AdFW::variable(1.234e-1);
const AdFW p0 = x * x;
BOOST_CHECK_CLOSE(p0.val(), x.val() * x.val(), atol);
@ -198,7 +203,7 @@ BOOST_AUTO_TEST_CASE(Cosine)
const double atol = 1.0e-14;
const AdFW x(3.14159265358979323846264338327950288);
const AdFW x = AdFW::variable(3.14159265358979323846264338327950288);
const AdFW f = std::cos(x);
BOOST_CHECK_CLOSE(f.val(), std::cos(x.val()), atol);
@ -217,7 +222,7 @@ BOOST_AUTO_TEST_CASE(SquareRoot)
const double atol = 1.0e-14;
const AdFW x(1.234e-5);
const AdFW x = AdFW::variable(1.234e-5);
const AdFW x2 = x * x;
const AdFW g = std::cos(x2) + x;