Reshape-able SliceConverter (#2954)
* initial commit * add cast * data type fix * added tests * added test without axes and steps * remove redundant imports * discussions resolving * Add cast to TFSliceToSlice * layer tests fix * update unittest
This commit is contained in:
parent
e3b879ad3b
commit
b437387bd5
@ -16,6 +16,7 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Add, Equal
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
@ -74,4 +75,7 @@ class TFSliceToSliceReplacer(FrontReplacementOp):
|
||||
# out of select to end (2nd of slice)
|
||||
select_node.out_port(0).connect(slice_node.in_port(2))
|
||||
|
||||
cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
|
||||
select_node.in_port(2).get_connection().insert_node(cast)
|
||||
|
||||
node.out_port(0).get_connection().set_source(slice_node.out_port(0))
|
||||
|
@ -37,6 +37,7 @@ nodes = {
|
||||
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
|
||||
**regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
|
||||
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
|
||||
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}),
|
||||
}
|
||||
|
||||
|
||||
@ -68,7 +69,8 @@ class SliceReplacerTest(unittest.TestCase):
|
||||
|
||||
*connect_front('equal:0', 'select:0'),
|
||||
|
||||
*connect_front('end_const:0', 'select:2'),
|
||||
*connect_front('end_const:0', 'cast:0'),
|
||||
*connect_front('cast:0', 'select:2'),
|
||||
*connect_front('select:0', 'slice:2'),
|
||||
|
||||
*connect_front('slice:0', 'output'),
|
||||
@ -97,7 +99,8 @@ class SliceReplacerTest(unittest.TestCase):
|
||||
*connect_front('int32_max:0', '1:select'),
|
||||
*connect_front('minus_one:0', '1:equal'),
|
||||
*connect_front('equal:0', '0:select'),
|
||||
*connect_front('end_const:0', '2:select'),
|
||||
*connect_front('end_const:0', '0:cast'),
|
||||
*connect_front('cast:0', '2:select'),
|
||||
*connect_front('select:0', '2:slice'),
|
||||
*connect_front('slice:0', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
@ -16,18 +16,30 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.graph.port import Port
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
def convert_negative_indices(indices: np.array, shape: np.array):
|
||||
for ind, value in enumerate(indices):
|
||||
if value < 0:
|
||||
indices[ind] += shape[ind]
|
||||
def create_ss_interval_border(graph: Graph, shape, axes, port_to_connect: Port, node_name):
|
||||
shape_mask = np.zeros(len(shape), dtype=np.int64)
|
||||
first_part = shape_mask[:axes[0]]
|
||||
last_part = shape_mask[axes[-1] + 1:]
|
||||
|
||||
cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node()
|
||||
port_to_connect.get_connection().set_destination(cast.in_port(0))
|
||||
concat = create_op_with_const_inputs(graph, Concat, port_value_dict={0: first_part, 2: last_part},
|
||||
op_attrs={'name': node_name + '/Concat', 'axis': 0,
|
||||
'in_ports_count': 3})
|
||||
cast.out_port(0).connect(concat.in_port(1))
|
||||
return concat
|
||||
|
||||
|
||||
class ConvertSlice(MiddleReplacementPattern):
|
||||
@ -36,80 +48,57 @@ class ConvertSlice(MiddleReplacementPattern):
|
||||
"""
|
||||
|
||||
enabled = True
|
||||
op = "Slice"
|
||||
force_clean_up = True
|
||||
op = "Slice"
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.pass_separator import MiddleStart
|
||||
return [MiddleStart]
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('slice', dict(kind='op', op='Slice'))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='Slice'):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
node = match['slice']
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
if node.is_in_port_connected(3):
|
||||
axes = node.in_port(3).data.get_value().copy()
|
||||
assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
|
||||
for i, val in enumerate(axes):
|
||||
axes[i] = get_canonical_axis_index(input_shape, val)
|
||||
else:
|
||||
axes = int64_array(range(len(input_shape)))
|
||||
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
output_shape = node.out_port(0).data.get_shape()
|
||||
starts = node.in_port(1).data.get_value()
|
||||
ends = node.in_port(2).data.get_value()
|
||||
if starts is None or ends is None:
|
||||
raise Error('The input with starts or end is not constant for node {}'.format(node.id))
|
||||
ss_begin = create_ss_interval_border(graph, input_shape, axes, node.in_port(1).get_source(), node_name)
|
||||
ss_end = create_ss_interval_border(graph, input_shape, axes, node.in_port(2).get_source(), node_name)
|
||||
rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])
|
||||
|
||||
# the value for 'ends' is usually maximum possible value of int64. This
|
||||
# value must be converted to maximum of int32 because such big values do not fit into the int32 which is
|
||||
# supported by the StridedSlice layer
|
||||
ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
|
||||
if node.is_in_port_connected(3):
|
||||
axes = node.in_port(3).data.get_value()
|
||||
if axes is None:
|
||||
raise Error('The input with axes is not constant for node {}'.format(node.id))
|
||||
else:
|
||||
axes = int64_array(list(range(starts.size)))
|
||||
if node.is_in_port_connected(4):
|
||||
steps = node.in_port(4).data.get_value()
|
||||
assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
|
||||
else:
|
||||
steps = np.ones([axes.size])
|
||||
|
||||
if node.is_in_port_connected(4):
|
||||
steps = node.in_port(4).data.get_value()
|
||||
if steps is None:
|
||||
raise Error('The input with steps is not constant for node {}'.format(node.id))
|
||||
else:
|
||||
steps = np.ones([starts.size])
|
||||
ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
|
||||
ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
|
||||
ss_step = np.ones(len(input_shape), dtype=np.int64)
|
||||
|
||||
ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
|
||||
ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
|
||||
ss_begin = np.zeros(len(input_shape), dtype=np.int32)
|
||||
ss_end = np.zeros(len(input_shape), dtype=np.int32)
|
||||
ss_step = np.ones(len(input_shape), dtype=np.int32)
|
||||
|
||||
# prepare inputs and attributes for the StridedSlice layer
|
||||
for i, axis in enumerate(axes):
|
||||
if starts[i] != 0:
|
||||
for i, axis in enumerate(axes):
|
||||
ss_begin_mask[axis] = 1
|
||||
ss_begin[axis] = starts[i]
|
||||
ss_end_mask[axis] = 1
|
||||
ss_step[axis] = steps[i]
|
||||
|
||||
ss_end_mask[axis] = 1
|
||||
ss_end[axis] = ends[i]
|
||||
ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()
|
||||
|
||||
ss_step[axis] = steps[i]
|
||||
ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
|
||||
shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
|
||||
ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
|
||||
begin_mask=ss_begin_mask,
|
||||
end_mask=ss_end_mask, override_output_shape=True)).create_node()
|
||||
|
||||
slice_node_name = node.soft_get('name', node.id)
|
||||
node.in_port(0).get_connection().set_destination(ss.in_port(0))
|
||||
ss.in_port(1).connect(ss_begin.out_port(0))
|
||||
ss.in_port(2).connect(ss_end.out_port(0))
|
||||
ss.in_port(3).connect(ss_strides.out_port(0))
|
||||
node.out_port(0).get_connection().set_source(ss.out_port(0))
|
||||
|
||||
begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
|
||||
end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
|
||||
strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()
|
||||
|
||||
ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
|
||||
shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
|
||||
ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
|
||||
begin_mask=ss_begin_mask,
|
||||
end_mask=ss_end_mask)).create_node()
|
||||
rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)])
|
||||
node.in_port(0).get_connection().set_destination(ss.in_port(0))
|
||||
begin_node.out_port(0).connect(ss.in_port(1))
|
||||
end_node.out_port(0).connect(ss.in_port(2))
|
||||
strides_node.out_port(0).connect(ss.in_port(3))
|
||||
node.out_port(0).get_connection().set_source(ss.out_port(0))
|
||||
rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
|
||||
|
@ -17,307 +17,148 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from generator import generate, generator
|
||||
|
||||
from extensions.middle.SliceConverter import ConvertSlice
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.ops.slice import Slice
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
nodes_attributes = {
|
||||
# input data
|
||||
'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_2': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'placeholder_3': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
# Slice layer
|
||||
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
|
||||
'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
# Output operation
|
||||
'output_op': {'type': 'Const', 'value': None, 'kind': 'op', 'op': 'Const'},
|
||||
'output_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'op_output': { 'kind': 'op', 'op': 'Result'},
|
||||
# StridedSlice layer
|
||||
'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
|
||||
}
|
||||
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
|
||||
regular_op_with_empty_data, result, connect, const, empty_data
|
||||
|
||||
|
||||
@generator
|
||||
class ConvertSliceTests(unittest.TestCase):
|
||||
nodes_attributes = {
|
||||
# input data
|
||||
'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
# Slice layer inputs
|
||||
'starts': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'starts_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'ends': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'ends_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'strides': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'strides_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'axes': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'axes_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'steps': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'steps_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
# Slice layer
|
||||
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
|
||||
'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
# Output operation
|
||||
'output_op': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
|
||||
'output_data': {'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'op_output': {'kind': 'op', 'op': 'Result'},
|
||||
# StridedSlice layer
|
||||
'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
|
||||
}
|
||||
@generate(*[
|
||||
(int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([2, 3]), np.array([1, 1]),
|
||||
(int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
|
||||
int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1])),
|
||||
|
||||
def test_slice_all_params(self):
|
||||
input_shape = int64_array([5, 10, 20])
|
||||
starts_value = int64_array([4, 2])
|
||||
ends_value = int64_array([15, 8])
|
||||
axes_value = int64_array([2, 1])
|
||||
steps_value = int64_array([1, 1])
|
||||
(int64_array([1, 3, 300, 300]), np.array([0]), np.array([150]), np.array([2]), np.array([1]),
|
||||
(int64_array([0, 0]), int64_array([0])), (int64_array([0, 0]), int64_array([0])), int64_array([1, 1, 1, 1]),
|
||||
int64_array([0, 0, 1, 0]), int64_array([0, 0, 1, 0])),
|
||||
|
||||
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'slice', {'in': 2}),
|
||||
('axes', 'axes_data'),
|
||||
('axes_data', 'slice', {'in': 3}),
|
||||
('steps', 'steps_data'),
|
||||
('steps_data', 'slice', {'in': 4}),
|
||||
('slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'starts': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'ends': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'steps': {'shape': steps_value.shape, 'value': steps_value},
|
||||
'steps_data': {'shape': steps_value.shape, 'value': steps_value},
|
||||
'axes': {'shape': axes_value.shape, 'value': axes_value},
|
||||
'axes_data': {'shape': axes_value.shape, 'value': axes_value},
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
slice_node = Node(graph, 'slice')
|
||||
Slice.infer(slice_node)
|
||||
|
||||
pattern = ConvertSlice()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
|
||||
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
|
||||
|
||||
graph_ref = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'strided_slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'strided_slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'strided_slice', {'in': 2}),
|
||||
('strides', 'strides_data'),
|
||||
('strides_data', 'strided_slice', {'in': 3}),
|
||||
('strided_slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
|
||||
'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
|
||||
'end_mask': int64_array([0, 1, 1])},
|
||||
'slice_data': {'shape': int64_array([5, 6, 11])}
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
|
||||
(int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([-2, -1]), np.array([1, 1]),
|
||||
(int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
|
||||
int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1]))
|
||||
])
|
||||
def test_convert_slice_to_strided_slice(self, input_shape, start, end, axes, steps,
|
||||
ss_begin_parts: tuple, ss_end_parts: tuple, ss_steps,
|
||||
ss_begin_mask, ss_end_mask):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
|
||||
**valued_const_with_data('start', start),
|
||||
**valued_const_with_data('end', end),
|
||||
**valued_const_with_data('axes', axes),
|
||||
**valued_const_with_data('steps', steps),
|
||||
**regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'slice'),
|
||||
*connect('start', '1:slice'),
|
||||
*connect('end', '2:slice'),
|
||||
*connect('axes', '3:slice'),
|
||||
*connect('steps', '4:slice'),
|
||||
*connect('slice', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
|
||||
**valued_const_with_data('start', start),
|
||||
**valued_const_with_data('begin_first_part', ss_begin_parts[0]),
|
||||
**valued_const_with_data('begin_last_part', ss_begin_parts[1]),
|
||||
**regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
|
||||
**regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
|
||||
**valued_const_with_data('end', end),
|
||||
**valued_const_with_data('end_first_part', ss_end_parts[0]),
|
||||
**valued_const_with_data('end_last_part', ss_end_parts[1]),
|
||||
**regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
|
||||
**regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
|
||||
**const('ss_steps', ss_steps),
|
||||
**empty_data('ss_steps_d'),
|
||||
**regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
|
||||
'begin_mask': ss_begin_mask, 'end_mask': ss_end_mask,
|
||||
'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
|
||||
'shrink_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
|
||||
'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64)}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'ss'),
|
||||
*connect('begin_first_part', 'ss_begin'),
|
||||
*connect('start', 'convert_start'),
|
||||
*connect('convert_start', '1:ss_begin'),
|
||||
*connect('begin_last_part', '2:ss_begin'),
|
||||
*connect('ss_begin', '1:ss'),
|
||||
*connect('end_first_part', 'ss_end'),
|
||||
*connect('end', 'convert_end'),
|
||||
*connect('convert_end', '1:ss_end'),
|
||||
*connect('end_last_part', '2:ss_end'),
|
||||
*connect('ss_end', '2:ss'),
|
||||
*connect('ss_steps', '3:ss'),
|
||||
*connect('ss', 'result')
|
||||
]
|
||||
)
|
||||
ConvertSlice().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_no_steps_no_axes(self):
|
||||
input_shape = int64_array([5, 10, 20])
|
||||
starts_value = int64_array([3, 2, 7])
|
||||
ends_value = int64_array([5, 8, 15])
|
||||
steps_value = int64_array([1, 1, 1])
|
||||
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'slice', {'in': 2}),
|
||||
('slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'starts': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'ends': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
slice_node = Node(graph, 'slice')
|
||||
Slice.infer(slice_node)
|
||||
|
||||
pattern = ConvertSlice()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
|
||||
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
|
||||
|
||||
graph_ref = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'strided_slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'strided_slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'strided_slice', {'in': 2}),
|
||||
('strides', 'strides_data'),
|
||||
('strides_data', 'strided_slice', {'in': 3}),
|
||||
('strided_slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
|
||||
'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
|
||||
'end_mask': np.ones([3])},
|
||||
'slice_data': {'shape': int64_array([2, 6, 8])}
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_no_axes(self):
|
||||
input_shape = int64_array([5, 10, 20])
|
||||
starts_value = int64_array([3, 2, 7])
|
||||
ends_value = int64_array([5, 8, 15])
|
||||
steps_value = int64_array([2, 3, 1])
|
||||
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'slice', {'in': 2}),
|
||||
('steps', 'steps_data'),
|
||||
('steps_data', 'slice', {'in': 4}),
|
||||
('slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'starts': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'ends': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'steps': {'shape': steps_value.shape, 'value': steps_value},
|
||||
'steps_data': {'shape': steps_value.shape, 'value': steps_value},
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
slice_node = Node(graph, 'slice')
|
||||
Slice.infer(slice_node)
|
||||
|
||||
pattern = ConvertSlice()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
|
||||
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
|
||||
|
||||
graph_ref = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'strided_slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'strided_slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'strided_slice', {'in': 2}),
|
||||
('strides', 'strides_data'),
|
||||
('strides_data', 'strided_slice', {'in': 3}),
|
||||
('strided_slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
|
||||
'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
|
||||
'end_mask': np.ones([3])},
|
||||
'slice_data': {'shape': int64_array([1, 2, 8])}
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_no_steps(self):
|
||||
input_shape = int64_array([5, 10, 20])
|
||||
starts_value = int64_array([4, 2])
|
||||
ends_value = int64_array([15, 8])
|
||||
axes_value = int64_array([2, 1])
|
||||
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'slice', {'in': 2}),
|
||||
('axes', 'axes_data'),
|
||||
('axes_data', 'slice', {'in': 3}),
|
||||
('slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'starts': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
|
||||
'ends': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
|
||||
'axes': {'shape': axes_value.shape, 'value': axes_value},
|
||||
'axes_data': {'shape': axes_value.shape, 'value': axes_value},
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
slice_node = Node(graph, 'slice')
|
||||
Slice.infer(slice_node)
|
||||
|
||||
pattern = ConvertSlice()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
|
||||
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
|
||||
|
||||
graph_ref = build_graph(self.nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'strided_slice', {'in': 0}),
|
||||
('starts', 'starts_data'),
|
||||
('starts_data', 'strided_slice', {'in': 1}),
|
||||
('ends', 'ends_data'),
|
||||
('ends_data', 'strided_slice', {'in': 2}),
|
||||
('strides', 'strides_data'),
|
||||
('strides_data', 'strided_slice', {'in': 3}),
|
||||
('strided_slice', 'slice_data'),
|
||||
('slice_data', 'output_op'),
|
||||
('output_op', 'output_data'),
|
||||
('output_data', 'op_output')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': input_shape},
|
||||
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
|
||||
'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
|
||||
'end_mask': int64_array([0, 1, 1])},
|
||||
'slice_data': {'shape': int64_array([5, 6, 11])}
|
||||
}, nodes_with_edges_only=True
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
|
||||
def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
|
||||
**valued_const_with_data('start', np.array([0, 0, 0])),
|
||||
**valued_const_with_data('end', np.array([1, 3, 5])),
|
||||
**regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'slice'),
|
||||
*connect('start', '1:slice'),
|
||||
*connect('end', '2:slice'),
|
||||
*connect('slice', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
|
||||
**valued_const_with_data('start', np.array([0, 0, 0])),
|
||||
**valued_const_with_data('begin_first_part', int64_array([])),
|
||||
**valued_const_with_data('begin_last_part', int64_array([])),
|
||||
**regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
|
||||
**regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
|
||||
**valued_const_with_data('end', np.array([1, 3, 5])),
|
||||
**valued_const_with_data('end_first_part', int64_array([])),
|
||||
**valued_const_with_data('end_last_part', int64_array([])),
|
||||
**regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
|
||||
**regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
|
||||
**const('ss_steps', int64_array([1, 1, 1])),
|
||||
**empty_data('ss_steps_d'),
|
||||
**regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
|
||||
'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 1, 1]),
|
||||
'new_axis_mask': np.zeros(3, dtype=np.int64),
|
||||
'shrink_axis_mask': np.zeros(3, dtype=np.int64),
|
||||
'ellipsis_mask': np.zeros(3, dtype=np.int64)}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'ss'),
|
||||
*connect('begin_first_part', 'ss_begin'),
|
||||
*connect('start', 'convert_start'),
|
||||
*connect('convert_start', '1:ss_begin'),
|
||||
*connect('begin_last_part', '2:ss_begin'),
|
||||
*connect('ss_begin', '1:ss'),
|
||||
*connect('end_first_part', 'ss_end'),
|
||||
*connect('end', 'convert_end'),
|
||||
*connect('convert_end', '1:ss_end'),
|
||||
*connect('end_last_part', '2:ss_end'),
|
||||
*connect('ss_end', '2:ss'),
|
||||
*connect('ss_steps', '3:ss'),
|
||||
*connect('ss', 'result')
|
||||
]
|
||||
)
|
||||
ConvertSlice().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user