Interpolate v11 -> v4 downgrade transformation (#16448)

This commit is contained in:
Tomasz Dołbniak 2023-03-22 17:00:53 +01:00 committed by GitHub
parent c23a1170ba
commit 8eb142ca6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 285 additions and 0 deletions

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ov {
namespace pass {
/**
* @ingroup ie_transformation_common_api
* @brief Converts Interpolate version 11 to Interpolate version 4 if the new op uses any of the v4 allowed
* interpolation modes.
*/
class TRANSFORMATIONS_API ConvertInterpolate11ToInterpolate4 : public MatcherPass {
public:
OPENVINO_RTTI("ConvertInterpolate11ToInterpolate4", "0");
ConvertInterpolate11ToInterpolate4();
};
} // namespace pass
} // namespace ov

View File

@ -76,6 +76,7 @@
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
#include "transformations/op_conversions/convert_gelu.hpp"
#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp"
#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
#include "transformations/op_conversions/convert_maxpool_downgrade.hpp"
#include "transformations/op_conversions/convert_maxpool_upgrade.hpp"
@ -211,6 +212,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertMulticlassNms8ToMulticlassNms9)
REGISTER_PASS(manager, ConvertXorToLogicalXor)
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)

View File

@ -0,0 +1,75 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp"
#include <array>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset4.hpp>
#include "itt.hpp"
ov::pass::ConvertInterpolate11ToInterpolate4::ConvertInterpolate11ToInterpolate4() {
MATCHER_SCOPE(ConvertInterpolate11ToInterpolate4);
const auto interpolate_v11_pattern = pattern::wrap_type<opset11::Interpolate>();
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto v4_compatible_interpolation_mode = [](const op::util::InterpolateBase::InterpolateMode mode) {
constexpr std::array<op::util::InterpolateBase::InterpolateMode, 4> allowed_modes = {
op::util::InterpolateBase::InterpolateMode::NEAREST,
op::util::InterpolateBase::InterpolateMode::LINEAR,
op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX,
op::util::InterpolateBase::InterpolateMode::CUBIC};
return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes);
};
const auto interpolate_v11 = std::dynamic_pointer_cast<opset11::Interpolate>(m.get_match_root());
if (!interpolate_v11 || !v4_compatible_interpolation_mode(interpolate_v11->get_attrs().mode) ||
transformation_callback(interpolate_v11)) {
return false;
}
// downgrade only if the interpolation mode used to create v11 is supported by v4
std::shared_ptr<ov::opset4::Interpolate> interpolate_v4;
ov::Output<ov::Node> v4_input_output_shape;
ov::Output<ov::Node> v4_input_scales;
if (interpolate_v11->get_attrs().shape_calculation_mode ==
ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) {
v4_input_scales = interpolate_v11->input_value(1);
v4_input_output_shape = opset4::Constant::create(element::i32, Shape{1}, {1});
copy_runtime_info(interpolate_v11, v4_input_output_shape.get_node_shared_ptr());
} else {
v4_input_output_shape = interpolate_v11->input_value(1);
v4_input_scales = opset4::Constant::create(element::f32, Shape{1}, {1.0f});
copy_runtime_info(interpolate_v11, v4_input_scales.get_node_shared_ptr());
}
if (interpolate_v11->get_input_size() == 3) { // with axes input
interpolate_v4 = std::make_shared<ov::opset4::Interpolate>(interpolate_v11->input_value(0),
v4_input_output_shape,
v4_input_scales,
interpolate_v11->input_value(2),
interpolate_v11->get_attrs());
} else {
interpolate_v4 = std::make_shared<ov::opset4::Interpolate>(interpolate_v11->input_value(0),
v4_input_output_shape,
v4_input_scales,
interpolate_v11->get_attrs());
}
interpolate_v4->set_friendly_name(interpolate_v11->get_friendly_name());
copy_runtime_info(interpolate_v11, interpolate_v4);
replace_node(interpolate_v11, interpolate_v4);
return true;
};
auto m = std::make_shared<pattern::Matcher>(interpolate_v11_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,147 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset4.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/op_conversions/convert_interpolate11_downgrade.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
namespace {
constexpr bool WITH_AXES = true;
constexpr bool WITHOUT_AXES = false;
std::shared_ptr<ov::Model> create_v11_model(const bool with_axes,
const ov::opset11::Interpolate::ShapeCalcMode shape_calc_mode) {
auto attributes = ov::opset11::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = shape_calc_mode;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
std::shared_ptr<ov::opset11::Parameter> scales_or_sizes;
std::shared_ptr<ov::opset11::Interpolate> interpolate;
const size_t num_scales_or_sizes = with_axes ? 2 : 4;
if (shape_calc_mode == ov::opset11::Interpolate::ShapeCalcMode::SCALES) {
scales_or_sizes = std::make_shared<ov::opset11::Parameter>(ov::element::f32, ov::Shape{num_scales_or_sizes});
} else {
scales_or_sizes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{num_scales_or_sizes});
}
ov::ParameterVector model_params;
model_params.push_back(input);
model_params.push_back(scales_or_sizes);
if (with_axes) {
const auto axes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2});
model_params.push_back(axes);
interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales_or_sizes, axes, attributes);
} else {
interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales_or_sizes, attributes);
}
interpolate->set_friendly_name("interpolate11");
return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
}
std::shared_ptr<ov::Model> create_v4_model(const bool with_axes,
const ov::opset4::Interpolate::ShapeCalcMode shape_calc_mode) {
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = shape_calc_mode;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
std::shared_ptr<ov::Node> output_shape;
std::shared_ptr<ov::Node> scales;
std::shared_ptr<ov::opset4::Interpolate> interpolate;
ov::ParameterVector model_params;
model_params.push_back(input);
const size_t num_scales_or_sizes = with_axes ? 2 : 4;
if (shape_calc_mode == ov::opset4::Interpolate::ShapeCalcMode::SCALES) {
scales = std::make_shared<ov::opset4::Parameter>(ov::element::f32, ov::Shape{num_scales_or_sizes});
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(scales));
output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{1}, {1});
} else {
output_shape = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{num_scales_or_sizes});
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(output_shape));
scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f});
}
if (with_axes) {
const auto axes = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{2});
model_params.push_back(axes);
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, axes, attributes);
} else {
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, attributes);
}
interpolate->set_friendly_name("interpolate11");
return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
}
} // namespace
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
}
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
}
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales_no_axes) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
}
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes_no_axes) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
}
namespace {
std::shared_ptr<ov::Model> create_non_downgradeable_model(const ov::opset11::Interpolate::InterpolateMode mode) {
auto attributes = ov::opset11::Interpolate::InterpolateAttrs{};
attributes.mode = mode;
attributes.shape_calculation_mode = ov::opset11::Interpolate::ShapeCalcMode::SCALES;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
const auto scales = std::make_shared<ov::opset11::Parameter>(ov::element::f32, ov::Shape{2});
const auto axes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2});
const auto interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales, axes, attributes);
interpolate->set_friendly_name("interpolate11");
return std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, scales, axes});
}
} // namespace
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bicubic_pillow) {
function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BICUBIC_PILLOW);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
}
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bilinear_pillow) {
function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BILINEAR_PILLOW);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
}

