From d7ce8289ac4252f23195695af86c8adeaf2995c8 Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Fri, 19 Aug 2022 12:14:10 +0200 Subject: [PATCH] [MO] Fix SSliceComplex transformation (#12537) --- tools/mo/automation/package_BOM.txt | 1 - .../tools/mo/front/tf/CorrectRollAxes.py | 23 ----- .../openvino/tools/mo/middle/SSliceComplex.py | 19 ++-- .../mo/front/tf/CorrectRollAxes_test.py | 89 ----------------- .../mo/middle/SSliceComplex_test.py | 96 +++++++++++++++++++ 5 files changed, 107 insertions(+), 121 deletions(-) delete mode 100644 tools/mo/openvino/tools/mo/front/tf/CorrectRollAxes.py delete mode 100644 tools/mo/unit_tests/mo/front/tf/CorrectRollAxes_test.py diff --git a/tools/mo/automation/package_BOM.txt b/tools/mo/automation/package_BOM.txt index cb38db9ba97..96ae2b20db0 100644 --- a/tools/mo/automation/package_BOM.txt +++ b/tools/mo/automation/package_BOM.txt @@ -516,7 +516,6 @@ openvino/tools/mo/front/tf/concat_ext.py openvino/tools/mo/front/tf/const_ext.py openvino/tools/mo/front/tf/conv_ext.py openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py -openvino/tools/mo/front/tf/CorrectRollAxes.py openvino/tools/mo/front/tf/crop_and_resize_ext.py openvino/tools/mo/front/tf/CropAndResizeReplacement.py openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py diff --git a/tools/mo/openvino/tools/mo/front/tf/CorrectRollAxes.py b/tools/mo/openvino/tools/mo/front/tf/CorrectRollAxes.py deleted file mode 100644 index 4f6ecd92a40..00000000000 --- a/tools/mo/openvino/tools/mo/front/tf/CorrectRollAxes.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (C) 2018-2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -from openvino.tools.mo.front.common.partial_infer.utils import int64_array -from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph -from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values -from openvino.tools.mo.graph.graph import Graph - - -class CorrectRollAxes(FrontReplacementSubgraph): - """ - If the Roll node is a consumer of Complex node in the original TF model, then we have a real input tensor for Roll - instead of a complex. Negative axes values for the Roll operation should be updated to reflect the fact that the - rank of input tensor was increased by one (a new trailing dimension of size 2 containing real and imaginary part - of complex number is added). - """ - enabled = True - - def find_and_replace_pattern(self, graph: Graph): - for roll in graph.get_op_nodes(op='Roll', input_rank_changed=True): - add_constant_to_negative_values(roll, 2, int64_array(-1)) - del roll['input_rank_changed'] diff --git a/tools/mo/openvino/tools/mo/middle/SSliceComplex.py b/tools/mo/openvino/tools/mo/middle/SSliceComplex.py index 5b3a4907897..c24898ea628 100644 --- a/tools/mo/openvino/tools/mo/middle/SSliceComplex.py +++ b/tools/mo/openvino/tools/mo/middle/SSliceComplex.py @@ -7,7 +7,7 @@ from typing import Dict import numpy as np from openvino.tools.mo.front.common.partial_infer.utils import int64_array -from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs +from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values, create_op_with_const_inputs from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes from openvino.tools.mo.middle.replacement import MiddleReplacementPattern from openvino.tools.mo.ops.transpose import Transpose @@ -34,14 +34,14 @@ class SSliceComplex(MiddleReplacementPattern): and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}]. But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with. - If k == r - 1, then the replacement should be the subgraph + 1. If k == r - 1, then the replacement should be the subgraph SomeOp other inputs | | ... | ------------------- SomeOp1 - In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph + 2. In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph SomeOp | @@ -54,9 +54,6 @@ class SSliceComplex(MiddleReplacementPattern): SomeOp1 where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k]. - - After this transformation we need to mark SomeOp1 operation that its input rank has changed because - its inputs/attributes should probably be updated. Currently we have such a case for a Roll operation. """ enabled = True @@ -109,10 +106,16 @@ class SSliceComplex(MiddleReplacementPattern): for dst in complex_node.out_port(0).get_connection().get_destinations(): after_complex_node = dst.node - after_complex_node['input_rank_changed'] = True + # TODO: now it does not support adjustment of `axis` inputs for other operations such Gather, Concat, etc. + # It does not traverse the full path affected by complex numbers for adjusting the corresponding operations. + # It can affect other models with complex numbers for which we can generate incorrect IRs or offline transformation fails. + if after_complex_node.type == 'Roll': + add_constant_to_negative_values(after_complex_node, 2, int64_array(emulated_complex_tensor_rank)) input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0 + # If output of SomeOp is sliced on the last dimension on the last dimension (like described in 1 case), skipping Complex op is enough. + # Otherwise, (like described in 2 case) Transpose insertion is needed to align data arrangement. if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis: complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source()) else: @@ -124,4 +127,4 @@ class SSliceComplex(MiddleReplacementPattern): {'name': complex_node_name + '/cmplx'}) complex_node.out_port(0).get_connection().set_source(transpose.out_port(0)) strided_slice_real.in_port(0).get_source().connect(transpose.in_port(0)) - rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)]) \ No newline at end of file + rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)]) diff --git a/tools/mo/unit_tests/mo/front/tf/CorrectRollAxes_test.py b/tools/mo/unit_tests/mo/front/tf/CorrectRollAxes_test.py deleted file mode 100644 index 65db305e87f..00000000000 --- a/tools/mo/unit_tests/mo/front/tf/CorrectRollAxes_test.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (C) 2018-2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -import unittest - -from openvino.tools.mo.front.tf.CorrectRollAxes import CorrectRollAxes -from openvino.tools.mo.front.common.partial_infer.utils import int64_array -from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs -from unit_tests.utils.graph import build_graph - - -graph_node_attrs = { - 'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll', 'input_rank_changed': True}, - 'roll_shift': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50]) - }, - 'roll_axes': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1]) - }, - 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'}, - 'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, -} - -graph_edges = [ - ('placeholder', 'roll', {'in': 0}), - ('roll', 'abs'), - ('abs', 'output'), - ('roll_shift', 'roll', {'in': 1}), - ('roll_axes', 'roll', {'in': 2}), -] - - -ref_graph_node_attrs = { - 'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll'}, - 'roll_shift': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50]) - }, - 'roll_axes': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1]) - }, - 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'}, - 'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, - 'add': {'type': 'Add', 'kind': 'op', 'op': 'Add'}, - 'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'}, - 'less': {'type': 'Less', 'kind': 'op', 'op': 'Less'}, - 'zero': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(0) - }, - 'minus_one': { - 'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(-1) - }, -} - -ref_graph_edges = [ - ('placeholder', 'roll', {'out': 0, 'in': 0}), - ('roll', 'abs'), - ('abs', 'output'), - ('roll_shift', 'roll', {'in': 1}), - ('mul', 'add', {'in': 1}), - ('add', 'roll', {'in': 2}), - ('zero', 'less', {'in': 1}), - ('minus_one', 'mul', {'in': 1}), - ('less', 'mul', {'in': 0}), - ('roll_axes', 'less', {'out': 0, 'in': 0}), - ('roll_axes', 'add', {'out': 0, 'in': 0}), -] - - -class CorrectRollAxesTest(unittest.TestCase): - def test_replacement(self): - graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges) - graph.stage = 'front' - CorrectRollAxes().find_and_replace_pattern(graph) - ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges) - (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) - self.assertTrue(flag, resp) - - def test_nonreplacement(self): - graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges, - update_attributes={'roll': {'input_rank_changed': False}}) - graph.stage = 'front' - CorrectRollAxes().find_and_replace_pattern(graph) - ref_graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges, - update_attributes={'roll': {'input_rank_changed': False}}) - (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) - self.assertTrue(flag, resp) diff --git a/tools/mo/unit_tests/mo/middle/SSliceComplex_test.py b/tools/mo/unit_tests/mo/middle/SSliceComplex_test.py index 7704a418254..233d1f5b538 100644 --- a/tools/mo/unit_tests/mo/middle/SSliceComplex_test.py +++ b/tools/mo/unit_tests/mo/middle/SSliceComplex_test.py @@ -176,6 +176,94 @@ ref_graph_edges_2 = [ ] +graph_node_attrs_3 = { + **regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]), + {'type': 'Parameter', 'op': 'Parameter'}), + **regular_op_with_shaped_data('strided_slice_real', int64_array([3, 100, 66, 34]), + { + 'type': 'StridedSlice', 'op': 'StridedSlice', + 'begin_mask': int64_array([0, 0, 1, 0, 0]), + 'end_mask': int64_array([0, 0, 1, 0, 0]), + 'ellipsis_mask': int64_array([0, 0, 0, 0, 0]), + 'new_axis_mask': int64_array([0, 0, 0, 0, 0]), + 'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]), + 'slices': np.array([slice(None, None, 1), + slice(None, None, 1), + 0, + slice(None, None, 1), + slice(None, None, 1)]) + }), + **regular_op_with_shaped_data('strided_slice_imag', int64_array([3, 100, 66, 34]), + { + 'type': 'StridedSlice', 'op': 'StridedSlice', + 'begin_mask': int64_array([0, 0, 1, 0, 0]), + 'end_mask': int64_array([0, 0, 1, 0, 0]), + 'ellipsis_mask': int64_array([0, 0, 0, 0, 0]), + 'new_axis_mask': int64_array([0, 0, 0, 0, 0]), + 'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]), + 'slices': np.array([slice(None, None, 1), + slice(None, None, 1), + 1, + slice(None, None, 1), + slice(None, None, 1)]) + }), + **regular_op_with_shaped_data('complex', int64_array([3, 100, 66, 34, 2]), {'op': 'Complex'}), + **regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}), + **valued_const_with_data('real_begin', int64_array([0, 0, 0, 0, 0])), + **valued_const_with_data('imag_begin', int64_array([0, 0, 1, 0, 0])), + **valued_const_with_data('real_end', int64_array([0, 0, 1, 0, 0])), + **valued_const_with_data('imag_end', int64_array([0, 0, 2, 0, 0])), + **valued_const_with_data('real_strides', int64_array([1, 1, 1, 1, 1])), + **valued_const_with_data('imag_strides', int64_array([1, 1, 1, 1, 1])), + **regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}), + **valued_const_with_data('shift', int64_array([20, 20])), + **valued_const_with_data('axis', int64_array([1, -2, -1])), + **result('output'), +} + +graph_edges_2 = [ + ('placeholder', 'placeholder_d', {'out': 0}), + ('placeholder_d', 'strided_slice_real', {'out': 0, 'in': 0}), + ('placeholder_d', 'strided_slice_imag', {'out': 0, 'in': 0}), + *connect('strided_slice_real:0', '0:complex'), + *connect('strided_slice_imag:0', '1:complex'), + *connect('real_begin:0', '1:strided_slice_real'), + *connect('imag_begin:0', '1:strided_slice_imag'), + *connect('real_end:0', '2:strided_slice_real'), + *connect('imag_end:0', '2:strided_slice_imag'), + *connect('real_strides:0', '3:strided_slice_real'), + *connect('imag_strides:0', '3:strided_slice_imag'), + *connect('complex:0', '0:roll'), + *connect('shift:0', '1:roll'), + *connect('axis:0', '2:roll'), + *connect('roll:0', '0:abs'), + *connect('abs:0', 'output'), +] + +ref_graph_node_attrs_3 = { + **regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]), + {'type': 'Parameter', 'op': 'Parameter'}), + **valued_const_with_data('perm', int64_array([0, 1, 3, 4, 2])), + **regular_op_with_shaped_data('transpose', int64_array([3, 100, 66, 34, 2]), + {'type': 'Transpose', 'op': 'Transpose'}), + **regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}), + **valued_const_with_data('shift', int64_array([20, 20])), + **valued_const_with_data('axis', int64_array([1, 3, 4])), + **regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}), + **result('output'), +} + +ref_graph_edges_3 = [ + *connect('placeholder:0', '0:transpose'), + *connect('perm:0', '1:transpose'), + *connect('transpose:0', '0:roll'), + *connect('shift:0', '1:roll'), + *connect('axis:0', '2:roll'), + *connect('roll:0', '0:abs'), + *connect('abs:0', 'output'), +] + + class SSliceComplexMiddleStageTest(unittest.TestCase): def test_replacement_for_the_last_axis(self): graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges) @@ -199,3 +287,11 @@ class SSliceComplexMiddleStageTest(unittest.TestCase): ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2) (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_replacement_with_update_roll_axes(self): + graph = build_graph(nodes_attrs=graph_node_attrs_3, edges=graph_edges_2) + SSliceComplex().find_and_replace_pattern(graph) + graph.clean_up() + ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_3, edges=ref_graph_edges_3) + (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) + self.assertTrue(flag, resp)