[MO] add IteratorGetNextCut (#8040)
* added IteratorGetNextCut, some improvements in graph.py * added allowed types check * reused new graph API for ports * returned back old API 'out_nodes', removed soft-getting name from base class, changed run_after -> [] * correctly used new port API * corrected IteratorGetNext message
This commit is contained in:
parent
f1ca728ab1
commit
68badf5165
@ -433,6 +433,7 @@ extensions/front/tf/identityN_to_identity.py
|
||||
extensions/front/tf/if_ext.py
|
||||
extensions/front/tf/InterpolateTransposes.py
|
||||
extensions/front/tf/IteratorGetNext_ext.py
|
||||
extensions/front/tf/IteratorGetNextCut.py
|
||||
extensions/front/tf/log_softmax_ext.py
|
||||
extensions/front/tf/LookupTableInsert_ext.py
|
||||
extensions/front/tf/LoopCond_ext.py
|
||||
|
@ -56,7 +56,7 @@ class InputsAnalysis(AnalyzeAction):
|
||||
def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict):
|
||||
message = None
|
||||
op_nodes = graph.get_op_nodes(op='IteratorGetNext')
|
||||
if len(op_nodes):
|
||||
|
||||
params = ''
|
||||
for iter_get_next in op_nodes:
|
||||
for port in iter_get_next.out_nodes().keys():
|
||||
@ -69,9 +69,11 @@ class InputsAnalysis(AnalyzeAction):
|
||||
params = params + ','
|
||||
shape = str(iter_get_next.shapes[port].tolist()).replace(',', '')
|
||||
params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape)
|
||||
|
||||
if len(op_nodes):
|
||||
message = 'It looks like there is IteratorGetNext as input\n' \
|
||||
'Run the Model Optimizer with:\n\t\t--input "{}"\n' \
|
||||
'And replace all negative values with positive values'.format(params)
|
||||
'Run the Model Optimizer without --input option \n' \
|
||||
'Otherwise, try to run the Model Optimizer with:\n\t\t--input "{}"\n'.format(params)
|
||||
return message
|
||||
|
||||
def analyze(self, graph: Graph):
|
||||
|
45
model-optimizer/extensions/front/tf/IteratorGetNextCut.py
Normal file
45
model-optimizer/extensions/front/tf/IteratorGetNextCut.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from mo.front.extractor import add_input_ops
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES, np_data_type_to_precision
|
||||
from mo.utils.error import Error
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
|
||||
|
||||
class IteratorGetNextCut(FrontReplacementPattern):
|
||||
"""
|
||||
Cuts OneShotIterator -> IteratorGetNext pattern
|
||||
in order to enable Out Of the Box (OOB) usage.
|
||||
Pass is run only if user didn't specify any inputs names and shapes.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['cmd_params'].input is None]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.front.output_cut import OutputCut
|
||||
from extensions.front.input_cut import InputCut
|
||||
return [OutputCut, InputCut]
|
||||
|
||||
def run_after(self):
|
||||
return []
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
iter_get_next_shapes = defaultdict(list)
|
||||
for iter_get_next in graph.get_op_nodes(op='IteratorGetNext'):
|
||||
iter_get_next_name = iter_get_next.soft_get('name', iter_get_next.id)
|
||||
for port in iter_get_next.out_ports():
|
||||
if not np_data_type_to_precision(iter_get_next.types[port]) in SUPPORTED_DATA_TYPES:
|
||||
raise Error("In IteratorGetNext node '{}' data type '{}' is not supported".format(
|
||||
iter_get_next_name, iter_get_next.types[port]))
|
||||
|
||||
iter_get_next_shapes[iter_get_next_name].append(dict(
|
||||
shape=iter_get_next.shapes[port],
|
||||
out=port,
|
||||
data_type=iter_get_next.types[port]
|
||||
))
|
||||
|
||||
add_input_ops(graph, iter_get_next_shapes, True)
|
@ -20,5 +20,5 @@ class IteratorGetNextExtractor(FrontExtractorOp):
|
||||
result_shapes = []
|
||||
for shape_pb in shapes:
|
||||
result_shapes.append(tf_tensor_shape(shape_pb))
|
||||
Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types})
|
||||
Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types, 'out_ports_count': 1})
|
||||
return cls.enabled
|
||||
|
@ -26,8 +26,8 @@ class IteratorGetNextAnalysisTest(unittest.TestCase):
|
||||
inputs_desc = {}
|
||||
message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc)
|
||||
ref_message = 'It looks like there is IteratorGetNext as input\n' \
|
||||
'Run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n' \
|
||||
'And replace all negative values with positive values'
|
||||
'Run the Model Optimizer without --input option \n' \
|
||||
'Otherwise, try to run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n'
|
||||
self.assertEqual(message, ref_message)
|
||||
|
||||
def test_negative(self):
|
||||
|
@ -0,0 +1,95 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.tf.IteratorGetNextCut import IteratorGetNextCut
|
||||
from mo.front.common.partial_infer.utils import shape_array
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph_with_edge_attrs
|
||||
|
||||
|
||||
class IteratorGetNextAnalysisTest(unittest.TestCase):
|
||||
|
||||
def test_one_output(self):
|
||||
graph = build_graph_with_edge_attrs(
|
||||
{
|
||||
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
|
||||
'types': [np.int32]},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
},
|
||||
[
|
||||
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
graph_ref = build_graph_with_edge_attrs(
|
||||
{
|
||||
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'type': np.int32},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
},
|
||||
[
|
||||
('parameter_1', 'sub', {'out': 0, 'in': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
IteratorGetNextCut().find_and_replace_pattern(graph)
|
||||
|
||||
flag, msg = compare_graphs(graph, graph_ref, last_node='sub')
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_two_outputs(self):
|
||||
graph = build_graph_with_edge_attrs(
|
||||
{
|
||||
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': [shape_array([2, 2]),
|
||||
shape_array([1, 1])],
|
||||
'types': [np.int32, np.float32]},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
'add': {'kind': 'op', 'op': 'Add'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||
},
|
||||
[
|
||||
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||
('iter_get_next', 'add', {'out': 1, 'in': 0}),
|
||||
('sub', 'concat', {'out': 0, 'in': 0}),
|
||||
('add', 'concat', {'out': 0, 'in': 1})
|
||||
]
|
||||
)
|
||||
|
||||
graph_ref = build_graph_with_edge_attrs(
|
||||
{
|
||||
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'data_type': np.int32},
|
||||
'parameter_2': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([1, 1]), 'data_type': np.float32},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
'add': {'kind': 'op', 'op': 'Add'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||
},
|
||||
[
|
||||
('parameter_1', 'sub', {'out': 0, 'in': 0}),
|
||||
('parameter_2', 'add', {'out': 0, 'in': 0}),
|
||||
('sub', 'concat', {'out': 0, 'in': 0}),
|
||||
('add', 'concat', {'out': 0, 'in': 1})
|
||||
]
|
||||
)
|
||||
|
||||
IteratorGetNextCut().find_and_replace_pattern(graph)
|
||||
|
||||
flag, msg = compare_graphs(graph, graph_ref, last_node='concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, msg)
|
||||
|
||||
def test_unsupported_data_type(self):
|
||||
graph = build_graph_with_edge_attrs(
|
||||
{
|
||||
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
|
||||
'types': [None]},
|
||||
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||
},
|
||||
[
|
||||
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertRaises(Error, IteratorGetNextCut().find_and_replace_pattern, graph)
|
Loading…
Reference in New Issue
Block a user