Tabulated functions: fix the bisection code

seems like somebody can't properly implement an interval-halving
algorithm. Maybe I should consider to give back my degree...
This commit is contained in:
Andreas Lauser 2015-01-15 18:12:03 +01:00
parent 0b675bfa38
commit f1cb777fb0
2 changed files with 44 additions and 28 deletions

View File

@ -438,24 +438,29 @@ public:
private:
int findSegmentIndex_(Scalar x) const
{
int n = xValues_.size() - 1;
assert(n >= 1); // we need at least two sampling points!
if (xValues_[n] < x)
return n - 1;
else if (xValues_[0] > x)
// we need at least two sampling points!
assert(xValues_.size() >= 2);
if (x <= xValues_[1])
return 0;
else if (x >= xValues_[xValues_.size() - 2])
return xValues_.size() - 2;
else {
// bisection
int segmentIdx = 1;
int upperIdx = xValues_.size() - 2;
while (segmentIdx + 1 < upperIdx) {
int pivotIdx = (segmentIdx + upperIdx) / 2;
if (x < xValues_[pivotIdx])
upperIdx = pivotIdx;
else
segmentIdx = pivotIdx;
}
// bisection
int lowIdx = 0, highIdx = n;
while (lowIdx + 1 < highIdx) {
int curIdx = (lowIdx + highIdx)/2;
if (xValues_[curIdx] < x)
lowIdx = curIdx;
else
highIdx = curIdx;
assert(xValues_[segmentIdx] <= x);
assert(x <= xValues_[segmentIdx + 1]);
return segmentIdx;
}
return lowIdx;
}
Scalar evalDerivative_(Scalar x, int segIdx) const

View File

@ -140,22 +140,33 @@ public:
{
assert(extrapolate || (xMin() <= x && x <= xMax()));
// interval halving
int lowerIdx = 0;
int upperIdx = xPos_.size() - 2;
int pivotIdx = (lowerIdx + upperIdx) / 2;
while (lowerIdx + 1 < upperIdx) {
if (x < xPos_[pivotIdx])
upperIdx = pivotIdx;
else
lowerIdx = pivotIdx;
// we need at least two sampling points!
assert(xPos_.size() >= 2);
pivotIdx = (lowerIdx + upperIdx) / 2;
int segmentIdx;
if (x <= xPos_[1])
segmentIdx = 0;
else if (x >= xPos_[xPos_.size() - 2])
segmentIdx = xPos_.size() - 2;
else {
// bisection
segmentIdx = 1;
int upperIdx = xPos_.size() - 2;
while (segmentIdx + 1 < upperIdx) {
int pivotIdx = (segmentIdx + upperIdx) / 2;
if (x < xPos_[pivotIdx])
upperIdx = pivotIdx;
else
segmentIdx = pivotIdx;
}
assert(xPos_[segmentIdx] <= x);
assert(x <= xPos_[segmentIdx + 1]);
}
Scalar x1 = xPos_[lowerIdx];
Scalar x2 = xPos_[lowerIdx + 1];
return lowerIdx + (x - x1)/(x2 - x1);
Scalar x1 = xPos_[segmentIdx];
Scalar x2 = xPos_[segmentIdx + 1];
return segmentIdx + (x - x1)/(x2 - x1);
}
/*!