[MO] Fix SSliceComplex transformation (#12537)
This commit is contained in:
parent
190d692c4d
commit
d7ce8289ac
@ -516,7 +516,6 @@ openvino/tools/mo/front/tf/concat_ext.py
|
|||||||
openvino/tools/mo/front/tf/const_ext.py
|
openvino/tools/mo/front/tf/const_ext.py
|
||||||
openvino/tools/mo/front/tf/conv_ext.py
|
openvino/tools/mo/front/tf/conv_ext.py
|
||||||
openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
|
openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
|
||||||
openvino/tools/mo/front/tf/CorrectRollAxes.py
|
|
||||||
openvino/tools/mo/front/tf/crop_and_resize_ext.py
|
openvino/tools/mo/front/tf/crop_and_resize_ext.py
|
||||||
openvino/tools/mo/front/tf/CropAndResizeReplacement.py
|
openvino/tools/mo/front/tf/CropAndResizeReplacement.py
|
||||||
openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py
|
openvino/tools/mo/front/tf/CTCGreedyDecoder_ext.py
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
# Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
|
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
|
||||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
|
|
||||||
from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values
|
|
||||||
from openvino.tools.mo.graph.graph import Graph
|
|
||||||
|
|
||||||
|
|
||||||
class CorrectRollAxes(FrontReplacementSubgraph):
|
|
||||||
"""
|
|
||||||
If the Roll node is a consumer of Complex node in the original TF model, then we have a real input tensor for Roll
|
|
||||||
instead of a complex. Negative axes values for the Roll operation should be updated to reflect the fact that the
|
|
||||||
rank of input tensor was increased by one (a new trailing dimension of size 2 containing real and imaginary part
|
|
||||||
of complex number is added).
|
|
||||||
"""
|
|
||||||
enabled = True
|
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
|
||||||
for roll in graph.get_op_nodes(op='Roll', input_rank_changed=True):
|
|
||||||
add_constant_to_negative_values(roll, 2, int64_array(-1))
|
|
||||||
del roll['input_rank_changed']
|
|
@ -7,7 +7,7 @@ from typing import Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
from openvino.tools.mo.front.tf.graph_utils import add_constant_to_negative_values, create_op_with_const_inputs
|
||||||
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
|
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
|
||||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||||
from openvino.tools.mo.ops.transpose import Transpose
|
from openvino.tools.mo.ops.transpose import Transpose
|
||||||
@ -34,14 +34,14 @@ class SSliceComplex(MiddleReplacementPattern):
|
|||||||
and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}].
|
and StridedSlice nodes have output shapes [N_0, ..., N_{k - 1}, N_{k +1}, ..., N_{r - 1}].
|
||||||
|
|
||||||
But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with.
|
But MO and Inference Engine do not support complex tensors. Hence, we need to replace this sub-graph with.
|
||||||
If k == r - 1, then the replacement should be the subgraph
|
1. If k == r - 1, then the replacement should be the subgraph
|
||||||
|
|
||||||
SomeOp other inputs
|
SomeOp other inputs
|
||||||
| | ... |
|
| | ... |
|
||||||
-------------------
|
-------------------
|
||||||
SomeOp1
|
SomeOp1
|
||||||
|
|
||||||
In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph
|
2. In the other case, that is if 0 <= k and k < r - 1 the replacement should be the subgraph
|
||||||
|
|
||||||
SomeOp
|
SomeOp
|
||||||
|
|
|
|
||||||
@ -54,9 +54,6 @@ class SSliceComplex(MiddleReplacementPattern):
|
|||||||
SomeOp1
|
SomeOp1
|
||||||
|
|
||||||
where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k].
|
where the input_order is a Constant, and the value of input_order is [0, ..., k - 1, k + 1, ..., r - 1, k].
|
||||||
|
|
||||||
After this transformation we need to mark SomeOp1 operation that its input rank has changed because
|
|
||||||
its inputs/attributes should probably be updated. Currently we have such a case for a Roll operation.
|
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
|
|
||||||
@ -109,10 +106,16 @@ class SSliceComplex(MiddleReplacementPattern):
|
|||||||
|
|
||||||
for dst in complex_node.out_port(0).get_connection().get_destinations():
|
for dst in complex_node.out_port(0).get_connection().get_destinations():
|
||||||
after_complex_node = dst.node
|
after_complex_node = dst.node
|
||||||
after_complex_node['input_rank_changed'] = True
|
# TODO: now it does not support adjustment of `axis` inputs for other operations such Gather, Concat, etc.
|
||||||
|
# It does not traverse the full path affected by complex numbers for adjusting the corresponding operations.
|
||||||
|
# It can affect other models with complex numbers for which we can generate incorrect IRs or offline transformation fails.
|
||||||
|
if after_complex_node.type == 'Roll':
|
||||||
|
add_constant_to_negative_values(after_complex_node, 2, int64_array(emulated_complex_tensor_rank))
|
||||||
|
|
||||||
input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0
|
input_slices_have_ellipsis = len(np.argwhere(real_slices == Ellipsis).flatten()) != 0
|
||||||
|
|
||||||
|
# If output of SomeOp is sliced on the last dimension on the last dimension (like described in 1 case), skipping Complex op is enough.
|
||||||
|
# Otherwise, (like described in 2 case) Transpose insertion is needed to align data arrangement.
|
||||||
if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis:
|
if slice_dim_for_real_part == emulated_complex_tensor_rank - 1 or input_slices_have_ellipsis:
|
||||||
complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source())
|
complex_node.out_port(0).get_connection().set_source(strided_slice_real.in_port(0).get_source())
|
||||||
else:
|
else:
|
||||||
@ -124,4 +127,4 @@ class SSliceComplex(MiddleReplacementPattern):
|
|||||||
{'name': complex_node_name + '/cmplx'})
|
{'name': complex_node_name + '/cmplx'})
|
||||||
complex_node.out_port(0).get_connection().set_source(transpose.out_port(0))
|
complex_node.out_port(0).get_connection().set_source(transpose.out_port(0))
|
||||||
strided_slice_real.in_port(0).get_source().connect(transpose.in_port(0))
|
strided_slice_real.in_port(0).get_source().connect(transpose.in_port(0))
|
||||||
rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)])
|
rename_nodes([(complex_node, complex_node_name + '/to_be_removed'), (transpose, complex_node_name)])
|
||||||
|
@ -1,89 +0,0 @@
|
|||||||
# Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from openvino.tools.mo.front.tf.CorrectRollAxes import CorrectRollAxes
|
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
|
||||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
|
||||||
from unit_tests.utils.graph import build_graph
|
|
||||||
|
|
||||||
|
|
||||||
graph_node_attrs = {
|
|
||||||
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
||||||
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll', 'input_rank_changed': True},
|
|
||||||
'roll_shift': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
|
|
||||||
},
|
|
||||||
'roll_axes': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
|
|
||||||
},
|
|
||||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
|
||||||
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
|
||||||
}
|
|
||||||
|
|
||||||
graph_edges = [
|
|
||||||
('placeholder', 'roll', {'in': 0}),
|
|
||||||
('roll', 'abs'),
|
|
||||||
('abs', 'output'),
|
|
||||||
('roll_shift', 'roll', {'in': 1}),
|
|
||||||
('roll_axes', 'roll', {'in': 2}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
ref_graph_node_attrs = {
|
|
||||||
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
||||||
'roll': {'kind': 'op', 'op': 'Roll', 'type': 'Roll'},
|
|
||||||
'roll_shift': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([50, 50])
|
|
||||||
},
|
|
||||||
'roll_axes': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([-2, -1])
|
|
||||||
},
|
|
||||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
|
||||||
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
|
||||||
'add': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
|
|
||||||
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
|
|
||||||
'less': {'type': 'Less', 'kind': 'op', 'op': 'Less'},
|
|
||||||
'zero': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(0)
|
|
||||||
},
|
|
||||||
'minus_one': {
|
|
||||||
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([]), 'value': int64_array(-1)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ref_graph_edges = [
|
|
||||||
('placeholder', 'roll', {'out': 0, 'in': 0}),
|
|
||||||
('roll', 'abs'),
|
|
||||||
('abs', 'output'),
|
|
||||||
('roll_shift', 'roll', {'in': 1}),
|
|
||||||
('mul', 'add', {'in': 1}),
|
|
||||||
('add', 'roll', {'in': 2}),
|
|
||||||
('zero', 'less', {'in': 1}),
|
|
||||||
('minus_one', 'mul', {'in': 1}),
|
|
||||||
('less', 'mul', {'in': 0}),
|
|
||||||
('roll_axes', 'less', {'out': 0, 'in': 0}),
|
|
||||||
('roll_axes', 'add', {'out': 0, 'in': 0}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class CorrectRollAxesTest(unittest.TestCase):
|
|
||||||
def test_replacement(self):
|
|
||||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
|
||||||
graph.stage = 'front'
|
|
||||||
CorrectRollAxes().find_and_replace_pattern(graph)
|
|
||||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
|
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
|
||||||
self.assertTrue(flag, resp)
|
|
||||||
|
|
||||||
def test_nonreplacement(self):
|
|
||||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
|
|
||||||
update_attributes={'roll': {'input_rank_changed': False}})
|
|
||||||
graph.stage = 'front'
|
|
||||||
CorrectRollAxes().find_and_replace_pattern(graph)
|
|
||||||
ref_graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges,
|
|
||||||
update_attributes={'roll': {'input_rank_changed': False}})
|
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
|
||||||
self.assertTrue(flag, resp)
|
|
@ -176,6 +176,94 @@ ref_graph_edges_2 = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
graph_node_attrs_3 = {
|
||||||
|
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
|
||||||
|
{'type': 'Parameter', 'op': 'Parameter'}),
|
||||||
|
**regular_op_with_shaped_data('strided_slice_real', int64_array([3, 100, 66, 34]),
|
||||||
|
{
|
||||||
|
'type': 'StridedSlice', 'op': 'StridedSlice',
|
||||||
|
'begin_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'end_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||||
|
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||||
|
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'slices': np.array([slice(None, None, 1),
|
||||||
|
slice(None, None, 1),
|
||||||
|
0,
|
||||||
|
slice(None, None, 1),
|
||||||
|
slice(None, None, 1)])
|
||||||
|
}),
|
||||||
|
**regular_op_with_shaped_data('strided_slice_imag', int64_array([3, 100, 66, 34]),
|
||||||
|
{
|
||||||
|
'type': 'StridedSlice', 'op': 'StridedSlice',
|
||||||
|
'begin_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'end_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'ellipsis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||||
|
'new_axis_mask': int64_array([0, 0, 0, 0, 0]),
|
||||||
|
'shrink_axis_mask': int64_array([0, 0, 1, 0, 0]),
|
||||||
|
'slices': np.array([slice(None, None, 1),
|
||||||
|
slice(None, None, 1),
|
||||||
|
1,
|
||||||
|
slice(None, None, 1),
|
||||||
|
slice(None, None, 1)])
|
||||||
|
}),
|
||||||
|
**regular_op_with_shaped_data('complex', int64_array([3, 100, 66, 34, 2]), {'op': 'Complex'}),
|
||||||
|
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
|
||||||
|
**valued_const_with_data('real_begin', int64_array([0, 0, 0, 0, 0])),
|
||||||
|
**valued_const_with_data('imag_begin', int64_array([0, 0, 1, 0, 0])),
|
||||||
|
**valued_const_with_data('real_end', int64_array([0, 0, 1, 0, 0])),
|
||||||
|
**valued_const_with_data('imag_end', int64_array([0, 0, 2, 0, 0])),
|
||||||
|
**valued_const_with_data('real_strides', int64_array([1, 1, 1, 1, 1])),
|
||||||
|
**valued_const_with_data('imag_strides', int64_array([1, 1, 1, 1, 1])),
|
||||||
|
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
|
||||||
|
**valued_const_with_data('shift', int64_array([20, 20])),
|
||||||
|
**valued_const_with_data('axis', int64_array([1, -2, -1])),
|
||||||
|
**result('output'),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_edges_2 = [
|
||||||
|
('placeholder', 'placeholder_d', {'out': 0}),
|
||||||
|
('placeholder_d', 'strided_slice_real', {'out': 0, 'in': 0}),
|
||||||
|
('placeholder_d', 'strided_slice_imag', {'out': 0, 'in': 0}),
|
||||||
|
*connect('strided_slice_real:0', '0:complex'),
|
||||||
|
*connect('strided_slice_imag:0', '1:complex'),
|
||||||
|
*connect('real_begin:0', '1:strided_slice_real'),
|
||||||
|
*connect('imag_begin:0', '1:strided_slice_imag'),
|
||||||
|
*connect('real_end:0', '2:strided_slice_real'),
|
||||||
|
*connect('imag_end:0', '2:strided_slice_imag'),
|
||||||
|
*connect('real_strides:0', '3:strided_slice_real'),
|
||||||
|
*connect('imag_strides:0', '3:strided_slice_imag'),
|
||||||
|
*connect('complex:0', '0:roll'),
|
||||||
|
*connect('shift:0', '1:roll'),
|
||||||
|
*connect('axis:0', '2:roll'),
|
||||||
|
*connect('roll:0', '0:abs'),
|
||||||
|
*connect('abs:0', 'output'),
|
||||||
|
]
|
||||||
|
|
||||||
|
ref_graph_node_attrs_3 = {
|
||||||
|
**regular_op_with_shaped_data('placeholder', int64_array([3, 100, 2, 66, 34]),
|
||||||
|
{'type': 'Parameter', 'op': 'Parameter'}),
|
||||||
|
**valued_const_with_data('perm', int64_array([0, 1, 3, 4, 2])),
|
||||||
|
**regular_op_with_shaped_data('transpose', int64_array([3, 100, 66, 34, 2]),
|
||||||
|
{'type': 'Transpose', 'op': 'Transpose'}),
|
||||||
|
**regular_op_with_shaped_data('roll', int64_array([3, 100, 66, 34, 2]), {'type': 'Roll', 'op': 'Roll'}),
|
||||||
|
**valued_const_with_data('shift', int64_array([20, 20])),
|
||||||
|
**valued_const_with_data('axis', int64_array([1, 3, 4])),
|
||||||
|
**regular_op_with_shaped_data('abs', int64_array([3, 100, 66, 34, 2]), {'type': 'Abs', 'op': 'Abs'}),
|
||||||
|
**result('output'),
|
||||||
|
}
|
||||||
|
|
||||||
|
ref_graph_edges_3 = [
|
||||||
|
*connect('placeholder:0', '0:transpose'),
|
||||||
|
*connect('perm:0', '1:transpose'),
|
||||||
|
*connect('transpose:0', '0:roll'),
|
||||||
|
*connect('shift:0', '1:roll'),
|
||||||
|
*connect('axis:0', '2:roll'),
|
||||||
|
*connect('roll:0', '0:abs'),
|
||||||
|
*connect('abs:0', 'output'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SSliceComplexMiddleStageTest(unittest.TestCase):
|
class SSliceComplexMiddleStageTest(unittest.TestCase):
|
||||||
def test_replacement_for_the_last_axis(self):
|
def test_replacement_for_the_last_axis(self):
|
||||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
||||||
@ -199,3 +287,11 @@ class SSliceComplexMiddleStageTest(unittest.TestCase):
|
|||||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2)
|
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_2, edges=ref_graph_edges_2)
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_replacement_with_update_roll_axes(self):
|
||||||
|
graph = build_graph(nodes_attrs=graph_node_attrs_3, edges=graph_edges_2)
|
||||||
|
SSliceComplex().find_and_replace_pattern(graph)
|
||||||
|
graph.clean_up()
|
||||||
|
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_3, edges=ref_graph_edges_3)
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user