[MO] StridedSlice improvements (#4139)
* fix ss * successfully converted * successfully run moved infer and normalizer unit-tests * successfully rewritten StridedSlice infer unittests * int64 array * Successfully converter crash-when-loading, xj_feauture and toy nets (cherry-picked maxpoolV4 and tf_broadcast_ext) * successfully moved PermuteAttrs to general mechanism * successfully converted xj_feauture and crash when loading with the new rewritten SS infer * fixed get_shape_from_slice and moved to common utils * fixed extending masks and some other * some refactoring * fixed extending masks in extractor, fixed licence year and some other code clearing * corrected a couple of unittests * fox permute for 5 rank slice and 4 rank inputs/ * WIP * Added comments * fixed StridedSlice in ProposalMutation.py * rechecked shape_infer unittests added some new cases * added shape_infer unit-tests after StridedSliceNormalizer pass and Permute unit-tests * corrected unittests * Applied review comments * general permutations for inputs implemented, corrected ellipsis unrolling when shrink_axis is at the beginning, some other corrections * removed code duplication in infer and normalizer, moved 'slices' attr normalizing to StridedSliceNormalizer.py * removed some code duplication and other minor improvements * Added tests * minor corrections * wider range of unittests added (froze the number) * review comments applied * enabled skipped unit-test * comment corrections * applied review comments: changed op -> type, added some asserts, corrected comments and other minor corrections * sorted inputs, updated Supported_Frameworks_Layers.md, some minor
This commit is contained in:
parent
d2548ddb60
commit
22169a05b9
@ -233,7 +233,7 @@ Standard TensorFlow\* operations:
|
||||
| Square| No |
|
||||
| Squeeze | The case when squeeze axis is not specified is not supported |
|
||||
| StopGradient | Not needed for shape inference |
|
||||
| StridedSlice | No |
|
||||
| StridedSlice | Supported only for constant begin, end, and strides inputs |
|
||||
| Sub | No |
|
||||
| Sum | No |
|
||||
| Swish | No |
|
||||
|
@ -606,6 +606,7 @@ extensions/middle/SliceLikeToStridedSlice.py
|
||||
extensions/middle/sparse_reshape.py
|
||||
extensions/middle/split_tdnn_memoryoffset.py
|
||||
extensions/middle/SplitConcatPairToInterpolate.py
|
||||
extensions/middle/StridedSliceNormalizer.py
|
||||
extensions/middle/SwapAxesMiddleReplacer.py
|
||||
extensions/middle/TensorIterator_utils.py
|
||||
extensions/middle/TensorIteratorBackEdge.py
|
||||
@ -800,7 +801,6 @@ mo/front/common/partial_infer/multi_box_prior.py
|
||||
mo/front/common/partial_infer/random_uniform.py
|
||||
mo/front/common/partial_infer/reshape.py
|
||||
mo/front/common/partial_infer/roipooling.py
|
||||
mo/front/common/partial_infer/slice.py
|
||||
mo/front/common/partial_infer/utils.py
|
||||
mo/front/common/register_custom_ops.py
|
||||
mo/front/common/replacement.py
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -64,8 +64,10 @@ class CropToStridedSlice(BackReplacementPattern):
|
||||
end_mask = axis_mask.copy()
|
||||
|
||||
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
|
||||
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
|
||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||
'end_mask': end_mask,
|
||||
'new_axis_mask': np.zeros(len(end_mask)),
|
||||
'shrink_axis_mask': np.zeros(len(end_mask)),
|
||||
'ellipsis_mask': np.zeros(len(end_mask))}).create_node()
|
||||
|
||||
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
|
||||
# Crop Type 1
|
||||
@ -112,7 +114,7 @@ class CropToStridedSlice(BackReplacementPattern):
|
||||
source = node.in_port(0).get_connection().get_source()
|
||||
|
||||
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
|
||||
'name': ss.name + '/stride'}).create_node()
|
||||
'name': ss.name + '/stride'}).create_node()
|
||||
|
||||
source.connect(ss.in_port(0))
|
||||
begin.out_port(0).connect(ss.in_port(1))
|
||||
|
@ -60,9 +60,9 @@ class ProposalMutation(BackReplacementPattern):
|
||||
{'name': 'cropped_im_info',
|
||||
'begin_mask': int64_array([1, 1]),
|
||||
'end_mask': int64_array([1, 1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0]),
|
||||
'new_axis_mask': int64_array([0, 0]),
|
||||
'shrink_axis_mask': int64_array([0, 0]),
|
||||
'ellipsis_mask': int64_array([0, 0]),
|
||||
'override_output_shape': True,
|
||||
})
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,11 +14,9 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
|
||||
from extensions.back.CropToStridedSlice import CropToStridedSlice
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
|
||||
class StridedSliceMasksNormalizer(BackReplacementPattern):
|
||||
@ -26,20 +24,13 @@ class StridedSliceMasksNormalizer(BackReplacementPattern):
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
|
||||
from extensions.back.CropToStridedSlice import CropToStridedSlice
|
||||
return [CropToStridedSlice, DeconvolutionNormalizer]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('strided_slice', dict(type='StridedSlice'))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: [str, Node]):
|
||||
node = match['strided_slice']
|
||||
assert node.has_valid('begin_mask')
|
||||
assert node.has_valid('end_mask')
|
||||
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
|
||||
node.end_mask = int64_array([1 - i for i in node.end_mask])
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(type='StridedSlice'):
|
||||
assert node.has_valid('begin_mask')
|
||||
assert node.has_valid('end_mask')
|
||||
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
|
||||
node.end_mask = int64_array([1 - i for i in node.end_mask])
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -133,15 +133,14 @@ class ApplyPermutation(MiddleReplacementPattern):
|
||||
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
|
||||
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
|
||||
for in_port, input_perm in input_permutations:
|
||||
permutation, port_info = input_perm
|
||||
permutation, port_info, check_shape = input_perm
|
||||
direction, port = port_info.split(':')
|
||||
port = int(port)
|
||||
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
|
||||
permutation_data_node = get_node_with_permutation(node, port_info)
|
||||
|
||||
if permutation_data_node.has_and_set('permutation') and \
|
||||
not is_input_data_in_correct_layout(node, in_port) and \
|
||||
len(port_to_check.data.get_shape()) >= 4:
|
||||
not is_input_data_in_correct_layout(node, in_port) and check_shape(port_to_check):
|
||||
permutation(node, port_info, in_port)
|
||||
if node.has_and_set('need_shape_inference'):
|
||||
node.infer(node)
|
||||
|
@ -69,7 +69,8 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
return [ConvertSlice]
|
||||
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
|
||||
return [ConvertSlice, StridedSliceNormalizer]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.middle.pass_separator import MiddleFinish
|
||||
|
251
model-optimizer/extensions/middle/StridedSliceNormalizer.py
Normal file
251
model-optimizer/extensions/middle/StridedSliceNormalizer.py
Normal file
@ -0,0 +1,251 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.split import VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.perm_inputs import PermuteInputs
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.op import PermuteAttrs
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
"""
|
||||
StridedSlice is not normal if it cannot be permuted by ApplyPermutations. This normalizer
|
||||
inserts blank colons ':' in slice expression so that it can be correctly permuted
|
||||
from NHWC to NCHW layout. It changes masks and inserts blank begin, end and strides values.
|
||||
In order to successfully handle StridedSlice in ShapeOf subgraphs
|
||||
changes must be done by inserting nodes not just by overwriting constants.
|
||||
|
||||
StridedSlice is not normal in 2 cases:
|
||||
1. rank of a slice expression is less than rank of input tensor
|
||||
2. there is an ellipsis
|
||||
|
||||
1st case example
|
||||
BEFORE:
|
||||
|
|
||||
begin
|
||||
value=[0, 0]
|
||||
|
|
||||
|
||||
AFTER:
|
||||
|
|
||||
begin Const
|
||||
value=[0, 0] value=[0, 0]
|
||||
\ /
|
||||
\ /
|
||||
Concat
|
||||
value=[0, 0, 0, 0]
|
||||
|
|
||||
|
||||
Input of a shape [16, 100, 100, 3] in NHWC layout, output = input[:, 0:50].
|
||||
StridedSlice will be extended to input[:, 0:50, :, :].
|
||||
After permutation to NCHW output = input[:, :, 0:50, :].
|
||||
Example for 'begin' input transformation is shown above on the picture.
|
||||
'end' and 'strides' inputs will be transformed the same way.
|
||||
|
||||
2nd case example
|
||||
BEFORE:
|
||||
|
|
||||
begin
|
||||
value=[1, 50]
|
||||
|
|
||||
|
||||
AFTER:
|
||||
|
|
||||
begin
|
||||
value=[1, 1, 1]
|
||||
|
|
||||
VariadicSplit
|
||||
/ \
|
||||
/ \
|
||||
/ Const \
|
||||
\ val=[0, 0] /
|
||||
\ | /
|
||||
\ | /
|
||||
Concat
|
||||
value=[1, 0, 0, 1, 1]
|
||||
|
|
||||
|
||||
Input of a shape [16, 10, 100, 100, 3] in NDHWC layout, output = input[1:4, ..., 1:51, 1:3],
|
||||
output_shape = [3, 10, 100, 50, 2]. In order to perform correct layout permutation
|
||||
ellipsis must be replaced with colons: input[1:4, ..., 1:51, 1:3] => input[1:4, :, :, 1:51, 1:3].
|
||||
After layour permutation input[1:4, 1:3, :, : 1:5].
|
||||
|
||||
In the places of colons blank begin, end and strides values should be inserted.
|
||||
In order to do that we split input and insert blank zeros to the middle.
|
||||
Example for 'begin' input transformation is shown above on the picture.
|
||||
'end' and 'strides' inputs will be transformed the same way.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_before(self):
|
||||
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
||||
return [LayoutChangeForConstantShapePaths]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(type='StridedSlice'):
|
||||
StridedSliceNormalizer.normalize_strided_slice(graph, node)
|
||||
PermuteAttrs.create_permute_attrs(node,
|
||||
attrs=[('begin_mask', 'input:0'), # but indeed depends from slice_rank
|
||||
('end_mask', 'input:0'),
|
||||
('new_axis_mask', 'input:0'),
|
||||
('shrink_axis_mask', 'input:0'),
|
||||
('ellipsis_mask', 'input:0')])
|
||||
|
||||
# StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
|
||||
# Until now it was not possible to set correct permutations
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size')
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size')
|
||||
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
|
||||
|
||||
@staticmethod
|
||||
def normalize_strided_slice(graph: Graph, node: Node):
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
input_rank = len(input_shape)
|
||||
begin, _, _ = StridedSlice.validate_inputs_and_get_args(node)
|
||||
slice_rank = len(begin)
|
||||
|
||||
StridedSlice.align_mask_with_slice_rank(node, slice_rank) # if StridedSlice is created after partial_infer
|
||||
StridedSliceNormalizer.normalize_slices_attr(node)
|
||||
|
||||
num_insertions = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
|
||||
assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \
|
||||
'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \
|
||||
format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask))
|
||||
|
||||
if np.any(node.ellipsis_mask):
|
||||
assert np.count_nonzero(node.ellipsis_mask) == 1, 'only one ellipsis_mask nonzero value is allowed'
|
||||
ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0]
|
||||
# since we don't expect values in begin and end: take the whole range along ellipsis_start
|
||||
node.begin_mask[ellipsis_start] = 0
|
||||
node.end_mask[ellipsis_start] = 0
|
||||
node.ellipsis_mask[ellipsis_start] = 0
|
||||
insertion_start_idx = ellipsis_start + 1
|
||||
|
||||
StridedSliceNormalizer.unroll_ellipsis_for_inputs(graph, node, ellipsis_start, num_insertions)
|
||||
elif num_insertions > 0:
|
||||
insertion_start_idx = slice_rank # insert blank values to mask ends
|
||||
StridedSliceNormalizer.extend_inputs(node, num_insertions)
|
||||
|
||||
if num_insertions > 0:
|
||||
# insert blank values for ellipsis unrolling and extending
|
||||
for mask_name in StridedSlice.get_mask_names():
|
||||
node[mask_name] = np.insert(node[mask_name], insertion_start_idx, [0] * num_insertions).astype(int)
|
||||
|
||||
@staticmethod
|
||||
def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
|
||||
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
|
||||
blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
|
||||
'value': int64_array(blank_values_arr)}).create_node()
|
||||
|
||||
if i == 3 and node.in_port(3).disconnected():
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
concat_in_ports_count = 3 if ellipsis_start != 0 else 2
|
||||
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
|
||||
'in_ports_count': concat_in_ports_count}).create_node()
|
||||
|
||||
if ellipsis_start != 0:
|
||||
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0),
|
||||
2: int64_array([ellipsis_start, -1])},
|
||||
{'name': node_name + '/split_for_{}_ellipsis'.format(input_name),
|
||||
'out_ports_count': 2})
|
||||
node.in_port(i).get_connection().set_destination(split.in_port(0))
|
||||
|
||||
concat.in_port(0).connect(split.out_port(0))
|
||||
concat.in_port(1).connect(blank_values_node.out_port(0))
|
||||
concat.in_port(2).connect(split.out_port(1))
|
||||
else:
|
||||
concat.in_port(0).connect(blank_values_node.out_port(0))
|
||||
node.in_port(i).get_connection().set_destination(concat.in_port(1))
|
||||
|
||||
concat.out_port(0).get_connection().set_destination(node.in_port(i))
|
||||
|
||||
@staticmethod
|
||||
def extend_inputs(node: Node, num_insertions: int):
|
||||
graph = node.graph
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
|
||||
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
|
||||
blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
|
||||
'value': int64_array(blank_values_arr)}).create_node()
|
||||
|
||||
if i == 3 and node.in_port(3).disconnected():
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
|
||||
# concat already exists
|
||||
concat = node.in_port(i).get_source().node
|
||||
last_in_port = max(concat.in_ports().keys())
|
||||
assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \
|
||||
'should be connected'. \
|
||||
format(concat.soft_get('name', node.id))
|
||||
|
||||
concat.add_input_port(last_in_port + 1)
|
||||
concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0))
|
||||
else:
|
||||
# have to create concat
|
||||
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
|
||||
'in_ports_count': 2}).create_node()
|
||||
node.in_port(i).get_connection().set_destination(concat.in_port(0))
|
||||
concat.in_port(1).connect(blank_values_node.out_port(0))
|
||||
concat.out_port(0).get_connection().set_destination(node.in_port(i))
|
||||
|
||||
@staticmethod
|
||||
def normalize_slices_attr(node: Node):
|
||||
# removes negative starts, ends and magic numbers from 'slice' attr which is used by ConvertGroupedStridedSlice
|
||||
slice_rank = len(node['slices'])
|
||||
data_shape = node.in_port(0).data.get_shape()
|
||||
|
||||
node_name = node.soft_get('name', node.id)
|
||||
if node.is_in_port_connected(3):
|
||||
strides = node.in_port(3).data.get_value()
|
||||
if strides is None:
|
||||
raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name))
|
||||
else:
|
||||
strides = np.ones(slice_rank)
|
||||
|
||||
num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1
|
||||
res_slices = []
|
||||
|
||||
in_idx = 0
|
||||
for i, s in enumerate(node['slices']):
|
||||
if node.new_axis_mask[i]:
|
||||
res_slices.append(slice(0, 1, 1))
|
||||
elif node.shrink_axis_mask[i]:
|
||||
res_slices.append(slice(s, s + 1, strides[i])) # need strides if shrink index is negative
|
||||
elif node.ellipsis_mask[i]:
|
||||
for idx in range(num_ellipsis_inserts):
|
||||
res_slices.append(slice(0, data_shape[in_idx], 1))
|
||||
in_idx += 1
|
||||
else:
|
||||
res_slices.append(s)
|
||||
|
||||
if not (node.new_axis_mask[i] or node.ellipsis_mask[i]):
|
||||
res_slices[-1] = slice(*res_slices[-1].indices(data_shape[in_idx])) # convert negative begins/ends
|
||||
in_idx += 1
|
||||
node.slices = np.array(res_slices)
|
1810
model-optimizer/extensions/middle/StridedSliceNormalizer_test.py
Normal file
1810
model-optimizer/extensions/middle/StridedSliceNormalizer_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,158 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
def tf_strided_slice_infer(node):
|
||||
if node.in_node(1).value is None or node.in_node(2).value is None:
|
||||
raise Error('Strided slice layer supports only constant begin and end inputs')
|
||||
begin_id = node.in_node(1).value.copy()
|
||||
end_id = node.in_node(2).value.copy()
|
||||
if len(node.in_nodes()) > 3:
|
||||
if node.in_node(3).value is None:
|
||||
raise Error('Strided slice layer supports only constant stride input')
|
||||
stride = node.in_node(3).value
|
||||
else:
|
||||
stride = []
|
||||
|
||||
shape = node.in_node(0).shape
|
||||
|
||||
if shape is None or any([x < 0 for x in shape]):
|
||||
return
|
||||
|
||||
convert_negative_indices(begin_id, shape)
|
||||
convert_negative_indices(end_id, shape)
|
||||
|
||||
slice_idx = []
|
||||
dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
|
||||
len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
|
||||
len(node.begin_mask), len(node.end_mask)]))
|
||||
|
||||
# make mask correct length
|
||||
def extend_mask(in_mask, fin_len, zeros=True):
|
||||
mask = list(in_mask)
|
||||
if len(mask) < fin_len:
|
||||
if zeros:
|
||||
mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
|
||||
else:
|
||||
mask.extend(np.ones(dims-len(mask), dtype=np.int32))
|
||||
return np.array(mask, dtype=np.int32)
|
||||
|
||||
for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
|
||||
node[mask] = extend_mask(node[mask], dims)
|
||||
node.begin_mask = extend_mask(node.begin_mask, dims, False)
|
||||
node.end_mask = extend_mask(node.end_mask, dims, False)
|
||||
|
||||
old_idx = 0
|
||||
ellips_ext = 0
|
||||
id_em = 0
|
||||
for idx in range(dims):
|
||||
if node.new_axis_mask[idx]:
|
||||
slice_idx.append(np.newaxis)
|
||||
elif node.ellipsis_mask[idx]:
|
||||
ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
|
||||
id_em = idx
|
||||
for i in range(0, ellips_ext):
|
||||
slice_idx.append(slice(0, shape[old_idx], 1))
|
||||
old_idx = old_idx + 1
|
||||
else:
|
||||
s = stride[idx] if len(stride) > idx else 1
|
||||
def_beg = 0 if s > 0 else -1
|
||||
def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
|
||||
l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
|
||||
r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
|
||||
|
||||
# Check shrink_axis_mask
|
||||
if node.shrink_axis_mask[idx] and idx < len(shape):
|
||||
slice_idx.append(slice(l, l+1, s))
|
||||
else:
|
||||
slice_idx.append(slice(l, r, s))
|
||||
old_idx = old_idx + 1
|
||||
|
||||
value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
|
||||
# fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
|
||||
# `arr[tuple(seq)]` instead of `arr[seq]`"
|
||||
value = value[tuple(slice_idx)]
|
||||
|
||||
for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
|
||||
if flag:
|
||||
if ellips_ext > 0 and idx > id_em:
|
||||
idx = idx + ellips_ext - 1
|
||||
try:
|
||||
value = np.squeeze(value, idx)
|
||||
except ValueError:
|
||||
# ignore this error
|
||||
continue
|
||||
|
||||
for i, s in enumerate(slice_idx):
|
||||
if s is None:
|
||||
slice_idx[i] = slice(0, 1, 1)
|
||||
|
||||
node['slices'] = np.array(slice_idx)
|
||||
for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
|
||||
node[attr] = np.array(node[attr], dtype=np.int32)
|
||||
|
||||
node['force_precision_in_ports'] = {port: 'int64' for port in range(1, len(node.in_nodes()))}
|
||||
|
||||
node.out_node().value = value.copy() if node.in_node(0).value is not None else None
|
||||
node.out_node().shape = np.array(value.shape, dtype=np.int64)
|
||||
|
||||
|
||||
def convert_negative_indices(indices: np.array, shape: np.array):
|
||||
for ind, value in enumerate(indices):
|
||||
if value < 0:
|
||||
indices[ind] += shape[ind]
|
||||
|
||||
|
||||
def mxnet_slice_axis_infer(node):
|
||||
in_shape = node.in_node(0).shape
|
||||
node.axis = get_canonical_axis_index(in_shape, node.axis)
|
||||
slice_axis = node.axis
|
||||
|
||||
new_shape = np.array(in_shape, dtype=np.int64)
|
||||
new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
|
||||
|
||||
axis_size = in_shape[slice_axis]
|
||||
if node.offset < 0:
|
||||
node.offset += axis_size
|
||||
|
||||
if not node.dim:
|
||||
node.dim = axis_size
|
||||
elif node.dim < 0:
|
||||
node.dim += axis_size
|
||||
|
||||
input_dim = in_shape.size
|
||||
node.dim = (node.dim - node.offset)
|
||||
if node.dim > in_shape[slice_axis]:
|
||||
raise Error(
|
||||
'{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
|
||||
'\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
|
||||
'\nTo overcome, try to edit the original model "end" property of the {0} layer.',
|
||||
node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
|
||||
)
|
||||
|
||||
for i in range(0, input_dim):
|
||||
if i == slice_axis:
|
||||
new_shape[i] = node.dim
|
||||
else:
|
||||
new_shape[i] = in_shape[i]
|
||||
|
||||
for i in range(0, len(node.out_nodes())):
|
||||
node.out_node(i)['shape'] = new_shape
|
@ -1,278 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.slice import tf_strided_slice_infer, convert_negative_indices, mxnet_slice_axis_infer
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
nodes_attributes = {'node_1': {'value': None, 'kind': 'data'},
|
||||
'Slice_node': {'type': 'Slice', 'kind': 'op'},
|
||||
'node_2': {'value': None, 'kind': 'data'},
|
||||
'node_3': {'value': None, 'kind': 'data'},
|
||||
'node_4': {'value': None, 'kind': 'data'},
|
||||
# StridedSlice node with attrs
|
||||
'sslice_input': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice'},
|
||||
'sslice_begin_1': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'sslice_end_1': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'sslice_stride_1': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'sslice_data_1': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
# TF slice
|
||||
'tf_slice_input': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'tf_slice_begin': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'tf_slice_size': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'tf_slice': {'kind': 'op'},
|
||||
'tf_slice_output': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'op_output': {'kind': 'op', 'op': 'Result'},
|
||||
'op_output_1': {'kind': 'op', 'op': 'Result'},
|
||||
'op_output_2': {'kind': 'op', 'op': 'Result'}
|
||||
}
|
||||
|
||||
tf_slice_edges = [('tf_slice_input', 'tf_slice'), ('tf_slice_begin', 'tf_slice'), ('tf_slice_size', 'tf_slice'),
|
||||
('tf_slice', 'tf_slice_output')]
|
||||
|
||||
|
||||
class TestTFStridedSliceInfer(unittest.TestCase):
|
||||
def build_test_graph2(self):
|
||||
return build_graph(nodes_attributes,
|
||||
[('sslice_input', 'sslice_1'),
|
||||
('sslice_begin_1', 'sslice_1'),
|
||||
('sslice_end_1', 'sslice_1'),
|
||||
('sslice_stride_1', 'sslice_1'),
|
||||
('sslice_1', 'sslice_data_1'),
|
||||
('sslice_data_1', 'op_output')
|
||||
],
|
||||
{
|
||||
'sslice_input': {'value': np.array([1, 34, 34, 62]),
|
||||
'shape': np.array([3])},
|
||||
'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
|
||||
'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
|
||||
'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
|
||||
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
|
||||
'begin_mask': [1], 'end_mask': [1]},
|
||||
})
|
||||
|
||||
def build_test_graph(self):
|
||||
return build_graph(nodes_attributes,
|
||||
[('sslice_input', 'sslice_1'),
|
||||
('sslice_begin_1', 'sslice_1'),
|
||||
('sslice_end_1', 'sslice_1'),
|
||||
('sslice_stride_1', 'sslice_1'),
|
||||
('sslice_1', 'sslice_data_1'),
|
||||
('sslice_data_1', 'op_output')
|
||||
],
|
||||
{
|
||||
'sslice_input': {'value': None, 'shape': np.array([1, 35, 35, 3])},
|
||||
'sslice_begin_1': {'value': np.array([0, 0, 0, 0]), 'shape': np.array([4])},
|
||||
'sslice_end_1': {'value': np.array([1, 34, 30, 2]), 'shape': np.array([4])},
|
||||
'sslice_stride_1': {'value': np.array([1, 1, 1, 1]),
|
||||
'shape': np.array([4])},
|
||||
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
|
||||
'begin_mask': [1], 'end_mask': [1]},
|
||||
})
|
||||
|
||||
def build_test_graph_dim_beg(self):
|
||||
return build_graph(nodes_attributes,
|
||||
[('sslice_input', 'sslice_1'),
|
||||
('sslice_begin_1', 'sslice_1'),
|
||||
('sslice_end_1', 'sslice_1'),
|
||||
('sslice_stride_1', 'sslice_1'),
|
||||
('sslice_1', 'sslice_data_1'),
|
||||
('sslice_data_1', 'op_output')
|
||||
],
|
||||
{
|
||||
'sslice_input': {'value': np.array([[1, 34, 34, 62]]),
|
||||
'shape': np.array([1, 4])},
|
||||
'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
|
||||
'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
|
||||
'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
|
||||
'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
|
||||
'begin_mask': [1], 'end_mask': [1]},
|
||||
})
|
||||
|
||||
def test_slice_infer_1(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_2(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.end_mask = [1, 0, 0, 1] # 6
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_3(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(1).value = np.array([0, 10, 10, 0])
|
||||
node.end_mask = [1, 0, 0, 1] # 6
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 25, 25, 2])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_4(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(1).value = np.array([0, 10, 10, 0])
|
||||
node.begin_mask = [1, 0, 0, 1] # 6
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_5(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(1).value = np.array([0, 10, 10, 0])
|
||||
node.begin_mask = [0, 0, 0, 0] # 15
|
||||
node.end_mask = [0, 0, 0, 0] # 15
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 3])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_6(self):
|
||||
graph = self.build_test_graph2()
|
||||
node = Node(graph, 'sslice_1')
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output value detected')
|
||||
|
||||
def test_slice_infer_7(self):
|
||||
graph = self.build_test_graph2()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(1).value = np.array([1])
|
||||
node.in_node(2).value = np.array([3])
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
|
||||
|
||||
def test_slice_infer_8(self):
|
||||
graph = self.build_test_graph2()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.new_axis_mask = [1]
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 4])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])),
|
||||
'Wrong output value detected')
|
||||
|
||||
def test_slice_infer_9(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.begin_mask = [0, 0, 0, 0] # 15
|
||||
node.end_mask = [0, 0, 0, 0] # 15
|
||||
node.shrink_axis_mask = [1]
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 3])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_10(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.begin_mask = [0, 0, 0, 0] # 15
|
||||
node.end_mask = [0, 0, 0, 0] # 15
|
||||
node.shrink_axis_mask = [1, 0, 0, 0]
|
||||
node.new_axis_mask = [0, 0, 0, 1] # 8
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 1, 3])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_11(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.begin_mask = [0, 0, 0, 0] # 15
|
||||
node.end_mask = [0, 0, 0, 0] # 15
|
||||
node.shrink_axis_mask = [1, 0, 1, 0] # 5
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 3])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_12(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.begin_mask = [0, 0, 0, 0] # 15
|
||||
node.end_mask = [0, 0, 0, 0] # 15
|
||||
node.shrink_axis_mask = [1, 1, 1, 0] # 7
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_13(self):
|
||||
graph = self.build_test_graph2()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(1).value = np.array([1])
|
||||
node.shrink_axis_mask = [1]
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_14(self):
|
||||
graph = self.build_test_graph2()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.in_node(3).value = np.array([-1])
|
||||
node.end_mask = [0]
|
||||
node.begin_mask = [0]
|
||||
node.in_node(0).shape = [4]
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
|
||||
print(node.out_node().value)
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_dim_beg(self):
|
||||
graph = self.build_test_graph_dim_beg()
|
||||
node = Node(graph, 'sslice_1')
|
||||
node.shrink_axis_mask = [1]
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
|
||||
|
||||
def test_slice_infer_neg_end(self):
|
||||
graph = self.build_test_graph()
|
||||
node = Node(graph, 'sslice_1')
|
||||
end_node = Node(graph, 'sslice_end_1')
|
||||
end_node.value = np.array([1, -1, -5, -1])
|
||||
tf_strided_slice_infer(node)
|
||||
self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
|
||||
self.assertTrue(np.array_equal(end_node.value, np.array([1, -1, -5, -1])), 'Negative values in end were converted to positive')
|
||||
|
||||
|
||||
class TestConvertNegativeIndices(unittest.TestCase):
|
||||
def test_convert_negative_indices(self):
|
||||
dimensions = np.array([3, 4, 8, 10])
|
||||
indices = np.array([2, 0, -3, -4])
|
||||
convert_negative_indices(indices, dimensions)
|
||||
self.assertTrue(np.array_equal(indices, np.array([2, 0, 5, 6])), 'Wrong dimension indices')
|
||||
|
||||
|
||||
class TestMXNetSliceAxisInfer(unittest.TestCase):
|
||||
def test_slice_axis_infer_layer(self):
|
||||
graph = build_graph(
|
||||
{'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
|
||||
'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
|
||||
'kind': 'op', 'op': 'slice_axis', },
|
||||
'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
|
||||
},
|
||||
[
|
||||
('node_1', 'slice_axis_node'),
|
||||
('slice_axis_node', 'node_3'),
|
||||
],
|
||||
{
|
||||
'node_1': {'shape': np.array([1, 1024, 19, 19])},
|
||||
'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
|
||||
})
|
||||
|
||||
slice_axis_node = Node(graph, 'slice_axis_node')
|
||||
mxnet_slice_axis_infer(slice_axis_node)
|
||||
res_shape = [1, 15, 19, 19]
|
||||
for i in range(0, len(graph.node['node_3']['shape'])):
|
||||
self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,7 +15,7 @@
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
from typing import Iterable
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -113,4 +113,34 @@ def broadcast_shape(first_shape, second_shape):
|
||||
assert a_val == 1 or b_val == 1 or a_val == b_val, "Input shape do not broadcast"
|
||||
new_val = b_val if a_val == 1 else a_val
|
||||
new_shape[-i - 1] = new_val
|
||||
return int64_array(new_shape)
|
||||
return int64_array(new_shape)
|
||||
|
||||
|
||||
def get_shape_from_slice(input_shape: np.ndarray, slices: List) -> np.ndarray:
|
||||
"""
|
||||
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
|
||||
Is introduced to prevent potentially large memory consumption.
|
||||
"""
|
||||
output_shape = []
|
||||
num_new_axes = np.count_nonzero(list(map(lambda x: x is np.newaxis, slices)))
|
||||
num_ellipsis_inserts = len(input_shape) - len(slices) + num_new_axes + 1
|
||||
|
||||
in_idx = 0
|
||||
for i, s in enumerate(slices):
|
||||
if isinstance(s, slice):
|
||||
output_shape.append(len(range(*s.indices(input_shape[in_idx]))))
|
||||
in_idx += 1
|
||||
elif s is np.newaxis:
|
||||
output_shape.append(1)
|
||||
elif isinstance(s, int): # shrink_axis
|
||||
in_idx += 1
|
||||
elif s is Ellipsis:
|
||||
for idx in range(num_ellipsis_inserts):
|
||||
output_shape.append(input_shape[in_idx])
|
||||
in_idx += 1
|
||||
else:
|
||||
raise Exception('Element type of a slice List is unacceptable. '
|
||||
'Allowed types are: Ellipsis, slice, int, and None. Instead got: '. format(type(s)))
|
||||
for i in range(in_idx, len(input_shape)):
|
||||
output_shape.append(input_shape[i])
|
||||
return int64_array(output_shape)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,7 +14,10 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.slice import mxnet_slice_axis_infer
|
||||
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
||||
import numpy as np
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
def slice_axis_ext(attrs):
|
||||
axis = attrs.int("axis", 0)
|
||||
@ -29,3 +32,40 @@ def slice_axis_ext(attrs):
|
||||
'infer': mxnet_slice_axis_infer
|
||||
}
|
||||
return node_attrs
|
||||
|
||||
|
||||
def mxnet_slice_axis_infer(node):
|
||||
in_shape = node.in_port(0).data.get_shape()
|
||||
node.axis = get_canonical_axis_index(in_shape, node.axis)
|
||||
slice_axis = node.axis
|
||||
|
||||
new_shape = np.array(in_shape, dtype=np.int64)
|
||||
new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
|
||||
|
||||
axis_size = in_shape[slice_axis]
|
||||
if node.offset < 0:
|
||||
node.offset += axis_size
|
||||
|
||||
if not node.dim:
|
||||
node.dim = axis_size
|
||||
elif node.dim < 0:
|
||||
node.dim += axis_size
|
||||
|
||||
input_dim = in_shape.size
|
||||
node.dim = (node.dim - node.offset)
|
||||
if node.dim > in_shape[slice_axis]:
|
||||
raise Error(
|
||||
'{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
|
||||
'\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
|
||||
'\nTo overcome, try to edit the original model "end" property of the {0} layer.',
|
||||
node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
|
||||
)
|
||||
|
||||
for i in range(0, input_dim):
|
||||
if i == slice_axis:
|
||||
new_shape[i] = node.dim
|
||||
else:
|
||||
new_shape[i] = in_shape[i]
|
||||
|
||||
for i in range(0, len(node.out_nodes())):
|
||||
node.out_node(i)['shape'] = new_shape
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,6 +16,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.mxnet.extractors.slice_axis import mxnet_slice_axis_infer
|
||||
from mo.front.mxnet.extractors.slice_axis import slice_axis_ext
|
||||
from mo.front.mxnet.extractors.utils import AttrDictionary
|
||||
from mo.graph.graph import Node
|
||||
@ -49,3 +52,27 @@ class TestMXNetSliceAxisExtractorOp(unittest.TestCase):
|
||||
|
||||
for key in exp_attrs.keys():
|
||||
self.assertEqual(res[key], exp_attrs[key])
|
||||
|
||||
|
||||
class TestMXNetSliceAxisInfer(unittest.TestCase):
|
||||
def test_slice_axis_infer_layer(self):
|
||||
graph = build_graph(
|
||||
{'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
|
||||
'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
|
||||
'kind': 'op', 'op': 'slice_axis', },
|
||||
'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
|
||||
},
|
||||
[
|
||||
('node_1', 'slice_axis_node'),
|
||||
('slice_axis_node', 'node_3'),
|
||||
],
|
||||
{
|
||||
'node_1': {'shape': np.array([1, 1024, 19, 19])},
|
||||
'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
|
||||
})
|
||||
|
||||
slice_axis_node = Node(graph, 'slice_axis_node')
|
||||
mxnet_slice_axis_infer(slice_axis_node)
|
||||
res_shape = [1, 15, 19, 19]
|
||||
for i in range(0, len(graph.node['node_3']['shape'])):
|
||||
self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -126,20 +126,49 @@ def order(op_node: Node, port_info: str, input_port: int):
|
||||
op_node['need_shape_inference'] = True
|
||||
|
||||
|
||||
def shape(op_node: Node, port_info: str, input_port: int):
|
||||
graph = op_node.graph
|
||||
def strided_slice(op_node: Node, port_info: str, input_port: int):
|
||||
"""
|
||||
StridedSLice must be permuted even if input or output tensors have rank lesser than 4
|
||||
e.g. input_shape = (1, 10, 10), out = input[:, 0:10, :, new_axis], input_rank < 4
|
||||
input_shape = (1, 10, 10, 3), out = input[:, 0:5, 0:4, 0], output_rank < 4
|
||||
in both examples slice_rank is >= 4
|
||||
slice_rank is defined by length of begin, end, strides (they all are of the same length)
|
||||
"""
|
||||
permutation_data_node = get_node_with_permutation(op_node, port_info)
|
||||
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
|
||||
'port_info "{}".'.format(permutation_data_node.id,
|
||||
op_node.id, port_info)
|
||||
permutation = permutation_data_node.permutation
|
||||
if len(permutation.perm) == 0:
|
||||
permute_indices_for_gather = permutation_data_node.permutation.perm
|
||||
if len(permute_indices_for_gather) == 0:
|
||||
return
|
||||
from mo.ops.op import PermuteAttrs
|
||||
|
||||
slice_rank = op_node.in_port(input_port).data.get_shape()[0] # length of begin, end or strides
|
||||
permute_indices_for_gather = PermuteAttrs.get_nhwc_to_nchw_permutation(slice_rank).perm
|
||||
reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
|
||||
|
||||
|
||||
def shape(op_node: Node, port_info: str, input_port: int):
|
||||
permutation_data_node = get_node_with_permutation(op_node, port_info)
|
||||
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
|
||||
'port_info "{}".'.format(permutation_data_node.id,
|
||||
op_node.id, port_info)
|
||||
permute_indices_for_gather = permutation_data_node.permutation.perm
|
||||
if len(permute_indices_for_gather) == 0:
|
||||
return
|
||||
reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
|
||||
|
||||
|
||||
def reorder_inputs_for_shape_or_slice(op_node: Node, input_port: int, permute_indices_for_gather: list):
|
||||
"""
|
||||
axis and slice permutations are almost the same the only difference is that for slice in general
|
||||
case permutation depends from slice_rank not from input_rank or output_rank
|
||||
"""
|
||||
graph = op_node.graph
|
||||
data_node = op_node.in_node(input_port)
|
||||
|
||||
gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
|
||||
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
|
||||
const = Const(graph, {'value': permute_indices_for_gather, 'name': gather_name + '/const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': gather_name,
|
||||
@ -191,10 +220,9 @@ def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int):
|
||||
|
||||
|
||||
class PermuteInputs:
|
||||
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)
|
||||
|
||||
input_permutes = {
|
||||
'axis': common_inv_permutation,
|
||||
'axis': lambda node, port_info, input_port: axis(node, port_info, input_port),
|
||||
'slice': lambda node, port_info, input_port: strided_slice(node, port_info, input_port),
|
||||
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
|
||||
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
|
||||
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
|
||||
@ -202,16 +230,27 @@ class PermuteInputs:
|
||||
input_port),
|
||||
}
|
||||
|
||||
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):
|
||||
shape_check_rules = {
|
||||
'rank': lambda port: bool(len(port.data.get_shape()) >= 4),
|
||||
'dim_size': lambda port: bool(port.data.get_shape()[0] >= 4), # if input 'dim_size' >= 4 need to permute
|
||||
}
|
||||
|
||||
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str,
|
||||
shape_check_rule: str = 'rank'):
|
||||
"""
|
||||
Sets input permutation attribute on the edge between node1 and node2.
|
||||
Input permutation consists of function that perform input permutation and
|
||||
input port info 'input' or 'output' + <port_number> that points on the input with PermuteAttr.Permutation which
|
||||
current input depends on
|
||||
current input depends on.
|
||||
|
||||
shape_check_rule defines the check rule if the op node inputs need to be permuted.
|
||||
By default 'rank' rule is applied, 'dim_size' is used only for StridedSlice so far.
|
||||
"""
|
||||
assert permutation_rule in self.input_permutes, 'No `{}` permutation rule in {}'.format(permutation_rule,
|
||||
__class__.__name__)
|
||||
assert shape_check_rule in self.shape_check_rules, 'No `{}` permutation shape check rule ' \
|
||||
'in {}'.format(shape_check_rule, __class__.__name__)
|
||||
nx.set_edge_attributes(G=node1.graph,
|
||||
values={(node1.id, node2.id, 0): (self.input_permutes[permutation_rule],
|
||||
port_info)},
|
||||
values={(node1.id, node2.id, 0): (self.input_permutes[permutation_rule], port_info,
|
||||
self.shape_check_rules[shape_check_rule])},
|
||||
name='input_permutation')
|
||||
|
@ -340,6 +340,8 @@ class PermuteAttrs:
|
||||
Attr = namedtuple('Attr', ['name', 'port', 'func'])
|
||||
|
||||
common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
|
||||
slice_permutation = lambda node, permutation, attr: node[attr][ # doesn't depend from permutation variable
|
||||
PermuteAttrs.get_nhwc_to_nchw_permutation(len(node[attr])).perm]
|
||||
common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]
|
||||
|
||||
# List of default permutations
|
||||
@ -354,9 +356,11 @@ class PermuteAttrs:
|
||||
'dilation': common_permutation,
|
||||
'kernel_shape': common_permutation,
|
||||
'output_shape': common_permutation,
|
||||
'slices': common_permutation,
|
||||
'shrink_axis_mask': common_permutation,
|
||||
'new_axis_mask': common_permutation,
|
||||
'begin_mask': slice_permutation,
|
||||
'end_mask': slice_permutation,
|
||||
'shrink_axis_mask': slice_permutation,
|
||||
'new_axis_mask': slice_permutation,
|
||||
'ellipsis_mask': slice_permutation,
|
||||
'axes': common_permutation_inv,
|
||||
'axis': common_permutation_inv,
|
||||
'batch_dims': common_permutation_inv,
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,10 +13,10 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import get_shape_from_slice
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op
|
||||
from mo.utils.error import Error
|
||||
@ -157,20 +157,9 @@ class Slice(Op):
|
||||
# Ranged for output value for specified axis
|
||||
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
|
||||
if input_value is None:
|
||||
output_shape = get_shape_after_slice(input_shape, slice_idx)
|
||||
output_shape = get_shape_from_slice(input_shape, slice_idx)
|
||||
if np.any(output_shape <= 0):
|
||||
raise Error('Output shape: {} of node "{}" contains non-positive values'.format(output_shape, node.name))
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
else:
|
||||
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])
|
||||
|
||||
|
||||
def get_shape_after_slice(input_shape: np.ndarray, slice_idx: List[slice]) -> np.ndarray:
|
||||
"""
|
||||
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
|
||||
Is introduced to prevent potentially large memory consumption.
|
||||
"""
|
||||
output_shape = np.zeros(len(input_shape), dtype=np.int32)
|
||||
for i, s in enumerate(slice_idx):
|
||||
output_shape[i] = len(range(*s.indices(input_shape[i])))
|
||||
return output_shape
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,58 +14,17 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.slice import tf_strided_slice_infer
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.partial_infer.utils import get_shape_from_slice
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op, PermuteAttrs
|
||||
from mo.ops.op import Op
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.utils import array_to_str
|
||||
|
||||
|
||||
def extend_mask_according_ellipsis(ellipsis_mask, shrink_axis_mask, length_output_shape, attr_mask_extended, ins_value):
|
||||
# ellipsis is set, add dimensions in right place otherwise insert in the end
|
||||
if np.any(ellipsis_mask):
|
||||
idx = np.nonzero(ellipsis_mask)
|
||||
assert len(idx[0]) == 1
|
||||
insert_ind = idx[0][0]
|
||||
else:
|
||||
insert_ind = len(attr_mask_extended) - 1
|
||||
|
||||
ellipse_ext = length_output_shape + np.count_nonzero(shrink_axis_mask) - len(attr_mask_extended)
|
||||
for i in range(0, ellipse_ext):
|
||||
attr_mask_extended.insert(insert_ind + i + 1, ins_value)
|
||||
|
||||
return attr_mask_extended
|
||||
|
||||
|
||||
def permute_array(node: Node, array: np.array):
|
||||
"""
|
||||
This function permutes masks according to permutation parameter. Mask have the same or more length than output
|
||||
"""
|
||||
attr_mask_extended = list(array)
|
||||
|
||||
# If input and output have length of shape 3 and less, no need to permute
|
||||
if len(node.in_port(0).data.get_shape()) < 4 and len(node.out_port(0).data.get_shape()) < 4:
|
||||
return attr_mask_extended
|
||||
|
||||
perm_len = len(node.out_port(0).data.get_shape()) + np.count_nonzero(node.shrink_axis_mask)
|
||||
perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len)
|
||||
perm_list = list(perm.perm)
|
||||
# if mask length is more than output, just add tail that will not be permuted to avoid error
|
||||
for i in range(perm_len, len(attr_mask_extended)):
|
||||
perm_list.append(i)
|
||||
return int64_array(attr_mask_extended)[int64_array(perm_list)]
|
||||
|
||||
|
||||
def permute_masks(node: Node, permutation: PermuteAttrs.Permutation, attr: str):
|
||||
if not node.has_valid(attr):
|
||||
return None
|
||||
|
||||
node[attr] = permute_array(node, node[attr])
|
||||
return node[attr]
|
||||
|
||||
|
||||
class StridedSlice(Op):
|
||||
op = 'StridedSlice'
|
||||
enabled = True
|
||||
@ -79,11 +38,12 @@ class StridedSlice(Op):
|
||||
'out_ports_count': 1,
|
||||
'infer': __class__.infer
|
||||
}, attrs)
|
||||
assert 'new_axis_mask' in attrs, "Attribute 'new_axis_mask' of the StridedSlice node is not given."
|
||||
assert 'shrink_axis_mask' in attrs, "Attribute 'shrink_axis_mask' of the StridedSlice node is not given."
|
||||
assert 'ellipsis_mask' in attrs, "Attribute 'ellipsis_mask' of the StridedSlice node is not given."
|
||||
assert 'begin_mask' in attrs, "Attribute 'begin_mask' of the StridedSlice node is not given."
|
||||
assert 'end_mask' in attrs, "Attribute 'end_mask' of the StridedSlice node is not given."
|
||||
for mask_name in StridedSlice.get_mask_names():
|
||||
assert mask_name in attrs, 'Attribute {} of the StridedSlice node is not given.'.format(mask_name)
|
||||
|
||||
@staticmethod
|
||||
def get_mask_names():
|
||||
return ['begin_mask', 'end_mask', 'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask']
|
||||
|
||||
def backend_attrs(self):
|
||||
al = list()
|
||||
@ -91,61 +51,86 @@ class StridedSlice(Op):
|
||||
def convert(attr):
|
||||
return lambda node: array_to_str(node, attr)
|
||||
|
||||
for a in list(['new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask']):
|
||||
for a in StridedSlice.get_mask_names():
|
||||
al.append((a, convert(a)))
|
||||
return al
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
tf_strided_slice_infer(node)
|
||||
begin, end, strides = StridedSlice.validate_inputs_and_get_args(node)
|
||||
|
||||
out_shape = node.out_port(0).data.get_shape()
|
||||
assert out_shape is not None, \
|
||||
'Output shape was not calculated for node {}'.format(node.name)
|
||||
# extend inputs according to ellipsis mask and/or input_shape
|
||||
for i_port in node.in_ports().values():
|
||||
if i_port.idx == 0 or i_port.disconnected():
|
||||
continue
|
||||
old_value = i_port.data.get_value()
|
||||
# additional check for non-const input
|
||||
# error will be return in shape inference if non-const will be added
|
||||
# it is paranoid check for case if shape inference will be changed
|
||||
assert old_value is not None, \
|
||||
'{} input of {} node is not constant: \'value\' attribute for edge ' + \
|
||||
'contains None'.format(i_port.idx, node.name)
|
||||
# insert 0 for begin and end and 1 for stride
|
||||
new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
|
||||
len(out_shape), list(old_value),
|
||||
int(i_port.idx == 3)))
|
||||
# set_value additionally set_shape and propagate value to Const node
|
||||
if not np.array_equal(new_value, old_value):
|
||||
i_port.data.set_value(new_value)
|
||||
StridedSlice.align_mask_with_slice_rank(node, len(begin))
|
||||
|
||||
# extend masks before removing ellipsis
|
||||
for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]:
|
||||
node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
|
||||
len(out_shape), list(node[attr]), 0))
|
||||
data_shape = node.in_port(0).data.get_shape()
|
||||
data_value = node.in_port(0).data.get_value()
|
||||
slices = StridedSlice.get_slices(node, data_shape, begin, end, strides)
|
||||
|
||||
# we will extend all masks and inputs to simplify future transformations
|
||||
idx = np.nonzero(node.ellipsis_mask)
|
||||
node.ellipsis_mask[idx] = 0
|
||||
if data_value is not None:
|
||||
node.out_port(0).data.set_value(data_value[tuple(slices)])
|
||||
else:
|
||||
node.out_port(0).data.set_shape(get_shape_from_slice(data_shape, slices))
|
||||
|
||||
if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None:
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
|
||||
('new_axis_mask', 'input:0', permute_masks),
|
||||
('ellipsis_mask', 'input:0', permute_masks),
|
||||
('begin_mask', 'input:0', permute_masks),
|
||||
('end_mask', 'input:0', permute_masks),
|
||||
])
|
||||
# permute inputs
|
||||
in_shape = node.in_port(0).get_source().data.get_shape()
|
||||
assert in_shape is not None, \
|
||||
'Input shape is unknown for 0 input of node {}'.format(node.name)
|
||||
input_rank = len(in_shape)
|
||||
if input_rank > 3:
|
||||
for i_port in node.in_ports().values():
|
||||
if i_port.idx == 0 or i_port.disconnected():
|
||||
continue
|
||||
new_value = permute_array(node, i_port.data.get_value())
|
||||
# set_value additionally set_shape and propagate value to Const node
|
||||
i_port.data.set_value(new_value)
|
||||
node['slices'] = slices
|
||||
node['force_precision_in_ports'] = {port: 'int64' for port in range(1, len(node.in_nodes()))}
|
||||
|
||||
# StridedSliceNormalizer inserts nodes that change original begin, end, and strides data nodes
|
||||
# and since input permutations are stored in data nodes we end up having permutations
|
||||
# in the wrong place of the graph.
|
||||
# Therefore PermuteInputs will be set after StridedSliceNormalizer.
|
||||
|
||||
@staticmethod
|
||||
def get_slices(node: Node, data_shape: Tuple, begin: np.array, end: np.array, strides: np.array) -> List:
|
||||
input_rank = len(data_shape)
|
||||
slice_rank = len(begin)
|
||||
# from now slices are without ellipsis
|
||||
slices = [[]] * slice_rank
|
||||
in_idx = 0 # index along input tensor shapes, note that input_rank not necessary is equal to slice_rank
|
||||
for i in range(slice_rank):
|
||||
if node.new_axis_mask[i]:
|
||||
slices[i] = np.newaxis
|
||||
elif node.shrink_axis_mask[i]:
|
||||
slices[i] = int(begin[i])
|
||||
if slices[i] < 0: # need for ConvertGroupedStridedSlice
|
||||
slices[i] += int(data_shape[in_idx])
|
||||
elif node.ellipsis_mask[i]:
|
||||
slices[i] = ...
|
||||
in_idx += input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
|
||||
else:
|
||||
start, stop = begin[i], end[i]
|
||||
if not node.begin_mask[i]: # if begin, and end are not specified take the whole range
|
||||
start = None
|
||||
if not node.end_mask[i]:
|
||||
stop = None
|
||||
slices[i] = slice(start, stop, strides[i])
|
||||
in_idx += 1 if not node.new_axis_mask[i] else 0
|
||||
return slices
|
||||
|
||||
@staticmethod
|
||||
def align_mask_with_slice_rank(node: Node, slice_rank: int):
|
||||
# align masks sizes with slice_rank (not confuse with extending, mask_aligment != mask_extending)
|
||||
for mask_name in StridedSlice.get_mask_names():
|
||||
num_insertations = slice_rank - len(node[mask_name])
|
||||
val = 0 if mask_name not in ['begin_mask', 'end_mask'] else 1 # extend with ones only for begin and end
|
||||
node[mask_name] = np.append(node[mask_name], [val] * num_insertations).astype(int)
|
||||
|
||||
@staticmethod
|
||||
def validate_inputs_and_get_args(node: Node) -> (np.ndarray, np.ndarray, np.ndarray):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
begin = node.in_port(1).data.get_value()
|
||||
end = node.in_port(2).data.get_value()
|
||||
|
||||
if begin is None or end is None:
|
||||
raise Error(
|
||||
'StridedSlice operation for node {} supports only constant begin and end inputs'.format(node_name))
|
||||
|
||||
if node.is_in_port_connected(3):
|
||||
strides = node.in_port(3).data.get_value()
|
||||
if strides is None:
|
||||
raise Error(
|
||||
'StridedSlice operation for node {} supports only constant strides input'.format(node_name))
|
||||
else:
|
||||
strides = np.ones_like(begin)
|
||||
assert len(begin) == len(end) == len(strides), \
|
||||
'begin, end, and strides of StridedSlice node {} must be of the same length. Got insted:' \
|
||||
'begin = {}, end = {}, strides = {}'.format(node_name, begin, end, strides)
|
||||
return begin, end, strides
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
from mo.utils.graph import Node
|
||||
from mo.utils.ir_reader.extender import Extender
|
||||
|
||||
@ -24,9 +25,7 @@ class StridedSlice_extender(Extender):
|
||||
|
||||
@staticmethod
|
||||
def extend(op: Node):
|
||||
|
||||
attrs = ['shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask']
|
||||
for attr in attrs:
|
||||
for attr in StridedSlice.get_mask_names():
|
||||
Extender.attr_to_list(op, attr)
|
||||
|
||||
op.begin_mask = int64_array([1 - i for i in op.begin_mask])
|
||||
|
Loading…
Reference in New Issue
Block a user