Reverse infer deduce second Squeeze input (#19923)

* Reverse infer deduce second Squeeze input

* Fix build

* Fix tf tests

* Small optimization

* Fix type

* Fix tests
This commit is contained in:
Maxim Vafin 2023-09-21 17:37:33 +02:00 committed by GitHub
parent 02f5b7a7e3
commit d0ef28e541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 8 deletions

View File

@ -5,6 +5,7 @@
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
@ -176,14 +177,32 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
is_changed |= inherit_output_rank(op, {0});
is_changed |= inherit_output_type(op, {0});
} else if (std::dynamic_pointer_cast<ov::op::v0::Squeeze>(op)) {
auto in0_rank = op->get_input_partial_shape(0).rank();
if (output_shape.rank().is_static() && in0_rank.is_dynamic() && op->get_input_size() > 1) {
auto in0_pshape = op->get_input_partial_shape(0);
auto in0_rank = in0_pshape.rank();
if (output_shape.rank().is_static()) {
if (in0_rank.is_dynamic() && op->get_input_size() > 1) {
auto in1_pshape = op->get_input_partial_shape(1);
if (in1_pshape.is_static()) {
auto num_dims = in1_pshape.size() == 0 ? 1 : in1_pshape[0].get_length();
op->get_input_tensor(0).m_partial_shape =
PartialShape::dynamic(output_shape.rank().get_length() + num_dims);
}
} else if (in0_rank.is_static() && op->get_input_size() == 1) {
// attempt to create second input
std::vector<int64_t> in1_data;
for (size_t i = 0; i < in0_pshape.size(); i++) {
if (in0_pshape[i] == 1) {
in1_data.push_back(i);
}
}
int64_t num_ones = in1_data.size();
if (num_ones == in0_rank.get_length() - output_shape.rank().get_length()) {
auto axes = ov::op::v0::Constant::create(element::i64, Shape{in1_data.size()}, in1_data);
auto new_squeeze = std::make_shared<ov::op::v0::Squeeze>(op->get_input_source_output(0), axes);
op->output(0).replace(new_squeeze->output(0));
copy_runtime_info(op, new_squeeze);
}
}
}
is_changed |= inherit_output_type(op, {0});
} else if (std::dynamic_pointer_cast<ov::op::v0::Unsqueeze>(op)) {

View File

@ -412,6 +412,41 @@ TEST_F(TransformationTestsF, SqueezeReverseInfer) {
}
}
TEST_F(TransformationTestsF, SqueezeAxesReverseInfer) {
auto dyn = Dimension::dynamic();
{
auto data = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape{1, dyn, 1, dyn, dyn, dyn});
auto squeeze = std::make_shared<opset10::Squeeze>(data);
// Convolution is needed to produce static rank
auto weights =
opset10::Constant::create(element::f32, Shape{64, 3, 7, 7}, std::vector<float>(64 * 3 * 7 * 7, 0.1f));
auto conv = std::make_shared<opset10::Convolution>(squeeze,
weights,
Strides{2, 2},
CoordinateDiff{3, 3},
CoordinateDiff{3, 3},
Strides{1, 1});
auto result = std::make_shared<opset10::Result>(conv);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{1, dyn, 1, dyn, dyn, dyn});
auto axes = opset10::Constant::create(element::i64, Shape{2}, {0, 2});
auto squeeze = std::make_shared<opset10::Squeeze>(data, axes);
auto weights =
opset10::Constant::create(element::f32, Shape{64, 3, 7, 7}, std::vector<float>(64 * 3 * 7 * 7, 0.1f));
auto conv = std::make_shared<opset10::Convolution>(squeeze,
weights,
Strides{2, 2},
CoordinateDiff{3, 3},
CoordinateDiff{3, 3},
Strides{1, 1});
auto result = std::make_shared<opset10::Result>(conv);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, UnsqueezeReverseInfer) {
{
auto data = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape::dynamic());

View File

@ -664,7 +664,8 @@ TEST_F(FrontEndConversionWithReferenceTestsF, NonMaxSuppressionWithNamedOutputs)
selected_scores = make_shared<Convert>(selected_scores, i32);
// compute the third output - valid_outputs
Output<Node> valid_outputs = make_shared<Squeeze>(nms->output(2));
auto squeeze_axes = make_shared<Constant>(i64, Shape{1}, 0);
Output<Node> valid_outputs = make_shared<Squeeze>(nms->output(2), squeeze_axes);
// make post-processing before the concatenation
auto const_minus_one = make_shared<Constant>(i32, Shape{1}, -1);