[MO] StridedSlice improvements (#4139)

* fix ss

* successfully converted

* successfully run moved infer and normalizer unit-tests

* successfully rewritten StridedSlice infer unittests

* int64 array

* Successfully converter crash-when-loading, xj_feauture and toy nets (cherry-picked maxpoolV4 and tf_broadcast_ext)

* successfully moved PermuteAttrs to general mechanism

* successfully converted xj_feauture and crash when loading with the new rewritten SS infer

* fixed get_shape_from_slice and moved to common utils

* fixed extending masks and some other

* some refactoring

* fixed extending masks in extractor, fixed licence year and some other code clearing

* corrected a couple of unittests

* fox permute for 5 rank slice and 4 rank inputs/

* WIP

* Added comments

* fixed StridedSlice in ProposalMutation.py

* rechecked shape_infer unittests added some new cases

* added shape_infer unit-tests after StridedSliceNormalizer pass and Permute unit-tests

* corrected unittests

* Applied review comments

* general permutations for inputs implemented, corrected ellipsis unrolling when shrink_axis is at the beginning, some other corrections

* removed code duplication in infer and normalizer, moved 'slices' attr normalizing to StridedSliceNormalizer.py

* removed some code duplication and other minor improvements

* Added tests

* minor corrections

* wider range of unittests added (froze the number)

* review comments applied

* enabled skipped unit-test

* comment corrections

* applied review comments: changed op -> type, added some asserts, corrected comments and other minor corrections

* sorted inputs, updated Supported_Frameworks_Layers.md, some minor
This commit is contained in:
Pavel Esir 2021-02-16 11:48:49 +03:00 committed by GitHub
parent d2548ddb60
commit 22169a05b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 3809 additions and 1108 deletions

View File

@ -233,7 +233,7 @@ Standard TensorFlow\* operations:
| Square| No |
| Squeeze | The case when squeeze axis is not specified is not supported |
| StopGradient | Not needed for shape inference |
| StridedSlice | No |
| StridedSlice | Supported only for constant begin, end, and strides inputs |
| Sub | No |
| Sum | No |
| Swish | No |

View File

@ -606,6 +606,7 @@ extensions/middle/SliceLikeToStridedSlice.py
extensions/middle/sparse_reshape.py
extensions/middle/split_tdnn_memoryoffset.py
extensions/middle/SplitConcatPairToInterpolate.py
extensions/middle/StridedSliceNormalizer.py
extensions/middle/SwapAxesMiddleReplacer.py
extensions/middle/TensorIterator_utils.py
extensions/middle/TensorIteratorBackEdge.py
@ -800,7 +801,6 @@ mo/front/common/partial_infer/multi_box_prior.py
mo/front/common/partial_infer/random_uniform.py
mo/front/common/partial_infer/reshape.py
mo/front/common/partial_infer/roipooling.py
mo/front/common/partial_infer/slice.py
mo/front/common/partial_infer/utils.py
mo/front/common/register_custom_ops.py
mo/front/common/replacement.py

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -64,8 +64,10 @@ class CropToStridedSlice(BackReplacementPattern):
end_mask = axis_mask.copy()
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
'end_mask': end_mask,
'new_axis_mask': np.zeros(len(end_mask)),
'shrink_axis_mask': np.zeros(len(end_mask)),
'ellipsis_mask': np.zeros(len(end_mask))}).create_node()
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1
@ -112,7 +114,7 @@ class CropToStridedSlice(BackReplacementPattern):
source = node.in_port(0).get_connection().get_source()
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
'name': ss.name + '/stride'}).create_node()
'name': ss.name + '/stride'}).create_node()
source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))

View File

