Completed simple block AD class.

This commit is contained in:
Atgeirr Flø Rasmussen
2013-04-30 20:28:32 +02:00
parent d518b5c787
commit d046b9e97a
2 changed files with 57 additions and 11 deletions

View File

@@ -25,6 +25,7 @@
#include <Eigen/Sparse>
#include <vector>
#include <cassert>
#include <iostream>
namespace AutoDiff
{
@@ -75,38 +76,74 @@ namespace AutoDiff
return ForwardBlock(index, val, jac);
}
#if 0
/// Operators.
ForwardBlock operator+(const ForwardBlock& rhs)
{
return function(val_ + rhs.val_, jac_ + rhs.jac_);
assert(index() == rhs.index());
std::vector<M> jac = jac_;
assert(numBlocks() == rhs.numBlocks());
int num_blocks = numBlocks();
for (int block = 0; block < num_blocks; ++block) {
assert(jac[block].rows() == rhs.jac_[block].rows());
assert(jac[block].cols() == rhs.jac_[block].cols());
jac[block] += rhs.jac_[block];
}
return function(index(), val_ + rhs.val_, jac);
}
/// Operators.
ForwardBlock operator-(const ForwardBlock& rhs)
{
return function(val_ - rhs.val_, jac_ - rhs.jac_);
assert(index() == rhs.index());
std::vector<M> jac = jac_;
assert(numBlocks() == rhs.numBlocks());
int num_blocks = numBlocks();
for (int block = 0; block < num_blocks; ++block) {
assert(jac[block].rows() == rhs.jac_[block].rows());
assert(jac[block].cols() == rhs.jac_[block].cols());
jac[block] -= rhs.jac_[block];
}
return function(index(), val_ - rhs.val_, jac);
}
/// Operators.
ForwardBlock operator*(const ForwardBlock& rhs)
{
assert(index() == rhs.index());
int num_blocks = numBlocks();
std::vector<M> jac(num_blocks);
assert(numBlocks() == rhs.numBlocks());
typedef Eigen::DiagonalMatrix<Scalar, Eigen::Dynamic> D;
D D1 = val_.matrix().asDiagonal();
D D2 = rhs.val_.matrix().asDiagonal();
return function(val_ * rhs.val_, D2*jac_ + D1*rhs.jac_);
for (int block = 0; block < num_blocks; ++block) {
std::cout << jac[block].rows() << ' ' << rhs.jac_[block].rows() << std::endl;
assert(jac_[block].rows() == rhs.jac_[block].rows());
assert(jac_[block].cols() == rhs.jac_[block].cols());
jac[block] = D2*rhs.jac_[block] + D1*rhs.jac_[block];
}
return function(index(), val_ * rhs.val_, jac);
}
/// Operators.
ForwardBlock operator/(const ForwardBlock& rhs)
{
assert(index() == rhs.index());
int num_blocks = numBlocks();
std::vector<M> jac(num_blocks);
assert(numBlocks() == rhs.numBlocks());
typedef Eigen::DiagonalMatrix<Scalar, Eigen::Dynamic> D;
D D1 = val_.matrix().asDiagonal();
D D2 = rhs.val_.matrix().asDiagonal();
D D3 = std::pow(rhs.val_, -2).matrix().asDiagonal();
return function(val_ / rhs.val_, D3 * (D2*jac_ - D1*rhs.jac_));
for (int block = 0; block < num_blocks; ++block) {
std::cout << jac[block].rows() << ' ' << rhs.jac_[block].rows() << std::endl;
assert(jac_[block].rows() == rhs.jac_[block].rows());
assert(jac_[block].cols() == rhs.jac_[block].cols());
jac[block] = D3 * (D2*jac_[block] - D1*rhs.jac_[block]);
}
return function(index(), val_ / rhs.val_, jac);
}
#endif
/// I/O.
template <class Ostream>
Ostream&
@@ -120,6 +157,18 @@ namespace AutoDiff
return os;
}
/// Index of this variable.
int index() const
{
return index_;
}
/// Number of variables or Jacobian blocks.
int numBlocks() const
{
return jac_.size();
}
private:
ForwardBlock(const int index,
const V& val,

View File

@@ -38,13 +38,13 @@ int main()
jacs[i].insert(0,0) = -1.0;
}
ADV f = ADV::function(FirstVar, v2, jacs);
std::cout << a << x << f;
/*
ADV xpx = x + x;
std::cout << xpx;
ADV xpxpa = x + x + a;
std::cout << xpxpa;
std::cout << xpxpa - xpx;
ADV sqx = x * x;
@@ -54,7 +54,4 @@ int main()
ADV sqxdx = sqx / x;
std::cout << sqxdx;
// std::cout << a << "\n\n" << x << "\n\n" << f << std::endl;
*/
}