[TF FE] Support SegmentSum operation (#13354)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-10-07 17:01:14 +03:00 committed by GitHub
parent cdb486d838
commit f5febef8a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,48 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset9.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset9;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_segment_sum_op(const NodeContext& node) {
default_op_checks(node, 2, {"SegmentSum"});
auto data = node.get_input(0);
auto segment_ids = node.get_input(1);
// compute SegmentSum using EmbeddingSegmentSum
// for this prepare all the required inputs
auto indices_type = segment_ids.get_element_type();
// 1. compute a number of segments using segment_ids values
// do not forget that segment ids are counting from zero
auto reduction_axis = make_shared<Constant>(element::i32, Shape{1}, 0);
auto num_segments_minus1 = make_shared<ReduceMax>(segment_ids, reduction_axis, false);
auto one = make_shared<Constant>(indices_type, Shape{}, 1);
auto num_segments = make_shared<Add>(num_segments_minus1, one);
// 2. generate indices input for EmbeddingSegmentSum
// that will collect slices consequently from data for each segment
auto squeeze_axis = make_shared<Constant>(element::i32, Shape{1}, 0);
auto segment_ids_shape = make_shared<ShapeOf>(segment_ids, indices_type);
auto num_indices = make_shared<Squeeze>(segment_ids_shape, squeeze_axis);
auto indices = make_shared<Range>(make_shared<Constant>(indices_type, ov::Shape{}, 0),
num_indices,
make_shared<Constant>(indices_type, ov::Shape{}, 1),
indices_type);
auto emb_segment_sum = make_shared<EmbeddingSegmentsSum>(data, indices, segment_ids, num_segments);
set_node_name(node.get_name(), emb_segment_sum);
return {emb_segment_sum};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -95,6 +95,7 @@ OP_CONVERTER(translate_roll_op);
OP_CONVERTER(translate_round_op);
OP_CONVERTER(translate_rsqrt_op);
OP_CONVERTER(translate_scatter_nd_op);
OP_CONVERTER(translate_segment_sum_op);
OP_CONVERTER(translate_sparse_to_dense_op);
OP_CONVERTER(translate_select_op);
OP_CONVERTER(translate_shape_op);
@ -274,6 +275,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Round", translate_round_op},
{"Rsqrt", translate_rsqrt_op},
{"ScatterNd", translate_scatter_nd_op},
{"SegmentSum", translate_segment_sum_op},
{"SparseToDense", translate_sparse_to_dense_op},
{"Select", translate_select_op},
{"SelectV2", translate_select_op},

View File

@ -0,0 +1,62 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestSegmentSum(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info, "Test error: inputs_info must contain `data`"
assert 'segment_ids' in inputs_info, "Test error: inputs_info must contain `segment_ids`"
data_shape = inputs_info['data']
segment_ids_shape = inputs_info['segment_ids']
inputs_data = {}
inputs_data['data'] = np.random.randint(-50, 50, data_shape)
# segment_ids data must be sorted according to TensorFlow SegmentSum specification
inputs_data['segment_ids'] = np.sort(np.random.randint(0, 20, segment_ids_shape))
return inputs_data
def create_segment_sum_net(self, data_shape, segment_ids_shape, data_type, segment_ids_type):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
segment_ids = tf.compat.v1.placeholder(segment_ids_type, segment_ids_shape, 'segment_ids')
tf.math.segment_sum(data, segment_ids)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(data_shape=[8], segment_ids_shape=[8], data_type=tf.float32, segment_ids_type=tf.int32),
dict(data_shape=[3, 7], segment_ids_shape=[3], data_type=tf.float32, segment_ids_type=tf.int32),
dict(data_shape=[4, 3, 2], segment_ids_shape=[4], data_type=tf.float32, segment_ids_type=tf.int32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.nightly
def test_segment_sum_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_segment_sum_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_different_types = [
dict(data_shape=[2, 3], segment_ids_shape=[2], data_type=tf.int32, segment_ids_type=tf.int32),
dict(data_shape=[3, 2], segment_ids_shape=[3], data_type=tf.float64, segment_ids_type=tf.int32),
dict(data_shape=[3, 1, 2], segment_ids_shape=[3], data_type=tf.float32, segment_ids_type=tf.int64),
dict(data_shape=[4, 2, 1], segment_ids_shape=[4], data_type=tf.float64, segment_ids_type=tf.int64),
]
@pytest.mark.parametrize("params", test_data_different_types)
@pytest.mark.nightly
def test_segment_sum_different_types(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_segment_sum_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)