From 8eb142ca6ead9ac08e15741d554e7bb061339e0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Do=C5=82bniak?= Date: Wed, 22 Mar 2023 17:00:53 +0100 Subject: [PATCH] Interpolate v11 -> v4 downgrade transformation (#16448) --- .../convert_interpolate11_downgrade.hpp | 24 +++ .../common_optimizations.cpp | 2 + .../convert_interpolate11_downgrade.cpp | 75 +++++++++ .../convert_interpolate11_downgrade_test.cpp | 147 ++++++++++++++++++ src/core/src/op/interpolate.cpp | 15 ++ src/core/tests/type_prop/interpolate.cpp | 22 +++ 6 files changed, 285 insertions(+) create mode 100644 src/common/transformations/include/transformations/op_conversions/convert_interpolate11_downgrade.hpp create mode 100644 src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp create mode 100644 src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp diff --git a/src/common/transformations/include/transformations/op_conversions/convert_interpolate11_downgrade.hpp b/src/common/transformations/include/transformations/op_conversions/convert_interpolate11_downgrade.hpp new file mode 100644 index 00000000000..b112c5d8abd --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_interpolate11_downgrade.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ov { +namespace pass { +/** + * @ingroup ie_transformation_common_api + * @brief Converts Interpolate version 11 to Interpolate version 4 if the new op uses any of the v4 allowed + * interpolation modes. + */ +class TRANSFORMATIONS_API ConvertInterpolate11ToInterpolate4 : public MatcherPass { +public: + OPENVINO_RTTI("ConvertInterpolate11ToInterpolate4", "0"); + ConvertInterpolate11ToInterpolate4(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 8b43dcfc8d2..6064effe880 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -76,6 +76,7 @@ #include "transformations/op_conversions/convert_gather_downgrade.hpp" #include "transformations/op_conversions/convert_gather_upgrade.hpp" #include "transformations/op_conversions/convert_gelu.hpp" +#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp" #include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp" #include "transformations/op_conversions/convert_maxpool_downgrade.hpp" #include "transformations/op_conversions/convert_maxpool_upgrade.hpp" @@ -211,6 +212,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp b/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp new file mode 100644 index 00000000000..c9b2e15dd4c --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp" + +#include +#include +#include +#include +#include + +#include "itt.hpp" + +ov::pass::ConvertInterpolate11ToInterpolate4::ConvertInterpolate11ToInterpolate4() { + MATCHER_SCOPE(ConvertInterpolate11ToInterpolate4); + + const auto interpolate_v11_pattern = pattern::wrap_type(); + + const matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto v4_compatible_interpolation_mode = [](const op::util::InterpolateBase::InterpolateMode mode) { + constexpr std::array allowed_modes = { + op::util::InterpolateBase::InterpolateMode::NEAREST, + op::util::InterpolateBase::InterpolateMode::LINEAR, + op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX, + op::util::InterpolateBase::InterpolateMode::CUBIC}; + + return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes); + }; + + const auto interpolate_v11 = std::dynamic_pointer_cast(m.get_match_root()); + if (!interpolate_v11 || !v4_compatible_interpolation_mode(interpolate_v11->get_attrs().mode) || + transformation_callback(interpolate_v11)) { + return false; + } + + // downgrade only if the interpolation mode used to create v11 is supported by v4 + std::shared_ptr interpolate_v4; + ov::Output v4_input_output_shape; + ov::Output v4_input_scales; + + if (interpolate_v11->get_attrs().shape_calculation_mode == + ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) { + v4_input_scales = interpolate_v11->input_value(1); + v4_input_output_shape = opset4::Constant::create(element::i32, Shape{1}, {1}); + copy_runtime_info(interpolate_v11, v4_input_output_shape.get_node_shared_ptr()); + } else { + v4_input_output_shape = interpolate_v11->input_value(1); + v4_input_scales = opset4::Constant::create(element::f32, Shape{1}, {1.0f}); + copy_runtime_info(interpolate_v11, v4_input_scales.get_node_shared_ptr()); + } + + if (interpolate_v11->get_input_size() == 3) { // with axes input + interpolate_v4 = std::make_shared(interpolate_v11->input_value(0), + v4_input_output_shape, + v4_input_scales, + interpolate_v11->input_value(2), + interpolate_v11->get_attrs()); + } else { + interpolate_v4 = std::make_shared(interpolate_v11->input_value(0), + v4_input_output_shape, + v4_input_scales, + interpolate_v11->get_attrs()); + } + + interpolate_v4->set_friendly_name(interpolate_v11->get_friendly_name()); + copy_runtime_info(interpolate_v11, interpolate_v4); + replace_node(interpolate_v11, interpolate_v4); + + return true; + }; + + auto m = std::make_shared(interpolate_v11_pattern, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp b/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp new file mode 100644 index 00000000000..7504cd378eb --- /dev/null +++ b/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp @@ -0,0 +1,147 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +namespace { +constexpr bool WITH_AXES = true; +constexpr bool WITHOUT_AXES = false; + +std::shared_ptr create_v11_model(const bool with_axes, + const ov::opset11::Interpolate::ShapeCalcMode shape_calc_mode) { + auto attributes = ov::opset11::Interpolate::InterpolateAttrs{}; + attributes.shape_calculation_mode = shape_calc_mode; + attributes.pads_begin = {0, 0}; + attributes.pads_end = {0, 0}; + + const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); + std::shared_ptr scales_or_sizes; + std::shared_ptr interpolate; + + const size_t num_scales_or_sizes = with_axes ? 2 : 4; + if (shape_calc_mode == ov::opset11::Interpolate::ShapeCalcMode::SCALES) { + scales_or_sizes = std::make_shared(ov::element::f32, ov::Shape{num_scales_or_sizes}); + } else { + scales_or_sizes = std::make_shared(ov::element::i32, ov::Shape{num_scales_or_sizes}); + } + + ov::ParameterVector model_params; + model_params.push_back(input); + model_params.push_back(scales_or_sizes); + if (with_axes) { + const auto axes = std::make_shared(ov::element::i32, ov::Shape{2}); + model_params.push_back(axes); + interpolate = std::make_shared(input, scales_or_sizes, axes, attributes); + } else { + interpolate = std::make_shared(input, scales_or_sizes, attributes); + } + interpolate->set_friendly_name("interpolate11"); + + return std::make_shared(interpolate->outputs(), model_params); +} + +std::shared_ptr create_v4_model(const bool with_axes, + const ov::opset4::Interpolate::ShapeCalcMode shape_calc_mode) { + auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; + attributes.shape_calculation_mode = shape_calc_mode; + attributes.pads_begin = {0, 0}; + attributes.pads_end = {0, 0}; + + const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); + std::shared_ptr output_shape; + std::shared_ptr scales; + std::shared_ptr interpolate; + + ov::ParameterVector model_params; + model_params.push_back(input); + + const size_t num_scales_or_sizes = with_axes ? 2 : 4; + if (shape_calc_mode == ov::opset4::Interpolate::ShapeCalcMode::SCALES) { + scales = std::make_shared(ov::element::f32, ov::Shape{num_scales_or_sizes}); + model_params.push_back(std::dynamic_pointer_cast(scales)); + output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{1}, {1}); + + } else { + output_shape = std::make_shared(ov::element::i32, ov::Shape{num_scales_or_sizes}); + model_params.push_back(std::dynamic_pointer_cast(output_shape)); + scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}); + } + + if (with_axes) { + const auto axes = std::make_shared(ov::element::i32, ov::Shape{2}); + model_params.push_back(axes); + interpolate = std::make_shared(input, output_shape, scales, axes, attributes); + } else { + interpolate = std::make_shared(input, output_shape, scales, attributes); + } + interpolate->set_friendly_name("interpolate11"); + + return std::make_shared(interpolate->outputs(), model_params); +} + +} // namespace + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) { + manager.register_pass(); + function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); + function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES); +} + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) { + manager.register_pass(); + function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); + function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES); +} + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales_no_axes) { + manager.register_pass(); + function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); + function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES); +} + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes_no_axes) { + manager.register_pass(); + function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); + function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES); +} + +namespace { +std::shared_ptr create_non_downgradeable_model(const ov::opset11::Interpolate::InterpolateMode mode) { + auto attributes = ov::opset11::Interpolate::InterpolateAttrs{}; + attributes.mode = mode; + attributes.shape_calculation_mode = ov::opset11::Interpolate::ShapeCalcMode::SCALES; + attributes.pads_begin = {0, 0}; + attributes.pads_end = {0, 0}; + + const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); + const auto scales = std::make_shared(ov::element::f32, ov::Shape{2}); + const auto axes = std::make_shared(ov::element::i32, ov::Shape{2}); + + const auto interpolate = std::make_shared(input, scales, axes, attributes); + interpolate->set_friendly_name("interpolate11"); + + return std::make_shared(interpolate->outputs(), ov::ParameterVector{input, scales, axes}); +} +} // namespace + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bicubic_pillow) { + function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BICUBIC_PILLOW); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bilinear_pillow) { + function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BILINEAR_PILLOW); + manager.register_pass(); +} diff --git a/src/core/src/op/interpolate.cpp b/src/core/src/op/interpolate.cpp index 6bfd961fc35..b34d39bc60e 100644 --- a/src/core/src/op/interpolate.cpp +++ b/src/core/src/op/interpolate.cpp @@ -186,6 +186,21 @@ void ov::op::v4::Interpolate::validate_and_infer_types() { input_shapes = {input_shape, target_spatial_shape, scales, axes}; } + const auto interpolation_mode_check = [](const op::util::InterpolateBase::InterpolateMode mode) { + constexpr std::array allowed_modes = { + op::util::InterpolateBase::InterpolateMode::NEAREST, + op::util::InterpolateBase::InterpolateMode::LINEAR, + op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX, + op::util::InterpolateBase::InterpolateMode::CUBIC}; + + return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes); + }; + + NODE_VALIDATION_CHECK(this, + interpolation_mode_check(m_attrs.mode), + "Unsupported interpolation mode used with version 4 of the Interpolate op: ", + as_string(m_attrs.mode)); + util::correct_pads_attr(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes); shape_infer(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes, output_shapes, {}); set_output_type(0, get_input_element_type(0), output_shapes[0]); diff --git a/src/core/tests/type_prop/interpolate.cpp b/src/core/tests/type_prop/interpolate.cpp index b220ecd8a8f..7f0f5ff3a5b 100644 --- a/src/core/tests/type_prop/interpolate.cpp +++ b/src/core/tests/type_prop/interpolate.cpp @@ -214,6 +214,28 @@ TEST(type_prop, interpolate_v4_interval_logic) { ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape)); } +TEST(type_prop, interpolate_v4_incorrect_mode) { + const auto image = std::make_shared(element::f32, Shape{1, 3, 30, 60}); + const auto target_shape = std::make_shared(element::i32, Shape{2}); + const auto scales = op::Constant::create(element::f32, Shape{2}, {6.f, 12.f}); + const auto axes = op::Constant::create(element::i64, Shape{2}, {2, 3}); + + ov::op::util::InterpolateBase::InterpolateAttrs attrs; + attrs.shape_calculation_mode = ov::op::util::InterpolateBase::ShapeCalcMode::SCALES; + attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BICUBIC_PILLOW; + attrs.pads_begin = {0, 0, 0, 0}; + attrs.pads_end = {0, 0, 0, 0}; + + OV_EXPECT_THROW(auto interp = std::make_shared(image, target_shape, scales, axes, attrs), + ov::NodeValidationFailure, + HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op")); + + attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BILINEAR_PILLOW; + OV_EXPECT_THROW(auto interp = std::make_shared(image, target_shape, scales, axes, attrs), + ov::NodeValidationFailure, + HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op")); +} + TEST(type_prop, interpolate_v11_scales) { const auto image = std::make_shared(element::f32, Shape{1, 3, 30, 60}); const auto scales = op::Constant::create(element::f32, Shape{2}, {0.2f, 0.2f});