Automatically detect --disable_nhwc_to_nchw option in MO (#8450)
Clean code Fix Result runtime info Fix documentation Fix documentation Apply suggestions from code review Co-authored-by: Tatiana Savina <tatiana.savina@intel.com> Update docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md Co-authored-by: Tatiana Savina <tatiana.savina@intel.com> Apply review feedback Apply review feedback Apply review feedback
This commit is contained in:
parent
d24a48901e
commit
e4b5c54006
@ -299,7 +299,9 @@ TensorFlow*-specific parameters:
|
||||
TensorFlow*: comma separated list of shared libraries
|
||||
with TensorFlow* custom operations implementation.
|
||||
--disable_nhwc_to_nchw
|
||||
Disables default translation from NHWC to NCHW
|
||||
[DEPRECATED] Disables default translation from NHWC to NCHW. Since 2022.1
|
||||
this option is deprecated and used only to maintain backward compatibility
|
||||
with previous releases.
|
||||
```
|
||||
|
||||
> **NOTE:** Models produces with TensorFlow\* usually have not fully defined shapes (contain `-1` in some dimensions). It is necessary to pass explicit shape for the input using command line parameter `--input_shape` or `-b` to override just batch dimension. If the shape is fully defined, then there is no need to specify either `-b` or `--input_shape` options.
|
||||
|
@ -35,7 +35,6 @@ To generate the BERT Intermediate Representation (IR) of the model, run the Mode
|
||||
python3 ./mo_tf.py
|
||||
--input_meta_graph uncased_L-12_H-768_A-12/bert_model.ckpt.meta \
|
||||
--output bert/pooler/dense/Tanh \
|
||||
--disable_nhwc_to_nchw \
|
||||
--input Placeholder{i32},Placeholder_1{i32},Placeholder_2{i32}
|
||||
```
|
||||
|
||||
@ -110,10 +109,9 @@ python3 run_classifier.py \
|
||||
|
||||
Run the Model Optimizer with the following command line parameters to generate reshape-able BERT Intermediate Representation (IR):
|
||||
```sh
|
||||
python3 ./mo_tf.py
|
||||
--input_model inference_graph.pb
|
||||
--input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]"
|
||||
--disable_nhwc_to_nchw
|
||||
python3 ./mo_tf.py \
|
||||
--input_model inference_graph.pb \
|
||||
--input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]"
|
||||
```
|
||||
For other applicable parameters, refer to [Convert Model from TensorFlow](../Convert_Model_From_TensorFlow.md).
|
||||
|
||||
|
@ -71,8 +71,7 @@ To generate the IR, run the Model Optimizer with the following parameters:
|
||||
python3 {path_to_mo}/mo_tf.py \
|
||||
--input_model output_graph.pb \
|
||||
--input "input_lengths->[16],input_node[1 16 19 26],previous_state_h[1 2048],previous_state_c[1 2048]" \
|
||||
--output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits" \
|
||||
--disable_nhwc_to_nchw
|
||||
--output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits"
|
||||
```
|
||||
|
||||
Where:
|
||||
|
@ -186,6 +186,7 @@ The script should save into `~/XLNet-Large/xlnet`.
|
||||
|
||||
To generate the XLNet Intermediate Representation (IR) of the model, run the Model Optimizer with the following parameters:
|
||||
```sh
|
||||
python3 mo.py --input_model path-to-model/model_frozen.pb --input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]" --log_level DEBUG --disable_nhwc_to_nchw --output_dir <OUTPUT_MODEL_DIR>
|
||||
python3 mo.py --input_model path-to-model/model_frozen.pb \
|
||||
--input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]"
|
||||
```
|
||||
|
||||
|
@ -285,10 +285,9 @@ More information on how to develop middle transformations and dedicated API desc
|
||||
### NHWC to NCHW Layout Change <a name="layout-change"></a>
|
||||
There are several middle transformations responsible for changing model layout from NHWC to NCHW. These transformations
|
||||
are triggered by default for TensorFlow\* models only because it is the only framework with Convolution operations in
|
||||
NHWC layout.
|
||||
|
||||
> **NOTE**: If a TensorFlow\* model is in NCHW layout, you should specify the `--disable_nhwc_to_nchw` command line
|
||||
> parameter to disable these transformations.
|
||||
NHWC layout. This layout change is disabled if the model does not have operations that OpenVINO&trade needs to execute in
|
||||
NCHW layout, for example, Convolutions in NHWC layout. It is still possible to force Model Optimizer to do layout change
|
||||
using `--disable_nhwc_to_nchw` command-line parameter.
|
||||
|
||||
The layout change is a complex problem and detailed explanation of it is out of this document scope. A very brief
|
||||
explanation of this process is provided below:
|
||||
|
@ -23,6 +23,7 @@ from mo.front.extractor import restore_edges, extract_node_attrs, remove_control
|
||||
from mo.front.tf.extractor import get_tf_edges, create_tf_edge, tf_op_extractor, tf_op_extractors
|
||||
from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
from mo.utils import tensorboard_util
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.telemetry_utils import send_op_names_info, send_shapes_info, send_framework_info
|
||||
@ -103,10 +104,10 @@ class TFLoader(Loader):
|
||||
|
||||
# try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we
|
||||
# consider that the graph is in NCHW layout and no layout conversion should be performed
|
||||
if not argv.disable_nhwc_to_nchw and not argv.silent and not graph_or_sub_graph_has_nhwc_ops(graph):
|
||||
log.error('The TensorFlow model does not contain Convolution operations with N(D)HWC layout. Most likely '
|
||||
'the model should be converted using additional "--disable_nhwc_to_nchw" command line parameter '
|
||||
'which disables model layout conversion inside the Model Optimizer.', extra={'is_warning': True})
|
||||
if not argv.disable_nhwc_to_nchw and not graph_or_sub_graph_has_nhwc_ops(graph):
|
||||
if not argv.silent:
|
||||
log.debug('disable_nhwc_to_nchw" was automatically enabled.')
|
||||
for_graph_and_each_sub_graph_recursively(graph, update_cmd_params_and_layout)
|
||||
|
||||
send_op_names_info(framework, graph)
|
||||
send_shapes_info(framework, graph)
|
||||
@ -143,3 +144,15 @@ def graph_or_sub_graph_has_nhwc_ops(graph: Graph):
|
||||
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
|
||||
|
||||
return NHWC_conv_detected
|
||||
|
||||
|
||||
def update_cmd_params_and_layout(graph: Graph):
|
||||
"""
|
||||
Updates "cmd_params" and "layout" attribute as the model has only NCHW layout operations.
|
||||
:param graph: graph to update attributes
|
||||
:return: Nones
|
||||
"""
|
||||
if 'cmd_params' in graph.graph:
|
||||
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True
|
||||
if 'layout' in graph.graph:
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
@ -5,8 +5,9 @@ import numpy as np
|
||||
|
||||
from extensions.middle.MergeNodesPermutations import MergeNodesPermutations
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.utils.runtime_info import OldAPIMapOrder
|
||||
|
||||
@ -39,30 +40,41 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
self.preserve_rt_info(graph)
|
||||
|
||||
@staticmethod
|
||||
def add_old_api_map_order_into_rt_info(op: Node):
|
||||
# rt info update
|
||||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
|
||||
|
||||
old_api_map = OldAPIMapOrder(version=0)
|
||||
attr_name = old_api_map.get_name()
|
||||
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
||||
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
||||
return attr_name, old_api_map.get_version()
|
||||
|
||||
@staticmethod
|
||||
def preserve_rt_info(graph: Graph):
|
||||
for op in graph.get_op_nodes():
|
||||
for op in graph.get_op_nodes(type='Parameter'):
|
||||
op_name = op.soft_get('name', op.id)
|
||||
op_type = op.soft_get('type')
|
||||
if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'):
|
||||
if not op.out_node(0).has_valid('permutation'):
|
||||
if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \
|
||||
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw:
|
||||
rank = op.out_port(0).data.get_shape().size
|
||||
if rank < 4:
|
||||
continue
|
||||
order = list(range(rank))
|
||||
order.remove(1)
|
||||
order.append(1)
|
||||
order = int64_array(order)
|
||||
elif op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout') and \
|
||||
op.out_node(0).has_valid('permutation'):
|
||||
permutation = op.out_node(0).permutation
|
||||
if np.array_equal(permutation.inv, range(len(permutation.inv))):
|
||||
order = permutation.inv
|
||||
if np.array_equal(order, range(len(permutation.inv))):
|
||||
continue
|
||||
|
||||
# rt info update
|
||||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name)
|
||||
|
||||
old_api_map = OldAPIMapOrder(version=0)
|
||||
attr_name = old_api_map.get_name()
|
||||
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
||||
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
||||
op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_parameter(permutation.inv)
|
||||
|
||||
# keep input in the framework format
|
||||
transpose = create_op_node_with_second_input(
|
||||
graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)})
|
||||
graph, Transpose, permutation.perm,
|
||||
{'name': op_name + '/Transpose({})'.format(permutation.perm)})
|
||||
|
||||
# source mode is used to keep tensor names at Parameter node
|
||||
op.out_port(0).get_connection().insert_node(transpose, "source")
|
||||
@ -71,25 +83,34 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
||||
del op['permute_attrs']
|
||||
if op.out_node(0).has_valid('permutation'):
|
||||
del op.out_node(0)['permutation']
|
||||
else:
|
||||
continue
|
||||
|
||||
elif op_type == 'Result' and op.in_ports():
|
||||
rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op)
|
||||
op.rt_info.info[rt_info_key].old_api_transpose_parameter(order)
|
||||
|
||||
for op in graph.get_op_nodes(type='Result'):
|
||||
if op.in_ports():
|
||||
prev_node_out_port = op.in_port(0).get_connection().get_source()
|
||||
if prev_node_out_port is None:
|
||||
continue
|
||||
in_node = prev_node_out_port.node
|
||||
in_data_node = in_node.out_node(prev_node_out_port.idx)
|
||||
if in_data_node.has_and_set('permutation'):
|
||||
permutation = in_data_node['permutation']
|
||||
if np.array_equal(permutation.perm, range(len(permutation.perm))):
|
||||
continue
|
||||
|
||||
# rt info update
|
||||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
|
||||
old_api_map = OldAPIMapOrder(version=0)
|
||||
attr_name = old_api_map.get_name()
|
||||
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
||||
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
||||
op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_result(permutation.perm)
|
||||
if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \
|
||||
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw:
|
||||
rank = prev_node_out_port.data.get_shape().size
|
||||
if rank < 4:
|
||||
continue
|
||||
order = list(range(rank - 1))
|
||||
order.insert(1, rank - 1)
|
||||
order = int64_array(order)
|
||||
elif in_data_node.has_and_set('permutation'):
|
||||
permutation = in_data_node['permutation']
|
||||
order = permutation.perm
|
||||
|
||||
if np.array_equal(order, range(len(permutation.perm))):
|
||||
continue
|
||||
|
||||
# keep result in the framework format
|
||||
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv)
|
||||
@ -98,3 +119,8 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
||||
in_node.name += "/prev"
|
||||
|
||||
prev_node_out_port.get_connection().insert_node(transpose)
|
||||
else:
|
||||
continue
|
||||
|
||||
rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op)
|
||||
op.rt_info.info[rt_info_key].old_api_transpose_result(order)
|
||||
|
@ -564,7 +564,8 @@ def get_tf_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
default=None,
|
||||
action=CanonicalizePathCheckExistenceAction)
|
||||
tf_group.add_argument('--disable_nhwc_to_nchw',
|
||||
help='Disables default translation from NHWC to NCHW',
|
||||
help='[DEPRECATED] Disables the default translation from NHWC to NCHW. Since 2022.1 this option '
|
||||
'is deprecated and used only to maintain backward compatibility with previous releases.',
|
||||
action='store_true')
|
||||
return parser
|
||||
|
||||
|
@ -88,3 +88,37 @@ class PreserveRuntimeInfoTest(unittest.TestCase):
|
||||
rt_info = result_node.rt_info.info
|
||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||
self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
|
||||
|
||||
def test_auto_disable_nhwc_to_nchw(self):
|
||||
shape_len = 4
|
||||
shape = np.array(range(shape_len))
|
||||
add_shape = shape
|
||||
graph_nodes = {
|
||||
**regular_op_with_shaped_data('placeholder1', shape,
|
||||
{'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
|
||||
**regular_op_with_shaped_data('placeholder2', shape,
|
||||
{'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
|
||||
**regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}),
|
||||
**regular_op_with_shaped_data('add', add_shape,
|
||||
{'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}),
|
||||
}
|
||||
|
||||
graph = build_graph(graph_nodes, edges)
|
||||
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True
|
||||
graph_ref = build_graph(graph_nodes, edges)
|
||||
|
||||
param_node = Node(graph, 'placeholder1')
|
||||
result_node = Node(graph, 'result')
|
||||
|
||||
PreserveRuntimeInfo().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
rt_info = param_node.rt_info.info
|
||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||
self.assertTrue(np.array_equal(old_api_map['inverse_order'], [0, 2, 3, 1]))
|
||||
|
||||
rt_info = result_node.rt_info.info
|
||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||
self.assertTrue(np.array_equal(old_api_map['order'], [0, 3, 1, 2]))
|
||||
|
Loading…
Reference in New Issue
Block a user