[TF FE] Refactor LinSpace translator and add layer test (#15495)
* [TF FE] Refactor LinSpace translator and add layer test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Remove start_shape from test parameters --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
d9dbf23ea3
commit
2cae7479a0
@ -13,68 +13,25 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_linspace_op(const NodeContext& node) {
|
||||
/* LinSpace operation can be expressed in the following form:
|
||||
* 1) Compute deltas by which each element will be increased in each slice
|
||||
* through new dimension as delta = (stop - start) / (num - 1)
|
||||
* 2) Generate a range of numbers by which times delta will be added to start to
|
||||
* compute new elements in each slide as range = [0, 1, ..., num - 1]
|
||||
* 3) Unsqueeze start and delta by new axis. And unsqueeze range by axes except given axis
|
||||
* 4) Compute the result of the operation as result = start + delta * range
|
||||
*/
|
||||
auto num_inputs = node.get_input_size();
|
||||
// The operation is simple that generates a range [start, ..., stop]
|
||||
// with num elements staying in the same distance between each other
|
||||
default_op_checks(node, 3, {"LinSpace"});
|
||||
auto start = node.get_input(0);
|
||||
auto stop = node.get_input(1);
|
||||
auto num = node.get_input(2);
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node, start.get_partial_shape().rank().is_static(), "Input rank must be static.");
|
||||
int64_t start_rank = start.get_partial_shape().rank().get_length();
|
||||
// compute delta value, i.e. distance between neighbor values of the result
|
||||
auto const_one = make_shared<Constant>(num.get_element_type(), Shape{}, 1);
|
||||
Output<Node> num_minus_one = make_shared<Subtract>(num, const_one);
|
||||
num_minus_one = make_shared<Convert>(num_minus_one, start.get_element_type());
|
||||
Output<Node> delta = make_shared<Subtract>(stop, start);
|
||||
delta = make_shared<Divide>(delta, num_minus_one);
|
||||
|
||||
// retrieve axis from Constant node and compute a range of axes except given axis
|
||||
// for unsqueezing start and delta tensors
|
||||
std::vector<int64_t> axis;
|
||||
std::vector<int64_t> except_axis_range;
|
||||
if (num_inputs > 3 && start_rank > 0) {
|
||||
get_const_input(node, 3, &axis);
|
||||
TENSORFLOW_OP_VALIDATION(node, axis.size() == 1, "Axis must be a scalar for LinSpace operation.");
|
||||
axis[0] = axis[0] >= 0 ? axis[0] : start_rank + 1 + axis[0];
|
||||
for (int64_t dim_ind = 0; dim_ind < start_rank + 1; ++dim_ind) {
|
||||
if (dim_ind != axis[0]) {
|
||||
except_axis_range.push_back(dim_ind);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
axis.empty() && start_rank == 0 || axis.size() == 1 && start_rank > 0,
|
||||
"Axis must be used only if input for LinSpace operation is ND tensor.");
|
||||
|
||||
auto one = make_shared<Constant>(num.get_element_type(), Shape{}, 1);
|
||||
auto num_minus_1 = make_shared<ConvertLike>(make_shared<Subtract>(num, one), start);
|
||||
auto delta = make_shared<Divide>(make_shared<Subtract>(stop, start), num_minus_1);
|
||||
|
||||
auto zero = make_shared<Constant>(num.get_element_type(), Shape{}, 0);
|
||||
auto range_0_num_minus_1 =
|
||||
make_shared<ConvertLike>(make_shared<Range>(zero, num, one, num.get_element_type()), start);
|
||||
|
||||
// convert a case with scalar inputs
|
||||
if (axis.empty() && start_rank == 0) {
|
||||
auto delta_mul_range = make_shared<Multiply>(delta, range_0_num_minus_1);
|
||||
auto result = make_shared<Add>(start, delta_mul_range);
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
}
|
||||
|
||||
auto const_axis = make_shared<Constant>(element::i64, Shape{axis.size()}, axis);
|
||||
auto const_except_axis = make_shared<Constant>(element::i64, Shape{except_axis_range.size()}, except_axis_range);
|
||||
|
||||
auto unsqueeze_start = make_shared<Unsqueeze>(start, const_axis);
|
||||
auto unsqueeze_delta = make_shared<Unsqueeze>(delta, const_axis);
|
||||
auto unsqueeze_range = make_shared<Unsqueeze>(range_0_num_minus_1, const_except_axis);
|
||||
|
||||
auto delta_mul_range = make_shared<Multiply>(unsqueeze_delta, unsqueeze_range);
|
||||
auto result = make_shared<Add>(unsqueeze_start, delta_mul_range);
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
// compute the result
|
||||
auto stop_plus_delta = make_shared<Add>(stop, delta);
|
||||
auto linspace = make_shared<Range>(start, stop_plus_delta, delta, start.get_element_type());
|
||||
set_node_name(node.get_name(), linspace);
|
||||
return {linspace};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
35
tests/layer_tests/tensorflow_tests/test_tf_LinSpace.py
Normal file
35
tests/layer_tests/tensorflow_tests/test_tf_LinSpace.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestLinSpace(CommonTFLayerTest):
|
||||
def create_lin_space_net(self, num_value):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
start = tf.compat.v1.placeholder(tf.float32, [], 'start')
|
||||
stop = tf.compat.v1.placeholder(tf.float32, [], 'stop')
|
||||
tf.raw_ops.LinSpace(start=start, stop=stop, num=num_value)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(num_value=2),
|
||||
dict(num_value=10),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_lin_space_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_lin_space_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user