* Added Caffe Slice_ext * Added TFSlice, AttributedSlice (both with extractors and replacers), corrected SliceConverter and added unittests for all cases * added comments to each type of Slice operation; optimized shape inference; moved mxlice inside of slice.py; renamed slice_replacers * removed type annotation for get_shape_after_slice routine * replaced zeros_like with zeros * Corrected preserving node names, renamed attributes names, added tests fro slice_replacer onnx phase * Renamed slice_replacers.py * added more unittest cases * added type annotations, moved to more relevant place routines for shape calculation, and some other minor corrections * corrected a typo `normalize_slice_indices` comment * corrected shape calculation for Nonconstant inputs * corrected a few typos * corrected type declarations * corrected shape inference with rounding * refactored unit-tests for front transforms of Slice * added error raising for negative and zero shapes * removed magic_num * corrected AttributedSlice, clarified comments * fixed unit-test for AttributedSliceToSlice * typo in type hints corrected * removed supported_attrs * returned back default None for attrs of Slice
224 lines
9.7 KiB
Python
224 lines
9.7 KiB
Python
"""
|
|
Copyright (C) 2018-2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from extensions.ops.elementwise import Add
|
|
from extensions.ops.gather import Gather
|
|
from extensions.ops.range import Range
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
|
from mo.graph.graph import Node
|
|
from mo.graph.port import Port
|
|
from mo.ops.concat import Concat
|
|
from mo.ops.const import Const
|
|
from mo.ops.shape import Shape
|
|
from mo.ops.squeeze import Squeeze
|
|
|
|
|
|
def get_canonical_axis_index_node(rank: Node, axis: int) -> Node:
|
|
"""
|
|
Returns positive axis value
|
|
|
|
:param rank: the node of 0D output shape to get rank of tensor from
|
|
:param axis: integer value from [-rank; rank - 1]
|
|
:return: node producing positive integer value of axis
|
|
"""
|
|
graph = rank.graph
|
|
name = rank.soft_get('name', rank.id)
|
|
if axis < 0:
|
|
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array([axis])}).create_node()
|
|
add = Add(graph, {'name': name + '/positive_axis'}).create_node()
|
|
rank.out_port(0).connect(add.in_port(0))
|
|
axis.out_port(0).connect(add.in_port(1))
|
|
return add
|
|
else:
|
|
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array([axis])}).create_node()
|
|
|
|
|
|
def get_range_node_of_idxs(rank: Node, begin: int, end: int,
|
|
include_begin: bool = True, include_end: bool = False) -> Node:
|
|
"""
|
|
Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point
|
|
|
|
:param rank: the node of 0D output shape to get rank of tensor from
|
|
:param begin: integer value from [-rank; rank - 1]
|
|
:param end: integer value from [-rank; +rank]
|
|
:param include_begin: boolean flag to include or exclude start point from range output
|
|
:param include_end: boolean flag to include or exclude end point from range output
|
|
:return: range node producing 1D output
|
|
"""
|
|
graph = rank.graph
|
|
name = rank.soft_get('name', rank.id)
|
|
|
|
start_idx = get_canonical_axis_index_node(rank, begin)
|
|
end_idx = get_canonical_axis_index_node(rank, end)
|
|
|
|
if not include_begin:
|
|
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
|
|
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
|
|
start_idx.out_port(0).connect(add.in_port(0))
|
|
const.out_port(0).connect(add.in_port(1))
|
|
start_idx = add
|
|
|
|
if include_end:
|
|
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
|
|
add = Add(graph, {'name': name + '/including_end'}).create_node()
|
|
end_idx.out_port(0).connect(add.in_port(0))
|
|
const.out_port(0).connect(add.in_port(1))
|
|
end_idx = add
|
|
|
|
delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node()
|
|
range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()
|
|
|
|
start_idx.out_port(0).connect(range_node.in_port(0))
|
|
end_idx.out_port(0).connect(range_node.in_port(1))
|
|
delta.out_port(0).connect(range_node.in_port(2))
|
|
|
|
return range_node
|
|
|
|
|
|
def get_shape_values_by_indices_node(shape_node: Node, indices_node: Node) -> Node:
|
|
"""
|
|
The function returns a node that produces values of the specified indices node of the input node 'shape_node'
|
|
|
|
:param shape_node: the node of 1D output shape to get elements from
|
|
:param indices_node: the node of 1D output shape with the list of element indices to get
|
|
:return: node producing required elements of the node
|
|
"""
|
|
graph = shape_node.graph
|
|
axis = Const(graph, {'value': int64_array(0), 'name': shape_node.name + '/Axis'}).create_node()
|
|
gather_node = Gather(graph, {'name': shape_node.name + '/Gather'}).create_node()
|
|
|
|
shape_node.out_port(0).connect(gather_node.in_port(0))
|
|
indices_node.out_port(0).connect(gather_node.in_port(1))
|
|
axis.out_port(0).connect(gather_node.in_port(2))
|
|
return gather_node
|
|
|
|
|
|
def node_to_get_shape_value_of_indices(shape_node: Node, indices: list) -> Node:
|
|
"""
|
|
The function returns a node that produces values of the specified indices of the input node 'shape_node'
|
|
|
|
:param shape_node: the node of 1D output shape to get elements from
|
|
:param indices: the list of element indices to get
|
|
:return: node producing required elements of the node
|
|
"""
|
|
graph = shape_node.graph
|
|
indices_node = Const(graph, {'value': int64_array(indices), 'name': shape_node.name + '/Indices'}).create_node()
|
|
|
|
gather_node = get_shape_values_by_indices_node(shape_node, indices_node)
|
|
return gather_node
|
|
|
|
|
|
def get_shape_values_by_range_idxs(shape: Node, rank: Node, begin: int, end: int,
|
|
include_begin: bool = True, include_end: bool = False):
|
|
"""
|
|
Gathers shape values that are represented by range from begin to end (in)/(ex)cluding begin or end point
|
|
|
|
:param shape: the node of 1D output shape to get elements from
|
|
:param rank: the node of 0D output shape to get rank of tensor from
|
|
:param begin: integer value from [-rank; rank - 1]
|
|
:param end: integer value from [-rank; +rank]
|
|
:param include_begin: boolean flag to include or exclude start point from range output
|
|
:param include_end: boolean flag to include or exclude end point from range output
|
|
:return: gather node producing 1D output
|
|
"""
|
|
range_node = get_range_node_of_idxs(rank, begin, end, include_begin=include_begin, include_end=include_end)
|
|
return get_shape_values_by_indices_node(shape, range_node)
|
|
|
|
|
|
def node_to_get_batch_value(shape_node: Node) -> Node:
|
|
"""
|
|
The function returns a node that produces the batch value which is usually the element of the shape with index 0
|
|
:param shape_node: the node of 1D output shape to get batch from
|
|
:return: the node producing batch value
|
|
"""
|
|
return node_to_get_shape_value_of_indices(shape_node, [0])
|
|
|
|
|
|
def node_to_get_features_dimension_value(shape_node: Node) -> Node:
|
|
"""
|
|
The function returns a node that produces the feature dimension value
|
|
:param shape_node: the node of 1D output shape to get the feature dimension value from
|
|
:return: the node producing feature dimension value
|
|
"""
|
|
layout = shape_node.graph.graph['layout']
|
|
if layout == 'NCHW':
|
|
return node_to_get_shape_value_of_indices(shape_node, [1])
|
|
elif layout == 'NHWC':
|
|
return node_to_get_shape_value_of_indices(shape_node, [-1])
|
|
else:
|
|
assert 'Unsupported layout "{}"'.format(layout)
|
|
|
|
|
|
def node_to_get_spatial_dimensions_value(shape_node: Node) -> Node:
|
|
"""
|
|
The function returns a node that produces the spatial dimension values
|
|
:param shape_node: the node of 1D output shape to get the spatial dimension values from
|
|
:return: the node producing the spatial dimension values
|
|
"""
|
|
layout = shape_node.graph.graph['layout']
|
|
shape = shape_node.in_port(0).get_connection().get_source().data.get_shape()
|
|
assert shape is not None, 'The shape must be inferred before running this function'
|
|
|
|
if layout == 'NCHW':
|
|
return node_to_get_shape_value_of_indices(shape_node, list(range(2, len(shape))))
|
|
elif layout == 'NHWC':
|
|
return node_to_get_shape_value_of_indices(shape_node, list(range(1, len(shape) - 1)))
|
|
else:
|
|
assert 'Unsupported layout "{}"'.format(layout)
|
|
|
|
|
|
def new_shape_node_from_shape_nodes(input_shape_nodes: list):
|
|
"""
|
|
The function returns a node producing 1D tensor with concatenated shapes produced by nodes from "input_shape_nodes"
|
|
:param input_shape_nodes: list of nodes producing 1D tensors
|
|
:return: the node producing concatenated values of nodes from the "input_shape_nodes"
|
|
"""
|
|
assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty'
|
|
new_shape_node = Concat(input_shape_nodes[0].graph,
|
|
{'axis': 0,
|
|
'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'}
|
|
).create_node()
|
|
|
|
for ind, input_node in enumerate(input_shape_nodes):
|
|
new_shape_node.add_input_port(ind)
|
|
new_shape_node.in_port(ind).connect(input_node.out_port(0))
|
|
return new_shape_node
|
|
|
|
|
|
def get_shape_and_rank_nodes_by_port(port: Port, return_as_a_scalar: bool = True):
|
|
"""
|
|
The function returns nodes producing shape and rank of the data from the desired port in order to use those
|
|
operations on the middle/back phase
|
|
:param port: Port object that specifies node output port
|
|
:param return_as_a_scalar: boolean flag to return 1D or 0D rank
|
|
:return: shape and rank nodes
|
|
"""
|
|
input_node_name = port.node.soft_get('name', port.node.id)
|
|
graph = port.node.graph
|
|
|
|
shape = Shape(graph, dict(name=input_node_name + '/ShapeOf')).create_node()
|
|
rank_1_d = Shape(graph, dict(name=input_node_name + '/1dRankOf')).create_node()
|
|
rank_1_d.in_port(0).connect(shape.out_port(0))
|
|
shape.in_port(0).connect(port)
|
|
if not return_as_a_scalar:
|
|
return shape, rank_1_d
|
|
|
|
rank = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), {'name': input_node_name + '/0dRankOf'},
|
|
rank_1_d)
|
|
return shape, rank
|
|
|