[ MO ] Fixed layout interpretation for 4/5D tensors calculated from ShapeOfs (#1634)
This commit is contained in:
committed by
GitHub
parent
2b474c8a47
commit
2d2a6dbfd8
@@ -17,8 +17,9 @@ from collections import deque
|
||||
from typing import List, Set
|
||||
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import is_output_data_in_correct_layout, \
|
||||
InsertLayoutPropagationTranspose
|
||||
InsertLayoutPropagationTranspose, mark_input_as_in_correct_layout, mark_output_as_in_correct_layout
|
||||
from extensions.ops.gather import Gather
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
@@ -50,6 +51,8 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
|
||||
Searches for input ports of data dependent operations starting from output ports passed to the function.
|
||||
Condition for data dependent operations is absence of node output value.
|
||||
|
||||
Side action: marking the sub-graph as it is in the correct layout
|
||||
|
||||
:param out_ports: list of output ports to start search from
|
||||
:param visited: set of input ports that were visited to avoid visiting them more than once
|
||||
:return: set of input ports of data dependent operations
|
||||
@@ -66,11 +69,14 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
|
||||
in_port = deque_of_in_ports.popleft()
|
||||
if in_port in visited:
|
||||
continue
|
||||
|
||||
next_in_ports = self.get_next_in_ports(in_port)
|
||||
if any([port.data.get_value() is None for port in next_in_ports]):
|
||||
end_points_in_ports.add(in_port)
|
||||
else:
|
||||
in_port.__setattr__('input_permutation', None)
|
||||
mark_input_as_in_correct_layout(in_port.node, in_port.idx)
|
||||
for port in next_in_ports:
|
||||
mark_output_as_in_correct_layout(port.get_source().node, port.get_source().idx)
|
||||
deque_of_in_ports.extend(next_in_ports)
|
||||
visited.add(in_port)
|
||||
return end_points_in_ports
|
||||
@@ -96,21 +102,34 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern):
|
||||
op_attrs={'name': name + '/GatherNCHWtoNHWC'})
|
||||
shape.out_port(0).get_connection().insert_node(gather)
|
||||
|
||||
# 2. Inserting Gather to NC* format
|
||||
# 2. Inserting Gather/Transpose to NC* format
|
||||
shape_sub_graph_end_points = self.find_shape_subgraph_endpoints([shape.out_port(0) for shape in shape_ops])
|
||||
for in_port in shape_sub_graph_end_points:
|
||||
name = in_port.node.soft_get('name', in_port.node.id)
|
||||
rank = in_port.data.get_shape().item(0)
|
||||
shape = in_port.data.get_shape()
|
||||
|
||||
should_insert_gather = rank in [4, 5] and not len(in_port.node.soft_get('correct_out_data_layout', {}))
|
||||
should_switch_layout = not any([is_output_data_in_correct_layout(port.node, port.idx)
|
||||
for port in in_port.node.out_ports().values() if not port.disconnected()])
|
||||
should_insert_gather = should_switch_layout and len(shape) == 1 and shape.item(0) in [4, 5]
|
||||
should_insert_transpose = should_switch_layout and len(shape) in [4, 5]
|
||||
|
||||
if should_insert_gather:
|
||||
# we should turn input permutation off to perform it with the following gather insertion
|
||||
in_port.__setattr__('input_permutation', None)
|
||||
index = int64_array([0, rank - 1, *list(range(1, rank - 1))])
|
||||
index = int64_array([0, shape.item(0) - 1, *list(range(1, shape.item(0) - 1))])
|
||||
gather = create_op_with_const_inputs(graph, op=Gather,
|
||||
port_value_dict={1: index, 2: int64_array(0)},
|
||||
op_attrs={'name': name + '/GatherNHWCtoNCHW'})
|
||||
in_port.get_connection().insert_node(gather)
|
||||
elif should_insert_transpose:
|
||||
# we should turn input permutation off to perform it with the following transpose insertion
|
||||
in_port.__setattr__('input_permutation', None)
|
||||
order = int64_array([0, len(shape) - 1, *list(range(1, len(shape) - 1))])
|
||||
transpose = create_op_with_const_inputs(graph, op=Transpose, port_value_dict={1: order},
|
||||
op_attrs={'name': name + '/TransposeNHWCtoNCHW',
|
||||
'override_output_shape': True})
|
||||
mark_input_as_in_correct_layout(transpose, 0)
|
||||
mark_output_as_in_correct_layout(transpose, 0)
|
||||
in_port.get_connection().insert_node(transpose)
|
||||
else:
|
||||
continue # data is layout independent
|
||||
|
||||
gather = create_op_with_const_inputs(graph, op=Gather, port_value_dict={1: index, 2: int64_array(0)},
|
||||
op_attrs={'name': name + '/GatherNHWCtoNCHW'})
|
||||
in_port.get_connection().insert_node(gather)
|
||||
|
||||
@@ -62,13 +62,6 @@ class Broadcast(Op):
|
||||
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
|
||||
|
||||
if node.mode == 'numpy':
|
||||
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
|
||||
|
||||
if input_value is not None and not node.has_and_set('stop_value_propagation'):
|
||||
@@ -76,3 +69,12 @@ class Broadcast(Op):
|
||||
node.out_port(0).data.set_value(uni_directional_broadcasting(input_value, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_value(bi_directional_broadcasting(input_value, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
else:
|
||||
if node.mode == 'numpy':
|
||||
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
|
||||
Reference in New Issue
Block a user