[MO] Replacing StridedSlice with Squeeze/Unsqueeze (#6693)

* added reinterp_shape parameter to tf ss extractor

* removed reinterp_shape

* added transformation to replace ss

* updated bom

* fix for e2e tests

* updated a case when shrink_axis_mask and new_axis_mask are both initialized

* unittests

* added comments

* updated graph_condition

* comments resolving

* updated the case, when shrink_axis_mask and new_axis_mask are both initialized

* added layer tests for squeeze/unsqueeze cases

* remove case when shrink and new axis masks are both set
This commit is contained in:
Yegor Kruglov
2021-08-24 13:19:40 +03:00
committed by GitHub
parent de46168e98
commit 14dcd43c32
4 changed files with 263 additions and 0 deletions

View File

@@ -640,6 +640,7 @@ extensions/middle/sparse_reshape.py
extensions/middle/split_tdnn_memoryoffset.py
extensions/middle/SplitConcatPairToInterpolate.py
extensions/middle/StridedSliceNormalizer.py
extensions/middle/StridedSliceReplacer.py
extensions/middle/SwapAxesMiddleReplacer.py
extensions/middle/TensorIterator_utils.py
extensions/middle/TensorIteratorBackEdge.py

View File

@@ -0,0 +1,64 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_nodes, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.squeeze import Squeeze
from mo.ops.unsqueeze import Unsqueeze
def replace_strided_slice(node: Node, mask: np.ndarray, op: callable):
node_name = node.soft_get('name', node.id)
axes = np.where(mask == 1)[0]
new_node = create_op_node_with_second_input(node.graph, op, int64_array(axes))
node.in_port(0).get_connection().set_destination(new_node.in_port(0))
node.out_port(0).get_connection().set_source(new_node.out_port(0))
rename_nodes([(node, node_name + '/ShouldBeDeleted'), (new_node, node_name)])
node.graph.remove_node(node.id)
class ReplaceStridedSliceWithSqueezeUnsqueeze(MiddleReplacementPattern):
r"""
The transformation replaces StridedSlice with a Squeeze/Unsqueeze node if StridedSlice executes like a Squeeze/Unsqueeze
and does not slice values. This is necessary if StridedSlice is to be executed in original N(D)HWC layout, because
the operation does not have reinterp_shape attribute and MO can not insert NC(D)HW -> N(D)HWC Transpose in
extensions/middle/InsertLayoutPropagationTransposes.py.
"""
enabled = True
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
def run_before(self):
return [InsertLayoutPropagationTranspose]
def run_after(self):
return [StridedSliceNormalizer]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='StridedSlice'):
input_shape = node.in_port(0).data.get_shape()
output_shape = node.out_port(0).data.get_shape()
if np.prod(input_shape) != np.prod(output_shape):
continue
shrink_axis_mask = node.soft_get('shrink_axis_mask', np.zeros(len(input_shape)))
new_axis_mask = node.soft_get('new_axis_mask', np.zeros(len(input_shape)))
is_shrink_axis_mask = any(x == 1 for x in shrink_axis_mask)
is_new_axis_mask = any(x == 1 for x in new_axis_mask)
if is_shrink_axis_mask and is_new_axis_mask:
# TODO: make it in a separate ticket
continue
elif is_shrink_axis_mask and not is_new_axis_mask:
replace_strided_slice(node, shrink_axis_mask, Squeeze)
elif not is_shrink_axis_mask and is_new_axis_mask:
replace_strided_slice(node, new_axis_mask, Unsqueeze)

View File

