[MO] turn on MarkSubGraphsWithCorrectLayout for TF NCHW (#7150)
* turned on MarkingSubgraphsWithCorrectLayout for TF NCHW * restricted MarkSubgraphsWithCorrectLayout.py only to TF * added comments why need to MarkSubgraphsWithCorrectLayout even for TF NCHW models
This commit is contained in:
parent
ef84c90367
commit
f77d838e6c
@ -22,10 +22,14 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
|||||||
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
|
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
|
||||||
the input and output tensors of this layout agnostic op.
|
the input and output tensors of this layout agnostic op.
|
||||||
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
|
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
|
||||||
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
|
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW.
|
||||||
|
The latest is needed for TF NCHW graphs as well. In Conv/Deconv infer functions "set_permutation()"
|
||||||
|
ads "permutation" attr to weights data node even for NCHW, it is needed to permute Conv weights from the
|
||||||
|
original TF layout into IE even for NCHW graphs. Therefore for TF models
|
||||||
|
to prevent unwarranted permutations need to mark weights path as having correct layout even for NCHW graphs.
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
|
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||||
op_conditions = [lambda n: n.soft_get('op') == 'MatMul' and
|
op_conditions = [lambda n: n.soft_get('op') == 'MatMul' and
|
||||||
any([len(port.data.get_shape()) in (4, 5) for port in n.in_ports().values()]),
|
any([len(port.data.get_shape()) in (4, 5) for port in n.in_ports().values()]),
|
||||||
]
|
]
|
||||||
|
@ -256,6 +256,9 @@ class Convolution(Op):
|
|||||||
('output_feature_channel', 'input:{}'.format(weights_index)),
|
('output_feature_channel', 'input:{}'.format(weights_index)),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# is needed to permute Conv weights from the original TF [H, W, C_IN, C_OUT] into IE [C_OUT, C_IN, H, W]
|
||||||
|
# but for other nodes in weights subgraph permutations must turned off
|
||||||
|
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
|
||||||
PermuteAttrs.set_permutation(node.in_node(weights_index), node, node.soft_get('get_weights_permute', None))
|
PermuteAttrs.set_permutation(node.in_node(weights_index), node, node.soft_get('get_weights_permute', None))
|
||||||
PermuteInputs().set_input_permutation(
|
PermuteInputs().set_input_permutation(
|
||||||
node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose')
|
node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose')
|
||||||
|
@ -99,7 +99,10 @@ class Deconvolution(Op):
|
|||||||
('input_feature_channel', 'input:1'),
|
('input_feature_channel', 'input:1'),
|
||||||
('output_feature_channel', 'input:1'),
|
('output_feature_channel', 'input:1'),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# is needed to permute Deconv weights from the original TF [H, W, C_OUT, C_IN] into IE [C_IN, C_OUT, H, W]
|
||||||
|
# but for other nodes in weights subgraph permutations must turned off
|
||||||
|
# by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
|
||||||
PermuteAttrs.set_permutation(node.in_node(1), node, node.soft_get('get_weights_permute', None))
|
PermuteAttrs.set_permutation(node.in_node(1), node, node.soft_get('get_weights_permute', None))
|
||||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'transpose')
|
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'transpose')
|
||||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
|
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
|
||||||
|
Loading…
Reference in New Issue
Block a user