Fix MO and nGraph to support the model context_rcnn_resnet101_snapshot_serengeti (#9255)

* Fixes in the infer function of MO operation Select.

* Fixes in the nGraph transformation SharedShapeOf.

* Deleted commented code.

* Added more tests for the infer function of the MO operation Select.

* Started to write tests for the transformation SharedShapeOf.

* Added more tests.

* Now the transformation can correctly process a mix of opset1::ShapeOf and opset8::ShapeOf.

* Small change.

* Used opset1 and opset3 instead of opset1 and opset8.

* Used get_output_element_type(0) instead of checking the version of ShapeOf.
This commit is contained in:
Vladimir Gavrilov 2021-12-23 09:44:47 +03:00 committed by GitHub
parent 42350a705e
commit 20ee7fd242
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 229 additions and 6 deletions

View File

@ -20,6 +20,9 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0);
static constexpr size_t index_for_int32 = 0;
static constexpr size_t index_for_int64 = 1;
bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
RUN_ON_FUNCTION_SCOPE(SharedShapeOf);
bool graph_rewritten = false;
@ -38,10 +41,21 @@ bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptr<ngraph::Fun
for (const auto& pair : source_to_shape_of) {
if (pair.second.size() < 2)
continue;
const auto& root_ss = pair.second[0];
for (const auto& child_ss : pair.second)
if (root_ss->get_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0))
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
NodeVector nodes_for_different_types[2];
for (const auto& child : pair.second) {
const auto& type_of_output = child->get_output_element_type(0);
size_t index = (type_of_output == element::i32) ? index_for_int32 : index_for_int64;
nodes_for_different_types[index].push_back(child);
}
for (const auto& v : nodes_for_different_types) {
if (v.empty())
continue;
const auto& root_ss = v[0];
for (const auto& child_ss : v)
if (root_ss->get_instance_id() != child_ss->get_instance_id())
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
}
}
return graph_rewritten;
}

View File

@ -0,0 +1,178 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST_F(TransformationTestsF, SharedShapeOfTest) {
ngraph::Shape input_shape { 120, 4 };
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof4_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof5_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof6_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof7_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof4_i32, ngraph::element::i64);
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof6_i32, ngraph::element::i64);
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof7_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert,
shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SharedShapeOf>();
}
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert,
shapeof2_i64, shapeof6_i32_convert, shapeof7_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SharedShapeOfTestI64Only) {
ngraph::Shape input_shape { 120, 4 };
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof3_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof2_i64, shapeof3_i64};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SharedShapeOf>();
}
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof1_i64, shapeof1_i64};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SharedShapeOfTestI32Only) {
ngraph::Shape input_shape { 120, 4 };
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof2_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof4_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof5_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof2_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof4_i32, ngraph::element::i64);
auto shapeof5_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof5_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert,
shapeof4_i32_convert, shapeof5_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SharedShapeOf>();
}
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof2_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
auto shapeof5_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert,
shapeof4_i32_convert, shapeof5_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, SharedShapeOfTestMixed) {
ngraph::Shape input_shape { 120, 4 };
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1 = std::make_shared<ngraph::opset1::ShapeOf>(input);
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof4 = std::make_shared<ngraph::opset1::ShapeOf>(input);
auto shapeof5_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
auto shapeof6_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof7_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof6_i32, ngraph::element::i64);
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof7_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1, shapeof2_i64, shapeof3_i32_convert, shapeof4,
shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::SharedShapeOf>();
}
{
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
auto shapeof1 = std::make_shared<ngraph::opset1::ShapeOf>(input);
auto shapeof2_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
ngraph::OutputVector inputs_of_concat {shapeof1, shapeof1, shapeof3_i32_convert, shapeof1,
shapeof1, shapeof6_i32_convert, shapeof7_i32_convert};
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
}
}

View File

@ -50,6 +50,8 @@ class Select(Op):
output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
assert output_shape is not None, msg
output_is_scalar = len(output_shape) == 0
# if Select was created from TF Where operations then 1D condition must have the same size
# as 0-index dimension of output_shape. This condition is different from being numpy compatible
# but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
@ -61,9 +63,10 @@ class Select(Op):
node_name, condition_shape, a_shape, b_shape)
# check equality only if both values non-dynamic
if is_fully_defined(condition_shape[0]) and is_fully_defined(output_shape[0]):
if is_fully_defined(condition_shape[0]) and not output_is_scalar and is_fully_defined(output_shape[0]):
assert condition_shape[0] == output_shape[0], msg_tf
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))
ones_shape = len(output_shape) if output_is_scalar else len(output_shape) - 1
condition_shape = np.concatenate((condition_shape, np.ones(ones_shape)))
output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)
assert output_shape is not None, msg

View File

@ -100,6 +100,34 @@ class TestSelect(unittest.TestCase):
out_value=np.ones([15, 3, 5], dtype=np.float))
self.assertTrue(flag, msg)
def test_select_infer_condition_true_then_and_else_are_scalars(self):
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([True], dtype=bool),
then_value=np.array(3, dtype=np.float),
else_value=np.array(1, dtype=np.float),
out_value=np.array([3], dtype=np.float))
self.assertTrue(flag, msg)
def test_select_infer_condition_true_then_and_else_are_scalars_2(self):
flag, msg = self.build_select_graph_and_infer(condition_value=np.array(True, dtype=bool),
then_value=np.array(3, dtype=np.float),
else_value=np.array(1, dtype=np.float),
out_value=np.array(3, dtype=np.float))
self.assertTrue(flag, msg)
def test_select_infer_condition_false_then_and_else_are_scalars(self):
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool),
then_value=np.array(3, dtype=np.float),
else_value=np.array(1, dtype=np.float),
out_value=np.array([1], dtype=np.float))
self.assertTrue(flag, msg)
def test_select_infer_condition_false_then_and_else_are_scalars_2(self):
flag, msg = self.build_select_graph_and_infer(condition_value=np.array(False, dtype=bool),
then_value=np.array(3, dtype=np.float),
else_value=np.array(1, dtype=np.float),
out_value=np.array(1, dtype=np.float))
self.assertTrue(flag, msg)
def test_select_infer_condition_false_2(self):
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool),
then_value=np.ones([15, 3, 5], dtype=np.float),