[MO] add IteratorGetNextCut (#8040)

* added IteratorGetNextCut, some improvements in graph.py

* added allowed types check

* reused new graph API for ports

* returned back old API 'out_nodes', removed soft-getting name from base class, changed run_after -> []

* correctly used new port API

* corrected IteratorGetNext message
This commit is contained in:
Pavel Esir 2021-11-11 10:27:51 +03:00 committed by GitHub
parent f1ca728ab1
commit 68badf5165
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 160 additions and 17 deletions

View File

@ -433,6 +433,7 @@ extensions/front/tf/identityN_to_identity.py
extensions/front/tf/if_ext.py extensions/front/tf/if_ext.py
extensions/front/tf/InterpolateTransposes.py extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/IteratorGetNextCut.py
extensions/front/tf/log_softmax_ext.py extensions/front/tf/log_softmax_ext.py
extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py extensions/front/tf/LoopCond_ext.py

View File

@ -56,22 +56,24 @@ class InputsAnalysis(AnalyzeAction):
def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict): def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict):
message = None message = None
op_nodes = graph.get_op_nodes(op='IteratorGetNext') op_nodes = graph.get_op_nodes(op='IteratorGetNext')
params = ''
for iter_get_next in op_nodes:
for port in iter_get_next.out_nodes().keys():
inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = {
'shape': iter_get_next.shapes[port].tolist(),
'value': None,
'data_type': iter_get_next.types[port]
}
if params != '':
params = params + ','
shape = str(iter_get_next.shapes[port].tolist()).replace(',', '')
params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape)
if len(op_nodes): if len(op_nodes):
params = ''
for iter_get_next in op_nodes:
for port in iter_get_next.out_nodes().keys():
inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = {
'shape': iter_get_next.shapes[port].tolist(),
'value': None,
'data_type': iter_get_next.types[port]
}
if params != '':
params = params + ','
shape = str(iter_get_next.shapes[port].tolist()).replace(',', '')
params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape)
message = 'It looks like there is IteratorGetNext as input\n' \ message = 'It looks like there is IteratorGetNext as input\n' \
'Run the Model Optimizer with:\n\t\t--input "{}"\n' \ 'Run the Model Optimizer without --input option \n' \
'And replace all negative values with positive values'.format(params) 'Otherwise, try to run the Model Optimizer with:\n\t\t--input "{}"\n'.format(params)
return message return message
def analyze(self, graph: Graph): def analyze(self, graph: Graph):

View File

@ -0,0 +1,45 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from mo.front.extractor import add_input_ops
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES, np_data_type_to_precision
from mo.utils.error import Error
from mo.front.common.replacement import FrontReplacementPattern
class IteratorGetNextCut(FrontReplacementPattern):
"""
Cuts OneShotIterator -> IteratorGetNext pattern
in order to enable Out Of the Box (OOB) usage.
Pass is run only if user didn't specify any inputs names and shapes.
"""
enabled = True
graph_condition = [lambda graph: graph.graph['cmd_params'].input is None]
def run_before(self):
from extensions.front.output_cut import OutputCut
from extensions.front.input_cut import InputCut
return [OutputCut, InputCut]
def run_after(self):
return []
def find_and_replace_pattern(self, graph: Graph):
iter_get_next_shapes = defaultdict(list)
for iter_get_next in graph.get_op_nodes(op='IteratorGetNext'):
iter_get_next_name = iter_get_next.soft_get('name', iter_get_next.id)
for port in iter_get_next.out_ports():
if not np_data_type_to_precision(iter_get_next.types[port]) in SUPPORTED_DATA_TYPES:
raise Error("In IteratorGetNext node '{}' data type '{}' is not supported".format(
iter_get_next_name, iter_get_next.types[port]))
iter_get_next_shapes[iter_get_next_name].append(dict(
shape=iter_get_next.shapes[port],
out=port,
data_type=iter_get_next.types[port]
))
add_input_ops(graph, iter_get_next_shapes, True)

View File

@ -20,5 +20,5 @@ class IteratorGetNextExtractor(FrontExtractorOp):
result_shapes = [] result_shapes = []
for shape_pb in shapes: for shape_pb in shapes:
result_shapes.append(tf_tensor_shape(shape_pb)) result_shapes.append(tf_tensor_shape(shape_pb))
Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types}) Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types, 'out_ports_count': 1})
return cls.enabled return cls.enabled

View File

@ -26,8 +26,8 @@ class IteratorGetNextAnalysisTest(unittest.TestCase):
inputs_desc = {} inputs_desc = {}
message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc) message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc)
ref_message = 'It looks like there is IteratorGetNext as input\n' \ ref_message = 'It looks like there is IteratorGetNext as input\n' \
'Run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n' \ 'Run the Model Optimizer without --input option \n' \
'And replace all negative values with positive values' 'Otherwise, try to run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n'
self.assertEqual(message, ref_message) self.assertEqual(message, ref_message)
def test_negative(self): def test_negative(self):

View File

@ -0,0 +1,95 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.front.tf.IteratorGetNextCut import IteratorGetNextCut
from mo.front.common.partial_infer.utils import shape_array
from mo.utils.error import Error
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph_with_edge_attrs
class IteratorGetNextAnalysisTest(unittest.TestCase):
def test_one_output(self):
graph = build_graph_with_edge_attrs(
{
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
'types': [np.int32]},
'sub': {'kind': 'op', 'op': 'Sub'},
},
[
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
]
)
graph_ref = build_graph_with_edge_attrs(
{
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'type': np.int32},
'sub': {'kind': 'op', 'op': 'Sub'},
},
[
('parameter_1', 'sub', {'out': 0, 'in': 0}),
]
)
IteratorGetNextCut().find_and_replace_pattern(graph)
flag, msg = compare_graphs(graph, graph_ref, last_node='sub')
self.assertTrue(flag, msg)
def test_two_outputs(self):
graph = build_graph_with_edge_attrs(
{
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': [shape_array([2, 2]),
shape_array([1, 1])],
'types': [np.int32, np.float32]},
'sub': {'kind': 'op', 'op': 'Sub'},
'add': {'kind': 'op', 'op': 'Add'},
'concat': {'kind': 'op', 'op': 'Concat'}
},
[
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
('iter_get_next', 'add', {'out': 1, 'in': 0}),
('sub', 'concat', {'out': 0, 'in': 0}),
('add', 'concat', {'out': 0, 'in': 1})
]
)
graph_ref = build_graph_with_edge_attrs(
{
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'data_type': np.int32},
'parameter_2': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([1, 1]), 'data_type': np.float32},
'sub': {'kind': 'op', 'op': 'Sub'},
'add': {'kind': 'op', 'op': 'Add'},
'concat': {'kind': 'op', 'op': 'Concat'}
},
[
('parameter_1', 'sub', {'out': 0, 'in': 0}),
('parameter_2', 'add', {'out': 0, 'in': 0}),
('sub', 'concat', {'out': 0, 'in': 0}),
('add', 'concat', {'out': 0, 'in': 1})
]
)
IteratorGetNextCut().find_and_replace_pattern(graph)
flag, msg = compare_graphs(graph, graph_ref, last_node='concat', check_op_attrs=True)
self.assertTrue(flag, msg)
def test_unsupported_data_type(self):
graph = build_graph_with_edge_attrs(
{
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
'types': [None]},
'sub': {'kind': 'op', 'op': 'Sub'},
},
[
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
]
)
self.assertRaises(Error, IteratorGetNextCut().find_and_replace_pattern, graph)