View File

@ -186,6 +186,21 @@ void ov::op::v4::Interpolate::validate_and_infer_types() {
input_shapes = {input_shape, target_spatial_shape, scales, axes};
}
const auto interpolation_mode_check = [](const op::util::InterpolateBase::InterpolateMode mode) {
constexpr std::array<op::util::InterpolateBase::InterpolateMode, 4> allowed_modes = {
op::util::InterpolateBase::InterpolateMode::NEAREST,
op::util::InterpolateBase::InterpolateMode::LINEAR,
op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX,
op::util::InterpolateBase::InterpolateMode::CUBIC};
return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes);
};
NODE_VALIDATION_CHECK(this,
interpolation_mode_check(m_attrs.mode),
"Unsupported interpolation mode used with version 4 of the Interpolate op: ",
as_string(m_attrs.mode));
util::correct_pads_attr(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes);
shape_infer(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes, output_shapes, {});
set_output_type(0, get_input_element_type(0), output_shapes[0]);

View File

@ -214,6 +214,28 @@ TEST(type_prop, interpolate_v4_interval_logic) {
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, interpolate_v4_incorrect_mode) {
const auto image = std::make_shared<op::Parameter>(element::f32, Shape{1, 3, 30, 60});
const auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{2});
const auto scales = op::Constant::create<float>(element::f32, Shape{2}, {6.f, 12.f});
const auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
ov::op::util::InterpolateBase::InterpolateAttrs attrs;
attrs.shape_calculation_mode = ov::op::util::InterpolateBase::ShapeCalcMode::SCALES;
attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BICUBIC_PILLOW;
attrs.pads_begin = {0, 0, 0, 0};
attrs.pads_end = {0, 0, 0, 0};
OV_EXPECT_THROW(auto interp = std::make_shared<ov::op::v4::Interpolate>(image, target_shape, scales, axes, attrs),
ov::NodeValidationFailure,
HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op"));
attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BILINEAR_PILLOW;
OV_EXPECT_THROW(auto interp = std::make_shared<ov::op::v4::Interpolate>(image, target_shape, scales, axes, attrs),
ov::NodeValidationFailure,
HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op"));
}
TEST(type_prop, interpolate_v11_scales) {
const auto image = std::make_shared<op::Parameter>(element::f32, Shape{1, 3, 30, 60});
const auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.2f, 0.2f});