Static Shape constraints removed from Interpolate 1->4 transformation (#10732)
* Static Shape constraints removed from Interpolate 1->4 transformation * Dynamic tests added
This commit is contained in:
committed by
GitHub
parent
bea352f272
commit
4b55ef9911
@@ -9,40 +9,37 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <openvino/core/core.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "ConvertInterpolate1ToInterpolate4", 0);
|
||||
|
||||
ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() {
|
||||
MATCHER_SCOPE(ConvertInterpolate1ToInterpolate4);
|
||||
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_shape()), pattern::any_input()});
|
||||
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_rank()), pattern::any_input()});
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto interpolationV0 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
|
||||
if (!interpolationV0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& inp_partial_shape = interpolationV0->get_input_partial_shape(0);
|
||||
auto& out_shape = interpolationV0->get_output_shape(0);
|
||||
auto attrsV0 = interpolationV0->get_attrs();
|
||||
std::vector<size_t> axes{attrsV0.axes.begin(), attrsV0.axes.end()};
|
||||
const auto& out_dims = std::make_shared<opset1::Convert>(interpolationV0->input_value(1), element::f32);
|
||||
const auto& in_dims = std::make_shared<opset1::Convert>(ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(
|
||||
interpolationV0->input_value(0), axes), element::f32);
|
||||
|
||||
std::vector<float> scales(attrsV0.axes.size(), 1.0f);
|
||||
auto inp_shape = inp_partial_shape.to_shape();
|
||||
size_t i = 0;
|
||||
for (std::size_t axis : attrsV0.axes) {
|
||||
scales[i] = static_cast<float>(out_shape.at(axis))/inp_shape.at(axis);
|
||||
i++;
|
||||
}
|
||||
|
||||
auto input_shape_rank = inp_partial_shape.rank().get_length();
|
||||
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
|
||||
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
|
||||
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
|
||||
std::shared_ptr<Node> scales = std::make_shared<opset1::Divide>(out_dims, in_dims);
|
||||
if (const auto& constant = ov::get_constant_from_source(scales))
|
||||
scales = constant;
|
||||
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {axes.size()}, axes);
|
||||
|
||||
ngraph::opset4::Interpolate::InterpolateAttrs attrsV4;
|
||||
|
||||
auto input_shape_rank = interpolationV0->get_input_partial_shape(0).rank().get_length();
|
||||
if (attrsV0.mode == "nearest") {
|
||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::NEAREST;
|
||||
} else if (attrsV0.mode == "linear") {
|
||||
@@ -85,7 +82,7 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
|
||||
}
|
||||
|
||||
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
|
||||
scalesConstant, axisConstant, attrsV4);
|
||||
scales, axisConstant, attrsV4);
|
||||
|
||||
interpolateV4->set_friendly_name(interpolationV0->get_friendly_name());
|
||||
ngraph::copy_runtime_info(interpolationV0, interpolateV4);
|
||||
|
||||
@@ -56,6 +56,7 @@ TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4) {
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4_1) {
|
||||
@@ -93,4 +94,28 @@ TEST_F(TransformationTestsF, ConvertInterpolate1ToInterpolate4_1) {
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, DynamiShapeInterpolate1To4) {
|
||||
auto data_node = std::make_shared<opset1::Parameter>(element::f32, PartialShape{-1, 5, {1, 10}, -1});
|
||||
auto out_shape_node = std::make_shared<opset1::Parameter>(element::i32, Shape{2});
|
||||
|
||||
auto interpolate1_attr = op::v0::InterpolateAttrs();
|
||||
interpolate1_attr.axes = AxisSet(std::vector<size_t>{2, 3});
|
||||
interpolate1_attr.mode = "linear";
|
||||
interpolate1_attr.align_corners = false;
|
||||
interpolate1_attr.antialias = true;
|
||||
interpolate1_attr.pads_begin = std::vector<size_t>{0, 0, 0, 0};
|
||||
interpolate1_attr.pads_end = std::vector<size_t>{0, 0, 0, 0};
|
||||
|
||||
auto interpolate1 = std::make_shared<opset1::Interpolate>(data_node, out_shape_node, interpolate1_attr);
|
||||
auto f = std::make_shared<Function>(NodeVector{interpolate1}, ParameterVector{data_node, out_shape_node});
|
||||
|
||||
auto manager = ov::pass::Manager();
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertInterpolate1ToInterpolate4>();
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(ngraph::op::util::has_op_with_type<opset4::Interpolate>(f));
|
||||
}
|
||||
Reference in New Issue
Block a user