Support dynamic Broadcast and new pattern for TI condition (#9735)

* Support dynamic Broadcast and new pattern for TI condition

* Apply review feedback

* Fix broadcast if statement
This commit is contained in:
Maxim Vafin
2022-02-22 16:46:48 +03:00
committed by GitHub
parent 487bb67995
commit 435584bb91
4 changed files with 243 additions and 163 deletions

View File

@@ -5,6 +5,7 @@ import logging as log
import numpy as np
from openvino.tools.mo.middle.pattern_match import apply_pattern
from openvino.tools.mo.middle.TensorIterator_utils import delete_selects_from
from openvino.tools.mo.ops.TensorIterator_ops import TensorIteratorCondition, TensorIteratorBackEdge
from openvino.tools.mo.ops.identity import Identity
@@ -66,134 +67,140 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
return [TensorIteratorMerge]
@staticmethod
def pattern():
def pattern(variation):
log.debug('+++++++++++++++ ConditionMatching ++++++++++++++++')
return dict(
nodes=[
('Enter_1_less', dict(kind='op', op='Enter')),
('Strided_slice', dict(kind='op', op='StridedSlice')),
('Strided_slice_data', dict(kind='data')),
('Enter_1_less_data', dict(kind='data')),
nodes = [
('Enter_1_less', dict(kind='op', op='Enter')),
('Strided_slice', dict(kind='op', op='StridedSlice')),
('Strided_slice_data', dict(kind='data')),
('Enter_1_less_data', dict(kind='data')),
('Less_1', dict(kind='op', op='Less')),
('Merge_1', dict(kind='op', op='Merge')),
('Merge_1_data', dict(kind='data')),
('Less_1_data', dict(kind='data')),
('Less_1', dict(kind='op', op='Less')),
('Merge_1', dict(kind='op', op='Merge')),
('Merge_1_data', dict(kind='data')),
('Less_1_data', dict(kind='data')),
('Less_2', dict(kind='op', op='Less')),
('Merge_2', dict(kind='op', op='Merge')),
('Merge_2_data', dict(kind='data')),
('Less_2_data', dict(kind='data')),
('Less_2', dict(kind='op', op='Less')),
('Merge_2', dict(kind='op', op='Merge')),
('Merge_2_data', dict(kind='data')),
('Less_2_data', dict(kind='data')),
('and', dict(kind='op', op='LogicalAnd')),
('and_data', dict(kind='data')),
('loop_cond', dict(kind='op', op='LoopCond')),
('loop_cond_data', dict(kind='data')),
('init_1', dict(kind='op', op='Const')),
('init_1_data', dict(kind='data')),
('Enter_1', dict(kind='op', op='Enter')),
('Enter_1_data', dict(kind='data')),
('init_2', dict(kind='op', op='Const')),
('init_2_data', dict(kind='data')),
('Enter_2', dict(kind='op', op='Enter')),
('Enter_2_data', dict(kind='data')),
('Switch_1', dict(kind='op', op='Switch')),
('Switch_1_data', dict(kind='data')),
('Identity_1', dict(kind='op', op='Identity')),
('Identity_1_data', dict(kind='data')),
('add_1', dict(kind='op', op='Add')),
('add_1_y', dict(kind='op', op='Const')),
('add_1_y_data', dict(kind='data')),
('add_1_data', dict(kind='data')),
('NextIteration_1', dict(kind='op', op='NextIteration')),
('Switch_2', dict(kind='op', op='Switch')),
('Switch_2_data', dict(kind='data')),
('Identity_2', dict(kind='op', op='Identity')),
('Identity_2_data', dict(kind='data')),
('add_2', dict(kind='op', op='Add')),
('add_2_y', dict(kind='op', op='Const')),
('add_2_y_data', dict(kind='data')),
('add_2_data', dict(kind='data')),
('NextIteration_2', dict(kind='op', op='NextIteration')),
]
edges = [
('Strided_slice', 'Strided_slice_data'),
('Strided_slice_data', 'Enter_1_less'),
('Enter_1_less', 'Enter_1_less_data'),
('Enter_1_less_data', 'Less_1'),
('Less_1', 'Less_1_data'),
('Less_1_data', 'and'),
('and', 'and_data'),
('and_data', 'loop_cond'),
('loop_cond', 'loop_cond_data'),
('loop_cond_data', 'Switch_1'),
('loop_cond_data', 'Switch_2'),
('init_1', 'init_1_data'),
('init_1_data', 'Enter_1'),
('Enter_1', 'Enter_1_data'),
('Enter_1_data', 'Merge_1'),
('Merge_1', 'Merge_1_data'),
('Merge_1_data', 'Less_1'),
('Merge_1_data', 'Switch_1'),
('Switch_1', 'Switch_1_data'),
('Switch_1_data', 'Identity_1'),
('Identity_1', 'Identity_1_data'),
('Identity_1_data', 'add_1'),
('add_1_y', 'add_1_y_data'),
('add_1_y_data', 'add_1'),
('add_1', 'add_1_data'),
('add_1_data', 'NextIteration_1'),
('Merge_2_data', 'Switch_2'),
('Switch_2', 'Switch_2_data'),
('Switch_2_data', 'Identity_2'),
('Identity_2', 'Identity_2_data'),
('Identity_2_data', 'add_2'),
('add_2_y', 'add_2_y_data'),
('add_2_y_data', 'add_2'),
('add_2', 'add_2_data'),
('add_2_data', 'NextIteration_2'),
('init_2', 'init_2_data'),
('init_2_data', 'Enter_2'),
('Enter_2', 'Enter_2_data'),
('Enter_2_data', 'Merge_2'),
('Merge_2', 'Merge_2_data'),
('Merge_2_data', 'Less_2'),
('Less_2', 'Less_2_data'),
('Less_2_data', 'and'),
]
if variation == 1:
nodes.extend([
('Enter_2_less', dict(kind='op', op='Enter')),
('Enter_2_less_data', dict(kind='data')),
('minimum_data', dict(kind='data')),
('and', dict(kind='op', op='LogicalAnd')),
('and_data', dict(kind='data')),
('loop_cond', dict(kind='op', op='LoopCond')),
('loop_cond_data', dict(kind='data')),
('init_1', dict(kind='op', op='Const')),
('init_1_data', dict(kind='data')),
('Enter_1', dict(kind='op', op='Enter')),
('Enter_1_data', dict(kind='data')),
('init_2', dict(kind='op', op='Const')),
('init_2_data', dict(kind='data')),
('Enter_2', dict(kind='op', op='Enter')),
('Enter_2_data', dict(kind='data')),
('Switch_1', dict(kind='op', op='Switch')),
('Switch_1_data', dict(kind='data')),
('Identity_1', dict(kind='op', op='Identity')),
('Identity_1_data', dict(kind='data')),
('add_1', dict(kind='op', op='Add')),
('add_1_y', dict(kind='op', op='Const')),
('add_1_y_data', dict(kind='data')),
('add_1_data', dict(kind='data')),
('NextIteration_1', dict(kind='op', op='NextIteration')),
('Switch_2', dict(kind='op', op='Switch')),
('Switch_2_data', dict(kind='data')),
('Identity_2', dict(kind='op', op='Identity')),
('Identity_2_data', dict(kind='data')),
('add_2', dict(kind='op', op='Add')),
('add_2_y', dict(kind='op', op='Const')),
('add_2_y_data', dict(kind='data')),
('add_2_data', dict(kind='data')),
('NextIteration_2', dict(kind='op', op='NextIteration')),
],
edges=[
('Strided_slice', 'Strided_slice_data'),
('Strided_slice_data', 'Enter_1_less'),
('Enter_1_less', 'Enter_1_less_data'),
('Enter_1_less_data', 'Less_1'),
('Less_1', 'Less_1_data'),
('Less_1_data', 'and'),
('and', 'and_data'),
('and_data', 'loop_cond'),
('loop_cond', 'loop_cond_data'),
('loop_cond_data', 'Switch_1'),
('loop_cond_data', 'Switch_2'),
('init_1', 'init_1_data'),
('init_1_data', 'Enter_1'),
('Enter_1', 'Enter_1_data'),
('Enter_1_data', 'Merge_1'),
('Merge_1', 'Merge_1_data'),
('Merge_1_data', 'Less_1'),
('Merge_1_data', 'Switch_1'),
('Switch_1', 'Switch_1_data'),
('Switch_1_data', 'Identity_1'),
('Identity_1', 'Identity_1_data'),
('Identity_1_data', 'add_1'),
('add_1_y', 'add_1_y_data'),
('add_1_y_data', 'add_1'),
('add_1', 'add_1_data'),
('add_1_data', 'NextIteration_1'),
('Merge_2_data', 'Switch_2'),
('Switch_2', 'Switch_2_data'),
('Switch_2_data', 'Identity_2'),
('Identity_2', 'Identity_2_data'),
('Identity_2_data', 'add_2'),
('add_2_y', 'add_2_y_data'),
('add_2_y_data', 'add_2'),
('add_2', 'add_2_data'),
('add_2_data', 'NextIteration_2'),
('minimum_data', dict(kind='data'))
])
edges.extend([
('minimum_data', 'Enter_2_less'),
('Enter_2_less', 'Enter_2_less_data'),
('Enter_2_less_data', 'Less_2'),
('init_2', 'init_2_data'),
('init_2_data', 'Enter_2'),
('Enter_2', 'Enter_2_data'),
('Enter_2_data', 'Merge_2'),
('Merge_2', 'Merge_2_data'),
('Merge_2_data', 'Less_2'),
('Less_2', 'Less_2_data'),
('Less_2_data', 'and'),
],
)
])
elif variation == 2:
edges.append(('Enter_1_less_data', 'Less_2'))
else:
raise Exception('Wrong pattern variation')
return dict(nodes=nodes, edges=edges)
@staticmethod
def looking_for_iteration_counter(graph: Graph, match: dict):
types = ['TensorIteratorInput', 'TensorIteratorOutput']
candidates = mo_array([match['Identity_1_data'], match['Identity_2_data']])
results = mo_array([False for i in range(len(candidates))])
for i, candidat in enumerate(candidates):
for node in candidat.out_nodes():
candidates = [match['Identity_1_data'], match['Identity_2_data']]
results = []
for candidate in candidates:
for node in candidate.out_nodes():
if node['op'] in types:
results[i] = True
assert not np.all(results)
assert sum(results) == 1
return candidates[results == True][0]
results.append(candidate)
break
assert len(results) == 1
return results[0]
@staticmethod
def check_dynamic_seq_len(graph: Graph, match: dict):
@@ -201,11 +208,17 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
Cycle is dynamic if at least one of the boundaries isn't constant OR this boundaries is different from tensor
shape.
"""
dynamic_seq_len = match['Enter_1_less_data'].value is None or match['Enter_2_less_data'].value is None or \
not np.array_equal(match['Enter_1_less_data'].value, match['Enter_2_less_data'].value)
dynamic_seq_len = match['Enter_1_less_data'].value is None
if 'Enter_2_less_data' in match:
dynamic_seq_len = dynamic_seq_len or match['Enter_2_less_data'].value is None or \
not np.array_equal(match['Enter_1_less_data'].value, match['Enter_2_less_data'].value)
return dynamic_seq_len
def find_and_replace_pattern(self, graph: Graph):
apply_pattern(graph, **self.pattern(1), action=self.replace_pattern) # pylint: disable=no-member
apply_pattern(graph, **self.pattern(2), action=self.replace_pattern) # pylint: disable=no-member
def replace_pattern(self, graph: Graph, match: dict):
log.debug('================== ConditionFind ===============')
# init_1
@@ -235,7 +248,11 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
name=match['loop_cond'].name + '/TensorIteratorCondition_')
condition = TensorIteratorCondition(graph, attrs=condition_attrs)
condition_data = condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
if 'minimum_data' in match:
condition_inp = [match['Strided_slice_data'], match['minimum_data']]
else:
condition_inp = [match['Strided_slice_data']]
condition_data = condition.create_node_with_data(inputs=condition_inp,
data_nodes=[loop_condition, iterator_data])
safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',

View File

@@ -1,7 +1,7 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, undefined_shape_of_rank
from openvino.tools.mo.graph.graph import Node, Graph
from openvino.tools.mo.graph.perm_inputs import PermuteInputs
from openvino.tools.mo.ops.op import Op
@@ -46,9 +46,16 @@ class Broadcast(Op):
input_shape = node.in_port(0).data.get_shape()
input_value = node.in_port(0).data.get_value()
target_shape_shape = node.in_port(1).data.get_shape()
target_shape = node.in_port(1).data.get_value()
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
# Dynamic target shape is possible to infer only if shape of target shape is static and 1D
if target_shape is None and len(target_shape_shape) == 1 and (len(input_shape) <= 1 or node.mode == 'explicit'):
assert is_fully_defined(target_shape_shape)
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
node.out_port(0).data.set_shape(new_shape)
return
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')

View File

@@ -7,55 +7,85 @@ import numpy as np
from openvino.tools.mo.middle.TensorIteratorCondition import LoopConditionMatcher
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph_with_attrs
from unit_tests.utils.graph import build_graph_with_attrs, regular_op_with_empty_data, connect, build_graph
class TensorIteratorConditionTests(unittest.TestCase):
def test_not_dynamic(self):
def test_not_dynamic_1(self):
pattern_matcher = LoopConditionMatcher()
pattern = pattern_matcher.pattern()
pattern = pattern_matcher.pattern(1)
graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'],
new_nodes_with_attrs=[('maximum', {'kind': 'op', 'op': 'Maximum'}),
('maximum_data', {'kind': 'data'}),
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})],
new_edges_with_attrs=[('maximum', 'maximum_data'),
('Identity_1_data', 'TensorIteratorInput')],
update_nodes_attributes=[('init_1_data', {'value': np.array([0])}),
('init_2_data', {'value': np.array([0])}),
('add_1_y_data', {'value': np.array(1)}),
('add_2_y_data', {'value': np.array(1)}),
('loop_cond_data', {'value': None}),
('Identity_2_data', {'value': None}, ),
('Enter_1_less_data', {'value': None},),
('Enter_2_less_data', {'value': None},),
])
new_nodes_with_attrs=[
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})],
new_edges_with_attrs=[
('Identity_1_data', 'TensorIteratorInput')],
update_nodes_attributes=[
('init_1_data', {'value': np.array([0])}),
('init_2_data', {'value': np.array([0])}),
('add_1_y_data', {'value': np.array(1)}),
('add_2_y_data', {'value': np.array(1)}),
('loop_cond_data', {'value': None}),
('Identity_2_data', {'value': None},),
('Enter_1_less_data', {'value': None},),
])
pattern_matcher.find_and_replace_pattern(graph)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=[('TensorIteratorCondition', {'kind': 'op', 'op': 'TensorIteratorCondition'}),
('loop_cond_data', {'kind': 'data'}),
('identity_data', {'kind': 'data'}),
('StridedSlice', {'kind': 'op', 'op':'StridedSlice'}),
('StridedSlice_data', {'kind': 'data'}),
('Maximum', {'kind': 'op', 'op': 'Maximum'}),
('Maximum_data', {'kind': 'data'}),
('minimum_data', {'kind': 'data'}),
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})
],
edges_with_attrs=[('Maximum', 'Maximum_data'),
('StridedSlice', 'StridedSlice_data'),
('StridedSlice_data', 'TensorIteratorCondition', {'in':0}),
('minimum_data', 'TensorIteratorCondition', {'in':1}),
('TensorIteratorCondition', 'loop_cond_data'),
('TensorIteratorCondition', 'identity_data'),
('identity_data', 'TensorIteratorInput'),
],
update_edge_attrs=None,
new_nodes_with_attrs=[],
new_edges_with_attrs=[],
)
nodes_attributes = {
**regular_op_with_empty_data('StridedSlice', {'op': 'StridedSlice', 'type': None}),
'TensorIteratorCondition': {'kind': 'op', 'op': 'TensorIteratorCondition'},
'loop_cond_data': {'kind': 'data'},
'identity_data': {'kind': 'data'},
'minimum_data': {'kind': 'data'},
'TensorIteratorInput': {'kind': 'op', 'op': 'TensorIteratorInput'}
}
edges = [
*connect('StridedSlice', '0:TensorIteratorCondition'),
('minimum_data', 'TensorIteratorCondition', {'in':1}),
('TensorIteratorCondition', 'loop_cond_data'),
('TensorIteratorCondition', 'identity_data'),
('identity_data', 'TensorIteratorInput')
]
graph_ref = build_graph(nodes_attributes, edges)
(flag, resp) = compare_graphs(graph, graph_ref, 'loop_cond_data', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_not_dynamic_2(self):
pattern_matcher = LoopConditionMatcher()
pattern = pattern_matcher.pattern(2)
graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'],
new_nodes_with_attrs=[
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'}),
('some_op', {'kind': 'op', 'op': 'Add'})],
new_edges_with_attrs=[
('Identity_1_data', 'TensorIteratorInput'),
('loop_cond_data', 'some_op'),
],
update_nodes_attributes=[
('init_1_data', {'value': np.array([0])}),
('init_2_data', {'value': np.array([0])}),
('add_1_y_data', {'value': np.array(1)}),
('add_2_y_data', {'value': np.array(1)}),
('loop_cond_data', {'value': None}),
('Identity_2_data', {'value': None},),
('Enter_1_less_data', {'value': None},),
])
pattern_matcher.find_and_replace_pattern(graph)
nodes_attributes = {
**regular_op_with_empty_data('loop_cond', {'op': 'TensorIteratorCondition', 'type': None}),
**regular_op_with_empty_data('StridedSlice', {'op': 'StridedSlice', 'type': None}),
'some_op': {'kind': 'op', 'op': 'Add'},
'identity_data': {'kind': 'data'},
'TensorIteratorInput': {'kind': 'op', 'op': 'TensorIteratorInput'}
}
edges = [
*connect('StridedSlice', 'loop_cond'),
*connect('loop_cond', 'some_op'),
('loop_cond', 'identity_data'),
('identity_data', 'TensorIteratorInput')
]
graph_ref = build_graph(nodes_attributes, edges)
(flag, resp) = compare_graphs(graph, graph_ref, 'some_op', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@@ -6,7 +6,7 @@ import unittest
import numpy as np
from generator import generator, generate
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, undefined_shape_of_rank
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.broadcast import Broadcast
from unit_tests.utils.graph import build_graph, valued_const_with_data, regular_op_with_empty_data, \
@@ -34,10 +34,10 @@ class BroadcastTest(unittest.TestCase):
([[3, 1]], [2, 1, 2], [-2, -1], 'explicit', [[[3, 1]], [[3, 1]]]), # ref_shape (2, 1, 2)
([[[9, 5, 7]], [[9, 5, 7]]], [2, 2, 1, 3], [1, 2, 3], 'explicit', # in_shape (2, 1, 3)
[[[[9, 5, 7]], [[9, 5, 7]]], [[[9, 5, 7]], [[9, 5, 7]]]]), # ref_out_shape (2, 2, 1, 3)
[[[[9, 5, 7]], [[9, 5, 7]]], [[[9, 5, 7]], [[9, 5, 7]]]]), # ref_out_shape (2, 2, 1, 3)
([[[9, 5, 7]], [[3, 4, 8]]], [2, 1, 3, 3], [0, 1, 2], 'explicit', # in_shape (2, 1, 3)
[[[[9, 9, 9], [5, 5, 5], [7, 7, 7]]], [[[3, 3, 3], [4, 4, 4], [8, 8, 8]]]]), # ref_out_shape (2, 1, 3, 3)
([[[9, 5, 7]], [[3, 4, 8]]], [2, 1, 3, 3], [0, 1, 2], 'explicit', # in_shape (2, 1, 3)
[[[[9, 9, 9], [5, 5, 5], [7, 7, 7]]], [[[3, 3, 3], [4, 4, 4], [8, 8, 8]]]]), # ref_out_shape (2, 1, 3, 3)
# negative tests
([1], [2, 2], [0], 'explicit', None, True),
@@ -76,3 +76,29 @@ class BroadcastTest(unittest.TestCase):
self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out)))
else:
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
@generate(*[
([1], [3], 'numpy', undefined_shape_of_rank(3)),
([1], [3], 'explicit', undefined_shape_of_rank(3)),
([1, 2], [3], 'numpy', None, True),
])
def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False):
nodes = {
**shaped_data('data', int64_array(data)),
**shaped_data('target_shape', int64_array(target_shape_shape)),
**regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}),
}
edges = [('data', 'broadcast'),
('target_shape', 'broadcast'),
('broadcast', 'broadcast_d')]
graph = build_graph(nodes, edges)
broadcast_node = Node(graph, 'broadcast')
if test_raising:
self.assertRaises(AssertionError, Broadcast.infer, broadcast_node)
return
Broadcast.infer(broadcast_node)
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, ref_out_shape))