Reverse infer deduce second Squeeze input (#19923)
* Reverse infer deduce second Squeeze input * Fix build * Fix tf tests * Small optimization * Fix type * Fix tests
This commit is contained in:
parent
02f5b7a7e3
commit
d0ef28e541
@ -5,6 +5,7 @@
|
|||||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
#include "openvino/core/validation_util.hpp"
|
#include "openvino/core/validation_util.hpp"
|
||||||
#include "openvino/op/concat.hpp"
|
#include "openvino/op/concat.hpp"
|
||||||
#include "openvino/op/convert_like.hpp"
|
#include "openvino/op/convert_like.hpp"
|
||||||
@ -176,14 +177,32 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
|
|||||||
is_changed |= inherit_output_rank(op, {0});
|
is_changed |= inherit_output_rank(op, {0});
|
||||||
is_changed |= inherit_output_type(op, {0});
|
is_changed |= inherit_output_type(op, {0});
|
||||||
} else if (std::dynamic_pointer_cast<ov::op::v0::Squeeze>(op)) {
|
} else if (std::dynamic_pointer_cast<ov::op::v0::Squeeze>(op)) {
|
||||||
auto in0_rank = op->get_input_partial_shape(0).rank();
|
auto in0_pshape = op->get_input_partial_shape(0);
|
||||||
if (output_shape.rank().is_static() && in0_rank.is_dynamic() && op->get_input_size() > 1) {
|
auto in0_rank = in0_pshape.rank();
|
||||||
|
if (output_shape.rank().is_static()) {
|
||||||
|
if (in0_rank.is_dynamic() && op->get_input_size() > 1) {
|
||||||
auto in1_pshape = op->get_input_partial_shape(1);
|
auto in1_pshape = op->get_input_partial_shape(1);
|
||||||
if (in1_pshape.is_static()) {
|
if (in1_pshape.is_static()) {
|
||||||
auto num_dims = in1_pshape.size() == 0 ? 1 : in1_pshape[0].get_length();
|
auto num_dims = in1_pshape.size() == 0 ? 1 : in1_pshape[0].get_length();
|
||||||
op->get_input_tensor(0).m_partial_shape =
|
op->get_input_tensor(0).m_partial_shape =
|
||||||
PartialShape::dynamic(output_shape.rank().get_length() + num_dims);
|
PartialShape::dynamic(output_shape.rank().get_length() + num_dims);
|
||||||
}
|
}
|
||||||
|
} else if (in0_rank.is_static() && op->get_input_size() == 1) {
|
||||||
|
// attempt to create second input
|
||||||
|
std::vector<int64_t> in1_data;
|
||||||
|
for (size_t i = 0; i < in0_pshape.size(); i++) {
|
||||||
|
if (in0_pshape[i] == 1) {
|
||||||
|
in1_data.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int64_t num_ones = in1_data.size();
|
||||||
|
if (num_ones == in0_rank.get_length() - output_shape.rank().get_length()) {
|
||||||
|
auto axes = ov::op::v0::Constant::create(element::i64, Shape{in1_data.size()}, in1_data);
|
||||||
|
auto new_squeeze = std::make_shared<ov::op::v0::Squeeze>(op->get_input_source_output(0), axes);
|
||||||
|
op->output(0).replace(new_squeeze->output(0));
|
||||||
|
copy_runtime_info(op, new_squeeze);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
is_changed |= inherit_output_type(op, {0});
|
is_changed |= inherit_output_type(op, {0});
|
||||||
} else if (std::dynamic_pointer_cast<ov::op::v0::Unsqueeze>(op)) {
|
} else if (std::dynamic_pointer_cast<ov::op::v0::Unsqueeze>(op)) {
|
||||||
|
@ -412,6 +412,41 @@ TEST_F(TransformationTestsF, SqueezeReverseInfer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, SqueezeAxesReverseInfer) {
|
||||||
|
auto dyn = Dimension::dynamic();
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape{1, dyn, 1, dyn, dyn, dyn});
|
||||||
|
auto squeeze = std::make_shared<opset10::Squeeze>(data);
|
||||||
|
// Convolution is needed to produce static rank
|
||||||
|
auto weights =
|
||||||
|
opset10::Constant::create(element::f32, Shape{64, 3, 7, 7}, std::vector<float>(64 * 3 * 7 * 7, 0.1f));
|
||||||
|
auto conv = std::make_shared<opset10::Convolution>(squeeze,
|
||||||
|
weights,
|
||||||
|
Strides{2, 2},
|
||||||
|
CoordinateDiff{3, 3},
|
||||||
|
CoordinateDiff{3, 3},
|
||||||
|
Strides{1, 1});
|
||||||
|
auto result = std::make_shared<opset10::Result>(conv);
|
||||||
|
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
|
||||||
|
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{1, dyn, 1, dyn, dyn, dyn});
|
||||||
|
auto axes = opset10::Constant::create(element::i64, Shape{2}, {0, 2});
|
||||||
|
auto squeeze = std::make_shared<opset10::Squeeze>(data, axes);
|
||||||
|
auto weights =
|
||||||
|
opset10::Constant::create(element::f32, Shape{64, 3, 7, 7}, std::vector<float>(64 * 3 * 7 * 7, 0.1f));
|
||||||
|
auto conv = std::make_shared<opset10::Convolution>(squeeze,
|
||||||
|
weights,
|
||||||
|
Strides{2, 2},
|
||||||
|
CoordinateDiff{3, 3},
|
||||||
|
CoordinateDiff{3, 3},
|
||||||
|
Strides{1, 1});
|
||||||
|
auto result = std::make_shared<opset10::Result>(conv);
|
||||||
|
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, UnsqueezeReverseInfer) {
|
TEST_F(TransformationTestsF, UnsqueezeReverseInfer) {
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape::dynamic());
|
auto data = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||||
|
@ -664,7 +664,8 @@ TEST_F(FrontEndConversionWithReferenceTestsF, NonMaxSuppressionWithNamedOutputs)
|
|||||||
selected_scores = make_shared<Convert>(selected_scores, i32);
|
selected_scores = make_shared<Convert>(selected_scores, i32);
|
||||||
|
|
||||||
// compute the third output - valid_outputs
|
// compute the third output - valid_outputs
|
||||||
Output<Node> valid_outputs = make_shared<Squeeze>(nms->output(2));
|
auto squeeze_axes = make_shared<Constant>(i64, Shape{1}, 0);
|
||||||
|
Output<Node> valid_outputs = make_shared<Squeeze>(nms->output(2), squeeze_axes);
|
||||||
|
|
||||||
// make post-processing before the concatenation
|
// make post-processing before the concatenation
|
||||||
auto const_minus_one = make_shared<Constant>(i32, Shape{1}, -1);
|
auto const_minus_one = make_shared<Constant>(i32, Shape{1}, -1);
|
||||||
|
Loading…
Reference in New Issue
Block a user