diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index dad585c2f43..79be4781543 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -433,6 +433,7 @@ extensions/front/tf/identityN_to_identity.py extensions/front/tf/if_ext.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py +extensions/front/tf/IteratorGetNextCut.py extensions/front/tf/log_softmax_ext.py extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LoopCond_ext.py diff --git a/model-optimizer/extensions/analysis/inputs.py b/model-optimizer/extensions/analysis/inputs.py index a3fee9ffd6c..887dee66a9a 100644 --- a/model-optimizer/extensions/analysis/inputs.py +++ b/model-optimizer/extensions/analysis/inputs.py @@ -56,22 +56,24 @@ class InputsAnalysis(AnalyzeAction): def iterator_get_next_analysis(cls, graph: Graph, inputs_desc: dict): message = None op_nodes = graph.get_op_nodes(op='IteratorGetNext') + + params = '' + for iter_get_next in op_nodes: + for port in iter_get_next.out_nodes().keys(): + inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = { + 'shape': iter_get_next.shapes[port].tolist(), + 'value': None, + 'data_type': iter_get_next.types[port] + } + if params != '': + params = params + ',' + shape = str(iter_get_next.shapes[port].tolist()).replace(',', '') + params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape) + if len(op_nodes): - params = '' - for iter_get_next in op_nodes: - for port in iter_get_next.out_nodes().keys(): - inputs_desc['{}:{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port)] = { - 'shape': iter_get_next.shapes[port].tolist(), - 'value': None, - 'data_type': iter_get_next.types[port] - } - if params != '': - params = params + ',' - shape = str(iter_get_next.shapes[port].tolist()).replace(',', '') - params = params + '{}:{}{}'.format(iter_get_next.soft_get('name', iter_get_next.id), port, shape) message = 'It looks like there is IteratorGetNext as input\n' \ - 'Run the Model Optimizer with:\n\t\t--input "{}"\n' \ - 'And replace all negative values with positive values'.format(params) + 'Run the Model Optimizer without --input option \n' \ + 'Otherwise, try to run the Model Optimizer with:\n\t\t--input "{}"\n'.format(params) return message def analyze(self, graph: Graph): diff --git a/model-optimizer/extensions/front/tf/IteratorGetNextCut.py b/model-optimizer/extensions/front/tf/IteratorGetNextCut.py new file mode 100644 index 00000000000..3cced57f658 --- /dev/null +++ b/model-optimizer/extensions/front/tf/IteratorGetNextCut.py @@ -0,0 +1,45 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict + +from mo.front.extractor import add_input_ops +from mo.graph.graph import Graph +from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES, np_data_type_to_precision +from mo.utils.error import Error +from mo.front.common.replacement import FrontReplacementPattern + + +class IteratorGetNextCut(FrontReplacementPattern): + """ + Cuts OneShotIterator -> IteratorGetNext pattern + in order to enable Out Of the Box (OOB) usage. + Pass is run only if user didn't specify any inputs names and shapes. + """ + enabled = True + graph_condition = [lambda graph: graph.graph['cmd_params'].input is None] + + def run_before(self): + from extensions.front.output_cut import OutputCut + from extensions.front.input_cut import InputCut + return [OutputCut, InputCut] + + def run_after(self): + return [] + + def find_and_replace_pattern(self, graph: Graph): + iter_get_next_shapes = defaultdict(list) + for iter_get_next in graph.get_op_nodes(op='IteratorGetNext'): + iter_get_next_name = iter_get_next.soft_get('name', iter_get_next.id) + for port in iter_get_next.out_ports(): + if not np_data_type_to_precision(iter_get_next.types[port]) in SUPPORTED_DATA_TYPES: + raise Error("In IteratorGetNext node '{}' data type '{}' is not supported".format( + iter_get_next_name, iter_get_next.types[port])) + + iter_get_next_shapes[iter_get_next_name].append(dict( + shape=iter_get_next.shapes[port], + out=port, + data_type=iter_get_next.types[port] + )) + + add_input_ops(graph, iter_get_next_shapes, True) diff --git a/model-optimizer/extensions/front/tf/IteratorGetNext_ext.py b/model-optimizer/extensions/front/tf/IteratorGetNext_ext.py index e7999164714..c9318c976e5 100644 --- a/model-optimizer/extensions/front/tf/IteratorGetNext_ext.py +++ b/model-optimizer/extensions/front/tf/IteratorGetNext_ext.py @@ -20,5 +20,5 @@ class IteratorGetNextExtractor(FrontExtractorOp): result_shapes = [] for shape_pb in shapes: result_shapes.append(tf_tensor_shape(shape_pb)) - Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types}) + Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types, 'out_ports_count': 1}) return cls.enabled diff --git a/model-optimizer/unit_tests/extensions/analysis/Iterator_get_next_test.py b/model-optimizer/unit_tests/extensions/analysis/Iterator_get_next_test.py index 801873bdbb9..05946988c0b 100644 --- a/model-optimizer/unit_tests/extensions/analysis/Iterator_get_next_test.py +++ b/model-optimizer/unit_tests/extensions/analysis/Iterator_get_next_test.py @@ -26,8 +26,8 @@ class IteratorGetNextAnalysisTest(unittest.TestCase): inputs_desc = {} message = InputsAnalysis.iterator_get_next_analysis(graph, inputs_desc) ref_message = 'It looks like there is IteratorGetNext as input\n' \ - 'Run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n' \ - 'And replace all negative values with positive values' + 'Run the Model Optimizer without --input option \n' \ + 'Otherwise, try to run the Model Optimizer with:\n\t\t--input "iter_get_next:0[2 2],iter_get_next:1[1 1]"\n' self.assertEqual(message, ref_message) def test_negative(self): diff --git a/model-optimizer/unit_tests/extensions/front/tf/IteratorGetNextCut_test.py b/model-optimizer/unit_tests/extensions/front/tf/IteratorGetNextCut_test.py new file mode 100644 index 00000000000..165e5189fc7 --- /dev/null +++ b/model-optimizer/unit_tests/extensions/front/tf/IteratorGetNextCut_test.py @@ -0,0 +1,95 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np + +from extensions.front.tf.IteratorGetNextCut import IteratorGetNextCut +from mo.front.common.partial_infer.utils import shape_array +from mo.utils.error import Error +from mo.utils.ir_engine.compare_graphs import compare_graphs +from unit_tests.utils.graph import build_graph_with_edge_attrs + + +class IteratorGetNextAnalysisTest(unittest.TestCase): + + def test_one_output(self): + graph = build_graph_with_edge_attrs( + { + 'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]), + 'types': [np.int32]}, + 'sub': {'kind': 'op', 'op': 'Sub'}, + }, + [ + ('iter_get_next', 'sub', {'out': 0, 'in': 0}), + ] + ) + + graph_ref = build_graph_with_edge_attrs( + { + 'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'type': np.int32}, + 'sub': {'kind': 'op', 'op': 'Sub'}, + }, + [ + ('parameter_1', 'sub', {'out': 0, 'in': 0}), + ] + ) + + IteratorGetNextCut().find_and_replace_pattern(graph) + + flag, msg = compare_graphs(graph, graph_ref, last_node='sub') + self.assertTrue(flag, msg) + + def test_two_outputs(self): + graph = build_graph_with_edge_attrs( + { + 'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': [shape_array([2, 2]), + shape_array([1, 1])], + 'types': [np.int32, np.float32]}, + 'sub': {'kind': 'op', 'op': 'Sub'}, + 'add': {'kind': 'op', 'op': 'Add'}, + 'concat': {'kind': 'op', 'op': 'Concat'} + }, + [ + ('iter_get_next', 'sub', {'out': 0, 'in': 0}), + ('iter_get_next', 'add', {'out': 1, 'in': 0}), + ('sub', 'concat', {'out': 0, 'in': 0}), + ('add', 'concat', {'out': 0, 'in': 1}) + ] + ) + + graph_ref = build_graph_with_edge_attrs( + { + 'parameter_1': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([2, 2]), 'data_type': np.int32}, + 'parameter_2': {'kind': 'op', 'op': 'Parameter', 'shape': shape_array([1, 1]), 'data_type': np.float32}, + 'sub': {'kind': 'op', 'op': 'Sub'}, + 'add': {'kind': 'op', 'op': 'Add'}, + 'concat': {'kind': 'op', 'op': 'Concat'} + }, + [ + ('parameter_1', 'sub', {'out': 0, 'in': 0}), + ('parameter_2', 'add', {'out': 0, 'in': 0}), + ('sub', 'concat', {'out': 0, 'in': 0}), + ('add', 'concat', {'out': 0, 'in': 1}) + ] + ) + + IteratorGetNextCut().find_and_replace_pattern(graph) + + flag, msg = compare_graphs(graph, graph_ref, last_node='concat', check_op_attrs=True) + self.assertTrue(flag, msg) + + def test_unsupported_data_type(self): + graph = build_graph_with_edge_attrs( + { + 'iter_get_next': {'kind': 'op', 'op': 'IteratorGetNext', 'shapes': shape_array([[2, 2]]), + 'types': [None]}, + 'sub': {'kind': 'op', 'op': 'Sub'}, + }, + [ + ('iter_get_next', 'sub', {'out': 0, 'in': 0}), + ] + ) + + self.assertRaises(Error, IteratorGetNextCut().find_and_replace_pattern, graph)