diff --git a/src/core/shape_inference/include/interpolate_shape_inference.hpp b/src/core/shape_inference/include/interpolate_shape_inference.hpp index aeb7e43c6f6..bbf76fa4e5c 100644 --- a/src/core/shape_inference/include/interpolate_shape_inference.hpp +++ b/src/core/shape_inference/include/interpolate_shape_inference.hpp @@ -73,7 +73,7 @@ inline void input_elements_num(const Node* const op, size_t element_count, size_t exp_count) { NODE_VALIDATION_CHECK(op, - element_count == exp_count, + element_count >= exp_count, "The number of elements in the '", input_name, "' input does not match the number of axes ", diff --git a/src/core/tests/type_prop/interpolate.cpp b/src/core/tests/type_prop/interpolate.cpp index 610838120e2..3029bac62b3 100644 --- a/src/core/tests/type_prop/interpolate.cpp +++ b/src/core/tests/type_prop/interpolate.cpp @@ -327,6 +327,38 @@ TEST(type_prop, interpolate_v4_use_scales_interval_shapes) { ElementsAre(10, 11, ov::no_label, ov::no_label, ov::no_label)); } +TEST(type_prop, interpolate_v4_target_shapes_gt_axes_number) { + const auto image = std::make_shared(element::f32, Shape{1, 3, 30, 60}); + const auto target_shape = op::Constant::create(element::i32, Shape{3}, {10, 12, 20}); + const auto scales = op::Constant::create(element::f32, Shape{1}, {0.3f}); + const auto axes = op::Constant::create(element::i64, Shape{2}, {0, 3}); + + ov::op::util::InterpolateBase::InterpolateAttrs attrs; + attrs.shape_calculation_mode = ov::op::util::InterpolateBase::ShapeCalcMode::SIZES; + attrs.pads_begin = {0, 0, 0, 0}; + attrs.pads_end = {0, 0, 0, 0}; + auto interp = std::make_shared(image, target_shape, scales, axes, attrs); + + EXPECT_EQ(interp->get_element_type(), element::f32); + EXPECT_EQ(interp->get_output_partial_shape(0), PartialShape({10, 3, 30, 12})); +} + +TEST(type_prop, interpolate_v4_scales_gt_axes_number) { + const auto image = std::make_shared(element::f32, Shape{1, 3, 30, 60}); + const auto target_shape = std::make_shared(element::i32, Shape{3}); + const auto scales = op::Constant::create(element::f32, Shape{3}, {0.2f, 0.2f, 0.3f}); + const auto axes = op::Constant::create(element::i64, Shape{2}, {2, 3}); + + ov::op::util::InterpolateBase::InterpolateAttrs attrs; + attrs.shape_calculation_mode = ov::op::util::InterpolateBase::ShapeCalcMode::SCALES; + attrs.pads_begin = {0, 0, 0, 0}; + attrs.pads_end = {0, 0, 0, 0}; + auto interp = std::make_shared(image, target_shape, scales, axes, attrs); + + EXPECT_EQ(interp->get_element_type(), element::f32); + EXPECT_EQ(interp->get_output_partial_shape(0), PartialShape({1, 3, 6, 12})); +} + TEST(type_prop, interpolate_v4_incorrect_mode) { const auto image = std::make_shared(element::f32, Shape{1, 3, 30, 60}); const auto target_shape = std::make_shared(element::i32, Shape{2});