Extend FIFOQueueDequeue replacer to support OOB case (#9428)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -6,14 +6,13 @@ from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
|
||||
from openvino.tools.mo.front.extractor import add_input_ops
|
||||
from openvino.tools.mo.front.output_cut import OutputCut
|
||||
from openvino.tools.mo.front.user_data_repack import UserDataRepack
|
||||
from openvino.tools.mo.middle.passes.convert_data_type import np_data_type_to_precision, SUPPORTED_DATA_TYPES
|
||||
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
|
||||
from openvino.tools.mo.graph.graph import Graph, Node
|
||||
from openvino.tools.mo.middle.passes.convert_data_type import np_data_type_to_precision, SUPPORTED_DATA_TYPES
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
|
||||
|
||||
@@ -61,8 +60,10 @@ class FIFOQueue(FrontReplacementSubgraph):
|
||||
true_placeholder_shape = match['placeholder'].shape
|
||||
placeholder_shape = match['fifo_queue'].shapes[0]
|
||||
placeholder_data_type = match['fifo_queue'].types[0]
|
||||
assert true_placeholder_shape.ndim <= 1
|
||||
if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
|
||||
# in case OOB conversion batch_size placeholder shape is not required
|
||||
# so use a shape specified in FIFOQueueV2 shapes list attribute
|
||||
assert true_placeholder_shape is None or true_placeholder_shape.ndim <= 1
|
||||
if true_placeholder_shape is not None and true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
|
||||
log.warning(
|
||||
'Placeholder \'{}\' got non 0-dimensional shape {} in FIFOQueue pattern. Placeholder will have the '
|
||||
'same shape after folding the pattern instead of {} shape which is original for the network.'
|
||||
|
||||
@@ -11,25 +11,43 @@ from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph_with_edge_attrs
|
||||
|
||||
|
||||
def create_fifo_queue_graph(batch_size_shape: np.ndarray):
|
||||
nodes = {
|
||||
'placeholder': {'op': 'Parameter', 'data_type': np.int32, 'kind': 'op', 'shape': batch_size_shape},
|
||||
'batch_join/fifo_queue': {'op': 'FIFOQueueV2', 'name': 'batch_join/fifo_queue',
|
||||
'shapes': np.array([[1, 2, 3]]), 'types': np.array([np.float32]), 'kind': 'op'},
|
||||
'batch_join': {'op': 'QueueDequeueUpToV2', 'kind': 'op'},
|
||||
'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'},
|
||||
'label_batch': {'op': 'Identity', 'kind': 'op'},
|
||||
'label_batch_op_output': {'op': 'Result', 'kind': 'op'},
|
||||
}
|
||||
edges = [
|
||||
('placeholder', 'batch_join', {'out': 0, 'in': 0}),
|
||||
('batch_join/fifo_queue', 'batch_join', {'out': 0, 'in': 1}),
|
||||
('batch_join', 'image_batch', {'out': 0, 'in': 0}),
|
||||
('batch_join', 'label_batch', {'out': 1, 'in': 0}),
|
||||
('label_batch', 'label_batch_op_output', {'out': 0, 'in': 0})
|
||||
]
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
return graph
|
||||
|
||||
|
||||
class TestFIFOQueueReplacement(unittest.TestCase):
|
||||
def test_fifo_with_label_batch(self):
|
||||
nodes = {
|
||||
'placeholder': {'op': 'Parameter', 'data_type': np.int32, 'kind': 'op', 'shape': np.array(1)},
|
||||
'batch_join/fifo_queue': {'op': 'FIFOQueueV2', 'name': 'batch_join/fifo_queue',
|
||||
'shapes': np.array([[1, 2, 3]]), 'types': np.array([np.float32]), 'kind': 'op'},
|
||||
'batch_join': {'op': 'QueueDequeueUpToV2', 'kind': 'op'},
|
||||
'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'},
|
||||
'label_batch': {'op': 'Identity', 'kind': 'op'},
|
||||
'label_batch_op_output': {'op': 'Result', 'kind': 'op'},
|
||||
}
|
||||
edges = [
|
||||
('placeholder', 'batch_join', {'out': 0, 'in': 0}),
|
||||
('batch_join/fifo_queue', 'batch_join', {'out': 0, 'in': 1}),
|
||||
('batch_join', 'image_batch', {'out': 0, 'in': 0}),
|
||||
('batch_join', 'label_batch', {'out': 1, 'in': 0}),
|
||||
('label_batch', 'label_batch_op_output', {'out': 0, 'in': 0})
|
||||
]
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
graph = create_fifo_queue_graph(shape_array([1]))
|
||||
tested_class = FIFOQueue()
|
||||
tested_class.find_and_replace_pattern(graph=graph)
|
||||
after_pattern = graph.nodes()
|
||||
self.assertEqual(2, len(after_pattern))
|
||||
try:
|
||||
new_ph_dict = graph.node[[u for u, v in graph.in_edges('image_batch')][0]]
|
||||
except Exception as e:
|
||||
self.fail("Can't get new placeholder. Broken edge. Additional information: {}".format(e))
|
||||
self.assertEqual(new_ph_dict['name'], 'batch_join/fifo_queue')
|
||||
self.assertTrue(np.array_equal(new_ph_dict['shape'], [1, 2, 3]))
|
||||
|
||||
def test_fifo_with_undefined_batch_size(self):
|
||||
graph = create_fifo_queue_graph(None)
|
||||
tested_class = FIFOQueue()
|
||||
tested_class.find_and_replace_pattern(graph=graph)
|
||||
after_pattern = graph.nodes()
|
||||
@@ -128,7 +146,7 @@ class FIFOQueueDequeueCutTest(unittest.TestCase):
|
||||
{
|
||||
'queue_dequeue': {'kind': 'op', 'op': 'QueueDequeue', 'shapes': [shape_array([2, 2]),
|
||||
shape_array([1, 1])],
|
||||
'types': [np.int32, np.float32]},
|
||||
'types': [np.int32, np.float32]},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
'add': {'kind': 'op', 'op': 'Add'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||
@@ -167,7 +185,7 @@ class FIFOQueueDequeueCutTest(unittest.TestCase):
|
||||
{
|
||||
'queue_dequeue': {'kind': 'op', 'op': 'QueueDequeueV2', 'shapes': [shape_array([2, 2]),
|
||||
shape_array([1, 1])],
|
||||
'types': [np.int32, np.float32]},
|
||||
'types': [np.int32, np.float32]},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
'add': {'kind': 'op', 'op': 'Add'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||
@@ -199,4 +217,4 @@ class FIFOQueueDequeueCutTest(unittest.TestCase):
|
||||
FIFOQueueDequeueCut().find_and_replace_pattern(graph)
|
||||
|
||||
flag, msg = compare_graphs(graph, graph_ref, last_node='concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, msg)
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
Reference in New Issue
Block a user