[TF FE] Refactor translation for Concat and ConcatV2 operations and add tests (#13542)
* [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * [TF FE] Correct the layer tests for passing * Apply code-review feedback: remove redundant function * Fix a name of function to create Concat model in the layer test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
88ffc23e3a
commit
22f3bee0f9
@ -15,36 +15,50 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_concat_op(const NodeContext& node) {
|
||||
size_t axis_idx, concat_idx_start, concat_idx_stop;
|
||||
if (node.get_op_type() == "ConcatV2") {
|
||||
axis_idx = node.get_input_size() - 1;
|
||||
concat_idx_start = 0;
|
||||
concat_idx_stop = node.get_input_size() - 1;
|
||||
} else if (node.get_op_type() == "Concat") {
|
||||
axis_idx = 0;
|
||||
concat_idx_start = 1;
|
||||
concat_idx_stop = node.get_input_size();
|
||||
// The difference between Concat and ConcatV2 is that
|
||||
// axis is the first input for Concat
|
||||
// and is the last input to ConcatV2
|
||||
default_op_checks(node, 2, {"Concat", "ConcatV2"});
|
||||
auto input_size = node.get_input_size();
|
||||
|
||||
int64_t axis;
|
||||
OutputVector inputs;
|
||||
|
||||
if (node.get_op_type() == "Concat") {
|
||||
std::vector<int64_t> axis_vector;
|
||||
get_const_input(node, 0, &axis_vector);
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
axis_vector.size() == 1,
|
||||
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||
axis = axis_vector[0];
|
||||
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
|
||||
inputs.push_back(node.get_input(input_idx));
|
||||
}
|
||||
} else if (node.get_op_type() == "ConcatV2") {
|
||||
std::vector<int64_t> axis_vector;
|
||||
get_const_input(node, input_size - 1, &axis_vector);
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
axis_vector.size() == 1,
|
||||
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||
axis = axis_vector[0];
|
||||
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
|
||||
inputs.push_back(node.get_input(input_idx));
|
||||
}
|
||||
} else {
|
||||
TENSORFLOW_OP_VALIDATION(node, false, "Incorrect operation type.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
false,
|
||||
"Internal TensorFlow Frontend error: incorrect operation type is passed to "
|
||||
"translate_concat_op function.");
|
||||
}
|
||||
|
||||
std::vector<int64_t> tf_concat_axis_vec;
|
||||
get_const_input(node, axis_idx, &tf_concat_axis_vec);
|
||||
int64_t concat_axis = tf_concat_axis_vec[0];
|
||||
|
||||
OutputVector ng_args;
|
||||
for (size_t i = concat_idx_start; i < concat_idx_stop; i++) {
|
||||
Output<Node> ng_arg = node.get_input(static_cast<int>(i));
|
||||
ng_args.push_back(ng_arg);
|
||||
}
|
||||
|
||||
auto res = make_shared<Concat>(ng_args, size_t(concat_axis));
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
auto concat = make_shared<Concat>(inputs, axis);
|
||||
set_node_name(node.get_name(), concat);
|
||||
return {concat};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -4,94 +4,72 @@
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
||||
|
||||
|
||||
class TestConcat(CommonTFLayerTest):
|
||||
def create_concat_net(self, shape, axis, ir_version, use_new_frontend):
|
||||
"""
|
||||
Tensorflow net IR net
|
||||
|
||||
Input->Concat => Input->Concat
|
||||
|
||||
"""
|
||||
|
||||
def create_concat_net(self, input_shapes, axis, is_v2, ir_version, use_new_frontend):
|
||||
# tf.concat is equivalent to tf.raw_ops.ConcatV2
|
||||
# only tf.concat accepts one input
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
ax = axis
|
||||
|
||||
tf_x_shape = shape.copy()
|
||||
|
||||
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
|
||||
|
||||
# TODO: add concat with const inputs to check fusing (as in ONNX)
|
||||
|
||||
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input')
|
||||
y = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') # Input_1 in graph_def
|
||||
|
||||
concat = tf.concat([x, y], axis=ax, name='Operation')
|
||||
concat_shape = concat.shape.as_list()
|
||||
|
||||
placeholders = []
|
||||
for ind, input_shape in enumerate(input_shapes):
|
||||
placeholders.append(tf.compat.v1.placeholder(tf.float32, input_shape, 'input_{}'.format(ind)))
|
||||
if len(input_shapes) == 1:
|
||||
tf.concat(values=placeholders, axis=axis, name='concat')
|
||||
elif is_v2:
|
||||
tf.raw_ops.ConcatV2(values=placeholders, axis=axis, name='concat')
|
||||
else:
|
||||
tf.raw_ops.Concat(values=placeholders, concat_dim=axis, name='concat')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
#
|
||||
# Create reference IR net
|
||||
# Please, specify 'type': 'Input' for input node
|
||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
||||
#
|
||||
|
||||
ref_net = None
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
# TODO: create tests for concat with 1 input and multiple inputs
|
||||
|
||||
test_data_1D = [dict(shape=[1], axis=0),
|
||||
dict(shape=[1], axis=-1)]
|
||||
test_data_1D = [dict(input_shapes=[[1], [2]], axis=0, is_v2=False),
|
||||
dict(input_shapes=[[1], [3]], axis=-1, is_v2=True)]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_1D)
|
||||
@pytest.mark.nightly
|
||||
def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_2D = [dict(shape=[1, 224], axis=0),
|
||||
dict(shape=[1, 224], axis=-1)]
|
||||
test_data_2D = [dict(input_shapes=[[3, 4]], axis=0, is_v2=True),
|
||||
dict(input_shapes=[[1, 4], [1, 2]], axis=-1, is_v2=True)]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_2D)
|
||||
@pytest.mark.nightly
|
||||
def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_3D = [
|
||||
dict(shape=[1, 3, 224], axis=0),
|
||||
pytest.param(dict(shape=[1, 3, 224], axis=-1), marks=pytest.mark.precommit_tf_fe),
|
||||
dict(shape=[1, 3, 224], axis=2)]
|
||||
dict(input_shapes=[[1, 3, 5], [2, 3, 5]], axis=0, is_v2=False),
|
||||
dict(input_shapes=[[1, 3, 2], [1, 3, 5]], axis=-1, is_v2=True),
|
||||
dict(input_shapes=[[1, 3, 1], [1, 3, 4], [1, 3, 3]], axis=2, is_v2=True)]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_3D)
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_4D = [dict(shape=[1, 3, 100, 224], axis=0),
|
||||
dict(shape=[1, 3, 100, 224], axis=-1),
|
||||
dict(shape=[1, 3, 100, 224], axis=2)]
|
||||
test_data_4D = [dict(input_shapes=[[1, 3, 5, 7], [3, 3, 5, 7], [2, 3, 5, 7]], axis=0, is_v2=False),
|
||||
dict(input_shapes=[[1, 3, 5, 5], [1, 3, 5, 7]], axis=-1, is_v2=True),
|
||||
dict(input_shapes=[[1, 3, 5, 7], [1, 3, 3, 7]], axis=2, is_v2=False)]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_4D)
|
||||
@pytest.mark.nightly
|
||||
@ -99,19 +77,20 @@ class TestConcat(CommonTFLayerTest):
|
||||
def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_5D = [dict(shape=[1, 3, 50, 100, 224], axis=0),
|
||||
dict(shape=[1, 3, 50, 100, 224], axis=-1),
|
||||
dict(shape=[1, 3, 50, 100, 224], axis=2)]
|
||||
test_data_5D = [dict(input_shapes=[[1, 3, 5, 7, 8], [2, 3, 5, 7, 8]], axis=0, is_v2=True),
|
||||
dict(input_shapes=[[1, 3, 5, 7, 2], [1, 3, 5, 7, 3], [1, 3, 5, 7, 2], [1, 3, 5, 7, 4]],
|
||||
axis=-1, is_v2=True),
|
||||
dict(input_shapes=[[1, 3, 5, 7, 8], [1, 3, 2, 7, 8]], axis=2, is_v2=False)]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_5D)
|
||||
@pytest.mark.nightly
|
||||
def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user