Files
openvino/model-optimizer/mo/utils/shape.py
Pavel Esir 75d2d88b61 Reshape able slice (#1241)
* Added Caffe Slice_ext

* Added TFSlice, AttributedSlice (both with extractors and replacers), corrected SliceConverter and added unittests for all cases

* added comments to each type of Slice operation; optimized shape inference; moved mxlice inside of slice.py; renamed slice_replacers

* removed type annotation for get_shape_after_slice routine

* replaced zeros_like with zeros

* Corrected preserving node names, renamed attributes names, added tests fro slice_replacer onnx phase

* Renamed slice_replacers.py

* added more unittest cases

* added type annotations, moved to more relevant place routines for shape calculation, and some other minor corrections

* corrected a typo `normalize_slice_indices` comment

* corrected shape calculation for Nonconstant inputs

* corrected a few typos

* corrected type declarations

* corrected shape inference with rounding

* refactored unit-tests for front transforms of Slice

* added error raising for negative and zero shapes

* removed magic_num

* corrected AttributedSlice, clarified comments

* fixed unit-test for AttributedSliceToSlice

* typo in type hints corrected

* removed supported_attrs

* returned back default None for attrs of Slice
2020-08-10 12:19:08 +03:00

224 lines
9.7 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.elementwise import Add
from extensions.ops.gather import Gather
from extensions.ops.range import Range
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Node
from mo.graph.port import Port
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.squeeze import Squeeze
def get_canonical_axis_index_node(rank: Node, axis: int) -> Node:
"""
Returns positive axis value
:param rank: the node of 0D output shape to get rank of tensor from
:param axis: integer value from [-rank; rank - 1]
:return: node producing positive integer value of axis
"""
graph = rank.graph
name = rank.soft_get('name', rank.id)
if axis < 0:
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array([axis])}).create_node()
add = Add(graph, {'name': name + '/positive_axis'}).create_node()
rank.out_port(0).connect(add.in_port(0))
axis.out_port(0).connect(add.in_port(1))
return add
else:
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array([axis])}).create_node()
def get_range_node_of_idxs(rank: Node, begin: int, end: int,
include_begin: bool = True, include_end: bool = False) -> Node:
"""
Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point
:param rank: the node of 0D output shape to get rank of tensor from
:param begin: integer value from [-rank; rank - 1]
:param end: integer value from [-rank; +rank]
:param include_begin: boolean flag to include or exclude start point from range output
:param include_end: boolean flag to include or exclude end point from range output
:return: range node producing 1D output
"""
graph = rank.graph
name = rank.soft_get('name', rank.id)
start_idx = get_canonical_axis_index_node(rank, begin)
end_idx = get_canonical_axis_index_node(rank, end)
if not include_begin:
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
start_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
start_idx = add
if include_end:
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
add = Add(graph, {'name': name + '/including_end'}).create_node()
end_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
end_idx = add
delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node()
range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()
start_idx.out_port(0).connect(range_node.in_port(0))
end_idx.out_port(0).connect(range_node.in_port(1))
delta.out_port(0).connect(range_node.in_port(2))
return range_node
def get_shape_values_by_indices_node(shape_node: Node, indices_node: Node) -> Node:
"""
The function returns a node that produces values of the specified indices node of the input node 'shape_node'
:param shape_node: the node of 1D output shape to get elements from
:param indices_node: the node of 1D output shape with the list of element indices to get
:return: node producing required elements of the node
"""
graph = shape_node.graph
axis = Const(graph, {'value': int64_array(0), 'name': shape_node.name + '/Axis'}).create_node()
gather_node = Gather(graph, {'name': shape_node.name + '/Gather'}).create_node()
shape_node.out_port(0).connect(gather_node.in_port(0))
indices_node.out_port(0).connect(gather_node.in_port(1))
axis.out_port(0).connect(gather_node.in_port(2))
return gather_node
def node_to_get_shape_value_of_indices(shape_node: Node, indices: list) -> Node:
"""
The function returns a node that produces values of the specified indices of the input node 'shape_node'
:param shape_node: the node of 1D output shape to get elements from
:param indices: the list of element indices to get
:return: node producing required elements of the node
"""
graph = shape_node.graph
indices_node = Const(graph, {'value': int64_array(indices), 'name': shape_node.name + '/Indices'}).create_node()
gather_node = get_shape_values_by_indices_node(shape_node, indices_node)
return gather_node
def get_shape_values_by_range_idxs(shape: Node, rank: Node, begin: int, end: int,
include_begin: bool = True, include_end: bool = False):
"""
Gathers shape values that are represented by range from begin to end (in)/(ex)cluding begin or end point
:param shape: the node of 1D output shape to get elements from
:param rank: the node of 0D output shape to get rank of tensor from
:param begin: integer value from [-rank; rank - 1]
:param end: integer value from [-rank; +rank]
:param include_begin: boolean flag to include or exclude start point from range output
:param include_end: boolean flag to include or exclude end point from range output
:return: gather node producing 1D output
"""
range_node = get_range_node_of_idxs(rank, begin, end, include_begin=include_begin, include_end=include_end)
return get_shape_values_by_indices_node(shape, range_node)
def node_to_get_batch_value(shape_node: Node) -> Node:
"""
The function returns a node that produces the batch value which is usually the element of the shape with index 0
:param shape_node: the node of 1D output shape to get batch from
:return: the node producing batch value
"""
return node_to_get_shape_value_of_indices(shape_node, [0])
def node_to_get_features_dimension_value(shape_node: Node) -> Node:
"""
The function returns a node that produces the feature dimension value
:param shape_node: the node of 1D output shape to get the feature dimension value from
:return: the node producing feature dimension value
"""
layout = shape_node.graph.graph['layout']
if layout == 'NCHW':
return node_to_get_shape_value_of_indices(shape_node, [1])
elif layout == 'NHWC':
return node_to_get_shape_value_of_indices(shape_node, [-1])
else:
assert 'Unsupported layout "{}"'.format(layout)
def node_to_get_spatial_dimensions_value(shape_node: Node) -> Node:
"""
The function returns a node that produces the spatial dimension values
:param shape_node: the node of 1D output shape to get the spatial dimension values from
:return: the node producing the spatial dimension values
"""
layout = shape_node.graph.graph['layout']
shape = shape_node.in_port(0).get_connection().get_source().data.get_shape()
assert shape is not None, 'The shape must be inferred before running this function'
if layout == 'NCHW':
return node_to_get_shape_value_of_indices(shape_node, list(range(2, len(shape))))
elif layout == 'NHWC':
return node_to_get_shape_value_of_indices(shape_node, list(range(1, len(shape) - 1)))
else:
assert 'Unsupported layout "{}"'.format(layout)
def new_shape_node_from_shape_nodes(input_shape_nodes: list):
"""
The function returns a node producing 1D tensor with concatenated shapes produced by nodes from "input_shape_nodes"
:param input_shape_nodes: list of nodes producing 1D tensors
:return: the node producing concatenated values of nodes from the "input_shape_nodes"
"""
assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty'
new_shape_node = Concat(input_shape_nodes[0].graph,
{'axis': 0,
'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'}
).create_node()
for ind, input_node in enumerate(input_shape_nodes):
new_shape_node.add_input_port(ind)
new_shape_node.in_port(ind).connect(input_node.out_port(0))
return new_shape_node
def get_shape_and_rank_nodes_by_port(port: Port, return_as_a_scalar: bool = True):
"""
The function returns nodes producing shape and rank of the data from the desired port in order to use those
operations on the middle/back phase
:param port: Port object that specifies node output port
:param return_as_a_scalar: boolean flag to return 1D or 0D rank
:return: shape and rank nodes
"""
input_node_name = port.node.soft_get('name', port.node.id)
graph = port.node.graph
shape = Shape(graph, dict(name=input_node_name + '/ShapeOf')).create_node()
rank_1_d = Shape(graph, dict(name=input_node_name + '/1dRankOf')).create_node()
rank_1_d.in_port(0).connect(shape.out_port(0))
shape.in_port(0).connect(port)
if not return_as_a_scalar:
return shape, rank_1_d
rank = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), {'name': input_node_name + '/0dRankOf'},
rank_1_d)
return shape, rank