Files
openvino/ngraph/test/util/all_close_f.cpp
2020-08-17 06:05:08 +03:00

622 lines
22 KiB
C++

//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <climits>
#include <cmath>
#include "ngraph/env_util.hpp"
#include "ngraph/util.hpp"
#include "util/all_close_f.hpp"
using namespace std;
using namespace ngraph;
union FloatUnion {
float f;
uint32_t i;
};
union DoubleUnion {
double d;
uint64_t i;
};
constexpr uint32_t FLOAT_BELOW_MIN_SIGNAL = UINT_MAX;
constexpr uint32_t FLOAT_MAX_DIFF = UINT_MAX - 1;
constexpr uint64_t DOUBLE_BELOW_MIN_SIGNAL = ULLONG_MAX;
constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;
uint32_t test::float_distance(float a, float b, float min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return FLOAT_MAX_DIFF;
}
FloatUnion a_fu{a};
FloatUnion b_fu{b};
FloatUnion min_signal_fu{min_signal};
uint32_t a_uint = a_fu.i;
uint32_t b_uint = b_fu.i;
// A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint32_t sign_mask = static_cast<uint32_t>(1U) << 31;
uint32_t abs_value_bits_mask = ~sign_mask;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint32_t distance;
uint32_t a_uint_abs = (abs_value_bits_mask & a_fu.i);
uint32_t b_uint_abs = (abs_value_bits_mask & b_fu.i);
uint32_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_fu.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
// Both a & b below minimum signal
distance = FLOAT_BELOW_MIN_SIGNAL;
}
else
{
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved UINT_MAX to mean FLOAT_BELOW_MIN_SIGNAL
if (distance == UINT_MAX)
{
distance = FLOAT_MAX_DIFF;
}
}
return distance;
}
uint64_t test::float_distance(double a, double b, double min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return DOUBLE_MAX_DIFF;
}
DoubleUnion a_du{a};
DoubleUnion b_du{b};
DoubleUnion min_signal_du{min_signal};
uint64_t a_uint = a_du.i;
uint64_t b_uint = b_du.i;
// A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
uint64_t abs_value_bits_mask = ~sign_mask;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint64_t distance;
uint64_t a_uint_abs = (abs_value_bits_mask & a_du.i);
uint64_t b_uint_abs = (abs_value_bits_mask & b_du.i);
uint64_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_du.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
// Both a & b below minimum signal
distance = DOUBLE_BELOW_MIN_SIGNAL;
}
else
{
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved ULLONG_MAX to mean DOUBLE_BELOW_MIN_SIGNAL
if (distance == ULLONG_MAX)
{
distance = DOUBLE_MAX_DIFF;
}
}
return distance;
}
bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false;
}
uint32_t distance = float_distance(a, b, min_signal);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
return (distance <= tolerance) || (distance == FLOAT_BELOW_MIN_SIGNAL);
}
bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false;
}
uint64_t distance = float_distance(a, b, min_signal);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
return (distance <= tolerance) || (distance == DOUBLE_BELOW_MIN_SIGNAL);
}
vector<uint32_t>
test::float_distances(const vector<float>& a, const vector<float>& b, float min_signal)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
}
vector<uint64_t>
test::float_distances(const vector<double>& a, const vector<double>& b, double min_signal)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint64_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
{
if (distance & (1 << check_bit))
{
tolerance_bit_shift = check_bit;
++num_bits_on;
}
}
// all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
++tolerance_bit_shift;
}
// clang-format off
// all_close_f calculation of tolerance_bit_shift:
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 32 - (1 + 8 + (tolerance_bit_shift - 1 ) )
// clang-format on
uint32_t matching_matissa_bits =
tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
uint32_t test::matching_mantissa_bits(uint64_t distance)
{
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
{
if (distance & (1ull << check_bit))
{
tolerance_bit_shift = check_bit;
++num_bits_on;
}
}
// all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
++tolerance_bit_shift;
}
// clang-format off
// all_close_f calculation of tolerance_bit_shift:
// e.g. for double with 53 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
// clang-format on
uint32_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
::testing::AssertionResult test::all_close_f(const vector<float>& a,
const vector<float>& b,
int tolerance_bits,
float min_signal)
{
if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS)
{
tolerance_bits = MIN_FLOAT_TOLERANCE_BITS;
}
if (tolerance_bits >= FLOAT_MANTISSA_BITS)
{
tolerance_bits = FLOAT_MANTISSA_BITS - 1;
}
bool rc = true;
stringstream msg;
if (a.size() != b.size())
{
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
if (a.size() == 0)
{
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint32_t> distances = float_distances(a, b, min_signal);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0;
uint32_t min_distance = FLOAT_BELOW_MIN_SIGNAL;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == FLOAT_BELOW_MIN_SIGNAL)
{
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << std::endl;
}
rc = false;
diff_count++;
}
}
if (!rc)
{
msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint32_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
// Find middle-1 value
uint64_t median_sum = static_cast<uint64_t>(median_distance) +
*max_element(distances.begin(), distances.begin() + middle);
median_distance = median_sum / 2;
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (getenv_bool("NGRAPH_GTEST_INFO")))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of <= " << (FLOAT_MANTISSA_BITS - tolerance_bits)
<< " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - "
<< tolerance_bits << " tolerance). ";
if (all_below_min_signal)
{
std::cout << "All values below min_signal: " << min_signal << std::endl;
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (FLOAT_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "All values below min_signal: " << min_signal << std::endl;
}
else
{
msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
<< " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
::testing::AssertionResult test::all_close_f(const vector<double>& a,
const vector<double>& b,
int tolerance_bits,
double min_signal)
{
if (tolerance_bits < 0)
{
tolerance_bits = 0;
}
if (tolerance_bits >= DOUBLE_MANTISSA_BITS)
{
tolerance_bits = DOUBLE_MANTISSA_BITS - 1;
}
bool rc = true;
stringstream msg;
if (a.size() != b.size())
{
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
if (a.size() == 0)
{
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint64_t> distances = float_distances(a, b, min_signal);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
uint64_t max_distance = 0;
uint64_t min_distance = DOUBLE_BELOW_MIN_SIGNAL;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == DOUBLE_BELOW_MIN_SIGNAL)
{
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
msg << a[i] << " is not close to " << b[i] << " at index " << i << std::endl;
}
rc = false;
diff_count++;
}
}
if (!rc)
{
msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint64_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
uint64_t remainder1 = median_distance % 2;
uint64_t remainder2 = median_distance2 % 2;
median_distance =
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (getenv_bool("NGRAPH_GTEST_INFO")))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " mantissa bits ("
<< DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
<< " tolerance). ";
if (all_below_min_signal)
{
std::cout << "All values below min_signal: " << min_signal << std::endl;
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "All values below min_signal: " << min_signal << std::endl;
}
else
{
msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
<< " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits,
float min_signal)
{
// Check that the layouts are compatible
if (a->get_shape() != b->get_shape())
{
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
}
return test::all_close_f(
read_float_vector(a), read_float_vector(b), tolerance_bits, min_signal);
}
::testing::AssertionResult
test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits,
float min_signal)
{
if (as.size() != bs.size())
{
return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i)
{
auto ar = test::all_close_f(as[i], bs[i], tolerance_bits, min_signal);
if (!ar)
{
return ar;
}
}
return ::testing::AssertionSuccess();
}