[TF FE][TF Hub] Support UnsortedSegmentSum operation (#19165)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-08-13 21:49:38 +04:00 committed by GitHub
parent 6067ab17ba
commit 329abd8864
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 1 deletions

View File

@ -272,6 +272,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Transpose", CreatorFunction(translate_transpose_op)},
{"Unpack", CreatorFunction(translate_unpack_op)},
{"UnravelIndex", CreatorFunction(translate_unravel_index_op)},
{"UnsortedSegmentSum", CreatorFunction(translate_unsorted_segment_sum_op)},
{"While", CreatorFunction(translate_while_op)},
{"Where", CreatorFunction(translate_where_op)},
{"Xdivy", CreatorFunction(translate_x_div_y_op)},

View File

@ -142,9 +142,9 @@ OP_CONVERTER_NAMED(translate_top_k_v2_op);
OP_CONVERTER(translate_transpose_op);
OP_CONVERTER(translate_unpack_op);
OP_CONVERTER(translate_unravel_index_op);
OP_CONVERTER(translate_unsorted_segment_sum_op);
OP_CONVERTER(translate_where_op);
OP_CONVERTER(translate_x_div_y_op);
OP_CONVERTER(translate_xla_dot_op);
OP_CONVERTER(translate_zeros_like_op);
// Translators for internal operations

View File

@ -149,6 +149,12 @@ void convert_nhwc_to_hw(bool is_nhwc, const std::vector<T>& src, std::vector<siz
}
}
// retrieve data slices collected in a range [start; stop) by the first dimension
ov::Output<ov::Node> get_data_slice(const ov::Output<ov::Node>& data,
const int64_t& start,
const int64_t& stop,
const int64_t& step);
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,91 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_op_table.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/embedding_segments_sum.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/topk.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::op;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_unsorted_segment_sum_op(const NodeContext& node) {
default_op_checks(node, 3, {"UnsortedSegmentSum"});
auto data = node.get_input(0);
auto segment_ids = node.get_input(1);
auto num_segments = node.get_input(2);
// convert both segment_ids and num_segments to int64 type
// since EmbeddingSegmentsSum requires to have them of the same type
segment_ids = make_shared<v0::Convert>(segment_ids, element::i64);
num_segments = make_shared<v0::Convert>(num_segments, element::i64);
// create auxiliary constants
auto const_zero_i64 = make_shared<v0::Constant>(element::i64, Shape{}, 0);
auto const_one_i64 = make_shared<v0::Constant>(element::i64, Shape{1}, 1);
auto const_one_i64_scalar = make_shared<v0::Constant>(element::i64, Shape{}, 1);
auto data_const_zero = create_same_type_const_scalar<float>(data, 0.0f);
// segment ids can be negative for which the resulted data will be zeroed
// so it needs to introduce default slice of zeros in the data
// 1. create default slice that will be used for negative segment ids
auto data_shape = make_shared<v3::ShapeOf>(data, element::i64);
auto slice_shape = get_data_slice(data_shape, 1, numeric_limits<int>::max(), 1);
slice_shape = make_shared<v0::Concat>(OutputVector{const_one_i64, slice_shape}, 0);
auto default_slice = make_shared<v3::Broadcast>(data_const_zero, slice_shape);
// 2. update data with the default slice
data = make_shared<v0::Concat>(OutputVector{data, default_slice}, 0);
// compute default index
auto squeeze_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto default_index = get_data_slice(data_shape, 0, 1, 1);
default_index = make_shared<v0::Squeeze>(default_index, squeeze_axis);
// adjust segment ids to have zero instead of negative values
auto is_negative_segment_id = make_shared<v1::Less>(segment_ids, const_zero_i64);
segment_ids = make_shared<v1::Select>(is_negative_segment_id, const_zero_i64, segment_ids);
// generate indices input for EmbeddingSegmentSum
// that will collect slices consequently from data for each segment
auto segment_ids_shape = make_shared<v3::ShapeOf>(segment_ids, element::i64);
auto num_indices = make_shared<v0::Squeeze>(segment_ids_shape, squeeze_axis);
auto indices = make_shared<v4::Range>(const_zero_i64, num_indices, const_one_i64_scalar, element::i64)->output(0);
// adjust the generated indices to retrieve the default slice for original negative segment ids
indices = make_shared<v1::Select>(is_negative_segment_id, default_index, indices);
// since EmbeddingSegmentSum accepts only sorted segments ids
// it needs to sort them and reorder indices
auto topk =
make_shared<v11::TopK>(segment_ids, num_indices, 0, TopKMode::MIN, TopKSortType::SORT_VALUES, element::i32);
segment_ids = topk->output(0);
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
indices = make_shared<v8::Gather>(indices, topk->output(1), gather_axis);
// compute UnsortedSegmentSum using EmbeddingSegmentSum
auto unsorted_segment_sum =
make_shared<v3::EmbeddingSegmentsSum>(data, indices, segment_ids, num_segments, default_index);
set_node_name(node.get_name(), unsorted_segment_sum);
return {unsorted_segment_sum};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -347,6 +347,13 @@ shared_ptr<Reshape> make_reshape(const Output<Node>& arg, const vector<int64_t>&
return reshape;
}
Output<Node> get_data_slice(const Output<Node>& data, const int64_t& start, const int64_t& stop, const int64_t& step) {
auto start_const = make_shared<Constant>(element::i64, Shape{1}, start);
auto stop_const = make_shared<Constant>(element::i64, Shape{1}, stop);
auto step_const = make_shared<Constant>(element::i64, Shape{1}, step);
return make_shared<Slice>(data, start_const, stop_const, step_const)->output(0);
}
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,67 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestUnsortedSegmentSum(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info, "Test error: inputs_info must contain `data`"
assert 'segment_ids' in inputs_info, "Test error: inputs_info must contain `segment_ids`"
data_shape = inputs_info['data']
segment_ids_shape = inputs_info['segment_ids']
inputs_data = {}
inputs_data['data'] = np.random.randint(-50, 50, data_shape).astype(self.data_type)
# segment_ids can have negative values
inputs_data['segment_ids'] = np.random.randint(-self.num_segments_val, self.num_segments_val, segment_ids_shape)
return inputs_data
def create_unsorted_segment_sum_net(self, data_shape, segment_ids_shape, num_segments_val, data_type,
segment_ids_type, num_segments_type):
self.data_type = data_type
self.segment_ids_type = segment_ids_type
self.num_segments_val = num_segments_val
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
segment_ids = tf.compat.v1.placeholder(segment_ids_type, segment_ids_shape, 'segment_ids')
num_segments = tf.constant(num_segments_val, dtype=num_segments_type, shape=[])
tf.raw_ops.UnsortedSegmentSum(data=data, segment_ids=segment_ids, num_segments=num_segments)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(data_shape=[8], segment_ids_shape=[8], num_segments_val=10),
dict(data_shape=[10, 4], segment_ids_shape=[10], num_segments_val=5),
dict(data_shape=[5, 6, 7], segment_ids_shape=[5], num_segments_val=100),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.parametrize("data_type", [
np.float32, np.int32
])
@pytest.mark.parametrize("segment_ids_type", [
np.int32, np.int64
])
@pytest.mark.parametrize("num_segments_type", [
np.int32, np.int64
])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_unsorted_segment_sum_basic(self, params, data_type, segment_ids_type, num_segments_type, ie_device,
precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
if not use_new_frontend:
pytest.skip("UnsortedSegmentSum operation is not supported via legacy frontend.")
self._test(
*self.create_unsorted_segment_sum_net(**params, data_type=data_type, segment_ids_type=segment_ids_type,
num_segments_type=num_segments_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)