[TF FE] Refactor Gather operations and add layer tests (#15808)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-20 11:02:42 +04:00 committed by GitHub
parent d9f0890a84
commit 699a1d1708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 52 deletions

View File

@ -15,16 +15,17 @@ namespace op {
OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output<ov::Node>& axis, int64_t batch_dims) {
auto op_type = node.get_op_type();
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 2, op_type + " must have at least two inputs.");
auto input = node.get_input(0);
auto input_indices = node.get_input(1);
auto res = make_shared<Gather>(input, input_indices, axis, batch_dims);
set_node_name(node.get_name(), res);
return res->outputs();
auto params = node.get_input(0);
auto indices = node.get_input(1);
auto gather = make_shared<Gather>(params, indices, axis, batch_dims);
set_node_name(node.get_name(), gather);
return {gather};
}
OutputVector translate_gather_op(const NodeContext& node) {
// Gather has two inputs: data and indices
// axis by which data is sliced is always equal to 0, batch_dims is always equal to 0
default_op_checks(node, 2, {"Gather"});
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
return translate_basic_gather_op(node, axis, 0);
}
@ -32,15 +33,16 @@ OutputVector translate_gather_op(const NodeContext& node) {
OutputVector translate_resource_gather_op(const NodeContext& node) {
// ResourceGather has two inputs: data and indices
// axis by which data is sliced is always equal to 0, batch_dims is an attribute and can vary
default_op_checks(node, 2, {"ResourceGather"});
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
return translate_basic_gather_op(node, axis, batch_dims);
}
OutputVector translate_gather_v2_op(const NodeContext& node) {
// ResourceGather has three inputs: data, indices, and axis by which data is sliced
// GatherV2 has three inputs: data, indices, and axis by which data is sliced
// batch_dims is an attribute and can vary
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "GatherV2 must have at least three inputs.");
default_op_checks(node, 3, {"GatherV2"});
auto axis = node.get_input(2);
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
return translate_basic_gather_op(node, axis, batch_dims);
@ -49,12 +51,13 @@ OutputVector translate_gather_v2_op(const NodeContext& node) {
OutputVector translate_gather_nd_op(const NodeContext& node) {
// GatherND has two inputs: data and indices
// batch_dims is always equal to 0
default_op_checks(node, 2, {"GatherNd", "GATHER_ND"});
auto input = node.get_input(0);
auto input_indices = node.get_input(1);
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
auto res = make_shared<GatherND>(input, input_indices, batch_dims);
set_node_name(node.get_name(), res);
return res->outputs();
auto gather_nd = make_shared<GatherND>(input, input_indices, batch_dims);
set_node_name(node.get_name(), gather_nd);
return {gather_nd};
}
} // namespace op

View File

@ -1,26 +1,48 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestGather(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'params' in inputs_info
assert 'indices' in inputs_info
params_shape = inputs_info['params']
indices_shape = inputs_info['indices']
inputs_data = {}
inputs_data['params'] = np.random.randint(-50, 50, params_shape).astype(self.params_type)
inputs_data['indices'] = np.random.randint(0, self.max_index, indices_shape).astype(self.indices_type)
return inputs_data
def create_indices_constant(self):
pass
def create_gather_net(self, data_shape, indices, axis, batch_dims, use_new_frontend, **kwargs):
import tensorflow as tf
def create_gather_net(self, params_shape, params_type, indices_shape, indices_type, axis_value, batch_dims,
operation_type):
self.params_type = params_type
self.indices_type = indices_type
if batch_dims is None:
batch_dims = 0
if axis_value is None:
axis_value = 0
axis_norm = axis_value
if axis_norm < 0:
axis_norm += len(params_shape)
assert 0 <= axis_norm < len(params_shape), "Incorrect `axis` value for the test case"
self.max_index = params_shape[axis_norm]
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
data = tf.compat.v1.placeholder(tf.float32, data_shape, 'data')
indices = tf.constant(indices, dtype=tf.int32)
gather = tf.gather(data, indices, axis=axis, batch_dims=batch_dims,
name='gather_output')
params = tf.compat.v1.placeholder(params_type, params_shape, 'params')
indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
if operation_type == "Gather":
tf.raw_ops.Gather(params=params, indices=indices)
elif operation_type == "GatherV2":
axis = tf.constant(axis_value, dtype=tf.int32)
tf.raw_ops.GatherV2(params=params, indices=indices, axis=axis, batch_dims=batch_dims)
else:
assert False, "Incorrect operation type is tested"
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
@ -30,41 +52,25 @@ class TestGather(CommonTFLayerTest):
return tf_net, ref_net
test_data_precommit = [
dict(data_shape=[6, 8, 10, 12], indices=[[0, 2, 4], [5, 7, 9]], axis=2, batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=1, batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=-1, batch_dims=0)
dict(params_shape=[4, 6], params_type=np.float32, indices_shape=[], indices_type=np.int32,
axis_value=None, batch_dims=None,
operation_type="Gather"),
dict(params_shape=[3, 4, 6], params_type=np.float32, indices_shape=[3, 4], indices_type=np.int32,
axis_value=None, batch_dims=None,
operation_type="Gather"),
dict(params_shape=[5, 4, 3], params_type=np.int32, indices_shape=[5, 2, 1], indices_type=np.int64,
axis_value=2, batch_dims=1,
operation_type="GatherV2"),
dict(params_shape=[3, 2, 6, 4], params_type=np.float32, indices_shape=[3, 2, 1, 3], indices_type=np.int32,
axis_value=-1, batch_dims=-2,
operation_type="GatherV2"),
]
@pytest.mark.parametrize("params", test_data_precommit)
@pytest.mark.precommit
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_gather(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_gather_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_nightly = [
dict(data_shape=[2, 3], axis=1, indices=[0, 2], batch_dims=0),
dict(data_shape=[10, 12], axis=0, indices=[3, 6], batch_dims=0),
dict(data_shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]],
batch_dims=0),
dict(data_shape=[8, 10, 12], axis=0, indices=[3, 6], batch_dims=0),
pytest.param(dict(data_shape=[8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
marks=pytest.mark.precommit_tf_fe),
dict(data_shape=[6, 8, 10, 12], axis=0, indices=[2, 5], batch_dims=0),
dict(data_shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
dict(data_shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], batch_dims=0),
dict(data_shape=[2, 14, 10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]],
batch_dims=1),
dict(data_shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
]
@pytest.mark.parametrize("params", test_data_nightly)
@pytest.mark.nightly
def test_gather_nightly(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_gather_net(**params, use_new_frontend=use_new_frontend),
self._test(*self.create_gather_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)