Interpolate v11 -> v4 downgrade transformation (#16448)
This commit is contained in:
parent
c23a1170ba
commit
8eb142ca6e
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts Interpolate version 11 to Interpolate version 4 if the new op uses any of the v4 allowed
|
||||
* interpolation modes.
|
||||
*/
|
||||
class TRANSFORMATIONS_API ConvertInterpolate11ToInterpolate4 : public MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertInterpolate11ToInterpolate4", "0");
|
||||
ConvertInterpolate11ToInterpolate4();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
@ -76,6 +76,7 @@
|
||||
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_gelu.hpp"
|
||||
#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
|
||||
#include "transformations/op_conversions/convert_maxpool_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_maxpool_upgrade.hpp"
|
||||
@ -211,6 +212,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
||||
REGISTER_PASS(manager, ConvertMulticlassNms8ToMulticlassNms9)
|
||||
REGISTER_PASS(manager, ConvertXorToLogicalXor)
|
||||
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
|
||||
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
|
||||
|
||||
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||
|
@ -0,0 +1,75 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset11.hpp>
|
||||
#include <openvino/opsets/opset4.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::pass::ConvertInterpolate11ToInterpolate4::ConvertInterpolate11ToInterpolate4() {
|
||||
MATCHER_SCOPE(ConvertInterpolate11ToInterpolate4);
|
||||
|
||||
const auto interpolate_v11_pattern = pattern::wrap_type<opset11::Interpolate>();
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto v4_compatible_interpolation_mode = [](const op::util::InterpolateBase::InterpolateMode mode) {
|
||||
constexpr std::array<op::util::InterpolateBase::InterpolateMode, 4> allowed_modes = {
|
||||
op::util::InterpolateBase::InterpolateMode::NEAREST,
|
||||
op::util::InterpolateBase::InterpolateMode::LINEAR,
|
||||
op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX,
|
||||
op::util::InterpolateBase::InterpolateMode::CUBIC};
|
||||
|
||||
return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes);
|
||||
};
|
||||
|
||||
const auto interpolate_v11 = std::dynamic_pointer_cast<opset11::Interpolate>(m.get_match_root());
|
||||
if (!interpolate_v11 || !v4_compatible_interpolation_mode(interpolate_v11->get_attrs().mode) ||
|
||||
transformation_callback(interpolate_v11)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// downgrade only if the interpolation mode used to create v11 is supported by v4
|
||||
std::shared_ptr<ov::opset4::Interpolate> interpolate_v4;
|
||||
ov::Output<ov::Node> v4_input_output_shape;
|
||||
ov::Output<ov::Node> v4_input_scales;
|
||||
|
||||
if (interpolate_v11->get_attrs().shape_calculation_mode ==
|
||||
ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) {
|
||||
v4_input_scales = interpolate_v11->input_value(1);
|
||||
v4_input_output_shape = opset4::Constant::create(element::i32, Shape{1}, {1});
|
||||
copy_runtime_info(interpolate_v11, v4_input_output_shape.get_node_shared_ptr());
|
||||
} else {
|
||||
v4_input_output_shape = interpolate_v11->input_value(1);
|
||||
v4_input_scales = opset4::Constant::create(element::f32, Shape{1}, {1.0f});
|
||||
copy_runtime_info(interpolate_v11, v4_input_scales.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (interpolate_v11->get_input_size() == 3) { // with axes input
|
||||
interpolate_v4 = std::make_shared<ov::opset4::Interpolate>(interpolate_v11->input_value(0),
|
||||
v4_input_output_shape,
|
||||
v4_input_scales,
|
||||
interpolate_v11->input_value(2),
|
||||
interpolate_v11->get_attrs());
|
||||
} else {
|
||||
interpolate_v4 = std::make_shared<ov::opset4::Interpolate>(interpolate_v11->input_value(0),
|
||||
v4_input_output_shape,
|
||||
v4_input_scales,
|
||||
interpolate_v11->get_attrs());
|
||||
}
|
||||
|
||||
interpolate_v4->set_friendly_name(interpolate_v11->get_friendly_name());
|
||||
copy_runtime_info(interpolate_v11, interpolate_v4);
|
||||
replace_node(interpolate_v11, interpolate_v4);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(interpolate_v11_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset11.hpp>
|
||||
#include <openvino/opsets/opset4.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate11_downgrade.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
constexpr bool WITH_AXES = true;
|
||||
constexpr bool WITHOUT_AXES = false;
|
||||
|
||||
std::shared_ptr<ov::Model> create_v11_model(const bool with_axes,
|
||||
const ov::opset11::Interpolate::ShapeCalcMode shape_calc_mode) {
|
||||
auto attributes = ov::opset11::Interpolate::InterpolateAttrs{};
|
||||
attributes.shape_calculation_mode = shape_calc_mode;
|
||||
attributes.pads_begin = {0, 0};
|
||||
attributes.pads_end = {0, 0};
|
||||
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
|
||||
std::shared_ptr<ov::opset11::Parameter> scales_or_sizes;
|
||||
std::shared_ptr<ov::opset11::Interpolate> interpolate;
|
||||
|
||||
const size_t num_scales_or_sizes = with_axes ? 2 : 4;
|
||||
if (shape_calc_mode == ov::opset11::Interpolate::ShapeCalcMode::SCALES) {
|
||||
scales_or_sizes = std::make_shared<ov::opset11::Parameter>(ov::element::f32, ov::Shape{num_scales_or_sizes});
|
||||
} else {
|
||||
scales_or_sizes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{num_scales_or_sizes});
|
||||
}
|
||||
|
||||
ov::ParameterVector model_params;
|
||||
model_params.push_back(input);
|
||||
model_params.push_back(scales_or_sizes);
|
||||
if (with_axes) {
|
||||
const auto axes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2});
|
||||
model_params.push_back(axes);
|
||||
interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales_or_sizes, axes, attributes);
|
||||
} else {
|
||||
interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales_or_sizes, attributes);
|
||||
}
|
||||
interpolate->set_friendly_name("interpolate11");
|
||||
|
||||
return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> create_v4_model(const bool with_axes,
|
||||
const ov::opset4::Interpolate::ShapeCalcMode shape_calc_mode) {
|
||||
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
|
||||
attributes.shape_calculation_mode = shape_calc_mode;
|
||||
attributes.pads_begin = {0, 0};
|
||||
attributes.pads_end = {0, 0};
|
||||
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
|
||||
std::shared_ptr<ov::Node> output_shape;
|
||||
std::shared_ptr<ov::Node> scales;
|
||||
std::shared_ptr<ov::opset4::Interpolate> interpolate;
|
||||
|
||||
ov::ParameterVector model_params;
|
||||
model_params.push_back(input);
|
||||
|
||||
const size_t num_scales_or_sizes = with_axes ? 2 : 4;
|
||||
if (shape_calc_mode == ov::opset4::Interpolate::ShapeCalcMode::SCALES) {
|
||||
scales = std::make_shared<ov::opset4::Parameter>(ov::element::f32, ov::Shape{num_scales_or_sizes});
|
||||
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(scales));
|
||||
output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{1}, {1});
|
||||
|
||||
} else {
|
||||
output_shape = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{num_scales_or_sizes});
|
||||
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(output_shape));
|
||||
scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f});
|
||||
}
|
||||
|
||||
if (with_axes) {
|
||||
const auto axes = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{2});
|
||||
model_params.push_back(axes);
|
||||
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, axes, attributes);
|
||||
} else {
|
||||
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, attributes);
|
||||
}
|
||||
interpolate->set_friendly_name("interpolate11");
|
||||
|
||||
return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) {
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
|
||||
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) {
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
|
||||
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales_no_axes) {
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
|
||||
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes_no_axes) {
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
|
||||
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<ov::Model> create_non_downgradeable_model(const ov::opset11::Interpolate::InterpolateMode mode) {
|
||||
auto attributes = ov::opset11::Interpolate::InterpolateAttrs{};
|
||||
attributes.mode = mode;
|
||||
attributes.shape_calculation_mode = ov::opset11::Interpolate::ShapeCalcMode::SCALES;
|
||||
attributes.pads_begin = {0, 0};
|
||||
attributes.pads_end = {0, 0};
|
||||
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
|
||||
const auto scales = std::make_shared<ov::opset11::Parameter>(ov::element::f32, ov::Shape{2});
|
||||
const auto axes = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2});
|
||||
|
||||
const auto interpolate = std::make_shared<ov::opset11::Interpolate>(input, scales, axes, attributes);
|
||||
interpolate->set_friendly_name("interpolate11");
|
||||
|
||||
return std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, scales, axes});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bicubic_pillow) {
|
||||
function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BICUBIC_PILLOW);
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_bilinear_pillow) {
|
||||
function = create_non_downgradeable_model(ov::opset11::Interpolate::InterpolateMode::BILINEAR_PILLOW);
|
||||
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
|
||||
}
|
@ -186,6 +186,21 @@ void ov::op::v4::Interpolate::validate_and_infer_types() {
|
||||
input_shapes = {input_shape, target_spatial_shape, scales, axes};
|
||||
}
|
||||
|
||||
const auto interpolation_mode_check = [](const op::util::InterpolateBase::InterpolateMode mode) {
|
||||
constexpr std::array<op::util::InterpolateBase::InterpolateMode, 4> allowed_modes = {
|
||||
op::util::InterpolateBase::InterpolateMode::NEAREST,
|
||||
op::util::InterpolateBase::InterpolateMode::LINEAR,
|
||||
op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX,
|
||||
op::util::InterpolateBase::InterpolateMode::CUBIC};
|
||||
|
||||
return std::find(std::begin(allowed_modes), std::end(allowed_modes), mode) != std::end(allowed_modes);
|
||||
};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
interpolation_mode_check(m_attrs.mode),
|
||||
"Unsupported interpolation mode used with version 4 of the Interpolate op: ",
|
||||
as_string(m_attrs.mode));
|
||||
|
||||
util::correct_pads_attr(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes);
|
||||
shape_infer(this, m_attrs.pads_begin, m_attrs.pads_end, input_shapes, output_shapes, {});
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
|
@ -214,6 +214,28 @@ TEST(type_prop, interpolate_v4_interval_logic) {
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v4_incorrect_mode) {
|
||||
const auto image = std::make_shared<op::Parameter>(element::f32, Shape{1, 3, 30, 60});
|
||||
const auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{2});
|
||||
const auto scales = op::Constant::create<float>(element::f32, Shape{2}, {6.f, 12.f});
|
||||
const auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
ov::op::util::InterpolateBase::InterpolateAttrs attrs;
|
||||
attrs.shape_calculation_mode = ov::op::util::InterpolateBase::ShapeCalcMode::SCALES;
|
||||
attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BICUBIC_PILLOW;
|
||||
attrs.pads_begin = {0, 0, 0, 0};
|
||||
attrs.pads_end = {0, 0, 0, 0};
|
||||
|
||||
OV_EXPECT_THROW(auto interp = std::make_shared<ov::op::v4::Interpolate>(image, target_shape, scales, axes, attrs),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op"));
|
||||
|
||||
attrs.mode = ov::op::util::InterpolateBase::InterpolateMode::BILINEAR_PILLOW;
|
||||
OV_EXPECT_THROW(auto interp = std::make_shared<ov::op::v4::Interpolate>(image, target_shape, scales, axes, attrs),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Unsupported interpolation mode used with version 4 of the Interpolate op"));
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v11_scales) {
|
||||
const auto image = std::make_shared<op::Parameter>(element::f32, Shape{1, 3, 30, 60});
|
||||
const auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.2f, 0.2f});
|
||||
|
Loading…
Reference in New Issue
Block a user