[MO] turn on MarkSubGraphsWithCorrectLayout for TF NCHW (#7150)

* turned on MarkingSubgraphsWithCorrectLayout for TF NCHW

* restricted MarkSubgraphsWithCorrectLayout.py only to TF

* added comments why need to MarkSubgraphsWithCorrectLayout even for TF NCHW models
This commit is contained in:
Pavel Esir 2021-08-20 18:47:45 +03:00 committed by GitHub
parent ef84c90367
commit f77d838e6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 3 deletions

View File

@ -22,10 +22,14 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
the input and output tensors of this layout agnostic op.
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW.
The latest is needed for TF NCHW graphs as well. In Conv/Deconv infer functions "set_permutation()"
ads "permutation" attr to weights data node even for NCHW, it is needed to permute Conv weights from the
original TF layout into IE even for NCHW graphs. Therefore for TF models
to prevent unwarranted permutations need to mark weights path as having correct layout even for NCHW graphs.
"""
enabled = True
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
op_conditions = [lambda n: n.soft_get('op') == 'MatMul' and
any([len(port.data.get_shape()) in (4, 5) for port in n.in_ports().values()]),
]

View File

@ -256,6 +256,9 @@ class Convolution(Op):
('output_feature_channel', 'input:{}'.format(weights_index)),
])
# is needed to permute Conv weights from the original TF [H, W, C_IN, C_OUT] into IE [C_OUT, C_IN, H, W]
# but for other nodes in weights subgraph permutations must turned off
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
PermuteAttrs.set_permutation(node.in_node(weights_index), node, node.soft_get('get_weights_permute', None))
PermuteInputs().set_input_permutation(
node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose')

View File

@ -99,7 +99,10 @@ class Deconvolution(Op):
('input_feature_channel', 'input:1'),
('output_feature_channel', 'input:1'),
])
# is needed to permute Deconv weights from the original TF [H, W, C_OUT, C_IN] into IE [C_IN, C_OUT, H, W]
# but for other nodes in weights subgraph permutations must turned off
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
PermuteAttrs.set_permutation(node.in_node(1), node, node.soft_get('get_weights_permute', None))
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'transpose')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')