[MO] Replacing StridedSlice with Squeeze/Unsqueeze (#6693)
* added reinterp_shape parameter to tf ss extractor * removed reinterp_shape * added transformation to replace ss * updated bom * fix for e2e tests * updated a case when shrink_axis_mask and new_axis_mask are both initialized * unittests * added comments * updated graph_condition * comments resolving * updated the case, when shrink_axis_mask and new_axis_mask are both initialized * added layer tests for squeeze/unsqueeze cases * remove case when shrink and new axis masks are both set
This commit is contained in:
@@ -640,6 +640,7 @@ extensions/middle/sparse_reshape.py
|
||||
extensions/middle/split_tdnn_memoryoffset.py
|
||||
extensions/middle/SplitConcatPairToInterpolate.py
|
||||
extensions/middle/StridedSliceNormalizer.py
|
||||
extensions/middle/StridedSliceReplacer.py
|
||||
extensions/middle/SwapAxesMiddleReplacer.py
|
||||
extensions/middle/TensorIterator_utils.py
|
||||
extensions/middle/TensorIteratorBackEdge.py
|
||||
|
||||
64
model-optimizer/extensions/middle/StridedSliceReplacer.py
Normal file
64
model-optimizer/extensions/middle/StridedSliceReplacer.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose
|
||||
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_nodes, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
def replace_strided_slice(node: Node, mask: np.ndarray, op: callable):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
axes = np.where(mask == 1)[0]
|
||||
new_node = create_op_node_with_second_input(node.graph, op, int64_array(axes))
|
||||
node.in_port(0).get_connection().set_destination(new_node.in_port(0))
|
||||
node.out_port(0).get_connection().set_source(new_node.out_port(0))
|
||||
|
||||
rename_nodes([(node, node_name + '/ShouldBeDeleted'), (new_node, node_name)])
|
||||
node.graph.remove_node(node.id)
|
||||
|
||||
|
||||
class ReplaceStridedSliceWithSqueezeUnsqueeze(MiddleReplacementPattern):
|
||||
r"""
|
||||
The transformation replaces StridedSlice with a Squeeze/Unsqueeze node if StridedSlice executes like a Squeeze/Unsqueeze
|
||||
and does not slice values. This is necessary if StridedSlice is to be executed in original N(D)HWC layout, because
|
||||
the operation does not have reinterp_shape attribute and MO can not insert NC(D)HW -> N(D)HWC Transpose in
|
||||
extensions/middle/InsertLayoutPropagationTransposes.py.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
|
||||
|
||||
def run_before(self):
|
||||
return [InsertLayoutPropagationTranspose]
|
||||
|
||||
def run_after(self):
|
||||
return [StridedSliceNormalizer]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='StridedSlice'):
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
output_shape = node.out_port(0).data.get_shape()
|
||||
|
||||
if np.prod(input_shape) != np.prod(output_shape):
|
||||
continue
|
||||
|
||||
shrink_axis_mask = node.soft_get('shrink_axis_mask', np.zeros(len(input_shape)))
|
||||
new_axis_mask = node.soft_get('new_axis_mask', np.zeros(len(input_shape)))
|
||||
|
||||
is_shrink_axis_mask = any(x == 1 for x in shrink_axis_mask)
|
||||
is_new_axis_mask = any(x == 1 for x in new_axis_mask)
|
||||
|
||||
if is_shrink_axis_mask and is_new_axis_mask:
|
||||
# TODO: make it in a separate ticket
|
||||
continue
|
||||
elif is_shrink_axis_mask and not is_new_axis_mask:
|
||||
replace_strided_slice(node, shrink_axis_mask, Squeeze)
|
||||
elif not is_shrink_axis_mask and is_new_axis_mask:
|
||||
replace_strided_slice(node, new_axis_mask, Unsqueeze)
|
||||
@@ -0,0 +1,108 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.middle.StridedSliceReplacer import ReplaceStridedSliceWithSqueezeUnsqueeze
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import regular_op_with_shaped_data, regular_op_with_empty_data, shaped_const_with_data, \
|
||||
result, connect, build_graph
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', [1, 3, 5, 5], {'type': 'Parameter', 'op': 'Parameter'}),
|
||||
**regular_op_with_empty_data('strided_slice', {'type': 'StridedSlice', 'op': 'StridedSlice',
|
||||
'begin_mask': [0, 0, 0, 0], 'end_mask': [0, 0, 0, 0]}),
|
||||
**shaped_const_with_data('begin', [4]),
|
||||
**shaped_const_with_data('end', [4]),
|
||||
**result('result'),
|
||||
|
||||
**regular_op_with_empty_data('squeeze', {'type': 'Squeeze', 'op': 'Squeeze'}),
|
||||
**shaped_const_with_data('squeeze_axes', None),
|
||||
|
||||
**regular_op_with_empty_data('unsqueeze', {'type': 'Unsqueeze', 'op': 'Unsqueeze'}),
|
||||
**shaped_const_with_data('unsqueeze_axes', None)
|
||||
}
|
||||
|
||||
pattern_edges = [
|
||||
*connect('input', '0:strided_slice'),
|
||||
*connect('begin', '1:strided_slice'),
|
||||
*connect('end', '2:strided_slice'),
|
||||
*connect('strided_slice', 'result')
|
||||
]
|
||||
|
||||
|
||||
class TestStridedSliceReplacer(unittest.TestCase):
|
||||
|
||||
def test_negative_different_input_and_output_shapes(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs=nodes,
|
||||
edges=pattern_edges,
|
||||
update_attributes={
|
||||
'strided_slice_d': {'shape': [1, 3, 3, 3]}
|
||||
},
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
ref_graph = graph.copy()
|
||||
|
||||
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_replace_with_squeeze(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs=nodes,
|
||||
edges=pattern_edges,
|
||||
update_attributes={
|
||||
'strided_slice': {'shrink_axis_mask': [1, 0, 0, 0], 'new_axis_mask': [0, 0, 0, 0]},
|
||||
'strided_slice_d': {'shape': [3, 5, 5]}
|
||||
},
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=nodes,
|
||||
edges=[
|
||||
*connect('input', '0:squeeze'),
|
||||
*connect('squeeze_axes', '1:squeeze'),
|
||||
*connect('squeeze', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'squeeze_axes_d': {'value': [0]},
|
||||
'squeeze_d': {'shape': [3, 5, 5]}
|
||||
},
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_replace_with_unsqueeze(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs=nodes,
|
||||
edges=pattern_edges,
|
||||
update_attributes={
|
||||
'strided_slice': {'shrink_axis_mask': [0, 0, 0, 0], 'new_axis_mask': [1, 0, 0, 0]},
|
||||
'strided_slice_d': {'shape': [1, 1, 3, 5, 5]}
|
||||
},
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=nodes,
|
||||
edges=[
|
||||
*connect('input', '0:unsqueeze'),
|
||||
*connect('unsqueeze_axes', '1:unsqueeze'),
|
||||
*connect('unsqueeze', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'unsqueeze_axes_d': {'value': [0]},
|
||||
'unsqueeze_d': {'shape': [1, 1, 3, 5, 5]}
|
||||
},
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
ReplaceStridedSliceWithSqueezeUnsqueeze().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
90
tests/layer_tests/tensorflow_tests/test_tf_StridedSlice.py
Normal file
90
tests/layer_tests/tensorflow_tests/test_tf_StridedSlice.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestStridedSlice(CommonTFLayerTest):
|
||||
|
||||
@staticmethod
|
||||
def create_strided_slice_net(input_shape, begin, end, strides, begin_mask, end_mask, ellipsis_mask,
|
||||
new_axis_mask, shrink_axis_mask, ir_version):
|
||||
|
||||
#
|
||||
# Create Tensorflow model
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input_node = tf.compat.v1.placeholder(tf.float32, input_shape, 'Input')
|
||||
strided_slice = tf.compat.v1.strided_slice(input_node, begin=begin, end=end, strides=strides,
|
||||
begin_mask=begin_mask, end_mask=end_mask,
|
||||
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
return tf_net, ref_net
|
||||
|
||||
test_squeeze_data = [
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[5, 1], begin=[0, 0], end=[5, 1], strides=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[1, 1, 3], begin=[0, 0, 0], end=[1, 1, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 1], begin=[0, 0, 0], end=[1, 5, 1], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[1, 1, 5, 3], begin=[0, 0, 0, 0], end=[1, 1, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 1, 3], begin=[0, 0, 0, 0], end=[1, 5, 1, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
|
||||
dict(input_shape=[1, 5, 5, 1], begin=[0, 0, 0, 0], end=[1, 5, 1, 1], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=8),
|
||||
dict(input_shape=[1, 1, 5, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 1, 5, 5, 3], strides=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=3),
|
||||
dict(input_shape=[1, 5, 1, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 3], strides=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=5),
|
||||
dict(input_shape=[1, 5, 1, 5, 1], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 1], strides=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=21),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize('params', test_squeeze_data)
|
||||
@pytest.mark.nightly
|
||||
def test_strided_slice_replace_with_squeeze(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
|
||||
test_unsqueeze_data = [
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=1, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=5, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=2, shrink_axis_mask=0),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize('params', test_unsqueeze_data)
|
||||
@pytest.mark.nightly
|
||||
def test_strided_slice_replace_with_unsqueeze(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
Reference in New Issue
Block a user