* Add test for negative axes, preliminary solution to solve uncorrect results * Normalize axes in operation NormalizeL2 * Add test for negative axes * Add EOF
This commit is contained in:
@@ -68,7 +68,10 @@ void op::v0::NormalizeL2::validate_and_infer_types() {
|
||||
AxisSet op::v0::NormalizeL2::get_reduction_axes() const {
|
||||
AxisSet axes;
|
||||
if (auto const_op = get_constant_from_source(input_value(1))) {
|
||||
axes = const_op->get_axis_set_val();
|
||||
const auto const_data = const_op->cast_vector<int64_t>();
|
||||
const auto input_data_rank = get_input_partial_shape(0).rank();
|
||||
const auto normalized_axes = ov::normalize_axes(get_friendly_name(), const_data, input_data_rank);
|
||||
axes = AxisSet{normalized_axes};
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
@@ -80,9 +80,22 @@ TEST(type_prop, normalize_l2_axes_out_of_bounds) {
|
||||
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid input tensor rank.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis ("));
|
||||
} catch (const ov::AssertFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("(axis_range_min <= axis) && (axis <= axis_range_max)"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, normalize_l2_negative_axes) {
|
||||
PartialShape data_shape{1, 2, 3, 4};
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
const auto axes = make_shared<op::Constant>(element::i32, Shape{1}, vector<int64_t>{-1});
|
||||
float eps{1e-6f};
|
||||
auto eps_mode = op::EpsMode::ADD;
|
||||
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
|
||||
|
||||
EXPECT_EQ(normalize->get_element_type(), element::f32);
|
||||
EXPECT_EQ(normalize->get_reduction_axes(), ov::AxisSet{3});
|
||||
EXPECT_EQ(normalize->get_output_partial_shape(0), data_shape);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user