[TF FE] Refactor Gather operations and add layer tests (#15808)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
d9f0890a84
commit
699a1d1708
@ -15,16 +15,17 @@ namespace op {
|
||||
OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output<ov::Node>& axis, int64_t batch_dims) {
|
||||
auto op_type = node.get_op_type();
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 2, op_type + " must have at least two inputs.");
|
||||
auto input = node.get_input(0);
|
||||
auto input_indices = node.get_input(1);
|
||||
auto res = make_shared<Gather>(input, input_indices, axis, batch_dims);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
auto params = node.get_input(0);
|
||||
auto indices = node.get_input(1);
|
||||
auto gather = make_shared<Gather>(params, indices, axis, batch_dims);
|
||||
set_node_name(node.get_name(), gather);
|
||||
return {gather};
|
||||
}
|
||||
|
||||
OutputVector translate_gather_op(const NodeContext& node) {
|
||||
// Gather has two inputs: data and indices
|
||||
// axis by which data is sliced is always equal to 0, batch_dims is always equal to 0
|
||||
default_op_checks(node, 2, {"Gather"});
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
return translate_basic_gather_op(node, axis, 0);
|
||||
}
|
||||
@ -32,15 +33,16 @@ OutputVector translate_gather_op(const NodeContext& node) {
|
||||
OutputVector translate_resource_gather_op(const NodeContext& node) {
|
||||
// ResourceGather has two inputs: data and indices
|
||||
// axis by which data is sliced is always equal to 0, batch_dims is an attribute and can vary
|
||||
default_op_checks(node, 2, {"ResourceGather"});
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
|
||||
return translate_basic_gather_op(node, axis, batch_dims);
|
||||
}
|
||||
|
||||
OutputVector translate_gather_v2_op(const NodeContext& node) {
|
||||
// ResourceGather has three inputs: data, indices, and axis by which data is sliced
|
||||
// GatherV2 has three inputs: data, indices, and axis by which data is sliced
|
||||
// batch_dims is an attribute and can vary
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "GatherV2 must have at least three inputs.");
|
||||
default_op_checks(node, 3, {"GatherV2"});
|
||||
auto axis = node.get_input(2);
|
||||
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
|
||||
return translate_basic_gather_op(node, axis, batch_dims);
|
||||
@ -49,12 +51,13 @@ OutputVector translate_gather_v2_op(const NodeContext& node) {
|
||||
OutputVector translate_gather_nd_op(const NodeContext& node) {
|
||||
// GatherND has two inputs: data and indices
|
||||
// batch_dims is always equal to 0
|
||||
default_op_checks(node, 2, {"GatherNd", "GATHER_ND"});
|
||||
auto input = node.get_input(0);
|
||||
auto input_indices = node.get_input(1);
|
||||
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
|
||||
auto res = make_shared<GatherND>(input, input_indices, batch_dims);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
auto gather_nd = make_shared<GatherND>(input, input_indices, batch_dims);
|
||||
set_node_name(node.get_name(), gather_nd);
|
||||
return {gather_nd};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -1,26 +1,48 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestGather(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'params' in inputs_info
|
||||
assert 'indices' in inputs_info
|
||||
params_shape = inputs_info['params']
|
||||
indices_shape = inputs_info['indices']
|
||||
inputs_data = {}
|
||||
inputs_data['params'] = np.random.randint(-50, 50, params_shape).astype(self.params_type)
|
||||
inputs_data['indices'] = np.random.randint(0, self.max_index, indices_shape).astype(self.indices_type)
|
||||
return inputs_data
|
||||
|
||||
def create_indices_constant(self):
|
||||
pass
|
||||
|
||||
def create_gather_net(self, data_shape, indices, axis, batch_dims, use_new_frontend, **kwargs):
|
||||
import tensorflow as tf
|
||||
def create_gather_net(self, params_shape, params_type, indices_shape, indices_type, axis_value, batch_dims,
|
||||
operation_type):
|
||||
self.params_type = params_type
|
||||
self.indices_type = indices_type
|
||||
if batch_dims is None:
|
||||
batch_dims = 0
|
||||
if axis_value is None:
|
||||
axis_value = 0
|
||||
axis_norm = axis_value
|
||||
if axis_norm < 0:
|
||||
axis_norm += len(params_shape)
|
||||
assert 0 <= axis_norm < len(params_shape), "Incorrect `axis` value for the test case"
|
||||
self.max_index = params_shape[axis_norm]
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
data = tf.compat.v1.placeholder(tf.float32, data_shape, 'data')
|
||||
indices = tf.constant(indices, dtype=tf.int32)
|
||||
gather = tf.gather(data, indices, axis=axis, batch_dims=batch_dims,
|
||||
name='gather_output')
|
||||
params = tf.compat.v1.placeholder(params_type, params_shape, 'params')
|
||||
indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
|
||||
if operation_type == "Gather":
|
||||
tf.raw_ops.Gather(params=params, indices=indices)
|
||||
elif operation_type == "GatherV2":
|
||||
axis = tf.constant(axis_value, dtype=tf.int32)
|
||||
tf.raw_ops.GatherV2(params=params, indices=indices, axis=axis, batch_dims=batch_dims)
|
||||
else:
|
||||
assert False, "Incorrect operation type is tested"
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
@ -30,41 +52,25 @@ class TestGather(CommonTFLayerTest):
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data_precommit = [
|
||||
dict(data_shape=[6, 8, 10, 12], indices=[[0, 2, 4], [5, 7, 9]], axis=2, batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=1, batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=-1, batch_dims=0)
|
||||
dict(params_shape=[4, 6], params_type=np.float32, indices_shape=[], indices_type=np.int32,
|
||||
axis_value=None, batch_dims=None,
|
||||
operation_type="Gather"),
|
||||
dict(params_shape=[3, 4, 6], params_type=np.float32, indices_shape=[3, 4], indices_type=np.int32,
|
||||
axis_value=None, batch_dims=None,
|
||||
operation_type="Gather"),
|
||||
dict(params_shape=[5, 4, 3], params_type=np.int32, indices_shape=[5, 2, 1], indices_type=np.int64,
|
||||
axis_value=2, batch_dims=1,
|
||||
operation_type="GatherV2"),
|
||||
dict(params_shape=[3, 2, 6, 4], params_type=np.float32, indices_shape=[3, 2, 1, 3], indices_type=np.int32,
|
||||
axis_value=-1, batch_dims=-2,
|
||||
operation_type="GatherV2"),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_precommit)
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_gather(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_gather_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_nightly = [
|
||||
dict(data_shape=[2, 3], axis=1, indices=[0, 2], batch_dims=0),
|
||||
dict(data_shape=[10, 12], axis=0, indices=[3, 6], batch_dims=0),
|
||||
dict(data_shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]],
|
||||
batch_dims=0),
|
||||
dict(data_shape=[8, 10, 12], axis=0, indices=[3, 6], batch_dims=0),
|
||||
pytest.param(dict(data_shape=[8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
marks=pytest.mark.precommit_tf_fe),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=0, indices=[2, 5], batch_dims=0),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], batch_dims=0),
|
||||
dict(data_shape=[2, 14, 10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]],
|
||||
batch_dims=1),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_nightly)
|
||||
@pytest.mark.nightly
|
||||
def test_gather_nightly(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_gather_net(**params, use_new_frontend=use_new_frontend),
|
||||
self._test(*self.create_gather_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user