[NormalizeL2] normalization of reduction axes (#15841) (#15879)

* Add test for negative axes, preliminary solution to solve uncorrect
results

* Normalize axes in operation NormalizeL2

* Add test for negative axes

* Add EOF
This commit is contained in:
Artur Kulikowski
2023-02-23 09:05:03 +01:00
committed by GitHub
parent 0a5ca53752
commit 174be7c8bd
2 changed files with 19 additions and 3 deletions

View File

@@ -68,7 +68,10 @@ void op::v0::NormalizeL2::validate_and_infer_types() {
AxisSet op::v0::NormalizeL2::get_reduction_axes() const {
AxisSet axes;
if (auto const_op = get_constant_from_source(input_value(1))) {
axes = const_op->get_axis_set_val();
const auto const_data = const_op->cast_vector<int64_t>();
const auto input_data_rank = get_input_partial_shape(0).rank();
const auto normalized_axes = ov::normalize_axes(get_friendly_name(), const_data, input_data_rank);
axes = AxisSet{normalized_axes};
}
return axes;
}

View File

@@ -80,9 +80,22 @@ TEST(type_prop, normalize_l2_axes_out_of_bounds) {
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis ("));
} catch (const ov::AssertFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("(axis_range_min <= axis) && (axis <= axis_range_max)"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, normalize_l2_negative_axes) {
PartialShape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto axes = make_shared<op::Constant>(element::i32, Shape{1}, vector<int64_t>{-1});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
EXPECT_EQ(normalize->get_element_type(), element::f32);
EXPECT_EQ(normalize->get_reduction_axes(), ov::AxisSet{3});
EXPECT_EQ(normalize->get_output_partial_shape(0), data_shape);
}