Revert "Reshape-able SliceConverter (#2954)" (#3118)

This reverts commit b437387bd5.
This commit is contained in:
Yegor Kruglov 2020-11-13 15:36:04 +03:00 committed by GitHub
parent 2b23eb8ade
commit 302ded7bd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 360 additions and 197 deletions

View File

@ -16,7 +16,6 @@
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Add, Equal
from extensions.ops.select import Select
from mo.front.common.replacement import FrontReplacementOp
@ -75,7 +74,4 @@ class TFSliceToSliceReplacer(FrontReplacementOp):
# out of select to end (2nd of slice)
select_node.out_port(0).connect(slice_node.in_port(2))
cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
select_node.in_port(2).get_connection().insert_node(cast)
node.out_port(0).get_connection().set_source(slice_node.out_port(0))

View File

@ -37,7 +37,6 @@ nodes = {
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
**regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}),
}
@ -69,8 +68,7 @@ class SliceReplacerTest(unittest.TestCase):
*connect_front('equal:0', 'select:0'),
*connect_front('end_const:0', 'cast:0'),
*connect_front('cast:0', 'select:2'),
*connect_front('end_const:0', 'select:2'),
*connect_front('select:0', 'slice:2'),
*connect_front('slice:0', 'output'),
@ -99,8 +97,7 @@ class SliceReplacerTest(unittest.TestCase):
*connect_front('int32_max:0', '1:select'),
*connect_front('minus_one:0', '1:equal'),
*connect_front('equal:0', '0:select'),
*connect_front('end_const:0', '0:cast'),
*connect_front('cast:0', '2:select'),
*connect_front('end_const:0', '2:select'),
*connect_front('select:0', '2:slice'),
*connect_front('slice:0', 'output'),
], nodes_with_edges_only=True)

View File

