[MO] add IteratorGetNextCut (#8040)
* added IteratorGetNextCut, some improvements in graph.py * added allowed types check * reused new graph API for ports * returned back old API 'out_nodes', removed soft-getting name from base class, changed run_after -> [] * correctly used new port API * corrected IteratorGetNext message
This commit is contained in:
parent
f1ca728ab1
commit
68badf5165
@ -433,6 +433,7 @@ extensions/front/tf/identityN_to_identity.py
|
|||||||
extensions/front/tf/if_ext.py
|
extensions/front/tf/if_ext.py
|
||||||
extensions/front/tf/InterpolateTransposes.py
|
extensions/front/tf/InterpolateTransposes.py
|
||||||
extensions/front/tf/IteratorGetNext_ext.py
|
extensions/front/tf/IteratorGetNext_ext.py
|
||||||
|
extensions/front/tf/IteratorGetNextCut.py
|
||||||
extensions/front/tf/log_softmax_ext.py
|
extensions/front/tf/log_softmax_ext.py
|
||||||
extensions/front/tf/LookupTableInsert_ext.py
|
extensions/front/tf/LookupTableInsert_ext.py
|
||||||
extensions/front/tf/LoopCond_ext.py
|
extensions/front/tf/LoopCond_ext.py
|
||||||
|
@ -56,22 +56,24 @@ class InputsAnalysis(AnalyzeAction):
|
|||||||
def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict):
|
def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict):
|
||||||
message = None
|
message = None
|
||||||
op_nodes = graph.get_op_nodes(op='IteratorGetNext')
|
op_nodes = graph.get_op_nodes(op='IteratorGetNext')
|
||||||
|
|
||||||
|
params = ''
|
||||||
|
for iter_get_next in op_nodes:
|
||||||
|
for port in iter_get_next.out_nodes().keys():
|
||||||
|
inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = {
|
||||||
|
'shape': iter_get_next.shapes[port].tolist(),
|
||||||
|
'value': None,
|
||||||
|
'data_type': iter_get_next.types[port]
|
||||||
|
}
|
||||||
|
if params != '':
|
||||||
|
params = params + ','
|
||||||
|
shape = str(iter_get_next.shapes[port].tolist()).replace(',', '')
|
||||||
|
params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape)
|
||||||
|
|
||||||
if len(op_nodes):
|
if len(op_nodes):
|
||||||
params = ''
|
|
||||||
for iter_get_next in op_nodes:
|
|
||||||
for port in iter_get_next.out_nodes().keys():
|
|
||||||
inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = {
|
|
||||||
'shape': iter_get_next.shapes[port].tolist(),
|
|
||||||
'value': None,
|
|
||||||
'data_type': iter_get_next.types[port]
|
|
||||||
}
|
|
||||||
if params != '':
|
|
||||||
params = params + ','
|
|
||||||
shape = str(iter_get_next.shapes[port].tolist()).replace(',', '')
|
|
||||||
params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape)
|
|
||||||
message = 'It looks like there is IteratorGetNext as input\n' \
|
message = 'It looks like there is IteratorGetNext as input\n' \
|
||||||
'Run the Model Optimizer with:\n\t\t--input "{}"\n' \
|
'Run the Model Optimizer without --input option \n' \
|
||||||
'And replace all negative values with positive values'.format(params)
|
'Otherwise, try to run the Model Optimizer with:\n\t\t--input "{}"\n'.format(params)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def analyze(self, graph: Graph):
|
def analyze(self, graph: Graph):
|
||||||
|
45
model-optimizer/extensions/front/tf/IteratorGetNextCut.py
Normal file
45
model-optimizer/extensions/front/tf/IteratorGetNextCut.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from mo.front.extractor import add_input_ops
|
||||||
|
from mo.graph.graph import Graph
|
||||||
|
from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES, np_data_type_to_precision
|
||||||
|
from mo.utils.error import Error
|
||||||
|
from mo.front.common.replacement import FrontReplacementPattern
|
||||||
|
|
||||||
|
|
||||||
|
class IteratorGetNextCut(FrontReplacementPattern):
|
||||||
|
"""
|
||||||
|
Cuts OneShotIterator -> IteratorGetNext pattern
|
||||||
|
in order to enable Out Of the Box (OOB) usage.
|
||||||
|
Pass is run only if user didn't specify any inputs names and shapes.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
graph_condition = [lambda graph: graph.graph['cmd_params'].input is None]
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
from extensions.front.output_cut import OutputCut
|
||||||
|
from extensions.front.input_cut import InputCut
|
||||||
|
return [OutputCut, InputCut]
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
iter_get_next_shapes = defaultdict(list)
|
||||||
|
for iter_get_next in graph.get_op_nodes(op='IteratorGetNext'):
|
||||||
|
iter_get_next_name = iter_get_next.soft_get('name', iter_get_next.id)
|
||||||
|
for port in iter_get_next.out_ports():
|
||||||
|
if not np_data_type_to_precision(iter_get_next.types[port]) in SUPPORTED_DATA_TYPES:
|
||||||
|
raise Error("In IteratorGetNext node '{}' data type '{}' is not supported".format(
|
||||||
|
iter_get_next_name, iter_get_next.types[port]))
|
||||||
|
|
||||||
|
iter_get_next_shapes[iter_get_next_name].append(dict(
|
||||||
|
shape=iter_get_next.shapes[port],
|
||||||
|
out=port,
|
||||||
|
data_type=iter_get_next.types[port]
|
||||||
|
))
|
||||||
|
|
||||||
|
add_input_ops(graph, iter_get_next_shapes, True)
|
@ -20,5 +20,5 @@ class IteratorGetNextExtractor(FrontExtractorOp):
|
|||||||
result_shapes = []
|
result_shapes = []
|
||||||
for shape_pb in shapes:
|
for shape_pb in shapes:
|
||||||
result_shapes.append(tf_tensor_shape(shape_pb))
|
result_shapes.append(tf_tensor_shape(shape_pb))
|
||||||
Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types})
|
Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types, 'out_ports_count': 1})
|
||||||
return cls.enabled
|
return cls.enabled
|
||||||
|
@ -26,8 +26,8 @@ class IteratorGetNextAnalysisTest(unittest.TestCase):
|
|||||||
inputs_desc = {}
|
inputs_desc = {}
|
||||||
message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc)
|
message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc)
|
||||||
ref_message = 'It looks like there is IteratorGetNext as input\n' \
|
ref_message = 'It looks like there is IteratorGetNext as input\n' \
|
||||||
'Run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n' \
|
'Run the Model Optimizer without --input option \n' \
|
||||||
'And replace all negative values with positive values'
|
'Otherwise, try to run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n'
|
||||||
self.assertEqual(message, ref_message)
|
self.assertEqual(message, ref_message)
|
||||||
|
|
||||||
def test_negative(self):
|
def test_negative(self):
|
||||||
|
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.front.tf.IteratorGetNextCut import IteratorGetNextCut
|
||||||
|
from mo.front.common.partial_infer.utils import shape_array
|
||||||
|
from mo.utils.error import Error
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from unit_tests.utils.graph import build_graph_with_edge_attrs
|
||||||
|
|
||||||
|
|
||||||
|
class IteratorGetNextAnalysisTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_one_output(self):
|
||||||
|
graph = build_graph_with_edge_attrs(
|
||||||
|
{
|
||||||
|
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
|
||||||
|
'types': [np.int32]},
|
||||||
|
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_ref = build_graph_with_edge_attrs(
|
||||||
|
{
|
||||||
|
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'type': np.int32},
|
||||||
|
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('parameter_1', 'sub', {'out': 0, 'in': 0}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
IteratorGetNextCut().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
flag, msg = compare_graphs(graph, graph_ref, last_node='sub')
|
||||||
|
self.assertTrue(flag, msg)
|
||||||
|
|
||||||
|
def test_two_outputs(self):
|
||||||
|
graph = build_graph_with_edge_attrs(
|
||||||
|
{
|
||||||
|
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': [shape_array([2, 2]),
|
||||||
|
shape_array([1, 1])],
|
||||||
|
'types': [np.int32, np.float32]},
|
||||||
|
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||||
|
'add': {'kind': 'op', 'op': 'Add'},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||||
|
('iter_get_next', 'add', {'out': 1, 'in': 0}),
|
||||||
|
('sub', 'concat', {'out': 0, 'in': 0}),
|
||||||
|
('add', 'concat', {'out': 0, 'in': 1})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_ref = build_graph_with_edge_attrs(
|
||||||
|
{
|
||||||
|
'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'data_type': np.int32},
|
||||||
|
'parameter_2': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([1, 1]), 'data_type': np.float32},
|
||||||
|
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||||
|
'add': {'kind': 'op', 'op': 'Add'},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'}
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('parameter_1', 'sub', {'out': 0, 'in': 0}),
|
||||||
|
('parameter_2', 'add', {'out': 0, 'in': 0}),
|
||||||
|
('sub', 'concat', {'out': 0, 'in': 0}),
|
||||||
|
('add', 'concat', {'out': 0, 'in': 1})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
IteratorGetNextCut().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
flag, msg = compare_graphs(graph, graph_ref, last_node='concat', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, msg)
|
||||||
|
|
||||||
|
def test_unsupported_data_type(self):
|
||||||
|
graph = build_graph_with_edge_attrs(
|
||||||
|
{
|
||||||
|
'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]),
|
||||||
|
'types': [None]},
|
||||||
|
'sub': {'kind': 'op', 'op': 'Sub'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('iter_get_next', 'sub', {'out': 0, 'in': 0}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertRaises(Error, IteratorGetNextCut().find_and_replace_pattern, graph)
|
Loading…
Reference in New Issue
Block a user