[MO] turn on MarkSubGraphsWithCorrectLayout for TF NCHW (#7150)
* turned on MarkingSubgraphsWithCorrectLayout for TF NCHW * restricted MarkSubgraphsWithCorrectLayout.py only to TF * added comments why need to MarkSubgraphsWithCorrectLayout even for TF NCHW models
This commit is contained in:
parent
ef84c90367
commit
f77d838e6c
@ -22,10 +22,14 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
|
||||
the input and output tensors of this layout agnostic op.
|
||||
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
|
||||
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
|
||||
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW.
|
||||
The latest is needed for TF NCHW graphs as well. In Conv/Deconv infer functions "set_permutation()"
|
||||
ads "permutation" attr to weights data node even for NCHW, it is needed to permute Conv weights from the
|
||||
original TF layout into IE even for NCHW graphs. Therefore for TF models
|
||||
to prevent unwarranted permutations need to mark weights path as having correct layout even for NCHW graphs.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||
op_conditions = [lambda n: n.soft_get('op') == 'MatMul' and
|
||||
any([len(port.data.get_shape()) in (4, 5) for port in n.in_ports().values()]),
|
||||
]
|
||||
|
@ -256,6 +256,9 @@ class Convolution(Op):
|
||||
('output_feature_channel', 'input:{}'.format(weights_index)),
|
||||
])
|
||||
|
||||
# is needed to permute Conv weights from the original TF [H, W, C_IN, C_OUT] into IE [C_OUT, C_IN, H, W]
|
||||
# but for other nodes in weights subgraph permutations must turned off
|
||||
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
|
||||
PermuteAttrs.set_permutation(node.in_node(weights_index), node, node.soft_get('get_weights_permute', None))
|
||||
PermuteInputs().set_input_permutation(
|
||||
node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose')
|
||||
|
@ -99,7 +99,10 @@ class Deconvolution(Op):
|
||||
('input_feature_channel', 'input:1'),
|
||||
('output_feature_channel', 'input:1'),
|
||||
])
|
||||
|
||||
|
||||
# is needed to permute Deconv weights from the original TF [H, W, C_OUT, C_IN] into IE [C_IN, C_OUT, H, W]
|
||||
# but for other nodes in weights subgraph permutations must turned off
|
||||
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
|
||||
PermuteAttrs.set_permutation(node.in_node(1), node, node.soft_get('get_weights_permute', None))
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'transpose')
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
|
||||
|
Loading…
Reference in New Issue
Block a user