Fix the NHWC->NCHW transformation for dynamic weights (#2848)

* Fix the NHWC->NCHW transformation when weights and data comes from same input

* Simplify code
This commit is contained in:
Maxim Vafin 2020-11-06 19:04:46 +03:00 committed by GitHub
parent 15d7919817
commit 9f0b26e14c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 27 deletions

View File

@ -19,10 +19,12 @@ from collections import deque
from typing import Set
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
mark_as_correct_data_layout
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
@ -34,8 +36,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
the input and output tensors of this layout agnostic op.
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
For now the transformation is triggered for MatMul operation only getting input as 4D or 5D tensors.
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
"""
enabled = True
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
@ -69,7 +70,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
:param start_nodes: Nodes to start search from
:param visited: set of already visited nodes where traversing should not happen
:param condition: function getting a Node as input and returning whether the node should be included into the
resukt or not. If the value is None then the node is added unconditionally.
result or not. If the value is None then the node is added unconditionally.
:param forward: boolean flag specifying the traverse direction
:return: the list of Nodes visited
"""
@ -127,8 +128,16 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
mark_as_correct_data_layout(visited_node)
visited_node['nchw_layout'] = True
for node in self.get_ports_and_nodes_on_weights(graph)[1]:
mark_as_correct_data_layout(node)
_, nodes_weigths, nodes_in_weights = self.get_ports_and_nodes_on_weights(graph)
for node in nodes_weigths:
if node in nodes_in_weights:
for ind, port in node.in_ports().items():
if ind not in nodes_in_weights[node]:
mark_input_as_in_correct_layout(node, ind)
for ind, port in node.out_ports().items():
mark_output_as_in_correct_layout(node, ind)
else:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True
@ -140,8 +149,39 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
node.out_node()['nchw_layout'] = True
@staticmethod
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port],
visited_ports: Set[Port] = None, visited_nodes: Set[Node] = None):
def get_weighted_layer_type_to_in_weights_port():
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
weighted_layer_type_to_in_weights_port = {
'Convolution': get_weights_port_index,
'DeformableConvolution': get_weights_port_index,
'Deconvolution': get_weights_port_index,
'BinaryConvolution': get_weights_port_index,
}
return weighted_layer_type_to_in_weights_port
@staticmethod
def insert_permute_inputs_before_dynamic_weights_subgraph(dynamic_subgraphs: Set[Node] = None):
"""
The function inserts permutations on input nodes in the weights subgraph
:param dynamic_subgraphs: Set of Nodes belonging to weight path subgraphs
:return: the list of Nodes which are inputs to weight path subgraphs
"""
dynamic_in_nodes = dict()
for node in dynamic_subgraphs:
node_type = node.soft_get('type')
if node_type not in ['Const', 'Parameter', 'ShapeOf']:
idx_lst = list()
for idx in [idx for idx, port in node.in_ports().items() if
not port.disconnected() and port.get_source().node not in dynamic_subgraphs]:
PermuteInputs().set_input_permutation(node.in_node(idx), node, 'input:{}'.format(idx),
'transpose_nchw_to_nhwc')
idx_lst.append(idx)
if len(idx_lst):
dynamic_in_nodes[node] = idx_lst
return dynamic_in_nodes
@staticmethod
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port], port_condition=None):
""""
Returns all intermediate ports and nodes of such a sub-graph:
@ -153,14 +193,14 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
\/ \/
in_ports
"""
if visited_ports is None:
visited_ports = set()
if visited_nodes is None:
visited_nodes = set()
visited_ports = set()
visited_nodes = set()
deque_of_in_ports = deque(in_ports)
while len(deque_of_in_ports):
in_port = deque_of_in_ports.popleft()
if in_port.get_source() is None:
continue
source_node = in_port.get_source().node
if in_port in visited_ports: # do not check visited_nodes as search is based on ports
continue
@ -169,40 +209,51 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
if not len(in_port.get_source().node.in_ports()): # for Constants and Parameters to be visited
visited_nodes.add(in_port.get_source().node)
continue
deque_of_in_ports.extend([port for port in source_node.in_ports().values() if not port.disconnected()])
for idx, port in source_node.in_ports().items():
if not port.disconnected() and (not port_condition or port_condition(source_node, idx)):
deque_of_in_ports.append(port)
visited_nodes.add(source_node)
return visited_ports, visited_nodes
@staticmethod
def is_not_weight_port(node: Node, idx: int):
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
node_type = node.soft_get('type')
return node_type in w_types_to_in_port_dict.keys() and idx != w_types_to_in_port_dict[node_type](node)
@staticmethod
def get_ports_and_nodes_on_weights(graph):
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
weighted_layer_type_to_in_weights_port = {
'Convolution': get_weights_port_index,
'DeformableConvolution': get_weights_port_index,
'Deconvolution': get_weights_port_index,
'BinaryConvolution': get_weights_port_index,
}
nodes = graph.get_op_nodes()
weighted_types = list(weighted_layer_type_to_in_weights_port.keys())
# collect all input ports with weights
weight_ports = set()
result_ports = set()
start_ports = set()
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
for node in nodes:
node_type = node.soft_get('type', 'unknown')
if node_type not in weighted_types:
if node_type in ['Const', 'Parameter', 'ShapeOf']:
if node_type not in w_types_to_in_port_dict.keys():
if node_type in ['Const', 'Parameter', 'ShapeOf', 'ExtractImagePatches']:
start_ports.add(node.out_port(0))
continue
weight_port_idx = weighted_layer_type_to_in_weights_port[node_type](node)
weight_port_idx = w_types_to_in_port_dict[node_type](node)
assert node.is_in_port_connected(weight_port_idx), \
'Unexpected port configuration of {} node with name=`{}`'.format(node_type,
node.soft_get('name', node.id))
weight_ports.add(node.in_port(weight_port_idx))
for result in graph.get_op_nodes(type='Result'):
result_ports.update(result.in_ports().values())
# collect all sub-graphs that start with Constant/Parameter/ShapeOf and end at in_port as weights
ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
return ports, nodes
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches and end at in_port as
# weights
ports_w, nodes_w = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches, end at Result nodes and
# not contains branches that end as weights
ports_d, nodes_d = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(
result_ports, start_ports, MarkSubGraphsWithCorrectLayout.is_not_weight_port)
nodes_dif = nodes_w.difference(nodes_d)
nodes_in_w = MarkSubGraphsWithCorrectLayout.insert_permute_inputs_before_dynamic_weights_subgraph(nodes_dif)
return ports_w.difference(ports_d), nodes_dif, nodes_in_w
@staticmethod
def get_ports_and_nodes_on_shape_subgraphs(graph):

View File

@ -171,6 +171,23 @@ def transpose(op_node: Node, port_info: str, input_port: int):
op_node.in_port(input_port).get_connection().insert_node(transpose)
def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int):
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
rank = len(permutation_data_node.shape)
assert rank >= 4, 'Rank must be 4D or higher for HCHW to HHWC permutation on node {}.'.format(op_node.id)
perm = list(range(rank))
perm.insert(1, perm.pop())
perm = int64_array(perm)
transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose'
from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports
transpose = create_op_with_const_inputs(
graph, Transpose, {1: perm}, {'name': transpose_name, 'override_output_shape': True})
op_node.in_port(input_port).get_connection().insert_node(transpose)
class PermuteInputs:
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)
@ -179,6 +196,8 @@ class PermuteInputs:
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
'transpose_nchw_to_nhwc': lambda node, port_info, input_port: transpose_nchw_to_nhwc(node, port_info,
input_port),
}
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):