[TF FE] Refactor translators for Reverse, ReverseV2 and test it (#13602)
* [TF FE] Refactor translators for Reverse, ReverseV2 and test it Make these operations reshapeable. Add layer tests for them to the pre-commit Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Apply code-review feedback: simplify checks in Reverse * Apply the rest of code-review feedback: simplify code for Reverse * Remove redundant check for axes * Apply code-review feedback: support dynamic rank Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
ebbf5e3f10
commit
c6dda68387
@ -3,47 +3,112 @@
|
||||
//
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
shared_ptr<Node> compute_sequence_lengths(const Output<Node>& input_shape, int64_t batch_axis, int64_t seq_axis) {
|
||||
auto batch_axis_const = make_shared<Constant>(element::i32, Shape{1}, batch_axis);
|
||||
auto seq_axis_const = make_shared<Constant>(element::i32, Shape{1}, seq_axis);
|
||||
auto gather_axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto batch_dim = make_shared<Gather>(input_shape, batch_axis_const, gather_axis);
|
||||
auto seq_dim = make_shared<Gather>(input_shape, seq_axis_const, gather_axis);
|
||||
auto seq_lengths = make_shared<Broadcast>(seq_dim, batch_dim);
|
||||
|
||||
OutputVector translate_reverse_op(const NodeContext& node) {
|
||||
auto input = node.get_input(0);
|
||||
auto axes = node.get_input(1);
|
||||
return seq_lengths;
|
||||
}
|
||||
|
||||
auto axes_const = dynamic_pointer_cast<Constant>(axes.get_node_shared_ptr());
|
||||
TENSORFLOW_OP_VALIDATION(node, axes_const != nullptr, "Axes input must be constant.");
|
||||
TENSORFLOW_OP_VALIDATION(node, axes_const->get_shape().size() == 1, "Axes input must be 1D.");
|
||||
TENSORFLOW_OP_VALIDATION(node, axes_const->get_shape()[0] == 1, "Axes input must have only one value.");
|
||||
auto seq_axis = axes_const->cast_vector<int64_t>().at(0);
|
||||
int64_t batch_axis = !seq_axis;
|
||||
|
||||
Output<Node> seq_lengths;
|
||||
if (input.get_partial_shape().is_static()) {
|
||||
auto in_shape = input.get_shape();
|
||||
seq_lengths = make_shared<Constant>(element::i64, Shape{in_shape[batch_axis]}, in_shape[seq_axis]);
|
||||
} else {
|
||||
auto shape = make_shared<ShapeOf>(input);
|
||||
auto one = make_shared<Constant>(element::i64, Shape{1}, 1);
|
||||
auto gather_batch = make_shared<Gather>(shape,
|
||||
make_shared<Constant>(element::i64, Shape{1}, batch_axis),
|
||||
make_shared<Constant>(element::i64, Shape{1}, 0));
|
||||
auto gather_seq = make_shared<Gather>(shape,
|
||||
make_shared<Constant>(element::i64, Shape{1}, seq_axis),
|
||||
make_shared<Constant>(element::i64, Shape{1}, 0));
|
||||
auto broadcast = make_shared<Broadcast>(one, gather_batch);
|
||||
seq_lengths = make_shared<Multiply>(broadcast, gather_seq);
|
||||
OutputVector translate_reverse_base_op(const NodeContext& node,
|
||||
const Output<Node>& input,
|
||||
const std::vector<int64_t>& axes) {
|
||||
auto reverse_node_name = node.get_name();
|
||||
if (axes.size() == 0) {
|
||||
// there is nothing to reverse
|
||||
input.get_tensor().add_names({reverse_node_name + ":0"});
|
||||
return {input};
|
||||
}
|
||||
|
||||
auto res = make_shared<ReverseSequence>(input, seq_lengths, batch_axis, seq_axis);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
axes.size() == 1,
|
||||
"OpenVINO TensorFlow Frontend does not support Reverse or ReverseV2 with multiple axes for the reversing.");
|
||||
|
||||
int64_t seq_axis = axes[0];
|
||||
int64_t batch_axis = 0;
|
||||
|
||||
// when we are not sure that input rank greater than 1
|
||||
// based on seq_axis, introduce the auxiliary dimension for the batch
|
||||
std::vector<int64_t> unsqueeze_axes;
|
||||
if (seq_axis == 0 || seq_axis == -1) {
|
||||
unsqueeze_axes.push_back(0);
|
||||
}
|
||||
|
||||
// make sure that batch and sequence dimensions are different
|
||||
// in case seq_axis is zero, we added the temporal dimension in the previous step
|
||||
// so we have to shift it by one
|
||||
seq_axis = (seq_axis == 0) ? 1 : seq_axis;
|
||||
auto batched_input = input;
|
||||
if (unsqueeze_axes.size() > 0) {
|
||||
// prepare input to issue auxiliary dimensions for batch
|
||||
auto unsqueeze_axes_const = make_shared<Constant>(element::i32, Shape{unsqueeze_axes.size()}, unsqueeze_axes);
|
||||
batched_input = make_shared<Unsqueeze>(input, unsqueeze_axes_const);
|
||||
}
|
||||
|
||||
auto input_shape = make_shared<ShapeOf>(batched_input, element::i32);
|
||||
auto seq_lenghts = compute_sequence_lengths(input_shape, batch_axis, seq_axis);
|
||||
auto reverse_sequence = make_shared<ReverseSequence>(batched_input, seq_lenghts, batch_axis, seq_axis)->output(0);
|
||||
|
||||
if (unsqueeze_axes.size() > 0) {
|
||||
// remove earlier added additional dimensions from the result
|
||||
auto squeeze_axes_const = make_shared<Constant>(element::i32, Shape{unsqueeze_axes.size()}, unsqueeze_axes);
|
||||
reverse_sequence = make_shared<Squeeze>(reverse_sequence, squeeze_axes_const);
|
||||
}
|
||||
|
||||
set_node_name(node.get_name(), reverse_sequence.get_node_shared_ptr());
|
||||
return {reverse_sequence};
|
||||
}
|
||||
|
||||
OutputVector translate_reverse_op(const NodeContext& node) {
|
||||
// The second input of Reverse is a boolean vector.
|
||||
// True elements correspond the axes along which
|
||||
// elements of the input tensor are reversed
|
||||
default_op_checks(node, 2, {"Reverse"});
|
||||
auto input = node.get_input(0);
|
||||
|
||||
std::vector<bool> dims;
|
||||
get_const_input(node, 1, &dims);
|
||||
|
||||
// collect axes along which to reverse
|
||||
std::vector<int64_t> axes;
|
||||
for (int64_t ind = 0; ind < static_cast<int64_t>(dims.size()); ++ind) {
|
||||
if (dims[ind]) {
|
||||
axes.push_back(ind);
|
||||
}
|
||||
}
|
||||
|
||||
return translate_reverse_base_op(node, input, axes);
|
||||
}
|
||||
|
||||
OutputVector translate_reverse_v2_op(const NodeContext& node) {
|
||||
// The second input of ReverseV2 is a vector of axes along which
|
||||
// elements of the input tensor are reversed
|
||||
default_op_checks(node, 2, {"ReverseV2"});
|
||||
auto input = node.get_input(0);
|
||||
|
||||
// the translator is able to convert ReverseV2 only
|
||||
// if axis is constant and has one element.
|
||||
// this limitation is due to the presence of batch_axis and seq_axis attributes.
|
||||
// the current limitation is sufficient for parity with Legacy MO frontend.
|
||||
std::vector<int64_t> axes;
|
||||
get_const_input(node, 1, &axes);
|
||||
|
||||
return translate_reverse_base_op(node, input, axes);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -95,6 +95,7 @@ OP_CONVERTER(translate_reciprocal_op);
|
||||
OP_CONVERTER(translate_reshape_op);
|
||||
OP_CONVERTER(translate_resource_gather_op);
|
||||
OP_CONVERTER(translate_reverse_op);
|
||||
OP_CONVERTER(translate_reverse_v2_op);
|
||||
OP_CONVERTER(translate_reverse_sequence_op);
|
||||
OP_CONVERTER(translate_roll_op);
|
||||
OP_CONVERTER(translate_round_op);
|
||||
@ -277,7 +278,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Reshape", translate_reshape_op},
|
||||
{"Reverse", translate_reverse_op},
|
||||
{"ReverseSequence", translate_reverse_sequence_op},
|
||||
{"ReverseV2", translate_reverse_op},
|
||||
{"ReverseV2", translate_reverse_v2_op},
|
||||
{"ResizeBilinear", translate_interpolate_op},
|
||||
{"ResizeNearestNeighbor", translate_interpolate_op},
|
||||
{"ResourceGather", translate_resource_gather_op},
|
||||
|
31
tests/layer_tests/tensorflow_tests/test_tf_Reverse.py
Normal file
31
tests/layer_tests/tensorflow_tests/test_tf_Reverse.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestReverse(CommonTFLayerTest):
|
||||
def create_reverse_net(self, shape, dims):
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(tf.float32, shape, 'Input')
|
||||
tf.raw_ops.Reverse(tensor=x, dims=dims, name='reverse')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(shape=[4], dims=[True]),
|
||||
dict(shape=[3, 2], dims=[False, True]),
|
||||
dict(shape=[4, 2, 3], dims=[False, True, False]),
|
||||
dict(shape=[1, 2, 4, 3], dims=[True, False, False, False]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_reverse_basic(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):
|
||||
self._test(*self.create_reverse_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api)
|
@ -1,56 +1,33 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestReverseV2Ops(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_dict):
|
||||
for input in inputs_dict.keys():
|
||||
inputs_dict[input] = np.random.random(inputs_dict[input])
|
||||
return inputs_dict
|
||||
|
||||
def create_reversev2_net(self, shape, keep_dims, axis, ir_version):
|
||||
class TestReverseV2(CommonTFLayerTest):
|
||||
def create_reverse_v2_net(self, shape, axis):
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
shapes = shape.copy()
|
||||
if len(shapes) >= 4:
|
||||
shapes.append(shapes.pop(1))
|
||||
|
||||
x = tf.compat.v1.placeholder(tf.float32, shapes, 'Input')
|
||||
tf.compat.v1.reverse_v2(x, axis)
|
||||
x = tf.compat.v1.placeholder(tf.float32, shape, 'Input')
|
||||
tf.raw_ops.ReverseV2(tensor=x, axis=axis, name='reverse')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data = []
|
||||
test_data.extend([
|
||||
test_data_basic = [
|
||||
dict(shape=[5], axis=[0]),
|
||||
pytest.param(dict(shape=[2, 3], axis=[1]), marks=pytest.mark.precommit_tf_fe),
|
||||
dict(shape=[3], axis=[-1]),
|
||||
dict(shape=[2, 3], axis=[1]),
|
||||
dict(shape=[2, 3, 5], axis=[-2]),
|
||||
dict(shape=[2, 3, 5, 7], axis=[0]),
|
||||
])
|
||||
dict(shape=[2, 3, 5, 7], axis=[3]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.parametrize("keep_dims", [True, False])
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.nightly
|
||||
def test_reversev2(self, params, keep_dims, ie_device, precision, ir_version, temp_dir, use_old_api):
|
||||
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_reverse_v2_basic(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):
|
||||
self._test(*self.create_reverse_v2_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api)
|
||||
|
||||
test_data_pre_commit = []
|
||||
test_data_pre_commit.extend([dict(shape=[5], axis=[0]),
|
||||
dict(shape=[2, 3, 5], axis=[-2])
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_pre_commit)
|
||||
@pytest.mark.parametrize("keep_dims", [True])
|
||||
@pytest.mark.precommit
|
||||
def test_reversev2_precommit(self, params, keep_dims, ie_device, precision, ir_version,
|
||||
temp_dir, use_old_api):
|
||||
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api, use_new_frontend=False)
|
||||
|
Loading…
Reference in New Issue
Block a user