Automatically detect --disable_nhwc_to_nchw option in MO (#8450)
Clean code Fix Result runtime info Fix documentation Fix documentation Apply suggestions from code review Co-authored-by: Tatiana Savina <tatiana.savina@intel.com> Update docs/MO_DG/prepare_model/customize_model_optimizer/Customize_Model_Optimizer.md Co-authored-by: Tatiana Savina <tatiana.savina@intel.com> Apply review feedback Apply review feedback Apply review feedback
This commit is contained in:
parent
d24a48901e
commit
e4b5c54006
@ -299,7 +299,9 @@ TensorFlow*-specific parameters:
|
|||||||
TensorFlow*: comma separated list of shared libraries
|
TensorFlow*: comma separated list of shared libraries
|
||||||
with TensorFlow* custom operations implementation.
|
with TensorFlow* custom operations implementation.
|
||||||
--disable_nhwc_to_nchw
|
--disable_nhwc_to_nchw
|
||||||
Disables default translation from NHWC to NCHW
|
[DEPRECATED] Disables default translation from NHWC to NCHW. Since 2022.1
|
||||||
|
this option is deprecated and used only to maintain backward compatibility
|
||||||
|
with previous releases.
|
||||||
```
|
```
|
||||||
|
|
||||||
> **NOTE:** Models produces with TensorFlow\* usually have not fully defined shapes (contain `-1` in some dimensions). It is necessary to pass explicit shape for the input using command line parameter `--input_shape` or `-b` to override just batch dimension. If the shape is fully defined, then there is no need to specify either `-b` or `--input_shape` options.
|
> **NOTE:** Models produces with TensorFlow\* usually have not fully defined shapes (contain `-1` in some dimensions). It is necessary to pass explicit shape for the input using command line parameter `--input_shape` or `-b` to override just batch dimension. If the shape is fully defined, then there is no need to specify either `-b` or `--input_shape` options.
|
||||||
|
@ -35,7 +35,6 @@ To generate the BERT Intermediate Representation (IR) of the model, run the Mode
|
|||||||
python3 ./mo_tf.py
|
python3 ./mo_tf.py
|
||||||
--input_meta_graph uncased_L-12_H-768_A-12/bert_model.ckpt.meta \
|
--input_meta_graph uncased_L-12_H-768_A-12/bert_model.ckpt.meta \
|
||||||
--output bert/pooler/dense/Tanh \
|
--output bert/pooler/dense/Tanh \
|
||||||
--disable_nhwc_to_nchw \
|
|
||||||
--input Placeholder{i32},Placeholder_1{i32},Placeholder_2{i32}
|
--input Placeholder{i32},Placeholder_1{i32},Placeholder_2{i32}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -110,10 +109,9 @@ python3 run_classifier.py \
|
|||||||
|
|
||||||
Run the Model Optimizer with the following command line parameters to generate reshape-able BERT Intermediate Representation (IR):
|
Run the Model Optimizer with the following command line parameters to generate reshape-able BERT Intermediate Representation (IR):
|
||||||
```sh
|
```sh
|
||||||
python3 ./mo_tf.py
|
python3 ./mo_tf.py \
|
||||||
--input_model inference_graph.pb
|
--input_model inference_graph.pb \
|
||||||
--input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]"
|
--input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]"
|
||||||
--disable_nhwc_to_nchw
|
|
||||||
```
|
```
|
||||||
For other applicable parameters, refer to [Convert Model from TensorFlow](../Convert_Model_From_TensorFlow.md).
|
For other applicable parameters, refer to [Convert Model from TensorFlow](../Convert_Model_From_TensorFlow.md).
|
||||||
|
|
||||||
|
@ -71,8 +71,7 @@ To generate the IR, run the Model Optimizer with the following parameters:
|
|||||||
python3 {path_to_mo}/mo_tf.py \
|
python3 {path_to_mo}/mo_tf.py \
|
||||||
--input_model output_graph.pb \
|
--input_model output_graph.pb \
|
||||||
--input "input_lengths->[16],input_node[1 16 19 26],previous_state_h[1 2048],previous_state_c[1 2048]" \
|
--input "input_lengths->[16],input_node[1 16 19 26],previous_state_h[1 2048],previous_state_c[1 2048]" \
|
||||||
--output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits" \
|
--output "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1,cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd,logits"
|
||||||
--disable_nhwc_to_nchw
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Where:
|
Where:
|
||||||
|
@ -186,6 +186,7 @@ The script should save into `~/XLNet-Large/xlnet`.
|
|||||||
|
|
||||||
To generate the XLNet Intermediate Representation (IR) of the model, run the Model Optimizer with the following parameters:
|
To generate the XLNet Intermediate Representation (IR) of the model, run the Model Optimizer with the following parameters:
|
||||||
```sh
|
```sh
|
||||||
python3 mo.py --input_model path-to-model/model_frozen.pb --input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]" --log_level DEBUG --disable_nhwc_to_nchw --output_dir <OUTPUT_MODEL_DIR>
|
python3 mo.py --input_model path-to-model/model_frozen.pb \
|
||||||
|
--input "input_mask[50 1],input_ids[50 1],seg_ids[50 1]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -285,10 +285,9 @@ More information on how to develop middle transformations and dedicated API desc
|
|||||||
### NHWC to NCHW Layout Change <a name="layout-change"></a>
|
### NHWC to NCHW Layout Change <a name="layout-change"></a>
|
||||||
There are several middle transformations responsible for changing model layout from NHWC to NCHW. These transformations
|
There are several middle transformations responsible for changing model layout from NHWC to NCHW. These transformations
|
||||||
are triggered by default for TensorFlow\* models only because it is the only framework with Convolution operations in
|
are triggered by default for TensorFlow\* models only because it is the only framework with Convolution operations in
|
||||||
NHWC layout.
|
NHWC layout. This layout change is disabled if the model does not have operations that OpenVINO&trade needs to execute in
|
||||||
|
NCHW layout, for example, Convolutions in NHWC layout. It is still possible to force Model Optimizer to do layout change
|
||||||
> **NOTE**: If a TensorFlow\* model is in NCHW layout, you should specify the `--disable_nhwc_to_nchw` command line
|
using `--disable_nhwc_to_nchw` command-line parameter.
|
||||||
> parameter to disable these transformations.
|
|
||||||
|
|
||||||
The layout change is a complex problem and detailed explanation of it is out of this document scope. A very brief
|
The layout change is a complex problem and detailed explanation of it is out of this document scope. A very brief
|
||||||
explanation of this process is provided below:
|
explanation of this process is provided below:
|
||||||
|
@ -23,6 +23,7 @@ from mo.front.extractor import restore_edges, extract_node_attrs, remove_control
|
|||||||
from mo.front.tf.extractor import get_tf_edges, create_tf_edge, tf_op_extractor, tf_op_extractors
|
from mo.front.tf.extractor import get_tf_edges, create_tf_edge, tf_op_extractor, tf_op_extractors
|
||||||
from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
|
from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
|
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||||
from mo.utils import tensorboard_util
|
from mo.utils import tensorboard_util
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
from mo.utils.telemetry_utils import send_op_names_info, send_shapes_info, send_framework_info
|
from mo.utils.telemetry_utils import send_op_names_info, send_shapes_info, send_framework_info
|
||||||
@ -103,10 +104,10 @@ class TFLoader(Loader):
|
|||||||
|
|
||||||
# try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we
|
# try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we
|
||||||
# consider that the graph is in NCHW layout and no layout conversion should be performed
|
# consider that the graph is in NCHW layout and no layout conversion should be performed
|
||||||
if not argv.disable_nhwc_to_nchw and not argv.silent and not graph_or_sub_graph_has_nhwc_ops(graph):
|
if not argv.disable_nhwc_to_nchw and not graph_or_sub_graph_has_nhwc_ops(graph):
|
||||||
log.error('The TensorFlow model does not contain Convolution operations with N(D)HWC layout. Most likely '
|
if not argv.silent:
|
||||||
'the model should be converted using additional "--disable_nhwc_to_nchw" command line parameter '
|
log.debug('disable_nhwc_to_nchw" was automatically enabled.')
|
||||||
'which disables model layout conversion inside the Model Optimizer.', extra={'is_warning': True})
|
for_graph_and_each_sub_graph_recursively(graph, update_cmd_params_and_layout)
|
||||||
|
|
||||||
send_op_names_info(framework, graph)
|
send_op_names_info(framework, graph)
|
||||||
send_shapes_info(framework, graph)
|
send_shapes_info(framework, graph)
|
||||||
@ -143,3 +144,15 @@ def graph_or_sub_graph_has_nhwc_ops(graph: Graph):
|
|||||||
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
|
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
|
||||||
|
|
||||||
return NHWC_conv_detected
|
return NHWC_conv_detected
|
||||||
|
|
||||||
|
|
||||||
|
def update_cmd_params_and_layout(graph: Graph):
|
||||||
|
"""
|
||||||
|
Updates "cmd_params" and "layout" attribute as the model has only NCHW layout operations.
|
||||||
|
:param graph: graph to update attributes
|
||||||
|
:return: Nones
|
||||||
|
"""
|
||||||
|
if 'cmd_params' in graph.graph:
|
||||||
|
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True
|
||||||
|
if 'layout' in graph.graph:
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
@ -5,8 +5,9 @@ import numpy as np
|
|||||||
|
|
||||||
from extensions.middle.MergeNodesPermutations import MergeNodesPermutations
|
from extensions.middle.MergeNodesPermutations import MergeNodesPermutations
|
||||||
from extensions.ops.transpose import Transpose
|
from extensions.ops.transpose import Transpose
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.utils.runtime_info import OldAPIMapOrder
|
from mo.utils.runtime_info import OldAPIMapOrder
|
||||||
|
|
||||||
@ -39,30 +40,41 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
|||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
self.preserve_rt_info(graph)
|
self.preserve_rt_info(graph)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_old_api_map_order_into_rt_info(op: Node):
|
||||||
|
# rt info update
|
||||||
|
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
|
||||||
|
|
||||||
|
old_api_map = OldAPIMapOrder(version=0)
|
||||||
|
attr_name = old_api_map.get_name()
|
||||||
|
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
||||||
|
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
||||||
|
return attr_name, old_api_map.get_version()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preserve_rt_info(graph: Graph):
|
def preserve_rt_info(graph: Graph):
|
||||||
for op in graph.get_op_nodes():
|
for op in graph.get_op_nodes(type='Parameter'):
|
||||||
op_name = op.soft_get('name', op.id)
|
op_name = op.soft_get('name', op.id)
|
||||||
op_type = op.soft_get('type')
|
if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \
|
||||||
if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'):
|
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw:
|
||||||
if not op.out_node(0).has_valid('permutation'):
|
rank = op.out_port(0).data.get_shape().size
|
||||||
|
if rank < 4:
|
||||||
continue
|
continue
|
||||||
|
order = list(range(rank))
|
||||||
|
order.remove(1)
|
||||||
|
order.append(1)
|
||||||
|
order = int64_array(order)
|
||||||
|
elif op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout') and \
|
||||||
|
op.out_node(0).has_valid('permutation'):
|
||||||
permutation = op.out_node(0).permutation
|
permutation = op.out_node(0).permutation
|
||||||
if np.array_equal(permutation.inv, range(len(permutation.inv))):
|
order = permutation.inv
|
||||||
|
if np.array_equal(order, range(len(permutation.inv))):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# rt info update
|
|
||||||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name)
|
|
||||||
|
|
||||||
old_api_map = OldAPIMapOrder(version=0)
|
|
||||||
attr_name = old_api_map.get_name()
|
|
||||||
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
|
||||||
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
|
||||||
op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_parameter(permutation.inv)
|
|
||||||
|
|
||||||
# keep input in the framework format
|
# keep input in the framework format
|
||||||
transpose = create_op_node_with_second_input(
|
transpose = create_op_node_with_second_input(
|
||||||
graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)})
|
graph, Transpose, permutation.perm,
|
||||||
|
{'name': op_name + '/Transpose({})'.format(permutation.perm)})
|
||||||
|
|
||||||
# source mode is used to keep tensor names at Parameter node
|
# source mode is used to keep tensor names at Parameter node
|
||||||
op.out_port(0).get_connection().insert_node(transpose, "source")
|
op.out_port(0).get_connection().insert_node(transpose, "source")
|
||||||
@ -71,25 +83,34 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
|||||||
del op['permute_attrs']
|
del op['permute_attrs']
|
||||||
if op.out_node(0).has_valid('permutation'):
|
if op.out_node(0).has_valid('permutation'):
|
||||||
del op.out_node(0)['permutation']
|
del op.out_node(0)['permutation']
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
elif op_type == 'Result' and op.in_ports():
|
rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op)
|
||||||
|
op.rt_info.info[rt_info_key].old_api_transpose_parameter(order)
|
||||||
|
|
||||||
|
for op in graph.get_op_nodes(type='Result'):
|
||||||
|
if op.in_ports():
|
||||||
prev_node_out_port = op.in_port(0).get_connection().get_source()
|
prev_node_out_port = op.in_port(0).get_connection().get_source()
|
||||||
if prev_node_out_port is None:
|
if prev_node_out_port is None:
|
||||||
continue
|
continue
|
||||||
in_node = prev_node_out_port.node
|
in_node = prev_node_out_port.node
|
||||||
in_data_node = in_node.out_node(prev_node_out_port.idx)
|
in_data_node = in_node.out_node(prev_node_out_port.idx)
|
||||||
if in_data_node.has_and_set('permutation'):
|
|
||||||
permutation = in_data_node['permutation']
|
|
||||||
if np.array_equal(permutation.perm, range(len(permutation.perm))):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# rt info update
|
if 'auto_disable_nhwc_to_nchw' in graph.graph['cmd_params'] and \
|
||||||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
|
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw:
|
||||||
old_api_map = OldAPIMapOrder(version=0)
|
rank = prev_node_out_port.data.get_shape().size
|
||||||
attr_name = old_api_map.get_name()
|
if rank < 4:
|
||||||
if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
|
continue
|
||||||
op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map
|
order = list(range(rank - 1))
|
||||||
op.rt_info.info[(attr_name, old_api_map.get_version())].old_api_transpose_result(permutation.perm)
|
order.insert(1, rank - 1)
|
||||||
|
order = int64_array(order)
|
||||||
|
elif in_data_node.has_and_set('permutation'):
|
||||||
|
permutation = in_data_node['permutation']
|
||||||
|
order = permutation.perm
|
||||||
|
|
||||||
|
if np.array_equal(order, range(len(permutation.perm))):
|
||||||
|
continue
|
||||||
|
|
||||||
# keep result in the framework format
|
# keep result in the framework format
|
||||||
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv)
|
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv)
|
||||||
@ -98,3 +119,8 @@ class PreserveRuntimeInfo(MiddleReplacementPattern):
|
|||||||
in_node.name += "/prev"
|
in_node.name += "/prev"
|
||||||
|
|
||||||
prev_node_out_port.get_connection().insert_node(transpose)
|
prev_node_out_port.get_connection().insert_node(transpose)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rt_info_key = PreserveRuntimeInfo.add_old_api_map_order_into_rt_info(op)
|
||||||
|
op.rt_info.info[rt_info_key].old_api_transpose_result(order)
|
||||||
|
@ -564,7 +564,8 @@ def get_tf_cli_parser(parser: argparse.ArgumentParser = None):
|
|||||||
default=None,
|
default=None,
|
||||||
action=CanonicalizePathCheckExistenceAction)
|
action=CanonicalizePathCheckExistenceAction)
|
||||||
tf_group.add_argument('--disable_nhwc_to_nchw',
|
tf_group.add_argument('--disable_nhwc_to_nchw',
|
||||||
help='Disables default translation from NHWC to NCHW',
|
help='[DEPRECATED] Disables the default translation from NHWC to NCHW. Since 2022.1 this option '
|
||||||
|
'is deprecated and used only to maintain backward compatibility with previous releases.',
|
||||||
action='store_true')
|
action='store_true')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -88,3 +88,37 @@ class PreserveRuntimeInfoTest(unittest.TestCase):
|
|||||||
rt_info = result_node.rt_info.info
|
rt_info = result_node.rt_info.info
|
||||||
old_api_map = rt_info[('old_api_map_order', 0)].info
|
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||||
self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
|
self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
|
||||||
|
|
||||||
|
def test_auto_disable_nhwc_to_nchw(self):
|
||||||
|
shape_len = 4
|
||||||
|
shape = np.array(range(shape_len))
|
||||||
|
add_shape = shape
|
||||||
|
graph_nodes = {
|
||||||
|
**regular_op_with_shaped_data('placeholder1', shape,
|
||||||
|
{'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
|
||||||
|
**regular_op_with_shaped_data('placeholder2', shape,
|
||||||
|
{'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
|
||||||
|
**regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}),
|
||||||
|
**regular_op_with_shaped_data('add', add_shape,
|
||||||
|
{'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = build_graph(graph_nodes, edges)
|
||||||
|
graph.graph['cmd_params'].auto_disable_nhwc_to_nchw = True
|
||||||
|
graph_ref = build_graph(graph_nodes, edges)
|
||||||
|
|
||||||
|
param_node = Node(graph, 'placeholder1')
|
||||||
|
result_node = Node(graph, 'result')
|
||||||
|
|
||||||
|
PreserveRuntimeInfo().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
rt_info = param_node.rt_info.info
|
||||||
|
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||||
|
self.assertTrue(np.array_equal(old_api_map['inverse_order'], [0, 2, 3, 1]))
|
||||||
|
|
||||||
|
rt_info = result_node.rt_info.info
|
||||||
|
old_api_map = rt_info[('old_api_map_order', 0)].info
|
||||||
|
self.assertTrue(np.array_equal(old_api_map['order'], [0, 3, 1, 2]))
|
||||||
|
Loading…
Reference in New Issue
Block a user