diff --git a/src/frontends/tensorflow/src/op/reverse.cpp b/src/frontends/tensorflow/src/op/reverse.cpp index 786b3e909af..d1ead7696e3 100644 --- a/src/frontends/tensorflow/src/op/reverse.cpp +++ b/src/frontends/tensorflow/src/op/reverse.cpp @@ -3,47 +3,112 @@ // #include "op_table.hpp" -#include "openvino/opsets/opset8.hpp" +#include "openvino/opsets/opset9.hpp" using namespace std; -using namespace ov::opset8; +using namespace ov; +using namespace ov::opset9; namespace ov { namespace frontend { namespace tensorflow { namespace op { +shared_ptr compute_sequence_lengths(const Output& input_shape, int64_t batch_axis, int64_t seq_axis) { + auto batch_axis_const = make_shared(element::i32, Shape{1}, batch_axis); + auto seq_axis_const = make_shared(element::i32, Shape{1}, seq_axis); + auto gather_axis = make_shared(element::i32, Shape{}, 0); + auto batch_dim = make_shared(input_shape, batch_axis_const, gather_axis); + auto seq_dim = make_shared(input_shape, seq_axis_const, gather_axis); + auto seq_lengths = make_shared(seq_dim, batch_dim); -OutputVector translate_reverse_op(const NodeContext& node) { - auto input = node.get_input(0); - auto axes = node.get_input(1); + return seq_lengths; +} - auto axes_const = dynamic_pointer_cast(axes.get_node_shared_ptr()); - TENSORFLOW_OP_VALIDATION(node, axes_const != nullptr, "Axes input must be constant."); - TENSORFLOW_OP_VALIDATION(node, axes_const->get_shape().size() == 1, "Axes input must be 1D."); - TENSORFLOW_OP_VALIDATION(node, axes_const->get_shape()[0] == 1, "Axes input must have only one value."); - auto seq_axis = axes_const->cast_vector().at(0); - int64_t batch_axis = !seq_axis; - - Output seq_lengths; - if (input.get_partial_shape().is_static()) { - auto in_shape = input.get_shape(); - seq_lengths = make_shared(element::i64, Shape{in_shape[batch_axis]}, in_shape[seq_axis]); - } else { - auto shape = make_shared(input); - auto one = make_shared(element::i64, Shape{1}, 1); - auto gather_batch = make_shared(shape, - make_shared(element::i64, Shape{1}, batch_axis), - make_shared(element::i64, Shape{1}, 0)); - auto gather_seq = make_shared(shape, - make_shared(element::i64, Shape{1}, seq_axis), - make_shared(element::i64, Shape{1}, 0)); - auto broadcast = make_shared(one, gather_batch); - seq_lengths = make_shared(broadcast, gather_seq); +OutputVector translate_reverse_base_op(const NodeContext& node, + const Output& input, + const std::vector& axes) { + auto reverse_node_name = node.get_name(); + if (axes.size() == 0) { + // there is nothing to reverse + input.get_tensor().add_names({reverse_node_name + ":0"}); + return {input}; } - auto res = make_shared(input, seq_lengths, batch_axis, seq_axis); - set_node_name(node.get_name(), res); - return res->outputs(); + TENSORFLOW_OP_VALIDATION( + node, + axes.size() == 1, + "OpenVINO TensorFlow Frontend does not support Reverse or ReverseV2 with multiple axes for the reversing."); + + int64_t seq_axis = axes[0]; + int64_t batch_axis = 0; + + // when we are not sure that input rank greater than 1 + // based on seq_axis, introduce the auxiliary dimension for the batch + std::vector unsqueeze_axes; + if (seq_axis == 0 || seq_axis == -1) { + unsqueeze_axes.push_back(0); + } + + // make sure that batch and sequence dimensions are different + // in case seq_axis is zero, we added the temporal dimension in the previous step + // so we have to shift it by one + seq_axis = (seq_axis == 0) ? 1 : seq_axis; + auto batched_input = input; + if (unsqueeze_axes.size() > 0) { + // prepare input to issue auxiliary dimensions for batch + auto unsqueeze_axes_const = make_shared(element::i32, Shape{unsqueeze_axes.size()}, unsqueeze_axes); + batched_input = make_shared(input, unsqueeze_axes_const); + } + + auto input_shape = make_shared(batched_input, element::i32); + auto seq_lenghts = compute_sequence_lengths(input_shape, batch_axis, seq_axis); + auto reverse_sequence = make_shared(batched_input, seq_lenghts, batch_axis, seq_axis)->output(0); + + if (unsqueeze_axes.size() > 0) { + // remove earlier added additional dimensions from the result + auto squeeze_axes_const = make_shared(element::i32, Shape{unsqueeze_axes.size()}, unsqueeze_axes); + reverse_sequence = make_shared(reverse_sequence, squeeze_axes_const); + } + + set_node_name(node.get_name(), reverse_sequence.get_node_shared_ptr()); + return {reverse_sequence}; +} + +OutputVector translate_reverse_op(const NodeContext& node) { + // The second input of Reverse is a boolean vector. + // True elements correspond the axes along which + // elements of the input tensor are reversed + default_op_checks(node, 2, {"Reverse"}); + auto input = node.get_input(0); + + std::vector dims; + get_const_input(node, 1, &dims); + + // collect axes along which to reverse + std::vector axes; + for (int64_t ind = 0; ind < static_cast(dims.size()); ++ind) { + if (dims[ind]) { + axes.push_back(ind); + } + } + + return translate_reverse_base_op(node, input, axes); +} + +OutputVector translate_reverse_v2_op(const NodeContext& node) { + // The second input of ReverseV2 is a vector of axes along which + // elements of the input tensor are reversed + default_op_checks(node, 2, {"ReverseV2"}); + auto input = node.get_input(0); + + // the translator is able to convert ReverseV2 only + // if axis is constant and has one element. + // this limitation is due to the presence of batch_axis and seq_axis attributes. + // the current limitation is sufficient for parity with Legacy MO frontend. + std::vector axes; + get_const_input(node, 1, &axes); + + return translate_reverse_base_op(node, input, axes); } } // namespace op } // namespace tensorflow diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 356626e8123..1805d610b83 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -95,6 +95,7 @@ OP_CONVERTER(translate_reciprocal_op); OP_CONVERTER(translate_reshape_op); OP_CONVERTER(translate_resource_gather_op); OP_CONVERTER(translate_reverse_op); +OP_CONVERTER(translate_reverse_v2_op); OP_CONVERTER(translate_reverse_sequence_op); OP_CONVERTER(translate_roll_op); OP_CONVERTER(translate_round_op); @@ -277,7 +278,7 @@ const std::map get_supported_ops() { {"Reshape", translate_reshape_op}, {"Reverse", translate_reverse_op}, {"ReverseSequence", translate_reverse_sequence_op}, - {"ReverseV2", translate_reverse_op}, + {"ReverseV2", translate_reverse_v2_op}, {"ResizeBilinear", translate_interpolate_op}, {"ResizeNearestNeighbor", translate_interpolate_op}, {"ResourceGather", translate_resource_gather_op}, diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Reverse.py b/tests/layer_tests/tensorflow_tests/test_tf_Reverse.py new file mode 100644 index 00000000000..9d4525a8327 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Reverse.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestReverse(CommonTFLayerTest): + def create_reverse_net(self, shape, dims): + import tensorflow as tf + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape, 'Input') + tf.raw_ops.Reverse(tensor=x, dims=dims, name='reverse') + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(shape=[4], dims=[True]), + dict(shape=[3, 2], dims=[False, True]), + dict(shape=[4, 2, 3], dims=[False, True, False]), + dict(shape=[1, 2, 4, 3], dims=[True, False, False, False]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + def test_reverse_basic(self, params, ie_device, precision, ir_version, temp_dir, use_old_api): + self._test(*self.create_reverse_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py b/tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py index ba903c4ef46..a5983f1fff3 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_ReverseV2.py @@ -1,56 +1,33 @@ # Copyright (C) 2018-2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import numpy as np import pytest from common.tf_layer_test_class import CommonTFLayerTest -class TestReverseV2Ops(CommonTFLayerTest): - def _prepare_input(self, inputs_dict): - for input in inputs_dict.keys(): - inputs_dict[input] = np.random.random(inputs_dict[input]) - return inputs_dict - - def create_reversev2_net(self, shape, keep_dims, axis, ir_version): +class TestReverseV2(CommonTFLayerTest): + def create_reverse_v2_net(self, shape, axis): import tensorflow as tf tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: - shapes = shape.copy() - if len(shapes) >= 4: - shapes.append(shapes.pop(1)) - - x = tf.compat.v1.placeholder(tf.float32, shapes, 'Input') - tf.compat.v1.reverse_v2(x, axis) + x = tf.compat.v1.placeholder(tf.float32, shape, 'Input') + tf.raw_ops.ReverseV2(tensor=x, axis=axis, name='reverse') tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def return tf_net, None - test_data = [] - test_data.extend([ + test_data_basic = [ dict(shape=[5], axis=[0]), - pytest.param(dict(shape=[2, 3], axis=[1]), marks=pytest.mark.precommit_tf_fe), + dict(shape=[3], axis=[-1]), + dict(shape=[2, 3], axis=[1]), dict(shape=[2, 3, 5], axis=[-2]), - dict(shape=[2, 3, 5, 7], axis=[0]), - ]) + dict(shape=[2, 3, 5, 7], axis=[3]), + ] - @pytest.mark.parametrize("params", test_data) - @pytest.mark.parametrize("keep_dims", [True, False]) + @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.nightly - def test_reversev2(self, params, keep_dims, ie_device, precision, ir_version, temp_dir, use_old_api): - self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version), + @pytest.mark.precommit_tf_fe + def test_reverse_v2_basic(self, params, ie_device, precision, ir_version, temp_dir, use_old_api): + self._test(*self.create_reverse_v2_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api) - - test_data_pre_commit = [] - test_data_pre_commit.extend([dict(shape=[5], axis=[0]), - dict(shape=[2, 3, 5], axis=[-2]) - ]) - - @pytest.mark.parametrize("params", test_data_pre_commit) - @pytest.mark.parametrize("keep_dims", [True]) - @pytest.mark.precommit - def test_reversev2_precommit(self, params, keep_dims, ie_device, precision, ir_version, - temp_dir, use_old_api): - self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version), - ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api, use_new_frontend=False)