Fix MO and nGraph to support the model context_rcnn_resnet101_snapshot_serengeti (#9255)
* Fixes in the infer function of MO operation Select. * Fixes in the nGraph transformation SharedShapeOf. * Deleted commented code. * Added more tests for the infer function of the MO operation Select. * Started to write tests for the transformation SharedShapeOf. * Added more tests. * Now the transformation can correctly process a mix of opset1::ShapeOf and opset8::ShapeOf. * Small change. * Used opset1 and opset3 instead of opset1 and opset8. * Used get_output_element_type(0) instead of checking the version of ShapeOf.
This commit is contained in:
parent
42350a705e
commit
20ee7fd242
@ -20,6 +20,9 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0);
|
||||
|
||||
static constexpr size_t index_for_int32 = 0;
|
||||
static constexpr size_t index_for_int64 = 1;
|
||||
|
||||
bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(SharedShapeOf);
|
||||
bool graph_rewritten = false;
|
||||
@ -38,10 +41,21 @@ bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
for (const auto& pair : source_to_shape_of) {
|
||||
if (pair.second.size() < 2)
|
||||
continue;
|
||||
const auto& root_ss = pair.second[0];
|
||||
for (const auto& child_ss : pair.second)
|
||||
if (root_ss->get_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0))
|
||||
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
|
||||
|
||||
NodeVector nodes_for_different_types[2];
|
||||
for (const auto& child : pair.second) {
|
||||
const auto& type_of_output = child->get_output_element_type(0);
|
||||
size_t index = (type_of_output == element::i32) ? index_for_int32 : index_for_int64;
|
||||
nodes_for_different_types[index].push_back(child);
|
||||
}
|
||||
for (const auto& v : nodes_for_different_types) {
|
||||
if (v.empty())
|
||||
continue;
|
||||
const auto& root_ss = v[0];
|
||||
for (const auto& child_ss : v)
|
||||
if (root_ss->get_instance_id() != child_ss->get_instance_id())
|
||||
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
|
||||
}
|
||||
}
|
||||
return graph_rewritten;
|
||||
}
|
||||
|
@ -0,0 +1,178 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST_F(TransformationTestsF, SharedShapeOfTest) {
|
||||
ngraph::Shape input_shape { 120, 4 };
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof4_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof5_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof6_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof7_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
|
||||
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
|
||||
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof4_i32, ngraph::element::i64);
|
||||
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof6_i32, ngraph::element::i64);
|
||||
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof7_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert,
|
||||
shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::SharedShapeOf>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
|
||||
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert,
|
||||
shapeof2_i64, shapeof6_i32_convert, shapeof7_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SharedShapeOfTestI64Only) {
|
||||
ngraph::Shape input_shape { 120, 4 };
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof3_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof2_i64, shapeof3_i64};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::SharedShapeOf>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto shapeof1_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof1_i64, shapeof1_i64};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SharedShapeOfTestI32Only) {
|
||||
ngraph::Shape input_shape { 120, 4 };
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof2_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof4_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof5_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
|
||||
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof2_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
|
||||
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof4_i32, ngraph::element::i64);
|
||||
auto shapeof5_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof5_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert,
|
||||
shapeof4_i32_convert, shapeof5_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::SharedShapeOf>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
|
||||
auto shapeof1_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof2_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof4_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
auto shapeof5_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof1_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert,
|
||||
shapeof4_i32_convert, shapeof5_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SharedShapeOfTestMixed) {
|
||||
ngraph::Shape input_shape { 120, 4 };
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1 = std::make_shared<ngraph::opset1::ShapeOf>(input);
|
||||
auto shapeof2_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof3_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof4 = std::make_shared<ngraph::opset1::ShapeOf>(input);
|
||||
auto shapeof5_i64 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i64);
|
||||
auto shapeof6_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
auto shapeof7_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof3_i32, ngraph::element::i64);
|
||||
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof6_i32, ngraph::element::i64);
|
||||
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof7_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1, shapeof2_i64, shapeof3_i32_convert, shapeof4,
|
||||
shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::SharedShapeOf>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::f32, input_shape);
|
||||
|
||||
auto shapeof1 = std::make_shared<ngraph::opset1::ShapeOf>(input);
|
||||
auto shapeof2_i32 = std::make_shared<ngraph::opset8::ShapeOf>(input, ngraph::element::i32);
|
||||
|
||||
auto shapeof3_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
|
||||
auto shapeof6_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
|
||||
auto shapeof7_i32_convert = std::make_shared<ngraph::opset8::Convert>(shapeof2_i32, ngraph::element::i64);
|
||||
|
||||
ngraph::OutputVector inputs_of_concat {shapeof1, shapeof1, shapeof3_i32_convert, shapeof1,
|
||||
shapeof1, shapeof6_i32_convert, shapeof7_i32_convert};
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset8::Concat>(inputs_of_concat, 0);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
@ -50,6 +50,8 @@ class Select(Op):
|
||||
output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
|
||||
assert output_shape is not None, msg
|
||||
|
||||
output_is_scalar = len(output_shape) == 0
|
||||
|
||||
# if Select was created from TF Where operations then 1D condition must have the same size
|
||||
# as 0-index dimension of output_shape. This condition is different from being numpy compatible
|
||||
# but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
|
||||
@ -61,9 +63,10 @@ class Select(Op):
|
||||
node_name, condition_shape, a_shape, b_shape)
|
||||
|
||||
# check equality only if both values non-dynamic
|
||||
if is_fully_defined(condition_shape[0]) and is_fully_defined(output_shape[0]):
|
||||
if is_fully_defined(condition_shape[0]) and not output_is_scalar and is_fully_defined(output_shape[0]):
|
||||
assert condition_shape[0] == output_shape[0], msg_tf
|
||||
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))
|
||||
ones_shape = len(output_shape) if output_is_scalar else len(output_shape) - 1
|
||||
condition_shape = np.concatenate((condition_shape, np.ones(ones_shape)))
|
||||
|
||||
output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)
|
||||
assert output_shape is not None, msg
|
||||
|
@ -100,6 +100,34 @@ class TestSelect(unittest.TestCase):
|
||||
out_value=np.ones([15, 3, 5], dtype=np.float))
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_condition_true_then_and_else_are_scalars(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([True], dtype=bool),
|
||||
then_value=np.array(3, dtype=np.float),
|
||||
else_value=np.array(1, dtype=np.float),
|
||||
out_value=np.array([3], dtype=np.float))
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_condition_true_then_and_else_are_scalars_2(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=np.array(True, dtype=bool),
|
||||
then_value=np.array(3, dtype=np.float),
|
||||
else_value=np.array(1, dtype=np.float),
|
||||
out_value=np.array(3, dtype=np.float))
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_condition_false_then_and_else_are_scalars(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool),
|
||||
then_value=np.array(3, dtype=np.float),
|
||||
else_value=np.array(1, dtype=np.float),
|
||||
out_value=np.array([1], dtype=np.float))
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_condition_false_then_and_else_are_scalars_2(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=np.array(False, dtype=bool),
|
||||
then_value=np.array(3, dtype=np.float),
|
||||
else_value=np.array(1, dtype=np.float),
|
||||
out_value=np.array(1, dtype=np.float))
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_select_infer_condition_false_2(self):
|
||||
flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool),
|
||||
then_value=np.ones([15, 3, 5], dtype=np.float),
|
||||
|
Loading…
Reference in New Issue
Block a user