[MO] Fix SSliceComplex transformation (#12537)

This commit is contained in:
Mateusz Bencer 2022-08-19 12:14:10 +02:00 committed by GitHub
parent 190d692c4d
commit d7ce8289ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 121 deletions

View File

@ -516,7 +516,6 @@ openvino/tools/mo/front/tf/concat_ext.py
openvino/tools/mo/front/tf/const_ext.py openvino/tools/mo/front/tf/const_ext.py
openvino/tools/mo/front/tf/conv_ext.py openvino/tools/mo/front/tf/conv_ext.py
openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
openvino/tools/mo/front/tf/CorrectRollAxes.py
openvino/tools/mo/front/tf/crop_and_resize_ext.py openvino/tools/mo/front/tf/crop_and_resize_ext.py
openvino/tools/mo/front/tf/CropAndResizeReplacement.py openvino/tools/mo/front/tf/CropAndResizeReplacement.py
openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py

View File

@ -1,23 +0,0 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values
from openvino.tools.mo.graph.graph import Graph
class CorrectRollAxes(FrontReplacementSubgraph):
"""
If the Roll node is a consumer of Complex node in the original TF model, then we have a real input tensor for Roll
instead of a complex. Negative axes values for the Roll operation should be updated to reflect the fact that the
rank of input tensor was increased by one (a new trailing dimension of size 2 containing real and imaginary part
of complex number is added).
"""
enabled = True
def find_and_replace_pattern(self, graph: Graph):
for roll in graph.get_op_nodes(op='Roll', input_rank_changed=True):
add_constant_to_negative_values(roll, 2, int64_array(-1))
del roll['input_rank_changed']

View File

@ -7,7 +7,7 @@ from typing import Dict
import numpy as np import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import int64_array from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values, create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
from openvino.tools.mo.ops.transpose import Transpose from openvino.tools.mo.ops.transpose import Transpose
@ -34,14 +34,14 @@ class SSliceComplex(MiddleReplacementPattern):
and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}]. and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}].
But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with. But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with.
If k == r - 1, then the replacement should be the subgraph 1. If k == r - 1, then the replacement should be the subgraph
SomeOp other inputs SomeOp other inputs
| | ... | | | ... |
------------------- -------------------
SomeOp1 SomeOp1
In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph 2. In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph
SomeOp SomeOp
| |
@ -54,9 +54,6 @@ class SSliceComplex(MiddleReplacementPattern):
SomeOp1 SomeOp1
where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k]. where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k].
After this transformation we need to mark SomeOp1 operation that its input rank has changed because
its inputs/attributes should probably be updated. Currently we have such a case for a Roll operation.
""" """
enabled = True enabled = True
@ -109,10 +106,16 @@ class SSliceComplex(MiddleReplacementPattern):
for dst in complex_node.out_port(0).get_connection().get_destinations(): for dst in complex_node.out_port(0).get_connection().get_destinations():
after_complex_node = dst.node after_complex_node = dst.node
after_complex_node['input_rank_changed'] = True # TODO: now it does not support adjustment of `axis` inputs for other operations such Gather, Concat, etc.
# It does not traverse the full path affected by complex numbers for adjusting the corresponding operations.
# It can affect other models with complex numbers for which we can generate incorrect IRs or offline transformation fails.
if after_complex_node.type == 'Roll':
add_constant_to_negative_values(after_complex_node, 2, int64_array(emulated_complex_tensor_rank))
input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0 input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0
# If output of SomeOp is sliced on the last dimension on the last dimension (like described in 1 case), skipping Complex op is enough.
# Otherwise, (like described in 2 case) Transpose insertion is needed to align data arrangement.
if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis: if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis:
complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source()) complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source())
else: else:

View File

