[TF FE] Refactor LinSpace translator and add layer test (#15495)

* [TF FE] Refactor LinSpace translator and add layer test

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Remove start_shape from test parameters

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-04 00:14:47 +04:00 committed by GitHub
parent d9dbf23ea3
commit 2cae7479a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 57 deletions

View File

@ -13,68 +13,25 @@ namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_linspace_op(const NodeContext& node) {
/* LinSpace operation can be expressed in the following form:
* 1) Compute deltas by which each element will be increased in each slice
* through new dimension as delta = (stop - start) / (num - 1)
* 2) Generate a range of numbers by which times delta will be added to start to
* compute new elements in each slide as range = [0, 1, ..., num - 1]
* 3) Unsqueeze start and delta by new axis. And unsqueeze range by axes except given axis
* 4) Compute the result of the operation as result = start + delta * range
*/
auto num_inputs = node.get_input_size();
// The operation is simple that generates a range [start, ..., stop]
// with num elements staying in the same distance between each other
default_op_checks(node, 3, {"LinSpace"});
auto start = node.get_input(0);
auto stop = node.get_input(1);
auto num = node.get_input(2);
TENSORFLOW_OP_VALIDATION(node, start.get_partial_shape().rank().is_static(), "Input rank must be static.");
int64_t start_rank = start.get_partial_shape().rank().get_length();
// compute delta value, i.e. distance between neighbor values of the result
auto const_one = make_shared<Constant>(num.get_element_type(), Shape{}, 1);
Output<Node> num_minus_one = make_shared<Subtract>(num, const_one);
num_minus_one = make_shared<Convert>(num_minus_one, start.get_element_type());
Output<Node> delta = make_shared<Subtract>(stop, start);
delta = make_shared<Divide>(delta, num_minus_one);
// retrieve axis from Constant node and compute a range of axes except given axis
// for unsqueezing start and delta tensors
std::vector<int64_t> axis;
std::vector<int64_t> except_axis_range;
if (num_inputs > 3 && start_rank > 0) {
get_const_input(node, 3, &axis);
TENSORFLOW_OP_VALIDATION(node, axis.size() == 1, "Axis must be a scalar for LinSpace operation.");
axis[0] = axis[0] >= 0 ? axis[0] : start_rank + 1 + axis[0];
for (int64_t dim_ind = 0; dim_ind < start_rank + 1; ++dim_ind) {
if (dim_ind != axis[0]) {
except_axis_range.push_back(dim_ind);
}
}
}
TENSORFLOW_OP_VALIDATION(node,
axis.empty() && start_rank == 0 || axis.size() == 1 && start_rank > 0,
"Axis must be used only if input for LinSpace operation is ND tensor.");
auto one = make_shared<Constant>(num.get_element_type(), Shape{}, 1);
auto num_minus_1 = make_shared<ConvertLike>(make_shared<Subtract>(num, one), start);
auto delta = make_shared<Divide>(make_shared<Subtract>(stop, start), num_minus_1);
auto zero = make_shared<Constant>(num.get_element_type(), Shape{}, 0);
auto range_0_num_minus_1 =
make_shared<ConvertLike>(make_shared<Range>(zero, num, one, num.get_element_type()), start);
// convert a case with scalar inputs
if (axis.empty() && start_rank == 0) {
auto delta_mul_range = make_shared<Multiply>(delta, range_0_num_minus_1);
auto result = make_shared<Add>(start, delta_mul_range);
set_node_name(node.get_name(), result);
return result->outputs();
}
auto const_axis = make_shared<Constant>(element::i64, Shape{axis.size()}, axis);
auto const_except_axis = make_shared<Constant>(element::i64, Shape{except_axis_range.size()}, except_axis_range);
auto unsqueeze_start = make_shared<Unsqueeze>(start, const_axis);
auto unsqueeze_delta = make_shared<Unsqueeze>(delta, const_axis);
auto unsqueeze_range = make_shared<Unsqueeze>(range_0_num_minus_1, const_except_axis);
auto delta_mul_range = make_shared<Multiply>(unsqueeze_delta, unsqueeze_range);
auto result = make_shared<Add>(unsqueeze_start, delta_mul_range);
set_node_name(node.get_name(), result);
return result->outputs();
// compute the result
auto stop_plus_delta = make_shared<Add>(stop, delta);
auto linspace = make_shared<Range>(start, stop_plus_delta, delta, start.get_element_type());
set_node_name(node.get_name(), linspace);
return {linspace};
}
} // namespace op
} // namespace tensorflow

View File

@ -0,0 +1,35 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestLinSpace(CommonTFLayerTest):
def create_lin_space_net(self, num_value):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
start = tf.compat.v1.placeholder(tf.float32, [], 'start')
stop = tf.compat.v1.placeholder(tf.float32, [], 'stop')
tf.raw_ops.LinSpace(start=start, stop=stop, num=num_value)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(num_value=2),
dict(num_value=10),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_lin_space_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_lin_space_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)