[MO] Fix SSliceComplex transformation (#12537)
This commit is contained in:
parent
190d692c4d
commit
d7ce8289ac
@ -516,7 +516,6 @@ openvino/tools/mo/front/tf/concat_ext.py
|
||||
openvino/tools/mo/front/tf/const_ext.py
|
||||
openvino/tools/mo/front/tf/conv_ext.py
|
||||
openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
|
||||
openvino/tools/mo/front/tf/CorrectRollAxes.py
|
||||
openvino/tools/mo/front/tf/crop_and_resize_ext.py
|
||||
openvino/tools/mo/front/tf/CropAndResizeReplacement.py
|
||||
openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py
|
||||
|
@ -1,23 +0,0 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
|
||||
|
||||
class CorrectRollAxes(FrontReplacementSubgraph):
|
||||
"""
|
||||
If the Roll node is a consumer of Complex node in the original TF model, then we have a real input tensor for Roll
|
||||
instead of a complex. Negative axes values for the Roll operation should be updated to reflect the fact that the
|
||||
rank of input tensor was increased by one (a new trailing dimension of size 2 containing real and imaginary part
|
||||
of complex number is added).
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for roll in graph.get_op_nodes(op='Roll', input_rank_changed=True):
|
||||
add_constant_to_negative_values(roll, 2, int64_array(-1))
|
||||
del roll['input_rank_changed']
|
@ -7,7 +7,7 @@ from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values, create_op_with_const_inputs
|
||||
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
|
||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||
from openvino.tools.mo.ops.transpose import Transpose
|
||||
@ -34,14 +34,14 @@ class SSliceComplex(MiddleReplacementPattern):
|
||||
and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}].
|
||||
|
||||
But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with.
|
||||
If k == r - 1, then the replacement should be the subgraph
|
||||
1. If k == r - 1, then the replacement should be the subgraph
|
||||
|
||||
SomeOp other inputs
|
||||
| | ... |
|
||||
-------------------
|
||||
SomeOp1
|
||||
|
||||
In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph
|
||||
2. In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph
|
||||
|
||||
SomeOp
|
||||
|
|
||||
@ -54,9 +54,6 @@ class SSliceComplex(MiddleReplacementPattern):
|
||||
SomeOp1
|
||||
|
||||
where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k].
|
||||
|
||||
After this transformation we need to mark SomeOp1 operation that its input rank has changed because
|
||||
its inputs/attributes should probably be updated. Currently we have such a case for a Roll operation.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
@ -109,10 +106,16 @@ class SSliceComplex(MiddleReplacementPattern):
|
||||
|
||||
for dst in complex_node.out_port(0).get_connection().get_destinations():
|
||||
after_complex_node = dst.node
|
||||
after_complex_node['input_rank_changed'] = True
|
||||
# TODO: now it does not support adjustment of `axis` inputs for other operations such Gather, Concat, etc.
|
||||
# It does not traverse the full path affected by complex numbers for adjusting the corresponding operations.
|
||||
# It can affect other models with complex numbers for which we can generate incorrect IRs or offline transformation fails.
|
||||
if after_complex_node.type == 'Roll':
|
||||
add_constant_to_negative_values(after_complex_node, 2, int64_array(emulated_complex_tensor_rank))
|
||||
|
||||
input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0
|
||||
|
||||
# If output of SomeOp is sliced on the last dimension on the last dimension (like described in 1 case), skipping Complex op is enough.
|
||||
# Otherwise, (like described in 2 case) Transpose insertion is needed to align data arrangement.
|
||||
if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis:
|
||||
complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source())
|
||||
else:
|
||||
@ -124,4 +127,4 @@ class SSliceComplex(MiddleReplacementPattern):
|
||||
{'name': complex_node_name + '/cmplx'})
|
||||
complex_node.out_port(0).get_connection().set_source(transpose.out_port(0))
|
||||
strided_slice_real.in_port(0).get_source().connect(transpose.in_port(0))
|
||||
rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)])
|
||||
rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)])
|
||||
|
@ -1,89 +0,0 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from openvino.tools.mo.front.tf.CorrectRollAxes import CorrectRollAxes
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
|
||||
graph_node_attrs = {
|
||||
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll', 'input_rank_changed': True},
|
||||
'roll_shift': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
|
||||
},
|
||||
'roll_axes': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
|
||||
},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
}
|
||||
|
||||
graph_edges = [
|
||||
('placeholder', 'roll', {'in': 0}),
|
||||
('roll', 'abs'),
|
||||
('abs', 'output'),
|
||||
('roll_shift', 'roll', {'in': 1}),
|
||||
('roll_axes', 'roll', {'in': 2}),
|
||||
]
|
||||
|
||||
|
||||
ref_graph_node_attrs = {
|
||||
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll'},
|
||||
'roll_shift': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
|
||||
},
|
||||
'roll_axes': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
|
||||
},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
'add': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
|
||||
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
|
||||
'less': {'type': 'Less', 'kind': 'op', 'op': 'Less'},
|
||||
'zero': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(0)
|
||||
},
|
||||
'minus_one': {
|
||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(-1)
|
||||
},
|
||||
}
|
||||
|
||||
ref_graph_edges = [
|
||||
('placeholder', 'roll', {'out': 0, 'in': 0}),
|
||||
('roll', 'abs'),
|
||||
('abs', 'output'),
|
||||
('roll_shift', 'roll', {'in': 1}),
|
||||
('mul', 'add', {'in': 1}),
|
||||
('add', 'roll', {'in': 2}),
|
||||
('zero', 'less', {'in': 1}),
|
||||
('minus_one', 'mul', {'in': 1}),
|
||||
('less', 'mul', {'in': 0}),
|
||||
('roll_axes', 'less', {'out': 0, 'in': 0}),
|
||||
('roll_axes', 'add', {'out': 0, 'in': 0}),
|
||||
]
|
||||
|
||||
|
||||
class CorrectRollAxesTest(unittest.TestCase):
|
||||
def test_replacement(self):
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
||||
graph.stage = 'front'
|
||||
CorrectRollAxes().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_nonreplacement(self):
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
|
||||
update_attributes={'roll': {'input_rank_changed': False}})
|
||||
graph.stage = 'front'
|
||||
CorrectRollAxes().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
|
||||
update_attributes={'roll': {'input_rank_changed': False}})
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -176,6 +176,94 @@ ref_graph_edges_2 = [
|
||||
]
|
||||
|
||||
|
||||
graph_node_attrs_3 = {
|
||||
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
|
||||
{'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('strided_slice_real', int64_array([3, 100, 66, 34]),
|
||||
{
|
||||
'type': 'StridedSlice', 'op': 'StridedSlice',
|
||||
'begin_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'end_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'slices': np.array([slice(None, None, 1),
|
||||
slice(None, None, 1),
|
||||
0,
|
||||
slice(None, None, 1),
|
||||
slice(None, None, 1)])
|
||||
}),
|
||||
**regular_op_with_shaped_data('strided_slice_imag', int64_array([3, 100, 66, 34]),
|
||||
{
|
||||
'type': 'StridedSlice', 'op': 'StridedSlice',
|
||||
'begin_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'end_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
|
||||
'slices': np.array([slice(None, None, 1),
|
||||
slice(None, None, 1),
|
||||
1,
|
||||
slice(None, None, 1),
|
||||
slice(None, None, 1)])
|
||||
}),
|
||||
**regular_op_with_shaped_data('complex', int64_array([3, 100, 66, 34, 2]), {'op': 'Complex'}),
|
||||
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
|
||||
**valued_const_with_data('real_begin', int64_array([0, 0, 0, 0, 0])),
|
||||
**valued_const_with_data('imag_begin', int64_array([0, 0, 1, 0, 0])),
|
||||
**valued_const_with_data('real_end', int64_array([0, 0, 1, 0, 0])),
|
||||
**valued_const_with_data('imag_end', int64_array([0, 0, 2, 0, 0])),
|
||||
**valued_const_with_data('real_strides', int64_array([1, 1, 1, 1, 1])),
|
||||
**valued_const_with_data('imag_strides', int64_array([1, 1, 1, 1, 1])),
|
||||
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
|
||||
**valued_const_with_data('shift', int64_array([20, 20])),
|
||||
**valued_const_with_data('axis', int64_array([1, -2, -1])),
|
||||
**result('output'),
|
||||
}
|
||||
|
||||
graph_edges_2 = [
|
||||
('placeholder', 'placeholder_d', {'out': 0}),
|
||||
('placeholder_d', 'strided_slice_real', {'out': 0, 'in': 0}),
|
||||
('placeholder_d', 'strided_slice_imag', {'out': 0, 'in': 0}),
|
||||
*connect('strided_slice_real:0', '0:complex'),
|
||||
*connect('strided_slice_imag:0', '1:complex'),
|
||||
*connect('real_begin:0', '1:strided_slice_real'),
|
||||
*connect('imag_begin:0', '1:strided_slice_imag'),
|
||||
*connect('real_end:0', '2:strided_slice_real'),
|
||||
*connect('imag_end:0', '2:strided_slice_imag'),
|
||||
*connect('real_strides:0', '3:strided_slice_real'),
|
||||
*connect('imag_strides:0', '3:strided_slice_imag'),
|
||||
*connect('complex:0', '0:roll'),
|
||||
*connect('shift:0', '1:roll'),
|
||||
*connect('axis:0', '2:roll'),
|
||||
*connect('roll:0', '0:abs'),
|
||||
*connect('abs:0', 'output'),
|
||||
]
|
||||
|
||||
ref_graph_node_attrs_3 = {
|
||||
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
|
||||
{'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**valued_const_with_data('perm', int64_array([0, 1, 3, 4, 2])),
|
||||
**regular_op_with_shaped_data('transpose', int64_array([3, 100, 66, 34, 2]),
|
||||
{'type': 'Transpose', 'op': 'Transpose'}),
|
||||
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
|
||||
**valued_const_with_data('shift', int64_array([20, 20])),
|
||||
**valued_const_with_data('axis', int64_array([1, 3, 4])),
|
||||
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
|
||||
**result('output'),
|
||||
}
|
||||
|
||||
ref_graph_edges_3 = [
|
||||
*connect('placeholder:0', '0:transpose'),
|
||||
*connect('perm:0', '1:transpose'),
|
||||
*connect('transpose:0', '0:roll'),
|
||||
*connect('shift:0', '1:roll'),
|
||||
*connect('axis:0', '2:roll'),
|
||||
*connect('roll:0', '0:abs'),
|
||||
*connect('abs:0', 'output'),
|
||||
]
|
||||
|
||||
|
||||
class SSliceComplexMiddleStageTest(unittest.TestCase):
|
||||
def test_replacement_for_the_last_axis(self):
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
||||
@ -199,3 +287,11 @@ class SSliceComplexMiddleStageTest(unittest.TestCase):
|
||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_replacement_with_update_roll_axes(self):
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs_3, edges=graph_edges_2)
|
||||
SSliceComplex().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_3, edges=ref_graph_edges_3)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user