@ -60,9 +60,9 @@ class ProposalMutation(BackReplacementPattern):
{'name': 'cropped_im_info',
'begin_mask': int64_array([1, 1]),
'end_mask': int64_array([1, 1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]),
'new_axis_mask': int64_array([0, 0]),
'shrink_axis_mask': int64_array([0, 0]),
'ellipsis_mask': int64_array([0, 0]),
'override_output_shape': True,
})

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,11 +14,9 @@
limitations under the License.
"""
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.graph import Graph
class StridedSliceMasksNormalizer(BackReplacementPattern):
@ -26,20 +24,13 @@ class StridedSliceMasksNormalizer(BackReplacementPattern):
force_clean_up = True
def run_after(self):
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
return [CropToStridedSlice, DeconvolutionNormalizer]
@staticmethod
def pattern():
return dict(
nodes=[
('strided_slice', dict(type='StridedSlice'))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: [str, Node]):
node = match['strided_slice']
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -133,15 +133,14 @@ class ApplyPermutation(MiddleReplacementPattern):
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
for in_port, input_perm in input_permutations:
permutation, port_info = input_perm
permutation, port_info, check_shape = input_perm
direction, port = port_info.split(':')
port = int(port)
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
permutation_data_node = get_node_with_permutation(node, port_info)
if permutation_data_node.has_and_set('permutation') and \
not is_input_data_in_correct_layout(node, in_port) and \
len(port_to_check.data.get_shape()) >= 4:
not is_input_data_in_correct_layout(node, in_port) and check_shape(port_to_check):
permutation(node, port_info, in_port)
if node.has_and_set('need_shape_inference'):
node.infer(node)

View File

@ -69,7 +69,8 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
enabled = True
def run_after(self):
return [ConvertSlice]
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
return [ConvertSlice, StridedSliceNormalizer]
def run_before(self):
from extensions.middle.pass_separator import MiddleFinish

View File

@ -0,0 +1,251 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.op import PermuteAttrs
from mo.ops.strided_slice import StridedSlice
from mo.utils.error import Error
class StridedSliceNormalizer(MiddleReplacementPattern):
"""
StridedSlice is not normal if it cannot be permuted by ApplyPermutations. This normalizer
inserts blank colons ':' in slice expression so that it can be correctly permuted
from NHWC to NCHW layout. It changes masks and inserts blank begin, end and strides values.
In order to successfully handle StridedSlice in ShapeOf subgraphs
changes must be done by inserting nodes not just by overwriting constants.
StridedSlice is not normal in 2 cases:
1. rank of a slice expression is less than rank of input tensor
2. there is an ellipsis
1st case example
BEFORE:
|
begin
value=[0, 0]
|
AFTER:
|
begin Const
value=[0, 0] value=[0, 0]
\ /
\ /
Concat
value=[0, 0, 0, 0]
|
Input of a shape [16, 100, 100, 3] in NHWC layout, output = input[:, 0:50].
StridedSlice will be extended to input[:, 0:50, :, :].
After permutation to NCHW output = input[:, :, 0:50, :].
Example for 'begin' input transformation is shown above on the picture.
'end' and 'strides' inputs will be transformed the same way.
2nd case example
BEFORE:
|
begin
value=[1, 50]
|
AFTER:
|
begin
value=[1, 1, 1]
|
VariadicSplit
/ \
/ \
/ Const \
\ val=[0, 0] /
\ | /
\ | /
Concat
value=[1, 0, 0, 1, 1]
|
Input of a shape [16, 10, 100, 100, 3] in NDHWC layout, output = input[1:4, ..., 1:51, 1:3],
output_shape = [3, 10, 100, 50, 2]. In order to perform correct layout permutation
ellipsis must be replaced with colons: input[1:4, ..., 1:51, 1:3] => input[1:4, :, :, 1:51, 1:3].
After layour permutation input[1:4, 1:3, :, : 1:5].
In the places of colons blank begin, end and strides values should be inserted.
In order to do that we split input and insert blank zeros to the middle.
Example for 'begin' input transformation is shown above on the picture.
'end' and 'strides' inputs will be transformed the same way.
"""
enabled = True
def run_before(self):
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
return [LayoutChangeForConstantShapePaths]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
StridedSliceNormalizer.normalize_strided_slice(graph, node)
PermuteAttrs.create_permute_attrs(node,
attrs=[('begin_mask', 'input:0'), # but indeed depends from slice_rank
('end_mask', 'input:0'),
('new_axis_mask', 'input:0'),
('shrink_axis_mask', 'input:0'),
('ellipsis_mask', 'input:0')])
# StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
# Until now it was not possible to set correct permutations
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
@staticmethod
def normalize_strided_slice(graph: Graph, node: Node):
input_shape = node.in_port(0).data.get_shape()
input_rank = len(input_shape)
begin, _, _ = StridedSlice.validate_inputs_and_get_args(node)
slice_rank = len(begin)
StridedSlice.align_mask_with_slice_rank(node, slice_rank) # if StridedSlice is created after partial_infer
StridedSliceNormalizer.normalize_slices_attr(node)
num_insertions = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \
'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \
format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask))
if np.any(node.ellipsis_mask):
assert np.count_nonzero(node.ellipsis_mask) == 1, 'only one ellipsis_mask nonzero value is allowed'
ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0]
# since we don't expect values in begin and end: take the whole range along ellipsis_start
node.begin_mask[ellipsis_start] = 0
node.end_mask[ellipsis_start] = 0
node.ellipsis_mask[ellipsis_start] = 0
insertion_start_idx = ellipsis_start + 1
StridedSliceNormalizer.unroll_ellipsis_for_inputs(graph, node, ellipsis_start, num_insertions)
elif num_insertions > 0:
insertion_start_idx = slice_rank # insert blank values to mask ends
StridedSliceNormalizer.extend_inputs(node, num_insertions)
if num_insertions > 0:
# insert blank values for ellipsis unrolling and extending
for mask_name in StridedSlice.get_mask_names():
node[mask_name] = np.insert(node[mask_name], insertion_start_idx, [0] * num_insertions).astype(int)
@staticmethod
def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int):
node_name = node.soft_get('name', node.id)
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()
if i == 3 and node.in_port(3).disconnected():
continue # no need to extend strides if they are not connected
concat_in_ports_count = 3 if ellipsis_start != 0 else 2
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
'in_ports_count': concat_in_ports_count}).create_node()
if ellipsis_start != 0:
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0),
2: int64_array([ellipsis_start, -1])},
{'name': node_name + '/split_for_{}_ellipsis'.format(input_name),
'out_ports_count': 2})
node.in_port(i).get_connection().set_destination(split.in_port(0))
concat.in_port(0).connect(split.out_port(0))
concat.in_port(1).connect(blank_values_node.out_port(0))
concat.in_port(2).connect(split.out_port(1))
else:
concat.in_port(0).connect(blank_values_node.out_port(0))
node.in_port(i).get_connection().set_destination(concat.in_port(1))
concat.out_port(0).get_connection().set_destination(node.in_port(i))
@staticmethod
def extend_inputs(node: Node, num_insertions: int):
graph = node.graph
node_name = node.soft_get('name', node.id)
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()
if i == 3 and node.in_port(3).disconnected():
continue # no need to extend strides if they are not connected
if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
# concat already exists
concat = node.in_port(i).get_source().node
last_in_port = max(concat.in_ports().keys())
assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \
'should be connected'. \
format(concat.soft_get('name', node.id))
concat.add_input_port(last_in_port + 1)
concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0))
else:
# have to create concat
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
'in_ports_count': 2}).create_node()
node.in_port(i).get_connection().set_destination(concat.in_port(0))
concat.in_port(1).connect(blank_values_node.out_port(0))
concat.out_port(0).get_connection().set_destination(node.in_port(i))
@staticmethod
def normalize_slices_attr(node: Node):
# removes negative starts, ends and magic numbers from 'slice' attr which is used by ConvertGroupedStridedSlice
slice_rank = len(node['slices'])
data_shape = node.in_port(0).data.get_shape()
node_name = node.soft_get('name', node.id)
if node.is_in_port_connected(3):
strides = node.in_port(3).data.get_value()
if strides is None:
raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name))
else:
strides = np.ones(slice_rank)
num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1
res_slices = []
in_idx = 0
for i, s in enumerate(node['slices']):
if node.new_axis_mask[i]:
res_slices.append(slice(0, 1, 1))
elif node.shrink_axis_mask[i]:
res_slices.append(slice(s, s + 1, strides[i])) # need strides if shrink index is negative
elif node.ellipsis_mask[i]:
for idx in range(num_ellipsis_inserts):
res_slices.append(slice(0, data_shape[in_idx], 1))
in_idx += 1
else:
res_slices.append(s)
if not (node.new_axis_mask[i] or node.ellipsis_mask[i]):
res_slices[-1] = slice(*res_slices[-1].indices(data_shape[in_idx])) # convert negative begins/ends
in_idx += 1
node.slices = np.array(res_slices)

File diff suppressed because it is too large Load Diff

View File

@ -1,158 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.utils.error import Error
def tf_strided_slice_infer(node):
if node.in_node(1).value is None or node.in_node(2).value is None:
raise Error('Strided slice layer supports only constant begin and end inputs')
begin_id = node.in_node(1).value.copy()
end_id = node.in_node(2).value.copy()
if len(node.in_nodes()) > 3:
if node.in_node(3).value is None:
raise Error('Strided slice layer supports only constant stride input')
stride = node.in_node(3).value
else:
stride = []
shape = node.in_node(0).shape
if shape is None or any([x < 0 for x in shape]):
return
convert_negative_indices(begin_id, shape)
convert_negative_indices(end_id, shape)
slice_idx = []
dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
len(node.begin_mask), len(node.end_mask)]))
# make mask correct length
def extend_mask(in_mask, fin_len, zeros=True):
mask = list(in_mask)
if len(mask) < fin_len:
if zeros:
mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
else:
mask.extend(np.ones(dims-len(mask), dtype=np.int32))
return np.array(mask, dtype=np.int32)
for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
node[mask] = extend_mask(node[mask], dims)
node.begin_mask = extend_mask(node.begin_mask, dims, False)
node.end_mask = extend_mask(node.end_mask, dims, False)
old_idx = 0
ellips_ext = 0
id_em = 0
for idx in range(dims):
if node.new_axis_mask[idx]:
slice_idx.append(np.newaxis)
elif node.ellipsis_mask[idx]:
ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
id_em = idx
for i in range(0, ellips_ext):
slice_idx.append(slice(0, shape[old_idx], 1))
old_idx = old_idx + 1
else:
s = stride[idx] if len(stride) > idx else 1
def_beg = 0 if s > 0 else -1
def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
# Check shrink_axis_mask
if node.shrink_axis_mask[idx] and idx < len(shape):
slice_idx.append(slice(l, l+1, s))
else:
slice_idx.append(slice(l, r, s))
old_idx = old_idx + 1
value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
# fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
# `arr[tuple(seq)]` instead of `arr[seq]`"
value = value[tuple(slice_idx)]
for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
if flag:
if ellips_ext > 0 and idx > id_em:
idx = idx + ellips_ext - 1
try:
value = np.squeeze(value, idx)
except ValueError:
# ignore this error
continue
for i, s in enumerate(slice_idx):
if s is None:
slice_idx[i] = slice(0, 1, 1)
node['slices'] = np.array(slice_idx)
for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
node[attr] = np.array(node[attr], dtype=np.int32)
node['force_precision_in_ports'] = {port: 'int64' for port in range(1, len(node.in_nodes()))}
node.out_node().value = value.copy() if node.in_node(0).value is not None else None
node.out_node().shape = np.array(value.shape, dtype=np.int64)
def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]
def mxnet_slice_axis_infer(node):
in_shape = node.in_node(0).shape
node.axis = get_canonical_axis_index(in_shape, node.axis)
slice_axis = node.axis
new_shape = np.array(in_shape, dtype=np.int64)
new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
axis_size = in_shape[slice_axis]
if node.offset < 0:
node.offset += axis_size
if not node.dim:
node.dim = axis_size
elif node.dim < 0:
node.dim += axis_size
input_dim = in_shape.size
node.dim = (node.dim - node.offset)
if node.dim > in_shape[slice_axis]:
raise Error(
'{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
'\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
'\nTo overcome, try to edit the original model "end" property of the {0} layer.',
node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
)
for i in range(0, input_dim):
if i == slice_axis:
new_shape[i] = node.dim
else:
new_shape[i] = in_shape[i]
for i in range(0, len(node.out_nodes())):
node.out_node(i)['shape'] = new_shape

View File

@ -1,278 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from mo.front.common.partial_infer.slice import tf_strided_slice_infer, convert_negative_indices, mxnet_slice_axis_infer
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph
nodes_attributes = {'node_1': {'value': None, 'kind': 'data'},
'Slice_node': {'type': 'Slice', 'kind': 'op'},
'node_2': {'value': None, 'kind': 'data'},
'node_3': {'value': None, 'kind': 'data'},
'node_4': {'value': None, 'kind': 'data'},
# StridedSlice node with attrs
'sslice_input': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice'},
'sslice_begin_1': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_end_1': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_stride_1': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_data_1': {'value': None, 'shape': None, 'kind': 'data'},
# TF slice
'tf_slice_input': {'value': None, 'shape': None, 'kind': 'data'},
'tf_slice_begin': {'value': None, 'shape': None, 'kind': 'data'},
'tf_slice_size': {'value': None, 'shape': None, 'kind': 'data'},
'tf_slice': {'kind': 'op'},
'tf_slice_output': {'value': None, 'shape': None, 'kind': 'data'},
'op_output': {'kind': 'op', 'op': 'Result'},
'op_output_1': {'kind': 'op', 'op': 'Result'},
'op_output_2': {'kind': 'op', 'op': 'Result'}
}
tf_slice_edges = [('tf_slice_input', 'tf_slice'), ('tf_slice_begin', 'tf_slice'), ('tf_slice_size', 'tf_slice'),
('tf_slice', 'tf_slice_output')]
class TestTFStridedSliceInfer(unittest.TestCase):
def build_test_graph2(self):
return build_graph(nodes_attributes,
[('sslice_input', 'sslice_1'),
('sslice_begin_1', 'sslice_1'),
('sslice_end_1', 'sslice_1'),
('sslice_stride_1', 'sslice_1'),
('sslice_1', 'sslice_data_1'),
('sslice_data_1', 'op_output')
],
{
'sslice_input': {'value': np.array([1, 34, 34, 62]),
'shape': np.array([3])},
'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
'begin_mask': [1], 'end_mask': [1]},
})
def build_test_graph(self):
return build_graph(nodes_attributes,
[('sslice_input', 'sslice_1'),
('sslice_begin_1', 'sslice_1'),
('sslice_end_1', 'sslice_1'),
('sslice_stride_1', 'sslice_1'),
('sslice_1', 'sslice_data_1'),
('sslice_data_1', 'op_output')
],
{
'sslice_input': {'value': None, 'shape': np.array([1, 35, 35, 3])},
'sslice_begin_1': {'value': np.array([0, 0, 0, 0]), 'shape': np.array([4])},
'sslice_end_1': {'value': np.array([1, 34, 30, 2]), 'shape': np.array([4])},
'sslice_stride_1': {'value': np.array([1, 1, 1, 1]),
'shape': np.array([4])},
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
'begin_mask': [1], 'end_mask': [1]},
})
def build_test_graph_dim_beg(self):
return build_graph(nodes_attributes,
[('sslice_input', 'sslice_1'),
('sslice_begin_1', 'sslice_1'),
('sslice_end_1', 'sslice_1'),
('sslice_stride_1', 'sslice_1'),
('sslice_1', 'sslice_data_1'),
('sslice_data_1', 'op_output')
],
{
'sslice_input': {'value': np.array([[1, 34, 34, 62]]),
'shape': np.array([1, 4])},
'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
'begin_mask': [1], 'end_mask': [1]},
})
def test_slice_infer_1(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
def test_slice_infer_2(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.end_mask = [1, 0, 0, 1] # 6
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])), 'Wrong output shape detected')
def test_slice_infer_3(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.in_node(1).value = np.array([0, 10, 10, 0])
node.end_mask = [1, 0, 0, 1] # 6
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 25, 25, 2])), 'Wrong output shape detected')
def test_slice_infer_4(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.in_node(1).value = np.array([0, 10, 10, 0])
node.begin_mask = [1, 0, 0, 1] # 6
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
def test_slice_infer_5(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.in_node(1).value = np.array([0, 10, 10, 0])
node.begin_mask = [0, 0, 0, 0] # 15
node.end_mask = [0, 0, 0, 0] # 15
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 3])), 'Wrong output shape detected')
def test_slice_infer_6(self):
graph = self.build_test_graph2()
node = Node(graph, 'sslice_1')
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output value detected')
def test_slice_infer_7(self):
graph = self.build_test_graph2()
node = Node(graph, 'sslice_1')
node.in_node(1).value = np.array([1])
node.in_node(2).value = np.array([3])
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
def test_slice_infer_8(self):
graph = self.build_test_graph2()
node = Node(graph, 'sslice_1')
node.new_axis_mask = [1]
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 4])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])),
'Wrong output value detected')
def test_slice_infer_9(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.begin_mask = [0, 0, 0, 0] # 15
node.end_mask = [0, 0, 0, 0] # 15
node.shrink_axis_mask = [1]
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 3])), 'Wrong output shape detected')
def test_slice_infer_10(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.begin_mask = [0, 0, 0, 0] # 15
node.end_mask = [0, 0, 0, 0] # 15
node.shrink_axis_mask = [1, 0, 0, 0]
node.new_axis_mask = [0, 0, 0, 1] # 8
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 1, 3])), 'Wrong output shape detected')
def test_slice_infer_11(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.begin_mask = [0, 0, 0, 0] # 15
node.end_mask = [0, 0, 0, 0] # 15
node.shrink_axis_mask = [1, 0, 1, 0] # 5
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 3])), 'Wrong output shape detected')
def test_slice_infer_12(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
node.begin_mask = [0, 0, 0, 0] # 15
node.end_mask = [0, 0, 0, 0] # 15
node.shrink_axis_mask = [1, 1, 1, 0] # 7
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
def test_slice_infer_13(self):
graph = self.build_test_graph2()
node = Node(graph, 'sslice_1')
node.in_node(1).value = np.array([1])
node.shrink_axis_mask = [1]
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
def test_slice_infer_14(self):
graph = self.build_test_graph2()
node = Node(graph, 'sslice_1')
node.in_node(3).value = np.array([-1])
node.end_mask = [0]
node.begin_mask = [0]
node.in_node(0).shape = [4]
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
print(node.out_node().value)
self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
def test_slice_infer_dim_beg(self):
graph = self.build_test_graph_dim_beg()
node = Node(graph, 'sslice_1')
node.shrink_axis_mask = [1]
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
def test_slice_infer_neg_end(self):
graph = self.build_test_graph()
node = Node(graph, 'sslice_1')
end_node = Node(graph, 'sslice_end_1')
end_node.value = np.array([1, -1, -5, -1])
tf_strided_slice_infer(node)
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
self.assertTrue(np.array_equal(end_node.value, np.array([1, -1, -5, -1])), 'Negative values in end were converted to positive')
class TestConvertNegativeIndices(unittest.TestCase):
def test_convert_negative_indices(self):
dimensions = np.array([3, 4, 8, 10])
indices = np.array([2, 0, -3, -4])
convert_negative_indices(indices, dimensions)
self.assertTrue(np.array_equal(indices, np.array([2, 0, 5, 6])), 'Wrong dimension indices')
class TestMXNetSliceAxisInfer(unittest.TestCase):
def test_slice_axis_infer_layer(self):
graph = build_graph(
{'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
'kind': 'op', 'op': 'slice_axis', },
'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
},
[
('node_1', 'slice_axis_node'),
('slice_axis_node', 'node_3'),
],
{
'node_1': {'shape': np.array([1, 1024, 19, 19])},
'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
})
slice_axis_node = Node(graph, 'slice_axis_node')
mxnet_slice_axis_infer(slice_axis_node)
res_shape = [1, 15, 19, 19]
for i in range(0, len(graph.node['node_3']['shape'])):
self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
"""
import logging as log
from typing import Iterable
from typing import Iterable, List, Union
import numpy as np
@ -113,4 +113,34 @@ def broadcast_shape(first_shape, second_shape):
assert a_val == 1 or b_val == 1 or a_val == b_val, "Input shape do not broadcast"
new_val = b_val if a_val == 1 else a_val
new_shape[-i - 1] = new_val
return int64_array(new_shape)
return int64_array(new_shape)
def get_shape_from_slice(input_shape: np.ndarray, slices: List) -> np.ndarray:
"""
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
Is introduced to prevent potentially large memory consumption.
"""
output_shape = []
num_new_axes = np.count_nonzero(list(map(lambda x: x is np.newaxis, slices)))
num_ellipsis_inserts = len(input_shape) - len(slices) + num_new_axes + 1
in_idx = 0
for i, s in enumerate(slices):
if isinstance(s, slice):
output_shape.append(len(range(*s.indices(input_shape[in_idx]))))
in_idx += 1
elif s is np.newaxis:
output_shape.append(1)
elif isinstance(s, int): # shrink_axis
in_idx += 1
elif s is Ellipsis:
for idx in range(num_ellipsis_inserts):
output_shape.append(input_shape[in_idx])
in_idx += 1
else:
raise Exception('Element type of a slice List is unacceptable. '
'Allowed types are: Ellipsis, slice, int, and None. Instead got: '. format(type(s)))
for i in range(in_idx, len(input_shape)):
output_shape.append(input_shape[i])
return int64_array(output_shape)

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,7 +14,10 @@
limitations under the License.
"""
from mo.front.common.partial_infer.slice import mxnet_slice_axis_infer
from mo.front.caffe.extractors.utils import get_canonical_axis_index
import numpy as np
from mo.utils.error import Error
def slice_axis_ext(attrs):
axis = attrs.int("axis", 0)
@ -29,3 +32,40 @@ def slice_axis_ext(attrs):
'infer': mxnet_slice_axis_infer
}
return node_attrs
def mxnet_slice_axis_infer(node):
in_shape = node.in_port(0).data.get_shape()
node.axis = get_canonical_axis_index(in_shape, node.axis)
slice_axis = node.axis
new_shape = np.array(in_shape, dtype=np.int64)
new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
axis_size = in_shape[slice_axis]
if node.offset < 0:
node.offset += axis_size
if not node.dim:
node.dim = axis_size
elif node.dim < 0:
node.dim += axis_size
input_dim = in_shape.size
node.dim = (node.dim - node.offset)
if node.dim > in_shape[slice_axis]:
raise Error(
'{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
'\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
'\nTo overcome, try to edit the original model "end" property of the {0} layer.',
node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
)
for i in range(0, input_dim):
if i == slice_axis:
new_shape[i] = node.dim
else:
new_shape[i] = in_shape[i]
for i in range(0, len(node.out_nodes())):
node.out_node(i)['shape'] = new_shape

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,6 +16,9 @@
import unittest
import numpy as np
from mo.front.mxnet.extractors.slice_axis import mxnet_slice_axis_infer
from mo.front.mxnet.extractors.slice_axis import slice_axis_ext
from mo.front.mxnet.extractors.utils import AttrDictionary
from mo.graph.graph import Node
@ -49,3 +52,27 @@ class TestMXNetSliceAxisExtractorOp(unittest.TestCase):
for key in exp_attrs.keys():
self.assertEqual(res[key], exp_attrs[key])
class TestMXNetSliceAxisInfer(unittest.TestCase):
def test_slice_axis_infer_layer(self):
graph = build_graph(
{'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
'kind': 'op', 'op': 'slice_axis', },
'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
},
[
('node_1', 'slice_axis_node'),
('slice_axis_node', 'node_3'),
],
{
'node_1': {'shape': np.array([1, 1024, 19, 19])},
'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
})
slice_axis_node = Node(graph, 'slice_axis_node')
mxnet_slice_axis_infer(slice_axis_node)
res_shape = [1, 15, 19, 19]
for i in range(0, len(graph.node['node_3']['shape'])):
self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -126,20 +126,49 @@ def order(op_node: Node, port_info: str, input_port: int):
op_node['need_shape_inference'] = True
def shape(op_node: Node, port_info: str, input_port: int):
graph = op_node.graph
def strided_slice(op_node: Node, port_info: str, input_port: int):
"""
StridedSLice must be permuted even if input or output tensors have rank lesser than 4
e.g. input_shape = (1, 10, 10), out = input[:, 0:10, :, new_axis], input_rank < 4
input_shape = (1, 10, 10, 3), out = input[:, 0:5, 0:4, 0], output_rank < 4
in both examples slice_rank is >= 4
slice_rank is defined by length of begin, end, strides (they all are of the same length)
"""
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
'port_info "{}".'.format(permutation_data_node.id,
op_node.id, port_info)
permutation = permutation_data_node.permutation
if len(permutation.perm) == 0:
permute_indices_for_gather = permutation_data_node.permutation.perm
if len(permute_indices_for_gather) == 0:
return
from mo.ops.op import PermuteAttrs
slice_rank = op_node.in_port(input_port).data.get_shape()[0] # length of begin, end or strides
permute_indices_for_gather = PermuteAttrs.get_nhwc_to_nchw_permutation(slice_rank).perm
reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
def shape(op_node: Node, port_info: str, input_port: int):
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
'port_info "{}".'.format(permutation_data_node.id,
op_node.id, port_info)
permute_indices_for_gather = permutation_data_node.permutation.perm
if len(permute_indices_for_gather) == 0:
return
reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
def reorder_inputs_for_shape_or_slice(op_node: Node, input_port: int, permute_indices_for_gather: list):
"""
axis and slice permutations are almost the same the only difference is that for slice in general
case permutation depends from slice_rank not from input_rank or output_rank
"""
graph = op_node.graph
data_node = op_node.in_node(input_port)
gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
const = Const(graph, {'value': permute_indices_for_gather, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name,
@ -191,10 +220,9 @@ def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int):
class PermuteInputs:
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)
input_permutes = {
'axis': common_inv_permutation,
'axis': lambda node, port_info, input_port: axis(node, port_info, input_port),
'slice': lambda node, port_info, input_port: strided_slice(node, port_info, input_port),
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
@ -202,16 +230,27 @@ class PermuteInputs:
input_port),
}
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):
shape_check_rules = {
'rank': lambda port: bool(len(port.data.get_shape()) >= 4),
'dim_size': lambda port: bool(port.data.get_shape()[0] >= 4), # if input 'dim_size' >= 4 need to permute
}
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str,
shape_check_rule: str = 'rank'):
"""
Sets input permutation attribute on the edge between node1 and node2.
Input permutation consists of function that perform input permutation and
input port info 'input' or 'output' + <port_number> that points on the input with PermuteAttr.Permutation which
current input depends on
current input depends on.
shape_check_rule defines the check rule if the op node inputs need to be permuted.
By default 'rank' rule is applied, 'dim_size' is used only for StridedSlice so far.
"""
assert permutation_rule in self.input_permutes, 'No `{}` permutation rule in {}'.format(permutation_rule,
__class__.__name__)
assert shape_check_rule in self.shape_check_rules, 'No `{}` permutation shape check rule ' \
'in {}'.format(shape_check_rule, __class__.__name__)
nx.set_edge_attributes(G=node1.graph,
values={(node1.id, node2.id, 0): (self.input_permutes[permutation_rule],
port_info)},
values={(node1.id, node2.id, 0): (self.input_permutes[permutation_rule], port_info,
self.shape_check_rules[shape_check_rule])},
name='input_permutation')

View File

@ -340,6 +340,8 @@ class PermuteAttrs:
Attr = namedtuple('Attr', ['name', 'port', 'func'])
common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
slice_permutation = lambda node, permutation, attr: node[attr][ # doesn't depend from permutation variable
PermuteAttrs.get_nhwc_to_nchw_permutation(len(node[attr])).perm]
common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]
# List of default permutations
@ -354,9 +356,11 @@ class PermuteAttrs:
'dilation': common_permutation,
'kernel_shape': common_permutation,
'output_shape': common_permutation,
'slices': common_permutation,
'shrink_axis_mask': common_permutation,
'new_axis_mask': common_permutation,
'begin_mask': slice_permutation,
'end_mask': slice_permutation,
'shrink_axis_mask': slice_permutation,
'new_axis_mask': slice_permutation,
'ellipsis_mask': slice_permutation,
'axes': common_permutation_inv,
'axis': common_permutation_inv,
'batch_dims': common_permutation_inv,

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,10 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List
import numpy as np
from mo.front.common.partial_infer.utils import get_shape_from_slice
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.error import Error
@ -157,20 +157,9 @@ class Slice(Op):
# Ranged for output value for specified axis
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
if input_value is None:
output_shape = get_shape_after_slice(input_shape, slice_idx)
output_shape = get_shape_from_slice(input_shape, slice_idx)
if np.any(output_shape <= 0):
raise Error('Output shape: {} of node "{}" contains non-positive values'.format(output_shape, node.name))
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])
def get_shape_after_slice(input_shape: np.ndarray, slice_idx: List[slice]) -> np.ndarray:
"""
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
Is introduced to prevent potentially large memory consumption.
"""
output_shape = np.zeros(len(input_shape), dtype=np.int32)
for i, s in enumerate(slice_idx):
output_shape[i] = len(range(*s.indices(input_shape[i])))
return output_shape

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,58 +14,17 @@
limitations under the License.
"""
from typing import List, Tuple
import numpy as np
from mo.front.common.partial_infer.slice import tf_strided_slice_infer
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.partial_infer.utils import get_shape_from_slice
from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs
from mo.ops.op import Op
from mo.utils.error import Error
from mo.utils.utils import array_to_str
def extend_mask_according_ellipsis(ellipsis_mask, shrink_axis_mask, length_output_shape, attr_mask_extended, ins_value):
# ellipsis is set, add dimensions in right place otherwise insert in the end
if np.any(ellipsis_mask):
idx = np.nonzero(ellipsis_mask)
assert len(idx[0]) == 1
insert_ind = idx[0][0]
else:
insert_ind = len(attr_mask_extended) - 1
ellipse_ext = length_output_shape + np.count_nonzero(shrink_axis_mask) - len(attr_mask_extended)
for i in range(0, ellipse_ext):
attr_mask_extended.insert(insert_ind + i + 1, ins_value)
return attr_mask_extended
def permute_array(node: Node, array: np.array):
"""
This function permutes masks according to permutation parameter. Mask have the same or more length than output
"""
attr_mask_extended = list(array)
# If input and output have length of shape 3 and less, no need to permute
if len(node.in_port(0).data.get_shape()) < 4 and len(node.out_port(0).data.get_shape()) < 4:
return attr_mask_extended
perm_len = len(node.out_port(0).data.get_shape()) + np.count_nonzero(node.shrink_axis_mask)
perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len)
perm_list = list(perm.perm)
# if mask length is more than output, just add tail that will not be permuted to avoid error
for i in range(perm_len, len(attr_mask_extended)):
perm_list.append(i)
return int64_array(attr_mask_extended)[int64_array(perm_list)]
def permute_masks(node: Node, permutation: PermuteAttrs.Permutation, attr: str):
if not node.has_valid(attr):
return None
node[attr] = permute_array(node, node[attr])
return node[attr]
class StridedSlice(Op):
op = 'StridedSlice'
enabled = True
@ -79,11 +38,12 @@ class StridedSlice(Op):
'out_ports_count': 1,
'infer': __class__.infer
}, attrs)
assert 'new_axis_mask' in attrs, "Attribute 'new_axis_mask' of the StridedSlice node is not given."
assert 'shrink_axis_mask' in attrs, "Attribute 'shrink_axis_mask' of the StridedSlice node is not given."
assert 'ellipsis_mask' in attrs, "Attribute 'ellipsis_mask' of the StridedSlice node is not given."
assert 'begin_mask' in attrs, "Attribute 'begin_mask' of the StridedSlice node is not given."
assert 'end_mask' in attrs, "Attribute 'end_mask' of the StridedSlice node is not given."
for mask_name in StridedSlice.get_mask_names():
assert mask_name in attrs, 'Attribute {} of the StridedSlice node is not given.'.format(mask_name)
@staticmethod
def get_mask_names():
return ['begin_mask', 'end_mask', 'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask']
def backend_attrs(self):
al = list()
@ -91,61 +51,86 @@ class StridedSlice(Op):
def convert(attr):
return lambda node: array_to_str(node, attr)
for a in list(['new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask']):
for a in StridedSlice.get_mask_names():
al.append((a, convert(a)))
return al
@staticmethod
def infer(node: Node):
tf_strided_slice_infer(node)
begin, end, strides = StridedSlice.validate_inputs_and_get_args(node)
out_shape = node.out_port(0).data.get_shape()
assert out_shape is not None, \
'Output shape was not calculated for node {}'.format(node.name)
# extend inputs according to ellipsis mask and/or input_shape
for i_port in node.in_ports().values():
if i_port.idx == 0 or i_port.disconnected():
continue
old_value = i_port.data.get_value()
# additional check for non-const input
# error will be return in shape inference if non-const will be added
# it is paranoid check for case if shape inference will be changed
assert old_value is not None, \
'{} input of {} node is not constant: \'value\' attribute for edge ' + \
'contains None'.format(i_port.idx, node.name)
# insert 0 for begin and end and 1 for stride
new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
len(out_shape), list(old_value),
int(i_port.idx == 3)))
# set_value additionally set_shape and propagate value to Const node
if not np.array_equal(new_value, old_value):
i_port.data.set_value(new_value)
StridedSlice.align_mask_with_slice_rank(node, len(begin))
# extend masks before removing ellipsis
for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]:
node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
len(out_shape), list(node[attr]), 0))
data_shape = node.in_port(0).data.get_shape()
data_value = node.in_port(0).data.get_value()
slices = StridedSlice.get_slices(node, data_shape, begin, end, strides)
# we will extend all masks and inputs to simplify future transformations
idx = np.nonzero(node.ellipsis_mask)
node.ellipsis_mask[idx] = 0
if data_value is not None:
node.out_port(0).data.set_value(data_value[tuple(slices)])
else:
node.out_port(0).data.set_shape(get_shape_from_slice(data_shape, slices))
if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None:
PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
('new_axis_mask', 'input:0', permute_masks),
('ellipsis_mask', 'input:0', permute_masks),
('begin_mask', 'input:0', permute_masks),
('end_mask', 'input:0', permute_masks),
])
# permute inputs
in_shape = node.in_port(0).get_source().data.get_shape()
assert in_shape is not None, \
'Input shape is unknown for 0 input of node {}'.format(node.name)
input_rank = len(in_shape)
if input_rank > 3:
for i_port in node.in_ports().values():
if i_port.idx == 0 or i_port.disconnected():
continue
new_value = permute_array(node, i_port.data.get_value())
# set_value additionally set_shape and propagate value to Const node
i_port.data.set_value(new_value)
node['slices'] = slices
node['force_precision_in_ports'] = {port: 'int64' for port in range(1, len(node.in_nodes()))}
# StridedSliceNormalizer inserts nodes that change original begin, end, and strides data nodes
# and since input permutations are stored in data nodes we end up having permutations
# in the wrong place of the graph.
# Therefore PermuteInputs will be set after StridedSliceNormalizer.
@staticmethod
def get_slices(node: Node, data_shape: Tuple, begin: np.array, end: np.array, strides: np.array) -> List:
input_rank = len(data_shape)
slice_rank = len(begin)
# from now slices are without ellipsis
slices = [[]] * slice_rank
in_idx = 0 # index along input tensor shapes, note that input_rank not necessary is equal to slice_rank
for i in range(slice_rank):
if node.new_axis_mask[i]:
slices[i] = np.newaxis
elif node.shrink_axis_mask[i]:
slices[i] = int(begin[i])
if slices[i] < 0: # need for ConvertGroupedStridedSlice
slices[i] += int(data_shape[in_idx])
elif node.ellipsis_mask[i]:
slices[i] = ...
in_idx += input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
else:
start, stop = begin[i], end[i]
if not node.begin_mask[i]: # if begin, and end are not specified take the whole range
start = None
if not node.end_mask[i]:
stop = None
slices[i] = slice(start, stop, strides[i])
in_idx += 1 if not node.new_axis_mask[i] else 0
return slices
@staticmethod
def align_mask_with_slice_rank(node: Node, slice_rank: int):
# align masks sizes with slice_rank (not confuse with extending, mask_aligment != mask_extending)
for mask_name in StridedSlice.get_mask_names():
num_insertations = slice_rank - len(node[mask_name])
val = 0 if mask_name not in ['begin_mask', 'end_mask'] else 1 # extend with ones only for begin and end
node[mask_name] = np.append(node[mask_name], [val] * num_insertations).astype(int)
@staticmethod
def validate_inputs_and_get_args(node: Node) -> (np.ndarray, np.ndarray, np.ndarray):
node_name = node.soft_get('name', node.id)
begin = node.in_port(1).data.get_value()
end = node.in_port(2).data.get_value()
if begin is None or end is None:
raise Error(
'StridedSlice operation for node {} supports only constant begin and end inputs'.format(node_name))
if node.is_in_port_connected(3):
strides = node.in_port(3).data.get_value()
if strides is None:
raise Error(
'StridedSlice operation for node {} supports only constant strides input'.format(node_name))
else:
strides = np.ones_like(begin)
assert len(begin) == len(end) == len(strides), \
'begin, end, and strides of StridedSlice node {} must be of the same length. Got insted:' \
'begin = {}, end = {}, strides = {}'.format(node_name, begin, end, strides)
return begin, end, strides

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
"""
from mo.front.common.partial_infer.utils import int64_array
from mo.ops.strided_slice import StridedSlice
from mo.utils.graph import Node
from mo.utils.ir_reader.extender import Extender
@ -24,9 +25,7 @@ class StridedSlice_extender(Extender):
@staticmethod
def extend(op: Node):
attrs = ['shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask']
for attr in attrs:
for attr in StridedSlice.get_mask_names():
Extender.attr_to_list(op, attr)
op.begin_mask = int64_array([1 - i for i in op.begin_mask])