Static Shape constraints removed from Interpolate 1->4 transformation (#10732)

* Static Shape constraints removed from Interpolate 1->4 transformation

* Dynamic tests added
This commit is contained in:
Evgenya Stepyreva
2022-03-02 19:16:34 +03:00
committed by GitHub
parent bea352f272
commit 4b55ef9911
2 changed files with 39 additions and 17 deletions

View File

@@ -9,40 +9,37 @@
#include <vector>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include <openvino/core/core.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "ConvertInterpolate1ToInterpolate4", 0);
ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() {
MATCHER_SCOPE(ConvertInterpolate1ToInterpolate4);
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_shape()), pattern::any_input()});
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_rank()), pattern::any_input()});
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto interpolationV0 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
if (!interpolationV0) {
return false;
}
auto& inp_partial_shape = interpolationV0->get_input_partial_shape(0);
auto& out_shape = interpolationV0->get_output_shape(0);
auto attrsV0 = interpolationV0->get_attrs();
std::vector<size_t> axes{attrsV0.axes.begin(), attrsV0.axes.end()};
const auto& out_dims = std::make_shared<opset1::Convert>(interpolationV0->input_value(1), element::f32);
const auto& in_dims = std::make_shared<opset1::Convert>(ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(
interpolationV0->input_value(0), axes), element::f32);
std::vector<float> scales(attrsV0.axes.size(), 1.0f);
auto inp_shape = inp_partial_shape.to_shape();
size_t i = 0;
for (std::size_t axis : attrsV0.axes) {
scales[i] = static_cast<float>(out_shape.at(axis))/inp_shape.at(axis);
i++;
}
auto input_shape_rank = inp_partial_shape.rank().get_length();
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
std::shared_ptr<Node> scales = std::make_shared<opset1::Divide>(out_dims, in_dims);
if (const auto& constant = ov::get_constant_from_source(scales))
scales = constant;
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {axes.size()}, axes);
ngraph::opset4::Interpolate::InterpolateAttrs attrsV4;
auto input_shape_rank = interpolationV0->get_input_partial_shape(0).rank().get_length();
if (attrsV0.mode == "nearest") {
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::NEAREST;
} else if (attrsV0.mode == "linear") {
@@ -85,7 +82,7 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
}
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
scalesConstant, axisConstant, attrsV4);
scales, axisConstant, attrsV4);
interpolateV4->set_friendly_name(interpolationV0->get_friendly_name());
ngraph::copy_runtime_info(interpolationV0, interpolateV4);

View File

@@ -56,6 +56,7 @@ TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4) {
function_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4_1) {
@@ -93,4 +94,28 @@ TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4_1) {
function_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST(TransformationTests, DynamiShapeInterpolate1To4) {
auto data_node = std::make_shared<opset1::Parameter>(element::f32, PartialShape{-1, 5, {1, 10}, -1});
auto out_shape_node = std::make_shared<opset1::Parameter>(element::i32, Shape{2});
auto interpolate1_attr = op::v0::InterpolateAttrs();
interpolate1_attr.axes = AxisSet(std::vector<size_t>{2, 3});
interpolate1_attr.mode = "linear";
interpolate1_attr.align_corners = false;
interpolate1_attr.antialias = true;
interpolate1_attr.pads_begin = std::vector<size_t>{0, 0, 0, 0};
interpolate1_attr.pads_end = std::vector<size_t>{0, 0, 0, 0};
auto interpolate1 = std::make_shared<opset1::Interpolate>(data_node, out_shape_node, interpolate1_attr);
auto f = std::make_shared<Function>(NodeVector{interpolate1}, ParameterVector{data_node, out_shape_node});
auto manager = ov::pass::Manager();
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertInterpolate1ToInterpolate4>();
manager.run_passes(f);
ASSERT_TRUE(ngraph::op::util::has_op_with_type<opset4::Interpolate>(f));
}