[TF FE] Refactor StridedSlice translator and add layer test to precommit (#16376)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
98237b06b5
commit
5cb20f8858
@ -16,42 +16,55 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_strided_slice_op(const NodeContext& node) {
|
||||
default_op_checks(node, 4, {"StridedSlice", "STRIDED_SLICE"});
|
||||
auto input = node.get_input(0);
|
||||
auto begin = node.get_input(1);
|
||||
auto end = node.get_input(2);
|
||||
auto strides = node.get_input(3);
|
||||
|
||||
auto begin_mask = node.get_attribute<int64_t>("begin_mask", 0);
|
||||
auto end_mask = node.get_attribute<int64_t>("end_mask", 0);
|
||||
auto new_axis_mask = node.get_attribute<int64_t>("new_axis_mask", 0);
|
||||
auto ellipsis_mask = node.get_attribute<int64_t>("ellipsis_mask", 0);
|
||||
auto shrink_axis_mask = node.get_attribute<int64_t>("shrink_axis_mask", 0);
|
||||
|
||||
auto mask_to_vector = [](int64_t mask) {
|
||||
size_t length = sizeof(mask) * CHAR_BIT;
|
||||
vector<int64_t> vec(length, 0);
|
||||
if (mask == 0) {
|
||||
return vec;
|
||||
return vector<int64_t>{};
|
||||
}
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (static_cast<unsigned char>(mask >> i & 0x1) == 1) {
|
||||
size_t max_length = sizeof(mask) * CHAR_BIT;
|
||||
vector<int64_t> vec{};
|
||||
for (size_t i = 0; i < max_length; ++i) {
|
||||
if ((mask >> i & 0x1) == 1) {
|
||||
// resize the vector by appending with required number of zeros
|
||||
vec.resize(i + 1, 0);
|
||||
vec[i] = 1;
|
||||
}
|
||||
}
|
||||
return vec;
|
||||
};
|
||||
|
||||
auto res = make_shared<StridedSlice>(input,
|
||||
begin,
|
||||
end,
|
||||
strides,
|
||||
mask_to_vector(begin_mask),
|
||||
mask_to_vector(end_mask),
|
||||
mask_to_vector(new_axis_mask),
|
||||
mask_to_vector(shrink_axis_mask),
|
||||
mask_to_vector(ellipsis_mask));
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
// retrieve attributes for StridedSlice operation
|
||||
auto begin_mask = mask_to_vector(node.get_attribute<int64_t>("begin_mask", 0));
|
||||
auto end_mask = mask_to_vector(node.get_attribute<int64_t>("end_mask", 0));
|
||||
auto new_axis_mask = mask_to_vector(node.get_attribute<int64_t>("new_axis_mask", 0));
|
||||
auto ellipsis_mask = mask_to_vector(node.get_attribute<int64_t>("ellipsis_mask", 0));
|
||||
auto shrink_axis_mask = mask_to_vector(node.get_attribute<int64_t>("shrink_axis_mask", 0));
|
||||
|
||||
// the masks can be of different length and we need to align them by the maximum length
|
||||
size_t max_length = std::max(
|
||||
{begin_mask.size(), end_mask.size(), new_axis_mask.size(), ellipsis_mask.size(), shrink_axis_mask.size()});
|
||||
begin_mask.resize(max_length, 0);
|
||||
end_mask.resize(max_length, 0);
|
||||
new_axis_mask.resize(max_length, 0);
|
||||
ellipsis_mask.resize(max_length, 0);
|
||||
shrink_axis_mask.resize(max_length, 0);
|
||||
|
||||
auto strided_slice = make_shared<StridedSlice>(input,
|
||||
begin,
|
||||
end,
|
||||
strides,
|
||||
begin_mask,
|
||||
end_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask,
|
||||
ellipsis_mask);
|
||||
set_node_name(node.get_name(), strided_slice);
|
||||
return {strided_slice};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -7,61 +7,72 @@ from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestStridedSlice(CommonTFLayerTest):
|
||||
|
||||
@staticmethod
|
||||
def create_strided_slice_net(input_shape, begin, end, strides, begin_mask, end_mask,
|
||||
def create_strided_slice_net(self, input_shape, begin_value, end_value, strides_value, begin_mask, end_mask,
|
||||
ellipsis_mask,
|
||||
new_axis_mask, shrink_axis_mask, ir_version, use_new_frontend):
|
||||
|
||||
new_axis_mask, shrink_axis_mask):
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input_node = tf.compat.v1.placeholder(tf.float32, input_shape, 'Input')
|
||||
strided_slice = tf.compat.v1.strided_slice(input_node, begin=begin, end=end,
|
||||
strides=strides,
|
||||
begin_mask=begin_mask, end_mask=end_mask,
|
||||
ellipsis_mask=ellipsis_mask,
|
||||
new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask)
|
||||
input = tf.compat.v1.placeholder(tf.float32, input_shape, 'Input')
|
||||
begin = tf.constant(begin_value, dtype=tf.int32)
|
||||
end = tf.constant(end_value, dtype=tf.int32)
|
||||
strides = tf.constant(strides_value, dtype=tf.int32)
|
||||
tf.raw_ops.StridedSlice(input=input, begin=begin, end=end, strides=strides, begin_mask=begin_mask,
|
||||
end_mask=end_mask, ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
return tf_net, ref_net
|
||||
return tf_net, None
|
||||
|
||||
test_basic_data = [
|
||||
dict(input_shape=[2, 5, 4, 3], begin_value=[1, 0, 2, 0], end_value=[2, 5, 4, 2], strides_value=[1, 2, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 2, 3, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
|
||||
dict(input_shape=[3, 4, 5, 7], begin_value=[2, 0, 3], end_value=[3, 0, 6], strides_value=[1, 1, 1],
|
||||
begin_mask=6, end_mask=6, ellipsis_mask=2, new_axis_mask=0, shrink_axis_mask=1),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize('params', test_basic_data)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_strided_slice_basic(self, params, ie_device, precision, ir_version,
|
||||
temp_dir, use_new_frontend, use_old_api):
|
||||
self._test(*self.create_strided_slice_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_squeeze_data = [
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[5, 1], begin=[0, 0], end=[5, 1], strides=[1, 1], begin_mask=0,
|
||||
dict(input_shape=[5, 1], begin_value=[0, 0], end_value=[5, 1], strides_value=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
dict(input_shape=[1, 1, 3], begin=[0, 0, 0], end=[1, 1, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 1, 3], begin_value=[0, 0, 0], end_value=[1, 1, 3], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 1], begin=[0, 0, 0], end=[1, 5, 1], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5, 1], begin_value=[0, 0, 0], end_value=[1, 5, 1], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
|
||||
pytest.param(dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
|
||||
marks=pytest.mark.precommit_tf_fe),
|
||||
dict(input_shape=[1, 1, 5, 3], begin=[0, 0, 0, 0], end=[1, 1, 5, 3], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[1, 1, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 1, 5, 3], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
|
||||
dict(input_shape=[1, 5, 1, 3], begin=[0, 0, 0, 0], end=[1, 5, 1, 3], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 1, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 1, 3], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
|
||||
dict(input_shape=[1, 5, 5, 1], begin=[0, 0, 0, 0], end=[1, 5, 1, 1], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 5, 1], begin_value=[0, 0, 0, 0], end_value=[1, 5, 1, 1], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=8),
|
||||
dict(input_shape=[1, 1, 5, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 1, 5, 5, 3],
|
||||
strides=[1, 1, 1, 1, 1],
|
||||
dict(input_shape=[1, 1, 5, 5, 3], begin_value=[0, 0, 0, 0, 0], end_value=[1, 1, 5, 5, 3],
|
||||
strides_value=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=3),
|
||||
dict(input_shape=[1, 5, 1, 5, 3], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 3],
|
||||
strides=[1, 1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 1, 5, 3], begin_value=[0, 0, 0, 0, 0], end_value=[1, 5, 1, 5, 3],
|
||||
strides_value=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=5),
|
||||
dict(input_shape=[1, 5, 1, 5, 1], begin=[0, 0, 0, 0, 0], end=[1, 5, 1, 5, 1],
|
||||
strides=[1, 1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 1, 5, 1], begin_value=[0, 0, 0, 0, 0], end_value=[1, 5, 1, 5, 1],
|
||||
strides_value=[1, 1, 1, 1, 1],
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=21),
|
||||
]
|
||||
|
||||
@ -69,32 +80,28 @@ class TestStridedSlice(CommonTFLayerTest):
|
||||
@pytest.mark.nightly
|
||||
def test_strided_slice_replace_with_squeeze(self, params, ie_device, precision, ir_version,
|
||||
temp_dir, use_new_frontend, use_old_api):
|
||||
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
self._test(*self.create_strided_slice_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_unsqueeze_data = [
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=1, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5], begin=[0, 0], end=[1, 5], strides=[1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 3], begin=[0, 0, 0], end=[1, 5, 3], strides=[1, 1, 1], begin_mask=0,
|
||||
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=5, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
|
||||
dict(input_shape=[1, 5, 5, 3], begin=[0, 0, 0, 0], end=[1, 5, 5, 3], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=0,
|
||||
end_mask=0, ellipsis_mask=0, new_axis_mask=2, shrink_axis_mask=0),
|
||||
dict(input_shape=[16, 4, 64], begin=[0, 0, 0, 0], end=[0, 0, 0, 0], strides=[1, 1, 1, 1],
|
||||
dict(input_shape=[16, 4, 64], begin_value=[0, 0, 0, 0], end_value=[0, 0, 0, 0], strides_value=[1, 1, 1, 1],
|
||||
begin_mask=19,
|
||||
end_mask=19, ellipsis_mask=0, new_axis_mask=12, shrink_axis_mask=0),
|
||||
]
|
||||
@ -103,7 +110,6 @@ class TestStridedSlice(CommonTFLayerTest):
|
||||
@pytest.mark.nightly
|
||||
def test_strided_slice_replace_with_unsqueeze(self, params, ie_device, precision, ir_version,
|
||||
temp_dir, use_new_frontend, use_old_api):
|
||||
self._test(*self.create_strided_slice_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
self._test(*self.create_strided_slice_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user