From 329abd886491364029afed864c92432f37bae845 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Sun, 13 Aug 2023 21:49:38 +0400 Subject: [PATCH] [TF FE][TF Hub] Support UnsortedSegmentSum operation (#19165) Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/op_table.cpp | 1 + .../include/common_op_table.hpp | 2 +- .../tensorflow_common/include/utils.hpp | 6 ++ .../src/op/unsorted_segment_sum.cpp | 91 +++++++++++++++++++ src/frontends/tensorflow_common/src/utils.cpp | 7 ++ .../test_tf_UnsortedSegmentSum.py | 67 ++++++++++++++ 6 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 src/frontends/tensorflow_common/src/op/unsorted_segment_sum.cpp create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_UnsortedSegmentSum.py diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 6958cef9559..955701d807c 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -272,6 +272,7 @@ const std::map get_supported_ops() { {"Transpose", CreatorFunction(translate_transpose_op)}, {"Unpack", CreatorFunction(translate_unpack_op)}, {"UnravelIndex", CreatorFunction(translate_unravel_index_op)}, + {"UnsortedSegmentSum", CreatorFunction(translate_unsorted_segment_sum_op)}, {"While", CreatorFunction(translate_while_op)}, {"Where", CreatorFunction(translate_where_op)}, {"Xdivy", CreatorFunction(translate_x_div_y_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index dce77f82a9d..2baab49b74a 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -142,9 +142,9 @@ OP_CONVERTER_NAMED(translate_top_k_v2_op); OP_CONVERTER(translate_transpose_op); OP_CONVERTER(translate_unpack_op); OP_CONVERTER(translate_unravel_index_op); +OP_CONVERTER(translate_unsorted_segment_sum_op); OP_CONVERTER(translate_where_op); OP_CONVERTER(translate_x_div_y_op); -OP_CONVERTER(translate_xla_dot_op); OP_CONVERTER(translate_zeros_like_op); // Translators for internal operations diff --git a/src/frontends/tensorflow_common/include/utils.hpp b/src/frontends/tensorflow_common/include/utils.hpp index 1abbb13a3b2..acca76aaab8 100644 --- a/src/frontends/tensorflow_common/include/utils.hpp +++ b/src/frontends/tensorflow_common/include/utils.hpp @@ -149,6 +149,12 @@ void convert_nhwc_to_hw(bool is_nhwc, const std::vector& src, std::vector get_data_slice(const ov::Output& data, + const int64_t& start, + const int64_t& stop, + const int64_t& step); + } // namespace tensorflow } // namespace frontend } // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/unsorted_segment_sum.cpp b/src/frontends/tensorflow_common/src/op/unsorted_segment_sum.cpp new file mode 100644 index 00000000000..24eecc168f5 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/unsorted_segment_sum.cpp @@ -0,0 +1,91 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/embedding_segments_sum.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/topk.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { +OutputVector translate_unsorted_segment_sum_op(const NodeContext& node) { + default_op_checks(node, 3, {"UnsortedSegmentSum"}); + auto data = node.get_input(0); + auto segment_ids = node.get_input(1); + auto num_segments = node.get_input(2); + + // convert both segment_ids and num_segments to int64 type + // since EmbeddingSegmentsSum requires to have them of the same type + segment_ids = make_shared(segment_ids, element::i64); + num_segments = make_shared(num_segments, element::i64); + + // create auxiliary constants + auto const_zero_i64 = make_shared(element::i64, Shape{}, 0); + auto const_one_i64 = make_shared(element::i64, Shape{1}, 1); + auto const_one_i64_scalar = make_shared(element::i64, Shape{}, 1); + auto data_const_zero = create_same_type_const_scalar(data, 0.0f); + + // segment ids can be negative for which the resulted data will be zeroed + // so it needs to introduce default slice of zeros in the data + // 1. create default slice that will be used for negative segment ids + auto data_shape = make_shared(data, element::i64); + auto slice_shape = get_data_slice(data_shape, 1, numeric_limits::max(), 1); + slice_shape = make_shared(OutputVector{const_one_i64, slice_shape}, 0); + auto default_slice = make_shared(data_const_zero, slice_shape); + // 2. update data with the default slice + data = make_shared(OutputVector{data, default_slice}, 0); + + // compute default index + auto squeeze_axis = make_shared(element::i32, Shape{1}, 0); + auto default_index = get_data_slice(data_shape, 0, 1, 1); + default_index = make_shared(default_index, squeeze_axis); + + // adjust segment ids to have zero instead of negative values + auto is_negative_segment_id = make_shared(segment_ids, const_zero_i64); + segment_ids = make_shared(is_negative_segment_id, const_zero_i64, segment_ids); + + // generate indices input for EmbeddingSegmentSum + // that will collect slices consequently from data for each segment + auto segment_ids_shape = make_shared(segment_ids, element::i64); + auto num_indices = make_shared(segment_ids_shape, squeeze_axis); + auto indices = make_shared(const_zero_i64, num_indices, const_one_i64_scalar, element::i64)->output(0); + + // adjust the generated indices to retrieve the default slice for original negative segment ids + indices = make_shared(is_negative_segment_id, default_index, indices); + + // since EmbeddingSegmentSum accepts only sorted segments ids + // it needs to sort them and reorder indices + auto topk = + make_shared(segment_ids, num_indices, 0, TopKMode::MIN, TopKSortType::SORT_VALUES, element::i32); + segment_ids = topk->output(0); + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + indices = make_shared(indices, topk->output(1), gather_axis); + + // compute UnsortedSegmentSum using EmbeddingSegmentSum + auto unsorted_segment_sum = + make_shared(data, indices, segment_ids, num_segments, default_index); + set_node_name(node.get_name(), unsorted_segment_sum); + return {unsorted_segment_sum}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/utils.cpp b/src/frontends/tensorflow_common/src/utils.cpp index 83c2f6e8796..4d727cc1d77 100644 --- a/src/frontends/tensorflow_common/src/utils.cpp +++ b/src/frontends/tensorflow_common/src/utils.cpp @@ -347,6 +347,13 @@ shared_ptr make_reshape(const Output& arg, const vector& return reshape; } +Output get_data_slice(const Output& data, const int64_t& start, const int64_t& stop, const int64_t& step) { + auto start_const = make_shared(element::i64, Shape{1}, start); + auto stop_const = make_shared(element::i64, Shape{1}, stop); + auto step_const = make_shared(element::i64, Shape{1}, step); + return make_shared(data, start_const, stop_const, step_const)->output(0); +} + } // namespace tensorflow } // namespace frontend } // namespace ov diff --git a/tests/layer_tests/tensorflow_tests/test_tf_UnsortedSegmentSum.py b/tests/layer_tests/tensorflow_tests/test_tf_UnsortedSegmentSum.py new file mode 100644 index 00000000000..09afd6f2633 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_UnsortedSegmentSum.py @@ -0,0 +1,67 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestUnsortedSegmentSum(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'data' in inputs_info, "Test error: inputs_info must contain `data`" + assert 'segment_ids' in inputs_info, "Test error: inputs_info must contain `segment_ids`" + data_shape = inputs_info['data'] + segment_ids_shape = inputs_info['segment_ids'] + inputs_data = {} + inputs_data['data'] = np.random.randint(-50, 50, data_shape).astype(self.data_type) + # segment_ids can have negative values + inputs_data['segment_ids'] = np.random.randint(-self.num_segments_val, self.num_segments_val, segment_ids_shape) + return inputs_data + + def create_unsorted_segment_sum_net(self, data_shape, segment_ids_shape, num_segments_val, data_type, + segment_ids_type, num_segments_type): + self.data_type = data_type + self.segment_ids_type = segment_ids_type + self.num_segments_val = num_segments_val + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + data = tf.compat.v1.placeholder(data_type, data_shape, 'data') + segment_ids = tf.compat.v1.placeholder(segment_ids_type, segment_ids_shape, 'segment_ids') + num_segments = tf.constant(num_segments_val, dtype=num_segments_type, shape=[]) + tf.raw_ops.UnsortedSegmentSum(data=data, segment_ids=segment_ids, num_segments=num_segments) + tf.compat.v1.global_variables_initializer() + + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(data_shape=[8], segment_ids_shape=[8], num_segments_val=10), + dict(data_shape=[10, 4], segment_ids_shape=[10], num_segments_val=5), + dict(data_shape=[5, 6, 7], segment_ids_shape=[5], num_segments_val=100), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.parametrize("data_type", [ + np.float32, np.int32 + ]) + @pytest.mark.parametrize("segment_ids_type", [ + np.int32, np.int64 + ]) + @pytest.mark.parametrize("num_segments_type", [ + np.int32, np.int64 + ]) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_unsorted_segment_sum_basic(self, params, data_type, segment_ids_type, num_segments_type, ie_device, + precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + if not use_new_frontend: + pytest.skip("UnsortedSegmentSum operation is not supported via legacy frontend.") + self._test( + *self.create_unsorted_segment_sum_net(**params, data_type=data_type, segment_ids_type=segment_ids_type, + num_segments_type=num_segments_type), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)