Support dynamic Broadcast and new pattern for TI condition (#9735)
* Support dynamic Broadcast and new pattern for TI condition * Apply review feedback * Fix broadcast if statement
This commit is contained in:
@@ -5,6 +5,7 @@ import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.middle.pattern_match import apply_pattern
|
||||
from openvino.tools.mo.middle.TensorIterator_utils import delete_selects_from
|
||||
from openvino.tools.mo.ops.TensorIterator_ops import TensorIteratorCondition, TensorIteratorBackEdge
|
||||
from openvino.tools.mo.ops.identity import Identity
|
||||
@@ -66,134 +67,140 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
|
||||
return [TensorIteratorMerge]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
def pattern(variation):
|
||||
log.debug('+++++++++++++++ ConditionMatching ++++++++++++++++')
|
||||
return dict(
|
||||
nodes=[
|
||||
('Enter_1_less', dict(kind='op', op='Enter')),
|
||||
('Strided_slice', dict(kind='op', op='StridedSlice')),
|
||||
('Strided_slice_data', dict(kind='data')),
|
||||
('Enter_1_less_data', dict(kind='data')),
|
||||
nodes = [
|
||||
('Enter_1_less', dict(kind='op', op='Enter')),
|
||||
('Strided_slice', dict(kind='op', op='StridedSlice')),
|
||||
('Strided_slice_data', dict(kind='data')),
|
||||
('Enter_1_less_data', dict(kind='data')),
|
||||
|
||||
('Less_1', dict(kind='op', op='Less')),
|
||||
('Merge_1', dict(kind='op', op='Merge')),
|
||||
('Merge_1_data', dict(kind='data')),
|
||||
('Less_1_data', dict(kind='data')),
|
||||
('Less_1', dict(kind='op', op='Less')),
|
||||
('Merge_1', dict(kind='op', op='Merge')),
|
||||
('Merge_1_data', dict(kind='data')),
|
||||
('Less_1_data', dict(kind='data')),
|
||||
|
||||
('Less_2', dict(kind='op', op='Less')),
|
||||
('Merge_2', dict(kind='op', op='Merge')),
|
||||
('Merge_2_data', dict(kind='data')),
|
||||
('Less_2_data', dict(kind='data')),
|
||||
('Less_2', dict(kind='op', op='Less')),
|
||||
('Merge_2', dict(kind='op', op='Merge')),
|
||||
('Merge_2_data', dict(kind='data')),
|
||||
('Less_2_data', dict(kind='data')),
|
||||
|
||||
('and', dict(kind='op', op='LogicalAnd')),
|
||||
('and_data', dict(kind='data')),
|
||||
('loop_cond', dict(kind='op', op='LoopCond')),
|
||||
('loop_cond_data', dict(kind='data')),
|
||||
|
||||
('init_1', dict(kind='op', op='Const')),
|
||||
('init_1_data', dict(kind='data')),
|
||||
('Enter_1', dict(kind='op', op='Enter')),
|
||||
('Enter_1_data', dict(kind='data')),
|
||||
|
||||
('init_2', dict(kind='op', op='Const')),
|
||||
('init_2_data', dict(kind='data')),
|
||||
('Enter_2', dict(kind='op', op='Enter')),
|
||||
('Enter_2_data', dict(kind='data')),
|
||||
|
||||
('Switch_1', dict(kind='op', op='Switch')),
|
||||
('Switch_1_data', dict(kind='data')),
|
||||
('Identity_1', dict(kind='op', op='Identity')),
|
||||
('Identity_1_data', dict(kind='data')),
|
||||
('add_1', dict(kind='op', op='Add')),
|
||||
('add_1_y', dict(kind='op', op='Const')),
|
||||
('add_1_y_data', dict(kind='data')),
|
||||
('add_1_data', dict(kind='data')),
|
||||
('NextIteration_1', dict(kind='op', op='NextIteration')),
|
||||
|
||||
('Switch_2', dict(kind='op', op='Switch')),
|
||||
('Switch_2_data', dict(kind='data')),
|
||||
('Identity_2', dict(kind='op', op='Identity')),
|
||||
('Identity_2_data', dict(kind='data')),
|
||||
('add_2', dict(kind='op', op='Add')),
|
||||
('add_2_y', dict(kind='op', op='Const')),
|
||||
('add_2_y_data', dict(kind='data')),
|
||||
('add_2_data', dict(kind='data')),
|
||||
('NextIteration_2', dict(kind='op', op='NextIteration')),
|
||||
|
||||
]
|
||||
edges = [
|
||||
('Strided_slice', 'Strided_slice_data'),
|
||||
('Strided_slice_data', 'Enter_1_less'),
|
||||
('Enter_1_less', 'Enter_1_less_data'),
|
||||
('Enter_1_less_data', 'Less_1'),
|
||||
('Less_1', 'Less_1_data'),
|
||||
('Less_1_data', 'and'),
|
||||
|
||||
('and', 'and_data'),
|
||||
('and_data', 'loop_cond'),
|
||||
('loop_cond', 'loop_cond_data'),
|
||||
('loop_cond_data', 'Switch_1'),
|
||||
('loop_cond_data', 'Switch_2'),
|
||||
|
||||
('init_1', 'init_1_data'),
|
||||
('init_1_data', 'Enter_1'),
|
||||
('Enter_1', 'Enter_1_data'),
|
||||
('Enter_1_data', 'Merge_1'),
|
||||
('Merge_1', 'Merge_1_data'),
|
||||
('Merge_1_data', 'Less_1'),
|
||||
|
||||
('Merge_1_data', 'Switch_1'),
|
||||
('Switch_1', 'Switch_1_data'),
|
||||
('Switch_1_data', 'Identity_1'),
|
||||
('Identity_1', 'Identity_1_data'),
|
||||
('Identity_1_data', 'add_1'),
|
||||
('add_1_y', 'add_1_y_data'),
|
||||
('add_1_y_data', 'add_1'),
|
||||
('add_1', 'add_1_data'),
|
||||
('add_1_data', 'NextIteration_1'),
|
||||
|
||||
('Merge_2_data', 'Switch_2'),
|
||||
('Switch_2', 'Switch_2_data'),
|
||||
('Switch_2_data', 'Identity_2'),
|
||||
('Identity_2', 'Identity_2_data'),
|
||||
('Identity_2_data', 'add_2'),
|
||||
('add_2_y', 'add_2_y_data'),
|
||||
('add_2_y_data', 'add_2'),
|
||||
('add_2', 'add_2_data'),
|
||||
('add_2_data', 'NextIteration_2'),
|
||||
|
||||
('init_2', 'init_2_data'),
|
||||
('init_2_data', 'Enter_2'),
|
||||
('Enter_2', 'Enter_2_data'),
|
||||
('Enter_2_data', 'Merge_2'),
|
||||
|
||||
('Merge_2', 'Merge_2_data'),
|
||||
('Merge_2_data', 'Less_2'),
|
||||
('Less_2', 'Less_2_data'),
|
||||
('Less_2_data', 'and'),
|
||||
]
|
||||
if variation == 1:
|
||||
nodes.extend([
|
||||
('Enter_2_less', dict(kind='op', op='Enter')),
|
||||
('Enter_2_less_data', dict(kind='data')),
|
||||
('minimum_data', dict(kind='data')),
|
||||
|
||||
('and', dict(kind='op', op='LogicalAnd')),
|
||||
('and_data', dict(kind='data')),
|
||||
('loop_cond', dict(kind='op', op='LoopCond')),
|
||||
('loop_cond_data', dict(kind='data')),
|
||||
|
||||
('init_1', dict(kind='op', op='Const')),
|
||||
('init_1_data', dict(kind='data')),
|
||||
('Enter_1', dict(kind='op', op='Enter')),
|
||||
('Enter_1_data', dict(kind='data')),
|
||||
|
||||
('init_2', dict(kind='op', op='Const')),
|
||||
('init_2_data', dict(kind='data')),
|
||||
('Enter_2', dict(kind='op', op='Enter')),
|
||||
('Enter_2_data', dict(kind='data')),
|
||||
|
||||
('Switch_1', dict(kind='op', op='Switch')),
|
||||
('Switch_1_data', dict(kind='data')),
|
||||
('Identity_1', dict(kind='op', op='Identity')),
|
||||
('Identity_1_data', dict(kind='data')),
|
||||
('add_1', dict(kind='op', op='Add')),
|
||||
('add_1_y', dict(kind='op', op='Const')),
|
||||
('add_1_y_data', dict(kind='data')),
|
||||
('add_1_data', dict(kind='data')),
|
||||
('NextIteration_1', dict(kind='op', op='NextIteration')),
|
||||
|
||||
('Switch_2', dict(kind='op', op='Switch')),
|
||||
('Switch_2_data', dict(kind='data')),
|
||||
('Identity_2', dict(kind='op', op='Identity')),
|
||||
('Identity_2_data', dict(kind='data')),
|
||||
('add_2', dict(kind='op', op='Add')),
|
||||
('add_2_y', dict(kind='op', op='Const')),
|
||||
('add_2_y_data', dict(kind='data')),
|
||||
('add_2_data', dict(kind='data')),
|
||||
('NextIteration_2', dict(kind='op', op='NextIteration')),
|
||||
|
||||
],
|
||||
edges=[
|
||||
('Strided_slice', 'Strided_slice_data'),
|
||||
('Strided_slice_data', 'Enter_1_less'),
|
||||
('Enter_1_less', 'Enter_1_less_data'),
|
||||
('Enter_1_less_data', 'Less_1'),
|
||||
('Less_1', 'Less_1_data'),
|
||||
('Less_1_data', 'and'),
|
||||
|
||||
('and', 'and_data'),
|
||||
('and_data', 'loop_cond'),
|
||||
('loop_cond', 'loop_cond_data'),
|
||||
('loop_cond_data', 'Switch_1'),
|
||||
('loop_cond_data', 'Switch_2'),
|
||||
|
||||
('init_1', 'init_1_data'),
|
||||
('init_1_data', 'Enter_1'),
|
||||
('Enter_1', 'Enter_1_data'),
|
||||
('Enter_1_data', 'Merge_1'),
|
||||
('Merge_1', 'Merge_1_data'),
|
||||
('Merge_1_data', 'Less_1'),
|
||||
|
||||
('Merge_1_data', 'Switch_1'),
|
||||
('Switch_1', 'Switch_1_data'),
|
||||
('Switch_1_data', 'Identity_1'),
|
||||
('Identity_1', 'Identity_1_data'),
|
||||
('Identity_1_data', 'add_1'),
|
||||
('add_1_y', 'add_1_y_data'),
|
||||
('add_1_y_data', 'add_1'),
|
||||
('add_1', 'add_1_data'),
|
||||
('add_1_data', 'NextIteration_1'),
|
||||
|
||||
('Merge_2_data', 'Switch_2'),
|
||||
('Switch_2', 'Switch_2_data'),
|
||||
('Switch_2_data', 'Identity_2'),
|
||||
('Identity_2', 'Identity_2_data'),
|
||||
('Identity_2_data', 'add_2'),
|
||||
('add_2_y', 'add_2_y_data'),
|
||||
('add_2_y_data', 'add_2'),
|
||||
('add_2', 'add_2_data'),
|
||||
('add_2_data', 'NextIteration_2'),
|
||||
|
||||
('minimum_data', dict(kind='data'))
|
||||
])
|
||||
edges.extend([
|
||||
('minimum_data', 'Enter_2_less'),
|
||||
('Enter_2_less', 'Enter_2_less_data'),
|
||||
('Enter_2_less_data', 'Less_2'),
|
||||
|
||||
('init_2', 'init_2_data'),
|
||||
('init_2_data', 'Enter_2'),
|
||||
('Enter_2', 'Enter_2_data'),
|
||||
('Enter_2_data', 'Merge_2'),
|
||||
|
||||
('Merge_2', 'Merge_2_data'),
|
||||
('Merge_2_data', 'Less_2'),
|
||||
('Less_2', 'Less_2_data'),
|
||||
('Less_2_data', 'and'),
|
||||
],
|
||||
)
|
||||
])
|
||||
elif variation == 2:
|
||||
edges.append(('Enter_1_less_data', 'Less_2'))
|
||||
else:
|
||||
raise Exception('Wrong pattern variation')
|
||||
return dict(nodes=nodes, edges=edges)
|
||||
|
||||
@staticmethod
|
||||
def looking_for_iteration_counter(graph: Graph, match: dict):
|
||||
types = ['TensorIteratorInput', 'TensorIteratorOutput']
|
||||
candidates = mo_array([match['Identity_1_data'], match['Identity_2_data']])
|
||||
results = mo_array([False for i in range(len(candidates))])
|
||||
for i, candidat in enumerate(candidates):
|
||||
for node in candidat.out_nodes():
|
||||
candidates = [match['Identity_1_data'], match['Identity_2_data']]
|
||||
results = []
|
||||
for candidate in candidates:
|
||||
for node in candidate.out_nodes():
|
||||
if node['op'] in types:
|
||||
results[i] = True
|
||||
assert not np.all(results)
|
||||
assert sum(results) == 1
|
||||
return candidates[results == True][0]
|
||||
results.append(candidate)
|
||||
break
|
||||
assert len(results) == 1
|
||||
return results[0]
|
||||
|
||||
@staticmethod
|
||||
def check_dynamic_seq_len(graph: Graph, match: dict):
|
||||
@@ -201,11 +208,17 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
|
||||
Cycle is dynamic if at least one of the boundaries isn't constant OR this boundaries is different from tensor
|
||||
shape.
|
||||
"""
|
||||
dynamic_seq_len = match['Enter_1_less_data'].value is None or match['Enter_2_less_data'].value is None or \
|
||||
not np.array_equal(match['Enter_1_less_data'].value, match['Enter_2_less_data'].value)
|
||||
dynamic_seq_len = match['Enter_1_less_data'].value is None
|
||||
if 'Enter_2_less_data' in match:
|
||||
dynamic_seq_len = dynamic_seq_len or match['Enter_2_less_data'].value is None or \
|
||||
not np.array_equal(match['Enter_1_less_data'].value, match['Enter_2_less_data'].value)
|
||||
|
||||
return dynamic_seq_len
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
apply_pattern(graph, **self.pattern(1), action=self.replace_pattern) # pylint: disable=no-member
|
||||
apply_pattern(graph, **self.pattern(2), action=self.replace_pattern) # pylint: disable=no-member
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
log.debug('================== ConditionFind ===============')
|
||||
# init_1
|
||||
@@ -235,7 +248,11 @@ Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
|
||||
condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
|
||||
name=match['loop_cond'].name + '/TensorIteratorCondition_')
|
||||
condition = TensorIteratorCondition(graph, attrs=condition_attrs)
|
||||
condition_data = condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
|
||||
if 'minimum_data' in match:
|
||||
condition_inp = [match['Strided_slice_data'], match['minimum_data']]
|
||||
else:
|
||||
condition_inp = [match['Strided_slice_data']]
|
||||
condition_data = condition.create_node_with_data(inputs=condition_inp,
|
||||
data_nodes=[loop_condition, iterator_data])
|
||||
|
||||
safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, undefined_shape_of_rank
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
from openvino.tools.mo.graph.perm_inputs import PermuteInputs
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
@@ -46,9 +46,16 @@ class Broadcast(Op):
|
||||
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
target_shape_shape = node.in_port(1).data.get_shape()
|
||||
target_shape = node.in_port(1).data.get_value()
|
||||
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
|
||||
# Dynamic target shape is possible to infer only if shape of target shape is static and 1D
|
||||
if target_shape is None and len(target_shape_shape) == 1 and (len(input_shape) <= 1 or node.mode == 'explicit'):
|
||||
assert is_fully_defined(target_shape_shape)
|
||||
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
|
||||
node.out_port(0).data.set_shape(new_shape)
|
||||
return
|
||||
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
|
||||
|
||||
|
||||
@@ -7,55 +7,85 @@ import numpy as np
|
||||
|
||||
from openvino.tools.mo.middle.TensorIteratorCondition import LoopConditionMatcher
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph_with_attrs
|
||||
from unit_tests.utils.graph import build_graph_with_attrs, regular_op_with_empty_data, connect, build_graph
|
||||
|
||||
|
||||
class TensorIteratorConditionTests(unittest.TestCase):
|
||||
def test_not_dynamic(self):
|
||||
def test_not_dynamic_1(self):
|
||||
pattern_matcher = LoopConditionMatcher()
|
||||
pattern = pattern_matcher.pattern()
|
||||
pattern = pattern_matcher.pattern(1)
|
||||
|
||||
graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'],
|
||||
new_nodes_with_attrs=[('maximum', {'kind': 'op', 'op': 'Maximum'}),
|
||||
('maximum_data', {'kind': 'data'}),
|
||||
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})],
|
||||
new_edges_with_attrs=[('maximum', 'maximum_data'),
|
||||
('Identity_1_data', 'TensorIteratorInput')],
|
||||
update_nodes_attributes=[('init_1_data', {'value': np.array([0])}),
|
||||
('init_2_data', {'value': np.array([0])}),
|
||||
('add_1_y_data', {'value': np.array(1)}),
|
||||
('add_2_y_data', {'value': np.array(1)}),
|
||||
('loop_cond_data', {'value': None}),
|
||||
('Identity_2_data', {'value': None}, ),
|
||||
('Enter_1_less_data', {'value': None},),
|
||||
('Enter_2_less_data', {'value': None},),
|
||||
])
|
||||
new_nodes_with_attrs=[
|
||||
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})],
|
||||
new_edges_with_attrs=[
|
||||
('Identity_1_data', 'TensorIteratorInput')],
|
||||
update_nodes_attributes=[
|
||||
('init_1_data', {'value': np.array([0])}),
|
||||
('init_2_data', {'value': np.array([0])}),
|
||||
('add_1_y_data', {'value': np.array(1)}),
|
||||
('add_2_y_data', {'value': np.array(1)}),
|
||||
('loop_cond_data', {'value': None}),
|
||||
('Identity_2_data', {'value': None},),
|
||||
('Enter_1_less_data', {'value': None},),
|
||||
])
|
||||
|
||||
pattern_matcher.find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph_with_attrs(
|
||||
nodes_with_attrs=[('TensorIteratorCondition', {'kind': 'op', 'op': 'TensorIteratorCondition'}),
|
||||
('loop_cond_data', {'kind': 'data'}),
|
||||
('identity_data', {'kind': 'data'}),
|
||||
('StridedSlice', {'kind': 'op', 'op':'StridedSlice'}),
|
||||
('StridedSlice_data', {'kind': 'data'}),
|
||||
('Maximum', {'kind': 'op', 'op': 'Maximum'}),
|
||||
('Maximum_data', {'kind': 'data'}),
|
||||
('minimum_data', {'kind': 'data'}),
|
||||
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})
|
||||
],
|
||||
edges_with_attrs=[('Maximum', 'Maximum_data'),
|
||||
('StridedSlice', 'StridedSlice_data'),
|
||||
('StridedSlice_data', 'TensorIteratorCondition', {'in':0}),
|
||||
('minimum_data', 'TensorIteratorCondition', {'in':1}),
|
||||
('TensorIteratorCondition', 'loop_cond_data'),
|
||||
('TensorIteratorCondition', 'identity_data'),
|
||||
('identity_data', 'TensorIteratorInput'),
|
||||
],
|
||||
update_edge_attrs=None,
|
||||
new_nodes_with_attrs=[],
|
||||
new_edges_with_attrs=[],
|
||||
)
|
||||
nodes_attributes = {
|
||||
**regular_op_with_empty_data('StridedSlice', {'op': 'StridedSlice', 'type': None}),
|
||||
'TensorIteratorCondition': {'kind': 'op', 'op': 'TensorIteratorCondition'},
|
||||
'loop_cond_data': {'kind': 'data'},
|
||||
'identity_data': {'kind': 'data'},
|
||||
'minimum_data': {'kind': 'data'},
|
||||
'TensorIteratorInput': {'kind': 'op', 'op': 'TensorIteratorInput'}
|
||||
}
|
||||
edges = [
|
||||
*connect('StridedSlice', '0:TensorIteratorCondition'),
|
||||
('minimum_data', 'TensorIteratorCondition', {'in':1}),
|
||||
('TensorIteratorCondition', 'loop_cond_data'),
|
||||
('TensorIteratorCondition', 'identity_data'),
|
||||
('identity_data', 'TensorIteratorInput')
|
||||
]
|
||||
graph_ref = build_graph(nodes_attributes, edges)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'loop_cond_data', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_not_dynamic_2(self):
|
||||
pattern_matcher = LoopConditionMatcher()
|
||||
pattern = pattern_matcher.pattern(2)
|
||||
|
||||
graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'],
|
||||
new_nodes_with_attrs=[
|
||||
('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'}),
|
||||
('some_op', {'kind': 'op', 'op': 'Add'})],
|
||||
new_edges_with_attrs=[
|
||||
('Identity_1_data', 'TensorIteratorInput'),
|
||||
('loop_cond_data', 'some_op'),
|
||||
],
|
||||
update_nodes_attributes=[
|
||||
('init_1_data', {'value': np.array([0])}),
|
||||
('init_2_data', {'value': np.array([0])}),
|
||||
('add_1_y_data', {'value': np.array(1)}),
|
||||
('add_2_y_data', {'value': np.array(1)}),
|
||||
('loop_cond_data', {'value': None}),
|
||||
('Identity_2_data', {'value': None},),
|
||||
('Enter_1_less_data', {'value': None},),
|
||||
])
|
||||
|
||||
pattern_matcher.find_and_replace_pattern(graph)
|
||||
nodes_attributes = {
|
||||
**regular_op_with_empty_data('loop_cond', {'op': 'TensorIteratorCondition', 'type': None}),
|
||||
**regular_op_with_empty_data('StridedSlice', {'op': 'StridedSlice', 'type': None}),
|
||||
'some_op': {'kind': 'op', 'op': 'Add'},
|
||||
'identity_data': {'kind': 'data'},
|
||||
'TensorIteratorInput': {'kind': 'op', 'op': 'TensorIteratorInput'}
|
||||
}
|
||||
edges = [
|
||||
*connect('StridedSlice', 'loop_cond'),
|
||||
*connect('loop_cond', 'some_op'),
|
||||
('loop_cond', 'identity_data'),
|
||||
('identity_data', 'TensorIteratorInput')
|
||||
]
|
||||
graph_ref = build_graph(nodes_attributes, edges)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'some_op', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, undefined_shape_of_rank
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.broadcast import Broadcast
|
||||
from unit_tests.utils.graph import build_graph, valued_const_with_data, regular_op_with_empty_data, \
|
||||
@@ -34,10 +34,10 @@ class BroadcastTest(unittest.TestCase):
|
||||
([[3, 1]], [2, 1, 2], [-2, -1], 'explicit', [[[3, 1]], [[3, 1]]]), # ref_shape (2, 1, 2)
|
||||
|
||||
([[[9, 5, 7]], [[9, 5, 7]]], [2, 2, 1, 3], [1, 2, 3], 'explicit', # in_shape (2, 1, 3)
|
||||
[[[[9, 5, 7]], [[9, 5, 7]]], [[[9, 5, 7]], [[9, 5, 7]]]]), # ref_out_shape (2, 2, 1, 3)
|
||||
[[[[9, 5, 7]], [[9, 5, 7]]], [[[9, 5, 7]], [[9, 5, 7]]]]), # ref_out_shape (2, 2, 1, 3)
|
||||
|
||||
([[[9, 5, 7]], [[3, 4, 8]]], [2, 1, 3, 3], [0, 1, 2], 'explicit', # in_shape (2, 1, 3)
|
||||
[[[[9, 9, 9], [5, 5, 5], [7, 7, 7]]], [[[3, 3, 3], [4, 4, 4], [8, 8, 8]]]]), # ref_out_shape (2, 1, 3, 3)
|
||||
([[[9, 5, 7]], [[3, 4, 8]]], [2, 1, 3, 3], [0, 1, 2], 'explicit', # in_shape (2, 1, 3)
|
||||
[[[[9, 9, 9], [5, 5, 5], [7, 7, 7]]], [[[3, 3, 3], [4, 4, 4], [8, 8, 8]]]]), # ref_out_shape (2, 1, 3, 3)
|
||||
|
||||
# negative tests
|
||||
([1], [2, 2], [0], 'explicit', None, True),
|
||||
@@ -76,3 +76,29 @@ class BroadcastTest(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out)))
|
||||
else:
|
||||
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
|
||||
|
||||
@generate(*[
|
||||
([1], [3], 'numpy', undefined_shape_of_rank(3)),
|
||||
([1], [3], 'explicit', undefined_shape_of_rank(3)),
|
||||
([1, 2], [3], 'numpy', None, True),
|
||||
])
|
||||
def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False):
|
||||
nodes = {
|
||||
**shaped_data('data', int64_array(data)),
|
||||
**shaped_data('target_shape', int64_array(target_shape_shape)),
|
||||
**regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}),
|
||||
}
|
||||
|
||||
edges = [('data', 'broadcast'),
|
||||
('target_shape', 'broadcast'),
|
||||
('broadcast', 'broadcast_d')]
|
||||
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
broadcast_node = Node(graph, 'broadcast')
|
||||
if test_raising:
|
||||
self.assertRaises(AssertionError, Broadcast.infer, broadcast_node)
|
||||
return
|
||||
|
||||
Broadcast.infer(broadcast_node)
|
||||
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, ref_out_shape))
|
||||
|
||||
Reference in New Issue
Block a user