diff --git a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp index 8b9a1c3de9f..14dede1e6f8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -20,6 +20,9 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0); +static constexpr size_t index_for_int32 = 0; +static constexpr size_t index_for_int64 = 1; + bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptr& f) { RUN_ON_FUNCTION_SCOPE(SharedShapeOf); bool graph_rewritten = false; @@ -38,10 +41,21 @@ bool ngraph::pass::SharedShapeOf::run_on_model(const std::shared_ptrget_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0)) - graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0)); + + NodeVector nodes_for_different_types[2]; + for (const auto& child : pair.second) { + const auto& type_of_output = child->get_output_element_type(0); + size_t index = (type_of_output == element::i32) ? index_for_int32 : index_for_int64; + nodes_for_different_types[index].push_back(child); + } + for (const auto& v : nodes_for_different_types) { + if (v.empty()) + continue; + const auto& root_ss = v[0]; + for (const auto& child_ss : v) + if (root_ss->get_instance_id() != child_ss->get_instance_id()) + graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0)); + } } return graph_rewritten; } diff --git a/src/tests/functional/inference_engine/transformations/shared_shapeof_test.cpp b/src/tests/functional/inference_engine/transformations/shared_shapeof_test.cpp new file mode 100644 index 00000000000..a6585b6cb92 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/shared_shapeof_test.cpp @@ -0,0 +1,178 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST_F(TransformationTestsF, SharedShapeOfTest) { + ngraph::Shape input_shape { 120, 4 }; + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof2_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof3_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof4_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof5_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof6_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof7_i32 = std::make_shared(input, ngraph::element::i32); + + auto shapeof1_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof3_i32_convert = std::make_shared(shapeof3_i32, ngraph::element::i64); + auto shapeof4_i32_convert = std::make_shared(shapeof4_i32, ngraph::element::i64); + auto shapeof6_i32_convert = std::make_shared(shapeof6_i32, ngraph::element::i64); + auto shapeof7_i32_convert = std::make_shared(shapeof7_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert, + shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof2_i64 = std::make_shared(input, ngraph::element::i64); + + auto shapeof1_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof3_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof4_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof6_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof7_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i64, shapeof3_i32_convert, shapeof4_i32_convert, + shapeof2_i64, shapeof6_i32_convert, shapeof7_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function_ref = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + } +} + +TEST_F(TransformationTestsF, SharedShapeOfTestI64Only) { + ngraph::Shape input_shape { 120, 4 }; + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof2_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof3_i64 = std::make_shared(input, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof2_i64, shapeof3_i64}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + auto shapeof1_i64 = std::make_shared(input, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i64, shapeof1_i64, shapeof1_i64}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function_ref = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + } +} + +TEST_F(TransformationTestsF, SharedShapeOfTestI32Only) { + ngraph::Shape input_shape { 120, 4 }; + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof2_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof3_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof4_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof5_i32 = std::make_shared(input, ngraph::element::i32); + + auto shapeof1_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof2_i32_convert = std::make_shared(shapeof2_i32, ngraph::element::i64); + auto shapeof3_i32_convert = std::make_shared(shapeof3_i32, ngraph::element::i64); + auto shapeof4_i32_convert = std::make_shared(shapeof4_i32, ngraph::element::i64); + auto shapeof5_i32_convert = std::make_shared(shapeof5_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert, + shapeof4_i32_convert, shapeof5_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1_i32 = std::make_shared(input, ngraph::element::i32); + + auto shapeof1_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof2_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof3_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof4_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + auto shapeof5_i32_convert = std::make_shared(shapeof1_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1_i32_convert, shapeof2_i32_convert, shapeof3_i32_convert, + shapeof4_i32_convert, shapeof5_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function_ref = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + } +} + +TEST_F(TransformationTestsF, SharedShapeOfTestMixed) { + ngraph::Shape input_shape { 120, 4 }; + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1 = std::make_shared(input); + auto shapeof2_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof3_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof4 = std::make_shared(input); + auto shapeof5_i64 = std::make_shared(input, ngraph::element::i64); + auto shapeof6_i32 = std::make_shared(input, ngraph::element::i32); + auto shapeof7_i32 = std::make_shared(input, ngraph::element::i32); + + auto shapeof3_i32_convert = std::make_shared(shapeof3_i32, ngraph::element::i64); + auto shapeof6_i32_convert = std::make_shared(shapeof6_i32, ngraph::element::i64); + auto shapeof7_i32_convert = std::make_shared(shapeof7_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1, shapeof2_i64, shapeof3_i32_convert, shapeof4, + shapeof5_i64, shapeof6_i32_convert, shapeof7_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ngraph::element::f32, input_shape); + + auto shapeof1 = std::make_shared(input); + auto shapeof2_i32 = std::make_shared(input, ngraph::element::i32); + + auto shapeof3_i32_convert = std::make_shared(shapeof2_i32, ngraph::element::i64); + auto shapeof6_i32_convert = std::make_shared(shapeof2_i32, ngraph::element::i64); + auto shapeof7_i32_convert = std::make_shared(shapeof2_i32, ngraph::element::i64); + + ngraph::OutputVector inputs_of_concat {shapeof1, shapeof1, shapeof3_i32_convert, shapeof1, + shapeof1, shapeof6_i32_convert, shapeof7_i32_convert}; + + auto concat = std::make_shared(inputs_of_concat, 0); + function_ref = std::make_shared(ngraph::NodeVector{ concat }, ngraph::ParameterVector{ input }); + } +} diff --git a/tools/mo/openvino/tools/mo/ops/select.py b/tools/mo/openvino/tools/mo/ops/select.py index 3d3520638a1..fa27e1f1c7e 100644 --- a/tools/mo/openvino/tools/mo/ops/select.py +++ b/tools/mo/openvino/tools/mo/ops/select.py @@ -50,6 +50,8 @@ class Select(Op): output_shape = bi_directional_shape_broadcasting(a_shape, b_shape) assert output_shape is not None, msg + output_is_scalar = len(output_shape) == 0 + # if Select was created from TF Where operations then 1D condition must have the same size # as 0-index dimension of output_shape. This condition is different from being numpy compatible # but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py @@ -61,9 +63,10 @@ class Select(Op): node_name, condition_shape, a_shape, b_shape) # check equality only if both values non-dynamic - if is_fully_defined(condition_shape[0]) and is_fully_defined(output_shape[0]): + if is_fully_defined(condition_shape[0]) and not output_is_scalar and is_fully_defined(output_shape[0]): assert condition_shape[0] == output_shape[0], msg_tf - condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1))) + ones_shape = len(output_shape) if output_is_scalar else len(output_shape) - 1 + condition_shape = np.concatenate((condition_shape, np.ones(ones_shape))) output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape) assert output_shape is not None, msg diff --git a/tools/mo/unit_tests/mo/ops/select_test.py b/tools/mo/unit_tests/mo/ops/select_test.py index 0b6e5671405..f5e49420dd4 100644 --- a/tools/mo/unit_tests/mo/ops/select_test.py +++ b/tools/mo/unit_tests/mo/ops/select_test.py @@ -100,6 +100,34 @@ class TestSelect(unittest.TestCase): out_value=np.ones([15, 3, 5], dtype=np.float)) self.assertTrue(flag, msg) + def test_select_infer_condition_true_then_and_else_are_scalars(self): + flag, msg = self.build_select_graph_and_infer(condition_value=np.array([True], dtype=bool), + then_value=np.array(3, dtype=np.float), + else_value=np.array(1, dtype=np.float), + out_value=np.array([3], dtype=np.float)) + self.assertTrue(flag, msg) + + def test_select_infer_condition_true_then_and_else_are_scalars_2(self): + flag, msg = self.build_select_graph_and_infer(condition_value=np.array(True, dtype=bool), + then_value=np.array(3, dtype=np.float), + else_value=np.array(1, dtype=np.float), + out_value=np.array(3, dtype=np.float)) + self.assertTrue(flag, msg) + + def test_select_infer_condition_false_then_and_else_are_scalars(self): + flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool), + then_value=np.array(3, dtype=np.float), + else_value=np.array(1, dtype=np.float), + out_value=np.array([1], dtype=np.float)) + self.assertTrue(flag, msg) + + def test_select_infer_condition_false_then_and_else_are_scalars_2(self): + flag, msg = self.build_select_graph_and_infer(condition_value=np.array(False, dtype=bool), + then_value=np.array(3, dtype=np.float), + else_value=np.array(1, dtype=np.float), + out_value=np.array(1, dtype=np.float)) + self.assertTrue(flag, msg) + def test_select_infer_condition_false_2(self): flag, msg = self.build_select_graph_and_infer(condition_value=np.array([False], dtype=bool), then_value=np.ones([15, 3, 5], dtype=np.float),