[GroupNormalization-12] Reference implementation update to minimize fp16 error (#18760)
* Update mean calculation in group_normalization ref * Update tests * Update using namespaces
This commit is contained in:
parent
6be083d37e
commit
26d53eb1da
@ -98,18 +98,12 @@ TEST_F(TransformationTestsF, GroupNormalizationDecompositionF16) {
|
||||
const int64_t num_groups = 4;
|
||||
element::Type elem_type = element::f16;
|
||||
|
||||
model = gen_model(input_shapes, elem_type, num_groups, 1e-3f);
|
||||
model = gen_model(input_shapes, elem_type, num_groups, 1e-3);
|
||||
manager.register_pass<pass::GroupNormalizationDecomposition>();
|
||||
|
||||
model_ref = gen_model_ref(input_shapes, elem_type, num_groups, 1e-3f);
|
||||
|
||||
// Ticket number: 115063
|
||||
// abs_max < abs_threshold && rel_max < rel_threshold
|
||||
// abs_max: 0.03125
|
||||
// coordinate 220; abs errors count 384; abs mean 0.00505998; abs threshold 0.0005
|
||||
// rel_max: 0.0220588
|
||||
// coordinate 434; rel errors count 232; rel mean 0.0027481; rel threshold 0.001
|
||||
// comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GroupNormalizationDecomposition_num_groups) {
|
||||
|
@ -7,10 +7,10 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/runtime/reference/mean.hpp"
|
||||
#include "ngraph/runtime/reference/sum.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
|
||||
@ -22,6 +22,9 @@ void group_normalization(const T* const data,
|
||||
const Shape& data_shape,
|
||||
const size_t num_groups,
|
||||
const double epsilon) {
|
||||
using namespace std;
|
||||
using namespace ngraph::runtime::reference;
|
||||
|
||||
const auto num_batches = data_shape[0];
|
||||
const auto num_channels = data_shape[1];
|
||||
const auto num_channels_in_group = num_channels / num_groups;
|
||||
@ -35,15 +38,15 @@ void group_normalization(const T* const data,
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const auto group_begin = data + n * batch_size + g * group_size;
|
||||
const auto group_end = group_begin + group_size;
|
||||
const auto mean = accumulate(group_begin, group_end, static_cast<T>(0)) / group_size;
|
||||
const auto variance = accumulate(group_begin,
|
||||
group_end,
|
||||
static_cast<T>(0),
|
||||
[mean](const T acc, const T d) {
|
||||
return acc + pow(d - mean, 2);
|
||||
}) /
|
||||
group_size;
|
||||
const auto standard_deviation = sqrt(variance + eps);
|
||||
std::vector<T> mean_value(1);
|
||||
mean(group_begin, mean_value.data(), Shape{group_size}, {0});
|
||||
T mean = mean_value[0];
|
||||
T variance = 0, err = 0;
|
||||
for_each(group_begin, group_end, [&](const T d) {
|
||||
return details::kahan_summation(static_cast<T>(pow(d - mean, 2)), err, variance);
|
||||
});
|
||||
variance /= group_size;
|
||||
const T standard_deviation = sqrt(variance + eps);
|
||||
|
||||
for (size_t s = 0; s < num_channels_in_group; ++s) {
|
||||
const auto c = g * num_channels_in_group + s;
|
||||
|
Loading…
Reference in New Issue
Block a user