Fix the NHWC->NCHW transformation for dynamic weights (#2848)
* Fix the NHWC->NCHW transformation when weights and data comes from same input * Simplify code
This commit is contained in:
parent
15d7919817
commit
9f0b26e14c
@ -19,10 +19,12 @@ from collections import deque
|
||||
from typing import Set
|
||||
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
|
||||
mark_as_correct_data_layout
|
||||
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
|
||||
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
||||
from extensions.middle.pass_separator import PostMiddleStart
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.perm_inputs import PermuteInputs
|
||||
from mo.graph.port import Port
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
|
||||
@ -34,8 +36,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
|
||||
the input and output tensors of this layout agnostic op.
|
||||
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
|
||||
|
||||
For now the transformation is triggered for MatMul operation only getting input as 4D or 5D tensors.
|
||||
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
|
||||
@ -69,7 +70,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
:param start_nodes: Nodes to start search from
|
||||
:param visited: set of already visited nodes where traversing should not happen
|
||||
:param condition: function getting a Node as input and returning whether the node should be included into the
|
||||
resukt or not. If the value is None then the node is added unconditionally.
|
||||
result or not. If the value is None then the node is added unconditionally.
|
||||
:param forward: boolean flag specifying the traverse direction
|
||||
:return: the list of Nodes visited
|
||||
"""
|
||||
@ -127,8 +128,16 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
mark_as_correct_data_layout(visited_node)
|
||||
visited_node['nchw_layout'] = True
|
||||
|
||||
for node in self.get_ports_and_nodes_on_weights(graph)[1]:
|
||||
mark_as_correct_data_layout(node)
|
||||
_, nodes_weigths, nodes_in_weights = self.get_ports_and_nodes_on_weights(graph)
|
||||
for node in nodes_weigths:
|
||||
if node in nodes_in_weights:
|
||||
for ind, port in node.in_ports().items():
|
||||
if ind not in nodes_in_weights[node]:
|
||||
mark_input_as_in_correct_layout(node, ind)
|
||||
for ind, port in node.out_ports().items():
|
||||
mark_output_as_in_correct_layout(node, ind)
|
||||
else:
|
||||
mark_as_correct_data_layout(node)
|
||||
node['nchw_layout'] = True
|
||||
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
|
||||
node.out_node()['nchw_layout'] = True
|
||||
@ -140,8 +149,39 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
node.out_node()['nchw_layout'] = True
|
||||
|
||||
@staticmethod
|
||||
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port],
|
||||
visited_ports: Set[Port] = None, visited_nodes: Set[Node] = None):
|
||||
def get_weighted_layer_type_to_in_weights_port():
|
||||
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
|
||||
weighted_layer_type_to_in_weights_port = {
|
||||
'Convolution': get_weights_port_index,
|
||||
'DeformableConvolution': get_weights_port_index,
|
||||
'Deconvolution': get_weights_port_index,
|
||||
'BinaryConvolution': get_weights_port_index,
|
||||
}
|
||||
return weighted_layer_type_to_in_weights_port
|
||||
|
||||
@staticmethod
|
||||
def insert_permute_inputs_before_dynamic_weights_subgraph(dynamic_subgraphs: Set[Node] = None):
|
||||
"""
|
||||
The function inserts permutations on input nodes in the weights subgraph
|
||||
:param dynamic_subgraphs: Set of Nodes belonging to weight path subgraphs
|
||||
:return: the list of Nodes which are inputs to weight path subgraphs
|
||||
"""
|
||||
dynamic_in_nodes = dict()
|
||||
for node in dynamic_subgraphs:
|
||||
node_type = node.soft_get('type')
|
||||
if node_type not in ['Const', 'Parameter', 'ShapeOf']:
|
||||
idx_lst = list()
|
||||
for idx in [idx for idx, port in node.in_ports().items() if
|
||||
not port.disconnected() and port.get_source().node not in dynamic_subgraphs]:
|
||||
PermuteInputs().set_input_permutation(node.in_node(idx), node, 'input:{}'.format(idx),
|
||||
'transpose_nchw_to_nhwc')
|
||||
idx_lst.append(idx)
|
||||
if len(idx_lst):
|
||||
dynamic_in_nodes[node] = idx_lst
|
||||
return dynamic_in_nodes
|
||||
|
||||
@staticmethod
|
||||
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port], port_condition=None):
|
||||
""""
|
||||
Returns all intermediate ports and nodes of such a sub-graph:
|
||||
|
||||
@ -153,14 +193,14 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
\/ \/
|
||||
in_ports
|
||||
"""
|
||||
if visited_ports is None:
|
||||
visited_ports = set()
|
||||
if visited_nodes is None:
|
||||
visited_nodes = set()
|
||||
visited_ports = set()
|
||||
visited_nodes = set()
|
||||
|
||||
deque_of_in_ports = deque(in_ports)
|
||||
while len(deque_of_in_ports):
|
||||
in_port = deque_of_in_ports.popleft()
|
||||
if in_port.get_source() is None:
|
||||
continue
|
||||
source_node = in_port.get_source().node
|
||||
if in_port in visited_ports: # do not check visited_nodes as search is based on ports
|
||||
continue
|
||||
@ -169,40 +209,51 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
if not len(in_port.get_source().node.in_ports()): # for Constants and Parameters to be visited
|
||||
visited_nodes.add(in_port.get_source().node)
|
||||
continue
|
||||
deque_of_in_ports.extend([port for port in source_node.in_ports().values() if not port.disconnected()])
|
||||
for idx, port in source_node.in_ports().items():
|
||||
if not port.disconnected() and (not port_condition or port_condition(source_node, idx)):
|
||||
deque_of_in_ports.append(port)
|
||||
visited_nodes.add(source_node)
|
||||
return visited_ports, visited_nodes
|
||||
|
||||
@staticmethod
|
||||
def is_not_weight_port(node: Node, idx: int):
|
||||
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
|
||||
node_type = node.soft_get('type')
|
||||
return node_type in w_types_to_in_port_dict.keys() and idx != w_types_to_in_port_dict[node_type](node)
|
||||
|
||||
@staticmethod
|
||||
def get_ports_and_nodes_on_weights(graph):
|
||||
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
|
||||
weighted_layer_type_to_in_weights_port = {
|
||||
'Convolution': get_weights_port_index,
|
||||
'DeformableConvolution': get_weights_port_index,
|
||||
'Deconvolution': get_weights_port_index,
|
||||
'BinaryConvolution': get_weights_port_index,
|
||||
}
|
||||
nodes = graph.get_op_nodes()
|
||||
weighted_types = list(weighted_layer_type_to_in_weights_port.keys())
|
||||
|
||||
# collect all input ports with weights
|
||||
weight_ports = set()
|
||||
result_ports = set()
|
||||
start_ports = set()
|
||||
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
|
||||
for node in nodes:
|
||||
node_type = node.soft_get('type', 'unknown')
|
||||
if node_type not in weighted_types:
|
||||
if node_type in ['Const', 'Parameter', 'ShapeOf']:
|
||||
if node_type not in w_types_to_in_port_dict.keys():
|
||||
if node_type in ['Const', 'Parameter', 'ShapeOf', 'ExtractImagePatches']:
|
||||
start_ports.add(node.out_port(0))
|
||||
continue
|
||||
weight_port_idx = weighted_layer_type_to_in_weights_port[node_type](node)
|
||||
weight_port_idx = w_types_to_in_port_dict[node_type](node)
|
||||
assert node.is_in_port_connected(weight_port_idx), \
|
||||
'Unexpected port configuration of {} node with name=`{}`'.format(node_type,
|
||||
node.soft_get('name', node.id))
|
||||
weight_ports.add(node.in_port(weight_port_idx))
|
||||
for result in graph.get_op_nodes(type='Result'):
|
||||
result_ports.update(result.in_ports().values())
|
||||
|
||||
# collect all sub-graphs that start with Constant/Parameter/ShapeOf and end at in_port as weights
|
||||
ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
|
||||
return ports, nodes
|
||||
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches and end at in_port as
|
||||
# weights
|
||||
ports_w, nodes_w = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
|
||||
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches, end at Result nodes and
|
||||
# not contains branches that end as weights
|
||||
ports_d, nodes_d = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(
|
||||
result_ports, start_ports, MarkSubGraphsWithCorrectLayout.is_not_weight_port)
|
||||
nodes_dif = nodes_w.difference(nodes_d)
|
||||
nodes_in_w = MarkSubGraphsWithCorrectLayout.insert_permute_inputs_before_dynamic_weights_subgraph(nodes_dif)
|
||||
return ports_w.difference(ports_d), nodes_dif, nodes_in_w
|
||||
|
||||
@staticmethod
|
||||
def get_ports_and_nodes_on_shape_subgraphs(graph):
|
||||
|
@ -171,6 +171,23 @@ def transpose(op_node: Node, port_info: str, input_port: int):
|
||||
op_node.in_port(input_port).get_connection().insert_node(transpose)
|
||||
|
||||
|
||||
def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int):
|
||||
graph = op_node.graph
|
||||
permutation_data_node = get_node_with_permutation(op_node, port_info)
|
||||
rank = len(permutation_data_node.shape)
|
||||
assert rank >= 4, 'Rank must be 4D or higher for HCHW to HHWC permutation on node {}.'.format(op_node.id)
|
||||
|
||||
perm = list(range(rank))
|
||||
perm.insert(1, perm.pop())
|
||||
perm = int64_array(perm)
|
||||
|
||||
transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose'
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports
|
||||
transpose = create_op_with_const_inputs(
|
||||
graph, Transpose, {1: perm}, {'name': transpose_name, 'override_output_shape': True})
|
||||
op_node.in_port(input_port).get_connection().insert_node(transpose)
|
||||
|
||||
|
||||
class PermuteInputs:
|
||||
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)
|
||||
|
||||
@ -179,6 +196,8 @@ class PermuteInputs:
|
||||
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
|
||||
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
|
||||
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
|
||||
'transpose_nchw_to_nhwc': lambda node, port_info, input_port: transpose_nchw_to_nhwc(node, port_info,
|
||||
input_port),
|
||||
}
|
||||
|
||||
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):
|
||||
|
Loading…
Reference in New Issue
Block a user