@ -16,30 +16,18 @@
import numpy as np
from extensions.ops.Cast import Cast
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.strided_slice import StridedSlice
from mo.utils.error import Error
def create_ss_interval_border(graph: Graph, shape, axes, port_to_connect: Port, node_name):
shape_mask = np.zeros(len(shape), dtype=np.int64)
first_part = shape_mask[:axes[0]]
last_part = shape_mask[axes[-1] + 1:]
cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node()
port_to_connect.get_connection().set_destination(cast.in_port(0))
concat = create_op_with_const_inputs(graph, Concat, port_value_dict={0: first_part, 2: last_part},
op_attrs={'name': node_name + '/Concat', 'axis': 0,
'in_ports_count': 3})
cast.out_port(0).connect(concat.in_port(1))
return concat
def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]
class ConvertSlice(MiddleReplacementPattern):
@ -48,57 +36,80 @@ class ConvertSlice(MiddleReplacementPattern):
"""
enabled = True
force_clean_up = True
op = "Slice"
force_clean_up = True
def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Slice'):
node_name = node.soft_get('name', node.id)
def pattern(self):
return dict(
nodes=[
('slice', dict(kind='op', op='Slice'))
],
edges=[]
)
input_shape = node.in_port(0).data.get_shape()
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value().copy()
assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
for i, val in enumerate(axes):
axes[i] = get_canonical_axis_index(input_shape, val)
else:
axes = int64_array(range(len(input_shape)))
def replace_pattern(self, graph: Graph, match: dict):
node = match['slice']
ss_begin = create_ss_interval_border(graph, input_shape, axes, node.in_port(1).get_source(), node_name)
ss_end = create_ss_interval_border(graph, input_shape, axes, node.in_port(2).get_source(), node_name)
rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])
input_shape = node.in_port(0).data.get_shape()
output_shape = node.out_port(0).data.get_shape()
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if starts is None or ends is None:
raise Error('The input with starts or end is not constant for node {}'.format(node.id))
if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
else:
steps = np.ones([axes.size])
# the value for 'ends' is usually maximum possible value of int64. This
# value must be converted to maximum of int32 because such big values do not fit into the int32 which is
# supported by the StridedSlice layer
ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value()
if axes is None:
raise Error('The input with axes is not constant for node {}'.format(node.id))
else:
axes = int64_array(list(range(starts.size)))
ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_step = np.ones(len(input_shape), dtype=np.int64)
if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
if steps is None:
raise Error('The input with steps is not constant for node {}'.format(node.id))
else:
steps = np.ones([starts.size])
for i, axis in enumerate(axes):
ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_begin = np.zeros(len(input_shape), dtype=np.int32)
ss_end = np.zeros(len(input_shape), dtype=np.int32)
ss_step = np.ones(len(input_shape), dtype=np.int32)
# prepare inputs and attributes for the StridedSlice layer
for i, axis in enumerate(axes):
if starts[i] != 0:
ss_begin_mask[axis] = 1
ss_end_mask[axis] = 1
ss_step[axis] = steps[i]
ss_begin[axis] = starts[i]
ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()
ss_end_mask[axis] = 1
ss_end[axis] = ends[i]
ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask, override_output_shape=True)).create_node()
ss_step[axis] = steps[i]
node.in_port(0).get_connection().set_destination(ss.in_port(0))
ss.in_port(1).connect(ss_begin.out_port(0))
ss.in_port(2).connect(ss_end.out_port(0))
ss.in_port(3).connect(ss_strides.out_port(0))
node.out_port(0).get_connection().set_source(ss.out_port(0))
slice_node_name = node.soft_get('name', node.id)
rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()
ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask)).create_node()
rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)])
node.in_port(0).get_connection().set_destination(ss.in_port(0))
begin_node.out_port(0).connect(ss.in_port(1))
end_node.out_port(0).connect(ss.in_port(2))
strides_node.out_port(0).connect(ss.in_port(3))
node.out_port(0).get_connection().set_source(ss.out_port(0))

View File

@ -17,148 +17,307 @@
import unittest
import numpy as np
from generator import generate, generator
from extensions.middle.SliceConverter import ConvertSlice
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.ops.slice import Slice
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
regular_op_with_empty_data, result, connect, const, empty_data
from mo.utils.unittest.graph import build_graph
nodes_attributes = {
# input data
'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_2': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'placeholder_3': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Slice layer
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
# Output operation
'output_op': {'type': 'Const', 'value': None, 'kind': 'op', 'op': 'Const'},
'output_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'op_output': { 'kind': 'op', 'op': 'Result'},
# StridedSlice layer
'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
}
@generator
class ConvertSliceTests(unittest.TestCase):
@generate(*[
(int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([2, 3]), np.array([1, 1]),
(int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1])),
nodes_attributes = {
# input data
'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Slice layer inputs
'starts': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'starts_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'ends': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'ends_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'strides': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'strides_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'axes': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'axes_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'steps': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'steps_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Slice layer
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
# Output operation
'output_op': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'output_data': {'shape': None, 'kind': 'data', 'data_type': None},
'op_output': {'kind': 'op', 'op': 'Result'},
# StridedSlice layer
'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
}
(int64_array([1, 3, 300, 300]), np.array([0]), np.array([150]), np.array([2]), np.array([1]),
(int64_array([0, 0]), int64_array([0])), (int64_array([0, 0]), int64_array([0])), int64_array([1, 1, 1, 1]),
int64_array([0, 0, 1, 0]), int64_array([0, 0, 1, 0])),
def test_slice_all_params(self):
input_shape = int64_array([5, 10, 20])
starts_value = int64_array([4, 2])
ends_value = int64_array([15, 8])
axes_value = int64_array([2, 1])
steps_value = int64_array([1, 1])
(int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([-2, -1]), np.array([1, 1]),
(int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1]))
])
def test_convert_slice_to_strided_slice(self, input_shape, start, end, axes, steps,
ss_begin_parts: tuple, ss_end_parts: tuple, ss_steps,
ss_begin_mask, ss_end_mask):
graph = build_graph(
nodes_attrs={
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
**valued_const_with_data('start', start),
**valued_const_with_data('end', end),
**valued_const_with_data('axes', axes),
**valued_const_with_data('steps', steps),
**regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
**result('result')
},
edges=[
*connect('input', 'slice'),
*connect('start', '1:slice'),
*connect('end', '2:slice'),
*connect('axes', '3:slice'),
*connect('steps', '4:slice'),
*connect('slice', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
**valued_const_with_data('start', start),
**valued_const_with_data('begin_first_part', ss_begin_parts[0]),
**valued_const_with_data('begin_last_part', ss_begin_parts[1]),
**regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
**regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
**valued_const_with_data('end', end),
**valued_const_with_data('end_first_part', ss_end_parts[0]),
**valued_const_with_data('end_last_part', ss_end_parts[1]),
**regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
**regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
**const('ss_steps', ss_steps),
**empty_data('ss_steps_d'),
**regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
'begin_mask': ss_begin_mask, 'end_mask': ss_end_mask,
'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
'shrink_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64)}),
**result('result')
},
edges=[
*connect('input', 'ss'),
*connect('begin_first_part', 'ss_begin'),
*connect('start', 'convert_start'),
*connect('convert_start', '1:ss_begin'),
*connect('begin_last_part', '2:ss_begin'),
*connect('ss_begin', '1:ss'),
*connect('end_first_part', 'ss_end'),
*connect('end', 'convert_end'),
*connect('convert_end', '1:ss_end'),
*connect('end_last_part', '2:ss_end'),
*connect('ss_end', '2:ss'),
*connect('ss_steps', '3:ss'),
*connect('ss', 'result')
]
)
ConvertSlice().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
graph = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'slice', {'in': 2}),
('axes', 'axes_data'),
('axes_data', 'slice', {'in': 3}),
('steps', 'steps_data'),
('steps_data', 'slice', {'in': 4}),
('slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'starts': {'shape': starts_value.shape, 'value': starts_value},
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
'ends': {'shape': ends_value.shape, 'value': ends_value},
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
'steps': {'shape': steps_value.shape, 'value': steps_value},
'steps_data': {'shape': steps_value.shape, 'value': steps_value},
'axes': {'shape': axes_value.shape, 'value': axes_value},
'axes_data': {'shape': axes_value.shape, 'value': axes_value},
}, nodes_with_edges_only=True
)
slice_node = Node(graph, 'slice')
Slice.infer(slice_node)
pattern = ConvertSlice()
pattern.find_and_replace_pattern(graph)
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
graph_ref = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'strided_slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'strided_slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'strided_slice', {'in': 2}),
('strides', 'strides_data'),
('strides_data', 'strided_slice', {'in': 3}),
('strided_slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
'end_mask': int64_array([0, 1, 1])},
'slice_data': {'shape': int64_array([5, 6, 11])}
}, nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
graph = build_graph(
nodes_attrs={
**regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
**valued_const_with_data('start', np.array([0, 0, 0])),
**valued_const_with_data('end', np.array([1, 3, 5])),
**regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
**result('result')
},
edges=[
*connect('input', 'slice'),
*connect('start', '1:slice'),
*connect('end', '2:slice'),
*connect('slice', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
**valued_const_with_data('start', np.array([0, 0, 0])),
**valued_const_with_data('begin_first_part', int64_array([])),
**valued_const_with_data('begin_last_part', int64_array([])),
**regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
**regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
**valued_const_with_data('end', np.array([1, 3, 5])),
**valued_const_with_data('end_first_part', int64_array([])),
**valued_const_with_data('end_last_part', int64_array([])),
**regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
**regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
**const('ss_steps', int64_array([1, 1, 1])),
**empty_data('ss_steps_d'),
**regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 1, 1]),
'new_axis_mask': np.zeros(3, dtype=np.int64),
'shrink_axis_mask': np.zeros(3, dtype=np.int64),
'ellipsis_mask': np.zeros(3, dtype=np.int64)}),
**result('result')
},
edges=[
*connect('input', 'ss'),
*connect('begin_first_part', 'ss_begin'),
*connect('start', 'convert_start'),
*connect('convert_start', '1:ss_begin'),
*connect('begin_last_part', '2:ss_begin'),
*connect('ss_begin', '1:ss'),
*connect('end_first_part', 'ss_end'),
*connect('end', 'convert_end'),
*connect('convert_end', '1:ss_end'),
*connect('end_last_part', '2:ss_end'),
*connect('ss_end', '2:ss'),
*connect('ss_steps', '3:ss'),
*connect('ss', 'result')
]
)
ConvertSlice().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
def test_no_steps_no_axes(self):
input_shape = int64_array([5, 10, 20])
starts_value = int64_array([3, 2, 7])
ends_value = int64_array([5, 8, 15])
steps_value = int64_array([1, 1, 1])
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
graph = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'slice', {'in': 2}),
('slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'starts': {'shape': starts_value.shape, 'value': starts_value},
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
'ends': {'shape': ends_value.shape, 'value': ends_value},
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
}, nodes_with_edges_only=True
)
slice_node = Node(graph, 'slice')
Slice.infer(slice_node)
pattern = ConvertSlice()
pattern.find_and_replace_pattern(graph)
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
graph_ref = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'strided_slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'strided_slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'strided_slice', {'in': 2}),
('strides', 'strides_data'),
('strides_data', 'strided_slice', {'in': 3}),
('strided_slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
'end_mask': np.ones([3])},
'slice_data': {'shape': int64_array([2, 6, 8])}
}, nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_no_axes(self):
input_shape = int64_array([5, 10, 20])
starts_value = int64_array([3, 2, 7])
ends_value = int64_array([5, 8, 15])
steps_value = int64_array([2, 3, 1])
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
graph = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'slice', {'in': 2}),
('steps', 'steps_data'),
('steps_data', 'slice', {'in': 4}),
('slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'starts': {'shape': starts_value.shape, 'value': starts_value},
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
'ends': {'shape': ends_value.shape, 'value': ends_value},
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
'steps': {'shape': steps_value.shape, 'value': steps_value},
'steps_data': {'shape': steps_value.shape, 'value': steps_value},
}, nodes_with_edges_only=True
)
slice_node = Node(graph, 'slice')
Slice.infer(slice_node)
pattern = ConvertSlice()
pattern.find_and_replace_pattern(graph)
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
graph_ref = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'strided_slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'strided_slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'strided_slice', {'in': 2}),
('strides', 'strides_data'),
('strides_data', 'strided_slice', {'in': 3}),
('strided_slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
'end_mask': np.ones([3])},
'slice_data': {'shape': int64_array([1, 2, 8])}
}, nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_no_steps(self):
input_shape = int64_array([5, 10, 20])
starts_value = int64_array([4, 2])
ends_value = int64_array([15, 8])
axes_value = int64_array([2, 1])
masks_value = np.zeros([len(input_shape)], dtype=np.int64)
graph = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'slice', {'in': 2}),
('axes', 'axes_data'),
('axes_data', 'slice', {'in': 3}),
('slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'starts': {'shape': starts_value.shape, 'value': starts_value},
'starts_data': {'shape': starts_value.shape, 'value': starts_value},
'ends': {'shape': ends_value.shape, 'value': ends_value},
'ends_data': {'shape': ends_value.shape, 'value': ends_value},
'axes': {'shape': axes_value.shape, 'value': axes_value},
'axes_data': {'shape': axes_value.shape, 'value': axes_value},
}, nodes_with_edges_only=True
)
slice_node = Node(graph, 'slice')
Slice.infer(slice_node)
pattern = ConvertSlice()
pattern.find_and_replace_pattern(graph)
ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
graph_ref = build_graph(self.nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'strided_slice', {'in': 0}),
('starts', 'starts_data'),
('starts_data', 'strided_slice', {'in': 1}),
('ends', 'ends_data'),
('ends_data', 'strided_slice', {'in': 2}),
('strides', 'strides_data'),
('strides_data', 'strided_slice', {'in': 3}),
('strided_slice', 'slice_data'),
('slice_data', 'output_op'),
('output_op', 'output_data'),
('output_data', 'op_output')
],
{'placeholder_1_data': {'shape': input_shape},
'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
'end_mask': int64_array([0, 1, 1])},
'slice_data': {'shape': int64_array([5, 6, 11])}
}, nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
self.assertTrue(flag, resp)