[ MO ] Fixed layout interpretation for 4/5D tensors calculated from ShapeOfs (#1634)

This commit is contained in:
Evgenya Stepyreva
2020-08-11 09:34:04 +03:00
committed by GitHub
parent 2b474c8a47
commit 2d2a6dbfd8
2 changed files with 38 additions and 17 deletions

View File

@@ -17,8 +17,9 @@ from collections import deque
from typing import List, Set
from extensions.middle.InsertLayoutPropagationTransposes import is_output_data_in_correct_layout, \
InsertLayoutPropagationTranspose
InsertLayoutPropagationTranspose, mark_input_as_in_correct_layout, mark_output_as_in_correct_layout
from extensions.ops.gather import Gather
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
@@ -50,6 +51,8 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
Searches for input ports of data dependent operations starting from output ports passed to the function.
Condition for data dependent operations is absence of node output value.
Side action: marking the sub-graph as it is in the correct layout
:param out_ports: list of output ports to start search from
:param visited: set of input ports that were visited to avoid visiting them more than once
:return: set of input ports of data dependent operations
@@ -66,11 +69,14 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
in_port = deque_of_in_ports.popleft()
if in_port in visited:
continue
next_in_ports = self.get_next_in_ports(in_port)
if any([port.data.get_value() is None for port in next_in_ports]):
end_points_in_ports.add(in_port)
else:
in_port.__setattr__('input_permutation', None)
mark_input_as_in_correct_layout(in_port.node, in_port.idx)
for port in next_in_ports:
mark_output_as_in_correct_layout(port.get_source().node, port.get_source().idx)
deque_of_in_ports.extend(next_in_ports)
visited.add(in_port)
return end_points_in_ports
@@ -96,21 +102,34 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
op_attrs={'name': name + '/GatherNCHWtoNHWC'})
shape.out_port(0).get_connection().insert_node(gather)
# 2. Inserting Gather to NC* format
# 2. Inserting Gather/Transpose to NC* format
shape_sub_graph_end_points = self.find_shape_subgraph_endpoints([shape.out_port(0) for shape in shape_ops])
for in_port in shape_sub_graph_end_points:
name = in_port.node.soft_get('name', in_port.node.id)
rank = in_port.data.get_shape().item(0)
shape = in_port.data.get_shape()
should_insert_gather = rank in [4, 5] and not len(in_port.node.soft_get('correct_out_data_layout', {}))
should_switch_layout = not any([is_output_data_in_correct_layout(port.node, port.idx)
for port in in_port.node.out_ports().values() if not port.disconnected()])
should_insert_gather = should_switch_layout and len(shape) == 1 and shape.item(0) in [4, 5]
should_insert_transpose = should_switch_layout and len(shape) in [4, 5]
if should_insert_gather:
# we should turn input permutation off to perform it with the following gather insertion
in_port.__setattr__('input_permutation', None)
index = int64_array([0, rank - 1, *list(range(1, rank - 1))])
index = int64_array([0, shape.item(0) - 1, *list(range(1, shape.item(0) - 1))])
gather = create_op_with_const_inputs(graph, op=Gather,
port_value_dict={1: index, 2: int64_array(0)},
op_attrs={'name': name + '/GatherNHWCtoNCHW'})
in_port.get_connection().insert_node(gather)
elif should_insert_transpose:
# we should turn input permutation off to perform it with the following transpose insertion
in_port.__setattr__('input_permutation', None)
order = int64_array([0, len(shape) - 1, *list(range(1, len(shape) - 1))])
transpose = create_op_with_const_inputs(graph, op=Transpose, port_value_dict={1: order},
op_attrs={'name': name + '/TransposeNHWCtoNCHW',
'override_output_shape': True})
mark_input_as_in_correct_layout(transpose, 0)
mark_output_as_in_correct_layout(transpose, 0)
in_port.get_connection().insert_node(transpose)
else:
continue # data is layout independent
gather = create_op_with_const_inputs(graph, op=Gather, port_value_dict={1: index, 2: int64_array(0)},
op_attrs={'name': name + '/GatherNHWCtoNCHW'})
in_port.get_connection().insert_node(gather)

View File

@@ -62,13 +62,6 @@ class Broadcast(Op):
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
if node.mode == 'numpy':
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
if input_value is not None and not node.has_and_set('stop_value_propagation'):
@@ -76,3 +69,12 @@ class Broadcast(Op):
node.out_port(0).data.set_value(uni_directional_broadcasting(input_value, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_value(bi_directional_broadcasting(input_value, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
else:
if node.mode == 'numpy':
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))