diff --git a/docs/MO_DG/prepare_model/convert_model/Convert_Model_From_TensorFlow.md b/docs/MO_DG/prepare_model/convert_model/Convert_Model_From_TensorFlow.md index d5124fab21b..20ebbf0c0ed 100644 --- a/docs/MO_DG/prepare_model/convert_model/Convert_Model_From_TensorFlow.md +++ b/docs/MO_DG/prepare_model/convert_model/Convert_Model_From_TensorFlow.md @@ -299,7 +299,9 @@ TensorFlow*-specific parameters: TensorFlow*: comma separated list of shared libraries with TensorFlow* custom operations implementation. --disable_nhwc_to_nchw - Disables default translation from NHWC to NCHW + [DEPRECATED] Disables default translation from NHWC to NCHW. Since 2022.1 + this option is deprecated and used only to maintain backward compatibility + with previous releases. ``` > **NOTE:** Models produces with TensorFlow\* usually have not fully defined shapes (contain `-1` in some dimensions). It is necessary to pass explicit shape for the input using command line parameter `--input_shape` or `-b` to override just batch dimension. If the shape is fully defined, then there is no need to specify either `-b` or `--input_shape` options. diff --git a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_BERT_From_Tensorflow.md b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_BERT_From_Tensorflow.md index ff141425813..0f68ae6e39e 100644 --- a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_BERT_From_Tensorflow.md +++ b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_BERT_From_Tensorflow.md @@ -35,7 +35,6 @@ To generate the BERT Intermediate Representation (IR) of the model, run the Mode python3 ./mo_tf.py --input_meta_graph uncased_L-12_H-768_A-12/bert_model.ckpt.meta \ --output bert/pooler/dense/Tanh \ ---disable_nhwc_to_nchw \ --input Placeholder{i32},Placeholder_1{i32},Placeholder_2{i32} ``` @@ -110,10 +109,9 @@ python3 run_classifier.py \ Run the Model Optimizer with the following command line parameters to generate reshape-able BERT Intermediate Representation (IR): ```sh -python3 ./mo_tf.py ---input_model inference_graph.pb ---input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]" ---disable_nhwc_to_nchw +python3 ./mo_tf.py \ + --input_model inference_graph.pb \ + --input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]" ``` For other applicable parameters, refer to [Convert Model from TensorFlow](../Convert_Model_From_TensorFlow.md). diff --git a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_DeepSpeech_From_Tensorflow.md b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_DeepSpeech_From_Tensorflow.md index 29df0e4695d..2de384fe858 100644 --- a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_DeepSpeech_From_Tensorflow.md +++ b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_DeepSpeech_From_Tensorflow.md @@ -71,8 +71,7 @@ To generate the IR, run the Model Optimizer with the following parameters: python3 {path_to_mo}/mo_tf.py \ --input_model output_graph.pb \ --input "input_lengths->[16],input_node[1 16 19 26],previous_state_h[1 2048],previous_state_c[1 2048]" \ ---output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits" \ ---disable_nhwc_to_nchw +--output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits" ``` Where: diff --git a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_XLNet_From_Tensorflow.md b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_XLNet_From_Tensorflow.md index ac706c664f2..e8b903b5193 100644 --- a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_XLNet_From_Tensorflow.md +++ b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_XLNet_From_Tensorflow.md @@ -186,6 +186,7 @@ The script should save into `~/XLNet-Large/xlnet`. To generate the XLNet Intermediate Representation (IR) of the model, run the Model Optimizer with the following parameters: ```sh -python3 mo.py --input_model path-to-model/model_frozen.pb --input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]" --log_level DEBUG --disable_nhwc_to_nchw --output_dir +python3 mo.py --input_model path-to-model/model_frozen.pb \ + --input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]" ``` diff --git a/docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md b/docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md index 567543a01a8..d133654b8df 100644 --- a/docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md +++ b/docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md @@ -285,10 +285,9 @@ More information on how to develop middle transformations and dedicated API desc ### NHWC to NCHW Layout Change There are several middle transformations responsible for changing model layout from NHWC to NCHW. These transformations are triggered by default for TensorFlow\* models only because it is the only framework with Convolution operations in -NHWC layout. - -> **NOTE**: If a TensorFlow\* model is in NCHW layout, you should specify the `--disable_nhwc_to_nchw` command line -> parameter to disable these transformations. +NHWC layout. This layout change is disabled if the model does not have operations that OpenVINO&trade needs to execute in +NCHW layout, for example, Convolutions in NHWC layout. It is still possible to force Model Optimizer to do layout change +using `--disable_nhwc_to_nchw` command-line parameter. The layout change is a complex problem and detailed explanation of it is out of this document scope. A very brief explanation of this process is provided below: diff --git a/model-optimizer/extensions/load/tf/loader.py b/model-optimizer/extensions/load/tf/loader.py index 46156e21612..3021cb656ff 100644 --- a/model-optimizer/extensions/load/tf/loader.py +++ b/model-optimizer/extensions/load/tf/loader.py @@ -23,6 +23,7 @@ from mo.front.extractor import restore_edges, extract_node_attrs, remove_control from mo.front.tf.extractor import get_tf_edges, create_tf_edge, tf_op_extractor, tf_op_extractors from mo.front.tf.loader import load_tf_graph_def, protobuf2nx from mo.graph.graph import Graph, Node +from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively from mo.utils import tensorboard_util from mo.utils.error import Error from mo.utils.telemetry_utils import send_op_names_info, send_shapes_info, send_framework_info @@ -103,10 +104,10 @@ class TFLoader(Loader): # try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we # consider that the graph is in NCHW layout and no layout conversion should be performed - if not argv.disable_nhwc_to_nchw and not argv.silent and not graph_or_sub_graph_has_nhwc_ops(graph): - log.error('The TensorFlow model does not contain Convolution operations with N(D)HWC layout. Most likely ' - 'the model should be converted using additional "--disable_nhwc_to_nchw" command line parameter ' - 'which disables model layout conversion inside the Model Optimizer.', extra={'is_warning': True}) + if not argv.disable_nhwc_to_nchw and not graph_or_sub_graph_has_nhwc_ops(graph): + if not argv.silent: + log.debug('disable_nhwc_to_nchw" was automatically enabled.') + for_graph_and_each_sub_graph_recursively(graph, update_cmd_params_and_layout) send_op_names_info(framework, graph) send_shapes_info(framework, graph) @@ -143,3 +144,15 @@ def graph_or_sub_graph_has_nhwc_ops(graph: Graph): NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name)) return NHWC_conv_detected + + +def update_cmd_params_and_layout(graph: Graph): + """ + Updates "cmd_params" and "layout" attribute as the model has only NCHW layout operations. + :param graph: graph to update attributes + :return: Nones + """ + if 'cmd_params' in graph.graph: + graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True + if 'layout' in graph.graph: + graph.graph['layout'] = 'NCHW' diff --git a/model-optimizer/extensions/middle/PreserveRuntimeInfo.py b/model-optimizer/extensions/middle/PreserveRuntimeInfo.py index c48a8734aae..dd926da7217 100644 --- a/model-optimizer/extensions/middle/PreserveRuntimeInfo.py +++ b/model-optimizer/extensions/middle/PreserveRuntimeInfo.py @@ -5,8 +5,9 @@ import numpy as np from extensions.middle.MergeNodesPermutations import MergeNodesPermutations from extensions.ops.transpose import Transpose +from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.graph_utils import create_op_node_with_second_input -from mo.graph.graph import Graph +from mo.graph.graph import Graph, Node from mo.middle.replacement import MiddleReplacementPattern from mo.utils.runtime_info import OldAPIMapOrder @@ -39,30 +40,41 @@ class PreserveRuntimeInfo(MiddleReplacementPattern): def find_and_replace_pattern(self, graph: Graph): self.preserve_rt_info(graph) + @staticmethod + def add_old_api_map_order_into_rt_info(op: Node): + # rt info update + assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op) + + old_api_map = OldAPIMapOrder(version=0) + attr_name = old_api_map.get_name() + if (attr_name, old_api_map.get_version()) not in op.rt_info.info: + op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map + return attr_name, old_api_map.get_version() + @staticmethod def preserve_rt_info(graph: Graph): - for op in graph.get_op_nodes(): + for op in graph.get_op_nodes(type='Parameter'): op_name = op.soft_get('name', op.id) - op_type = op.soft_get('type') - if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'): - if not op.out_node(0).has_valid('permutation'): + if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \ + graph.graph['cmd_params'].auto_disable_nhwc_to_nchw: + rank = op.out_port(0).data.get_shape().size + if rank < 4: continue + order = list(range(rank)) + order.remove(1) + order.append(1) + order = int64_array(order) + elif op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout') and \ + op.out_node(0).has_valid('permutation'): permutation = op.out_node(0).permutation - if np.array_equal(permutation.inv, range(len(permutation.inv))): + order = permutation.inv + if np.array_equal(order, range(len(permutation.inv))): continue - # rt info update - assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name) - - old_api_map = OldAPIMapOrder(version=0) - attr_name = old_api_map.get_name() - if (attr_name, old_api_map.get_version()) not in op.rt_info.info: - op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map - op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_parameter(permutation.inv) - # keep input in the framework format transpose = create_op_node_with_second_input( - graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)}) + graph, Transpose, permutation.perm, + {'name': op_name + '/Transpose({})'.format(permutation.perm)}) # source mode is used to keep tensor names at Parameter node op.out_port(0).get_connection().insert_node(transpose, "source") @@ -71,25 +83,34 @@ class PreserveRuntimeInfo(MiddleReplacementPattern): del op['permute_attrs'] if op.out_node(0).has_valid('permutation'): del op.out_node(0)['permutation'] + else: + continue - elif op_type == 'Result' and op.in_ports(): + rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op) + op.rt_info.info[rt_info_key].old_api_transpose_parameter(order) + + for op in graph.get_op_nodes(type='Result'): + if op.in_ports(): prev_node_out_port = op.in_port(0).get_connection().get_source() if prev_node_out_port is None: continue in_node = prev_node_out_port.node in_data_node = in_node.out_node(prev_node_out_port.idx) - if in_data_node.has_and_set('permutation'): - permutation = in_data_node['permutation'] - if np.array_equal(permutation.perm, range(len(permutation.perm))): - continue - # rt info update - assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op) - old_api_map = OldAPIMapOrder(version=0) - attr_name = old_api_map.get_name() - if (attr_name, old_api_map.get_version()) not in op.rt_info.info: - op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map - op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_result(permutation.perm) + if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \ + graph.graph['cmd_params'].auto_disable_nhwc_to_nchw: + rank = prev_node_out_port.data.get_shape().size + if rank < 4: + continue + order = list(range(rank - 1)) + order.insert(1, rank - 1) + order = int64_array(order) + elif in_data_node.has_and_set('permutation'): + permutation = in_data_node['permutation'] + order = permutation.perm + + if np.array_equal(order, range(len(permutation.perm))): + continue # keep result in the framework format transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv) @@ -98,3 +119,8 @@ class PreserveRuntimeInfo(MiddleReplacementPattern): in_node.name += "/prev" prev_node_out_port.get_connection().insert_node(transpose) + else: + continue + + rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op) + op.rt_info.info[rt_info_key].old_api_transpose_result(order) diff --git a/model-optimizer/mo/utils/cli_parser.py b/model-optimizer/mo/utils/cli_parser.py index 0482c865a02..6b2dd86360c 100644 --- a/model-optimizer/mo/utils/cli_parser.py +++ b/model-optimizer/mo/utils/cli_parser.py @@ -564,7 +564,8 @@ def get_tf_cli_parser(parser: argparse.ArgumentParser = None): default=None, action=CanonicalizePathCheckExistenceAction) tf_group.add_argument('--disable_nhwc_to_nchw', - help='Disables default translation from NHWC to NCHW', + help='[DEPRECATED] Disables the default translation from NHWC to NCHW. Since 2022.1 this option ' + 'is deprecated and used only to maintain backward compatibility with previous releases.', action='store_true') return parser diff --git a/model-optimizer/unit_tests/extensions/middle/PreserveRuntimeInfo_test.py b/model-optimizer/unit_tests/extensions/middle/PreserveRuntimeInfo_test.py index 305855e33c8..a8b85342f26 100644 --- a/model-optimizer/unit_tests/extensions/middle/PreserveRuntimeInfo_test.py +++ b/model-optimizer/unit_tests/extensions/middle/PreserveRuntimeInfo_test.py @@ -88,3 +88,37 @@ class PreserveRuntimeInfoTest(unittest.TestCase): rt_info = result_node.rt_info.info old_api_map = rt_info[('old_api_map_order', 0)].info self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order)) + + def test_auto_disable_nhwc_to_nchw(self): + shape_len = 4 + shape = np.array(range(shape_len)) + add_shape = shape + graph_nodes = { + **regular_op_with_shaped_data('placeholder1', shape, + {'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}), + **regular_op_with_shaped_data('placeholder2', shape, + {'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}), + **regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}), + **regular_op_with_shaped_data('add', add_shape, + {'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}), + } + + graph = build_graph(graph_nodes, edges) + graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True + graph_ref = build_graph(graph_nodes, edges) + + param_node = Node(graph, 'placeholder1') + result_node = Node(graph, 'result') + + PreserveRuntimeInfo().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + rt_info = param_node.rt_info.info + old_api_map = rt_info[('old_api_map_order', 0)].info + self.assertTrue(np.array_equal(old_api_map['inverse_order'], [0, 2, 3, 1])) + + rt_info = result_node.rt_info.info + old_api_map = rt_info[('old_api_map_order', 0)].info + self.assertTrue(np.array_equal(old_api_map['order'], [0, 3, 1, 2]))