[TF FE] Refactor translation for Concat and ConcatV2 operations and add tests (#13542)
* [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * [TF FE] Correct the layer tests for passing * Apply code-review feedback: remove redundant function * Fix a name of function to create Concat model in the layer test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
88ffc23e3a
commit
22f3bee0f9
@ -15,36 +15,50 @@ namespace ov {
|
|||||||
namespace frontend {
|
namespace frontend {
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace op {
|
namespace op {
|
||||||
|
|
||||||
OutputVector translate_concat_op(const NodeContext& node) {
|
OutputVector translate_concat_op(const NodeContext& node) {
|
||||||
size_t axis_idx, concat_idx_start, concat_idx_stop;
|
// The difference between Concat and ConcatV2 is that
|
||||||
if (node.get_op_type() == "ConcatV2") {
|
// axis is the first input for Concat
|
||||||
axis_idx = node.get_input_size() - 1;
|
// and is the last input to ConcatV2
|
||||||
concat_idx_start = 0;
|
default_op_checks(node, 2, {"Concat", "ConcatV2"});
|
||||||
concat_idx_stop = node.get_input_size() - 1;
|
auto input_size = node.get_input_size();
|
||||||
} else if (node.get_op_type() == "Concat") {
|
|
||||||
axis_idx = 0;
|
int64_t axis;
|
||||||
concat_idx_start = 1;
|
OutputVector inputs;
|
||||||
concat_idx_stop = node.get_input_size();
|
|
||||||
|
if (node.get_op_type() == "Concat") {
|
||||||
|
std::vector<int64_t> axis_vector;
|
||||||
|
get_const_input(node, 0, &axis_vector);
|
||||||
|
TENSORFLOW_OP_VALIDATION(
|
||||||
|
node,
|
||||||
|
axis_vector.size() == 1,
|
||||||
|
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||||
|
axis = axis_vector[0];
|
||||||
|
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
|
||||||
|
inputs.push_back(node.get_input(input_idx));
|
||||||
|
}
|
||||||
|
} else if (node.get_op_type() == "ConcatV2") {
|
||||||
|
std::vector<int64_t> axis_vector;
|
||||||
|
get_const_input(node, input_size - 1, &axis_vector);
|
||||||
|
TENSORFLOW_OP_VALIDATION(
|
||||||
|
node,
|
||||||
|
axis_vector.size() == 1,
|
||||||
|
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||||
|
axis = axis_vector[0];
|
||||||
|
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
|
||||||
|
inputs.push_back(node.get_input(input_idx));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TENSORFLOW_OP_VALIDATION(node, false, "Incorrect operation type.");
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
false,
|
||||||
|
"Internal TensorFlow Frontend error: incorrect operation type is passed to "
|
||||||
|
"translate_concat_op function.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> tf_concat_axis_vec;
|
auto concat = make_shared<Concat>(inputs, axis);
|
||||||
get_const_input(node, axis_idx, &tf_concat_axis_vec);
|
set_node_name(node.get_name(), concat);
|
||||||
int64_t concat_axis = tf_concat_axis_vec[0];
|
return {concat};
|
||||||
|
|
||||||
OutputVector ng_args;
|
|
||||||
for (size_t i = concat_idx_start; i < concat_idx_stop; i++) {
|
|
||||||
Output<Node> ng_arg = node.get_input(static_cast<int>(i));
|
|
||||||
ng_args.push_back(ng_arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = make_shared<Concat>(ng_args, size_t(concat_axis));
|
|
||||||
set_node_name(node.get_name(), res);
|
|
||||||
return res->outputs();
|
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
@ -4,94 +4,72 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from common.tf_layer_test_class import CommonTFLayerTest
|
from common.tf_layer_test_class import CommonTFLayerTest
|
||||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
|
||||||
|
|
||||||
|
|
||||||
class TestConcat(CommonTFLayerTest):
|
class TestConcat(CommonTFLayerTest):
|
||||||
def create_concat_net(self, shape, axis, ir_version, use_new_frontend):
|
def create_concat_net(self, input_shapes, axis, is_v2, ir_version, use_new_frontend):
|
||||||
"""
|
# tf.concat is equivalent to tf.raw_ops.ConcatV2
|
||||||
Tensorflow net IR net
|
# only tf.concat accepts one input
|
||||||
|
|
||||||
Input->Concat => Input->Concat
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
tf.compat.v1.reset_default_graph()
|
tf.compat.v1.reset_default_graph()
|
||||||
|
|
||||||
# Create the graph and model
|
|
||||||
with tf.compat.v1.Session() as sess:
|
with tf.compat.v1.Session() as sess:
|
||||||
ax = axis
|
placeholders = []
|
||||||
|
for ind, input_shape in enumerate(input_shapes):
|
||||||
tf_x_shape = shape.copy()
|
placeholders.append(tf.compat.v1.placeholder(tf.float32, input_shape, 'input_{}'.format(ind)))
|
||||||
|
if len(input_shapes) == 1:
|
||||||
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
|
tf.concat(values=placeholders, axis=axis, name='concat')
|
||||||
|
elif is_v2:
|
||||||
# TODO: add concat with const inputs to check fusing (as in ONNX)
|
tf.raw_ops.ConcatV2(values=placeholders, axis=axis, name='concat')
|
||||||
|
else:
|
||||||
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input')
|
tf.raw_ops.Concat(values=placeholders, concat_dim=axis, name='concat')
|
||||||
y = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') # Input_1 in graph_def
|
|
||||||
|
|
||||||
concat = tf.concat([x, y], axis=ax, name='Operation')
|
|
||||||
concat_shape = concat.shape.as_list()
|
|
||||||
|
|
||||||
tf.compat.v1.global_variables_initializer()
|
tf.compat.v1.global_variables_initializer()
|
||||||
tf_net = sess.graph_def
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
#
|
|
||||||
# Create reference IR net
|
|
||||||
# Please, specify 'type': 'Input' for input node
|
|
||||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
|
||||||
#
|
|
||||||
|
|
||||||
ref_net = None
|
ref_net = None
|
||||||
|
|
||||||
return tf_net, ref_net
|
return tf_net, ref_net
|
||||||
|
|
||||||
# TODO: create tests for concat with 1 input and multiple inputs
|
test_data_1D = [dict(input_shapes=[[1], [2]], axis=0, is_v2=False),
|
||||||
|
dict(input_shapes=[[1], [3]], axis=-1, is_v2=True)]
|
||||||
test_data_1D = [dict(shape=[1], axis=0),
|
|
||||||
dict(shape=[1], axis=-1)]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data_1D)
|
@pytest.mark.parametrize("params", test_data_1D)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||||
use_old_api):
|
use_old_api):
|
||||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
test_data_2D = [dict(shape=[1, 224], axis=0),
|
test_data_2D = [dict(input_shapes=[[3, 4]], axis=0, is_v2=True),
|
||||||
dict(shape=[1, 224], axis=-1)]
|
dict(input_shapes=[[1, 4], [1, 2]], axis=-1, is_v2=True)]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data_2D)
|
@pytest.mark.parametrize("params", test_data_2D)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||||
use_old_api):
|
use_old_api):
|
||||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
test_data_3D = [
|
test_data_3D = [
|
||||||
dict(shape=[1, 3, 224], axis=0),
|
dict(input_shapes=[[1, 3, 5], [2, 3, 5]], axis=0, is_v2=False),
|
||||||
pytest.param(dict(shape=[1, 3, 224], axis=-1), marks=pytest.mark.precommit_tf_fe),
|
dict(input_shapes=[[1, 3, 2], [1, 3, 5]], axis=-1, is_v2=True),
|
||||||
dict(shape=[1, 3, 224], axis=2)]
|
dict(input_shapes=[[1, 3, 1], [1, 3, 4], [1, 3, 3]], axis=2, is_v2=True)]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data_3D)
|
@pytest.mark.parametrize("params", test_data_3D)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||||
use_old_api):
|
use_old_api):
|
||||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
test_data_4D = [dict(shape=[1, 3, 100, 224], axis=0),
|
test_data_4D = [dict(input_shapes=[[1, 3, 5, 7], [3, 3, 5, 7], [2, 3, 5, 7]], axis=0, is_v2=False),
|
||||||
dict(shape=[1, 3, 100, 224], axis=-1),
|
dict(input_shapes=[[1, 3, 5, 5], [1, 3, 5, 7]], axis=-1, is_v2=True),
|
||||||
dict(shape=[1, 3, 100, 224], axis=2)]
|
dict(input_shapes=[[1, 3, 5, 7], [1, 3, 3, 7]], axis=2, is_v2=False)]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data_4D)
|
@pytest.mark.parametrize("params", test_data_4D)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@ -99,19 +77,20 @@ class TestConcat(CommonTFLayerTest):
|
|||||||
def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||||
use_old_api):
|
use_old_api):
|
||||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
test_data_5D = [dict(shape=[1, 3, 50, 100, 224], axis=0),
|
test_data_5D = [dict(input_shapes=[[1, 3, 5, 7, 8], [2, 3, 5, 7, 8]], axis=0, is_v2=True),
|
||||||
dict(shape=[1, 3, 50, 100, 224], axis=-1),
|
dict(input_shapes=[[1, 3, 5, 7, 2], [1, 3, 5, 7, 3], [1, 3, 5, 7, 2], [1, 3, 5, 7, 4]],
|
||||||
dict(shape=[1, 3, 50, 100, 224], axis=2)]
|
axis=-1, is_v2=True),
|
||||||
|
dict(input_shapes=[[1, 3, 5, 7, 8], [1, 3, 2, 7, 8]], axis=2, is_v2=False)]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data_5D)
|
@pytest.mark.parametrize("params", test_data_5D)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||||
use_old_api):
|
use_old_api):
|
||||||
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
self._test(*self.create_concat_net(**params, ir_version=ir_version,
|
||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
Loading…
Reference in New Issue
Block a user