@ -1,89 +0,0 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from openvino.tools.mo.front.tf.CorrectRollAxes import CorrectRollAxes
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph
graph_node_attrs = {
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll', 'input_rank_changed': True},
'roll_shift': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
},
'roll_axes': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
}
graph_edges = [
('placeholder', 'roll', {'in': 0}),
('roll', 'abs'),
('abs', 'output'),
('roll_shift', 'roll', {'in': 1}),
('roll_axes', 'roll', {'in': 2}),
]
ref_graph_node_attrs = {
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll'},
'roll_shift': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
},
'roll_axes': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
'add': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
'less': {'type': 'Less', 'kind': 'op', 'op': 'Less'},
'zero': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(0)
},
'minus_one': {
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(-1)
},
}
ref_graph_edges = [
('placeholder', 'roll', {'out': 0, 'in': 0}),
('roll', 'abs'),
('abs', 'output'),
('roll_shift', 'roll', {'in': 1}),
('mul', 'add', {'in': 1}),
('add', 'roll', {'in': 2}),
('zero', 'less', {'in': 1}),
('minus_one', 'mul', {'in': 1}),
('less', 'mul', {'in': 0}),
('roll_axes', 'less', {'out': 0, 'in': 0}),
('roll_axes', 'add', {'out': 0, 'in': 0}),
]
class CorrectRollAxesTest(unittest.TestCase):
def test_replacement(self):
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
graph.stage = 'front'
CorrectRollAxes().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_nonreplacement(self):
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
update_attributes={'roll': {'input_rank_changed': False}})
graph.stage = 'front'
CorrectRollAxes().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
update_attributes={'roll': {'input_rank_changed': False}})
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -176,6 +176,94 @@ ref_graph_edges_2 = [
] ]
graph_node_attrs_3 = {
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
{'type': 'Parameter', 'op': 'Parameter'}),
**regular_op_with_shaped_data('strided_slice_real', int64_array([3, 100, 66, 34]),
{
'type': 'StridedSlice', 'op': 'StridedSlice',
'begin_mask': int64_array([0, 0, 1, 0, 0]),
'end_mask': int64_array([0, 0, 1, 0, 0]),
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
'slices': np.array([slice(None, None, 1),
slice(None, None, 1),
0,
slice(None, None, 1),
slice(None, None, 1)])
}),
**regular_op_with_shaped_data('strided_slice_imag', int64_array([3, 100, 66, 34]),
{
'type': 'StridedSlice', 'op': 'StridedSlice',
'begin_mask': int64_array([0, 0, 1, 0, 0]),
'end_mask': int64_array([0, 0, 1, 0, 0]),
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
'slices': np.array([slice(None, None, 1),
slice(None, None, 1),
1,
slice(None, None, 1),
slice(None, None, 1)])
}),
**regular_op_with_shaped_data('complex', int64_array([3, 100, 66, 34, 2]), {'op': 'Complex'}),
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
**valued_const_with_data('real_begin', int64_array([0, 0, 0, 0, 0])),
**valued_const_with_data('imag_begin', int64_array([0, 0, 1, 0, 0])),
**valued_const_with_data('real_end', int64_array([0, 0, 1, 0, 0])),
**valued_const_with_data('imag_end', int64_array([0, 0, 2, 0, 0])),
**valued_const_with_data('real_strides', int64_array([1, 1, 1, 1, 1])),
**valued_const_with_data('imag_strides', int64_array([1, 1, 1, 1, 1])),
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
**valued_const_with_data('shift', int64_array([20, 20])),
**valued_const_with_data('axis', int64_array([1, -2, -1])),
**result('output'),
}
graph_edges_2 = [
('placeholder', 'placeholder_d', {'out': 0}),
('placeholder_d', 'strided_slice_real', {'out': 0, 'in': 0}),
('placeholder_d', 'strided_slice_imag', {'out': 0, 'in': 0}),
*connect('strided_slice_real:0', '0:complex'),
*connect('strided_slice_imag:0', '1:complex'),
*connect('real_begin:0', '1:strided_slice_real'),
*connect('imag_begin:0', '1:strided_slice_imag'),
*connect('real_end:0', '2:strided_slice_real'),
*connect('imag_end:0', '2:strided_slice_imag'),
*connect('real_strides:0', '3:strided_slice_real'),
*connect('imag_strides:0', '3:strided_slice_imag'),
*connect('complex:0', '0:roll'),
*connect('shift:0', '1:roll'),
*connect('axis:0', '2:roll'),
*connect('roll:0', '0:abs'),
*connect('abs:0', 'output'),
]
ref_graph_node_attrs_3 = {
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
{'type': 'Parameter', 'op': 'Parameter'}),
**valued_const_with_data('perm', int64_array([0, 1, 3, 4, 2])),
**regular_op_with_shaped_data('transpose', int64_array([3, 100, 66, 34, 2]),
{'type': 'Transpose', 'op': 'Transpose'}),
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
**valued_const_with_data('shift', int64_array([20, 20])),
**valued_const_with_data('axis', int64_array([1, 3, 4])),
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
**result('output'),
}
ref_graph_edges_3 = [
*connect('placeholder:0', '0:transpose'),
*connect('perm:0', '1:transpose'),
*connect('transpose:0', '0:roll'),
*connect('shift:0', '1:roll'),
*connect('axis:0', '2:roll'),
*connect('roll:0', '0:abs'),
*connect('abs:0', 'output'),
]
class SSliceComplexMiddleStageTest(unittest.TestCase): class SSliceComplexMiddleStageTest(unittest.TestCase):
def test_replacement_for_the_last_axis(self): def test_replacement_for_the_last_axis(self):
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges) graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
@ -199,3 +287,11 @@ class SSliceComplexMiddleStageTest(unittest.TestCase):
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2) ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2)
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_replacement_with_update_roll_axes(self):
graph = build_graph(nodes_attrs=graph_node_attrs_3, edges=graph_edges_2)
SSliceComplex().find_and_replace_pattern(graph)
graph.clean_up()
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_3, edges=ref_graph_edges_3)
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)