diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp index 44d5b124341..d4f61e7e728 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp @@ -17,7 +17,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThr ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { MATCHER_SCOPE(PullTransposeThroughFQUp); - auto m_fq = pattern::wrap_type({pattern::any_input(pattern::has_static_rank()), + const auto weights = ngraph::pattern::wrap_type(); + auto m_fq = pattern::wrap_type({weights, pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()), diff --git a/inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp index 6b5f2d82b6a..822c575ac11 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp @@ -56,28 +56,28 @@ TEST_F(TransformationTestsF, FQTransposeTest1) { } } -TEST(TransformationTests, FQTransposeDynamic) { - auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); - auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); - auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); - auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); - auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); - auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); +TEST_F(TransformationTestsF, FQTransposeNegativeCase) { + auto create_graph = []() -> std::shared_ptr { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{1, 3, 1}); + auto sigmoid = std::make_shared(data); + auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); + auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); + auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); + auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); + auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); - std::shared_ptr f(nullptr); - { - auto fq = std::make_shared(data, input_low, input_high, output_low, output_high, 1); + auto fq = std::make_shared(sigmoid, input_low, input_high, output_low, output_high, 1); auto transpose = std::make_shared(fq, transpose_order); - f = std::make_shared(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data}); + return std::make_shared(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data}); + }; + function = create_graph(); - ngraph::pass::Manager manager; - manager.register_pass(); - manager.register_pass(); - manager.register_pass([](std::shared_ptr f) { - check_rt_info(f); - }); - manager.register_pass(); - ASSERT_NO_THROW(manager.run_passes(f)); - } + manager.register_pass(); + manager.register_pass(); + manager.register_pass([](std::shared_ptr f) { + check_rt_info(f); + }); + + function_ref = create_graph(); } diff --git a/model-optimizer/extensions/back/MatMulNormalizer.py b/model-optimizer/extensions/back/MatMulNormalizer.py index fb860f86b56..96221640025 100644 --- a/model-optimizer/extensions/back/MatMulNormalizer.py +++ b/model-optimizer/extensions/back/MatMulNormalizer.py @@ -73,14 +73,13 @@ class MatMulConstTransposesExtraction(BackReplacementPattern): class PullTransposeThroughFQUp(BackReplacementPattern): r""" BEFORE AFTER - T T T T T - \ \ | / / \ \ | / / - FakeQuantize FakeQuantize + Const Const + \ \ | / / | + FakeQuantize T T T T T + | \ \ | / / + Transpose FakeQuantize | | - Transpose next_op - | - next_op - + next_op next_op `T` is Transpose for short """ enabled = True @@ -94,13 +93,17 @@ class PullTransposeThroughFQUp(BackReplacementPattern): def pattern(): return dict( nodes=[ + ('fq_const_input', dict(kind='op', type='Const')), + ('fq_const_input_d', dict()), ('fq', dict(kind='op', type='FakeQuantize')), - ('data', dict()), + ('fq_d', dict()), ('transpose', dict(kind='op', type='Transpose')), ], edges=[ - ('fq', 'data'), - ('data', 'transpose'), + ('fq_const_input', 'fq_const_input_d'), + ('fq_const_input_d', 'fq', {'in': 0}), + ('fq', 'fq_d'), + ('fq_d', 'transpose'), ] ) diff --git a/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py b/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py index 30805aa19aa..e6b98ce3b91 100644 --- a/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py +++ b/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py @@ -15,7 +15,7 @@ from mo.front.common.partial_infer.utils import int64_array from mo.ops.reshape import Reshape from mo.utils.ir_engine.compare_graphs import compare_graphs from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \ - result, connect, connect_data + shaped_const_with_data, result, connect, connect_data from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data @@ -101,9 +101,8 @@ class SmartReshape_HC_Reshape_MatMulTest(unittest.TestCase): class FQTransposePullerTest(unittest.TestCase): - def nodes(self, input_shape, transpose_shape, fq_shape): - return { - **regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')), + def nodes(self, input_shape, transpose_shape, fq_shape, is_input_const): + nodes = { **valued_const_with_data('il', np.array([[[[0]]]])), **valued_const_with_data('ih', np.array([[[[255]]]])), **valued_const_with_data('ol', np.array([[[[0]]]])), @@ -116,8 +115,16 @@ class FQTransposePullerTest(unittest.TestCase): **result(), } + if is_input_const: + input_node = shaped_const_with_data('input', input_shape) + else: + input_node = regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')) + + nodes.update(input_node) + return nodes + def test_positive(self): - nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True) edges = [ *connect('input', '0:FQ'), *connect('il', '1:FQ'), @@ -132,7 +139,7 @@ class FQTransposePullerTest(unittest.TestCase): PullTransposeThroughFQUp().find_and_replace_pattern(graph) graph.clean_up() - nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3]) + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3], True) edges = [ *connect('input', '0:transpose'), *connect('order:0', '1:transpose'), @@ -148,8 +155,8 @@ class FQTransposePullerTest(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) - def test_negative(self): - nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) + def test_negative_1(self): + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True) edges = [ *connect('input', '0:FQ'), *connect('il', '1:FQ'), @@ -168,3 +175,21 @@ class FQTransposePullerTest(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + def test_negative_2(self): + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], False) + edges = [ + *connect('input', '0:FQ'), + *connect('il', '1:FQ'), + *connect('ih', '2:FQ'), + *connect('ol', '3:FQ'), + *connect('oh', '4:FQ'), + *connect('FQ:0', '0:transpose'), + *connect('order:0', '1:transpose'), + *connect('transpose:0', 'output'), + ] + graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True) + graph_ref = graph.copy() + PullTransposeThroughFQUp().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp)