From 8a9890dbf1dcd234b342a2cd97c24f21b5a3f9ba Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Wed, 12 Oct 2022 08:04:19 +0400 Subject: [PATCH] [TF FE] Support DynamicStitch operation (#13408) * Delete unneccessary changes * codestyle * skip dynamic stitch layer tests for legacy frontend * apply review comments * fix unit tests, apply review comments --- .../src/op/parallel_dynamic_stitch.cpp | 77 ++++++++++++++++ src/frontends/tensorflow/src/op_table.cpp | 3 + tests/layer_tests/common/utils/tf_utils.py | 2 +- .../test_tf_ParallelDynamicStitch.py | 91 +++++++++++++++++++ 4 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 src/frontends/tensorflow/src/op/parallel_dynamic_stitch.cpp create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_ParallelDynamicStitch.py diff --git a/src/frontends/tensorflow/src/op/parallel_dynamic_stitch.cpp b/src/frontends/tensorflow/src/op/parallel_dynamic_stitch.cpp new file mode 100644 index 00000000000..cf7227a4968 --- /dev/null +++ b/src/frontends/tensorflow/src/op/parallel_dynamic_stitch.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "op_table.hpp" +#include "openvino/opsets/opset9.hpp" + +using namespace std; +using namespace ov::opset9; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_parallel_dynamic_stitch_op(const NodeContext& node) { + // format for inputs: [indices1, indices2, ..., indicesN, data1, data2, ..., dataN] + // so we expect at least 2 input and the total number of inputs must be divisible by 2 + default_op_checks(node, 2, {"ParallelDynamicStitch", "DynamicStitch"}); + auto in_size = node.get_input_size(); + TENSORFLOW_OP_VALIDATION(node, + in_size % 2 == 0, + "The total number of inputs to DynamicStitch or ParallelDynamicStitch operation " + "must be divisible by 2."); + + int N = static_cast(in_size / 2); + OutputVector indices_to_concat; + OutputVector data_to_concat; + auto data_element_type = node.get_input(N).get_element_type(); + auto const_minus_one = make_shared(ov::element::i32, Shape{1}, -1); + auto const_zero = make_shared(ov::element::i32, Shape{1}, 0); + auto const_one = make_shared(ov::element::i32, Shape{1}, 1); + for (int i = 0; i < N; ++i) { + auto indices = node.get_input(i); + auto data = node.get_input(N + i); + + const auto& indices_pshape = indices.get_partial_shape(); + auto rank = indices_pshape.rank(); + TENSORFLOW_OP_VALIDATION(node, + indices_pshape.rank().is_static(), + "Only static rank for `indices` input is supported."); + auto rank_val = rank.get_length(); + auto norm_indices = make_shared(indices, const_minus_one, false); + if (rank_val < 1) { + data = make_shared(data, const_zero); + } else if (rank_val > 1) { + auto data_shape = make_shared(data, ov::element::i32); + auto start = make_shared(ov::element::i32, Shape{1}, rank_val); + auto stop = make_shared(ov::element::i32, Shape{1}, numeric_limits::max()); + auto shape_of_single_element = make_shared(data_shape, start, stop, const_one); + auto new_shape = make_shared(OutputVector{const_minus_one, shape_of_single_element}, 0); + data = make_shared(data, new_shape, false); + } + data_to_concat.push_back(data); + indices_to_concat.push_back(norm_indices); + } + auto update = make_shared(data_to_concat, 0); + auto indices = make_shared(indices_to_concat, 0); + auto data_shape = make_shared(update, ov::element::i32); + + auto zero = make_shared(data_element_type, Shape{}, 0); + auto zeros = make_shared(zero, data_shape); + auto max_idx = make_shared(indices, Constant::create(element::i32, {1}, {0}), true); + auto stop = make_shared(max_idx->output(0), const_one); + auto start = make_shared(ov::element::i32, Shape{1}, 0); + auto axis = make_shared(ov::element::i32, Shape{1}, 0); + auto sliced_zeros = make_shared(zeros, start, stop, const_one, axis); + + auto result = make_shared(sliced_zeros, indices, update, const_zero); + set_node_name(node.get_name(), result); + return result->outputs(); +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 2f405fd46dc..5a92e1b467a 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -77,6 +77,7 @@ OP_CONVERTER(translate_max_pool_op); OP_CONVERTER(translate_non_max_suppression_op); OP_CONVERTER(translate_normalize_l2_op); OP_CONVERTER(translate_pad_op); +OP_CONVERTER(translate_parallel_dynamic_stitch_op); OP_CONVERTER(translate_placeholder_op); OP_CONVERTER(translate_placeholder_with_default_op); OP_CONVERTER(translate_no_op); @@ -257,6 +258,8 @@ const std::map get_supported_ops() { {"Pack", translate_pack_op}, {"Pad", translate_pad_op}, {"PadV2", translate_pad_op}, + {"DynamicStitch", translate_parallel_dynamic_stitch_op}, + {"ParallelDynamicStitch", translate_parallel_dynamic_stitch_op}, {"Placeholder", translate_placeholder_op}, {"PlaceholderWithDefault", translate_placeholder_with_default_op}, {"PreventGradient", translate_identity_op}, diff --git a/tests/layer_tests/common/utils/tf_utils.py b/tests/layer_tests/common/utils/tf_utils.py index f5e60d322cd..ec437a61329 100644 --- a/tests/layer_tests/common/utils/tf_utils.py +++ b/tests/layer_tests/common/utils/tf_utils.py @@ -86,7 +86,7 @@ def summarize_graph(model_path, output_nodes_for_freeze=None, reshape_net=None): node_dict['type'] = tf.DType(node.attr['dtype'].type).name node_dict['shape'] = str(node.attr['shape'].shape.dim).replace('\n', '').replace(' ', '').replace( 'size:', '').replace('[', '').replace(']', '') - node_dict['shape'] = tuple(map(lambda x: int(x), node_dict['shape'].split(','))) + node_dict['shape'] = tuple(map(lambda x: int(x) if x else 0, node_dict['shape'].split(','))) placeholders[node.name] = node_dict if node.op == "Variable" or node.op == "VariableV2": variables.append(node.name) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ParallelDynamicStitch.py b/tests/layer_tests/tensorflow_tests/test_tf_ParallelDynamicStitch.py new file mode 100644 index 00000000000..eb940b6d96e --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_ParallelDynamicStitch.py @@ -0,0 +1,91 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestParallelDynamicStitch(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + inputs_data = {} + num_elements = 0 + assert len(inputs_info) % 2 == 0, "Number of inputs should be divisible by 2." + data_input_cnt = len(inputs_info)//2 + for i in range(1, data_input_cnt + 1): + indices_in_name = "indices{}".format(i) + assert indices_in_name in inputs_info, "Test error: inputs_info must contain `{}`".format(indices_in_name) + indices_shape = inputs_info[indices_in_name] + num_elements = num_elements + np.prod(indices_shape, dtype=int) + + indices_array = np.arange(np.random.randint(1, num_elements+1), dtype=np.intc) + np.random.shuffle(indices_array) + indices_array = np.resize(indices_array, num_elements) + + idx = 0 + for i in range(1, data_input_cnt + 1): + data_in_name = "data{}".format(i) + indices_in_name = "indices{}".format(i) + assert data_in_name in inputs_info, "Test error: inputs_info must contain `{}`".format(data_in_name) + data_shape = inputs_info[data_in_name] + indices_shape = inputs_info[indices_in_name] + inputs_data[data_in_name] = np.random.randint(-50, 50, data_shape) + + num_elements_i = np.prod(indices_shape, dtype=int) + inputs_data[indices_in_name] = np.reshape(indices_array[idx:idx+num_elements_i], indices_shape) + idx = idx + num_elements_i + return inputs_data + + def create_parallel_dynamic_stitch_net(self, data_input_cnt, shape_of_element, data_type): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + indices = [] + data = [] + data_shape = shape_of_element + indices_shape = [] + + for i in range(1, data_input_cnt + 1): + indices.append(tf.compat.v1.placeholder(tf.int32, indices_shape, 'indices{}'.format(i))) + data.append(tf.compat.v1.placeholder(data_type, data_shape, 'data{}'.format(i))) + data_shape.insert(0, i) + indices_shape.insert(0, i) + tf.dynamic_stitch(indices, data) + tf.compat.v1.global_variables_initializer() + + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(data_input_cnt=1, shape_of_element=[1], data_type=tf.float32), + dict(data_input_cnt=2, shape_of_element=[2, 2], data_type=tf.float32), + dict(data_input_cnt=3, shape_of_element=[2, 1, 2], data_type=tf.float32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + def test_parallel_dynamic_stitch_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + if not use_new_frontend: + pytest.skip("DynamicStitch operation is not supported via legacy frontend.") + self._test(*self.create_parallel_dynamic_stitch_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + test_data_different_types = [ + dict(data_input_cnt=4, shape_of_element=[3, 2], data_type=tf.float64), + dict(data_input_cnt=2, shape_of_element=[2, 2, 1], data_type=tf.int64), + dict(data_input_cnt=3, shape_of_element=[2, 1, 2, 4], data_type=tf.int32), + ] + + @pytest.mark.parametrize("params", test_data_different_types) + @pytest.mark.nightly + def test_parallel_dynamic_stitch_different_types(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + if not use_new_frontend: + pytest.skip("DynamicStitch operation is not supported via legacy frontend.") + self._test(*self.create_parallel_dynamic_stitch_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) \ No newline at end of file