@@ -0,0 +1,108 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from extensions.middle.StridedSliceReplacer import ReplaceStridedSliceWithSqueezeUnsqueeze
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import regular_op_with_shaped_data, regular_op_with_empty_data, shaped_const_with_data, \
result, connect, build_graph
nodes = {
**regular_op_with_shaped_data('input', [1, 3, 5, 5], {'type': 'Parameter', 'op': 'Parameter'}),
**regular_op_with_empty_data('strided_slice', {'type': 'StridedSlice', 'op': 'StridedSlice',
'begin_mask': [0, 0, 0, 0], 'end_mask': [0, 0, 0, 0]}),
**shaped_const_with_data('begin', [4]),
**shaped_const_with_data('end', [4]),
**result('result'),
**regular_op_with_empty_data('squeeze', {'type': 'Squeeze', 'op': 'Squeeze'}),
**shaped_const_with_data('squeeze_axes', None),
**regular_op_with_empty_data('unsqueeze', {'type': 'Unsqueeze', 'op': 'Unsqueeze'}),
**shaped_const_with_data('unsqueeze_axes', None)
}
pattern_edges = [
*connect('input', '0:strided_slice'),
*connect('begin', '1:strided_slice'),
*connect('end', '2:strided_slice'),
*connect('strided_slice', 'result')
]
class TestStridedSliceReplacer(unittest.TestCase):
def test_negative_different_input_and_output_shapes(self):
graph = build_graph(
nodes_attrs=nodes,
edges=pattern_edges,
update_attributes={
'strided_slice_d': {'shape': [1, 3, 3, 3]}
},
nodes_with_edges_only=True
)
ref_graph = graph.copy()
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_replace_with_squeeze(self):
graph = build_graph(
nodes_attrs=nodes,
edges=pattern_edges,
update_attributes={
'strided_slice': {'shrink_axis_mask': [1, 0, 0, 0], 'new_axis_mask': [0, 0, 0, 0]},
'strided_slice_d': {'shape': [3, 5, 5]}
},
nodes_with_edges_only=True
)
ref_graph = build_graph(
nodes_attrs=nodes,
edges=[
*connect('input', '0:squeeze'),
*connect('squeeze_axes', '1:squeeze'),
*connect('squeeze', 'result')
],
update_attributes={
'squeeze_axes_d': {'value': [0]},
'squeeze_d': {'shape': [3, 5, 5]}
},
nodes_with_edges_only=True
)
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_replace_with_unsqueeze(self):
graph = build_graph(
nodes_attrs=nodes,
edges=pattern_edges,
update_attributes={
'strided_slice': {'shrink_axis_mask': [0, 0, 0, 0], 'new_axis_mask': [1, 0, 0, 0]},
'strided_slice_d': {'shape': [1, 1, 3, 5, 5]}
},
nodes_with_edges_only=True
)
ref_graph = build_graph(
nodes_attrs=nodes,
edges=[
*connect('input', '0:unsqueeze'),
*connect('unsqueeze_axes', '1:unsqueeze'),
*connect('unsqueeze', 'result')
],
update_attributes={
'unsqueeze_axes_d': {'value': [0]},
'unsqueeze_d': {'shape': [1, 1, 3, 5, 5]}
},
nodes_with_edges_only=True
)
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@@ -0,0 +1,90 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
class TestStridedSlice(CommonTFLayerTest):
@staticmethod
def create_strided_slice_net(input_shape, begin, end, strides, begin_mask, end_mask, ellipsis_mask,
new_axis_mask, shrink_axis_mask, ir_version):
#
# Create Tensorflow model
#
import tensorflow as tf
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
input_node = tf.compat.v1.placeholder(tf.float32, input_shape, 'Input')
strided_slice = tf.compat.v1.strided_slice(input_node, begin=begin, end=end, strides=strides,
begin_mask=begin_mask, end_mask=end_mask,
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
test_squeeze_data = [
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[5, 1], begin=[0, 0], end=[5, 1], strides=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 1, 3], begin=[0, 0, 0], end=[1, 1, 3], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 1], begin=[0, 0, 0], end=[1, 5, 1], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 1, 5, 3], begin=[0, 0, 0, 0], end=[1, 1, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 1, 3], begin=[0, 0, 0, 0], end=[1, 5, 1, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
dict(input_shape=[1, 5, 5, 1], begin=[0, 0, 0, 0], end=[1, 5, 1, 1], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=8),
dict(input_shape=[1, 1, 5, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 1, 5, 5, 3], strides=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=3),
dict(input_shape=[1, 5, 1, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 3], strides=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=5),
dict(input_shape=[1, 5, 1, 5, 1], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 1], strides=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=21),
]
@pytest.mark.parametrize('params', test_squeeze_data)
@pytest.mark.nightly
def test_strided_slice_replace_with_squeeze(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)
test_unsqueeze_data = [
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=1, shrink_axis_mask=0),
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=5, shrink_axis_mask=0),
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=2, shrink_axis_mask=0),
]
@pytest.mark.parametrize('params', test_unsqueeze_data)
@pytest.mark.nightly
def test_strided_slice_replace_with_unsqueeze(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)