Add power operator to AutoDiffBlock

This commit is contained in:
Tor Harald Sanve 2015-12-10 10:30:28 +01:00 committed by Tor Harald Sandve
parent cb3ed9aa4f
commit 650fef5bc2
2 changed files with 66 additions and 0 deletions

View File

@ -622,6 +622,30 @@ namespace Opm
return rhs * lhs; // Commutative operation.
}
/**
* @brief Computes the value of base raised to the power of exp
*
* @param base The AD forward block
* @param exp exponent
* @return The value of base raised to the power of exp
*/
template <typename Scalar>
AutoDiffBlock<Scalar> pow(const AutoDiffBlock<Scalar>& base,
const double exp)
{
const typename AutoDiffBlock<Scalar>::V val = base.value().pow(exp);
const typename AutoDiffBlock<Scalar>::V derivative = exp * base.value().pow(exp-1.0);
const typename AutoDiffBlock<Scalar>::M derivative_diag(derivative.matrix().asDiagonal());
std::vector< typename AutoDiffBlock<Scalar>::M > jac (base.numBlocks());
for (int block = 0; block < base.numBlocks(); block++) {
fastSparseProduct(derivative_diag, base.derivative()[block], jac[block]);
}
return AutoDiffBlock<Scalar>::function( std::move(val), std::move(jac) );
}
} // namespace Opm

View File

@ -283,3 +283,45 @@ BOOST_AUTO_TEST_CASE(AssignAddSubtractOperators)
z += x;
checkClose(z, yconst, tolerance);
}
BOOST_AUTO_TEST_CASE(Pow)
{
typedef AutoDiffBlock<double> ADB;
// Basic testing of derivatives
ADB::V vx(3);
vx << 0.2, 1.2, 13.4;
ADB::V vy(3);
vy << 1.0, 2.2, 3.4;
std::vector<ADB::V> vals{ vx, vy };
std::vector<ADB> vars = ADB::variables(vals);
const ADB x = vars[0];
const ADB y = vars[1];
const double tolerance = 1e-14;
const ADB xx = x * x;
ADB xxpow2 = Opm::pow(x,2.0);
checkClose(xxpow2, xx, tolerance);
const ADB xy = x * y;
const ADB xxyy = xy * xy;
ADB xypow2 = Opm::pow(xy,2.0);
checkClose(xypow2, xxyy, tolerance);
const ADB xxx = x * x * x;
ADB xpow3 = Opm::pow(x,3.0);
checkClose(xpow3, xxx, tolerance);
ADB xpowhalf = Opm::pow(x,0.5);
ADB::V x_sqrt(3);
x_sqrt << 0.447214 , 1.095445 , 3.6606;
for (int i = 0 ; i < 3; ++i){
BOOST_CHECK_CLOSE(xpowhalf.value()[i], x_sqrt[i], 1e-4);
}
}