diff --git a/src/frontends/tensorflow/src/input_model.cpp b/src/frontends/tensorflow/src/input_model.cpp index 33298721271..b60a77fa0c7 100644 --- a/src/frontends/tensorflow/src/input_model.cpp +++ b/src/frontends/tensorflow/src/input_model.cpp @@ -207,9 +207,12 @@ std::vector> InputModel::InputModelTFImpl::get_op_place std::vector> InputModel::InputModelTFImpl::determine_cut_nodes() const { std::vector> topologically_sorted_ops; std::stack> ops_to_do; - std::unordered_set> ops_set_to_do; std::unordered_set> ops_done; + // TODO: implement logic to check direct cycles in the graph + // and break them + // probably not only NextIteration can generate cycles + for (const auto& output_place : m_outputs) { FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name."); auto output_place_name = output_place->get_names()[0]; @@ -221,7 +224,6 @@ std::vector> InputModel::InputModelTFImpl::determine_cu "Custom specified output is incorrect: " + output_place_name); auto output_operation_place = m_op_places_map.at(operation_name); ops_to_do.push(output_operation_place); - ops_set_to_do.insert(output_operation_place); } // the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in @@ -233,6 +235,12 @@ std::vector> InputModel::InputModelTFImpl::determine_cu if (ops_done.count(current_operation_place) == 0) { bool can_add = true; auto input_count = current_operation_decoder->get_input_size(); + auto current_operation_type = current_operation_decoder->get_op_type(); + + if (current_operation_type == "NextIteration") { + // break the cycle created by NextIteration + input_count = 0; + } for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) { std::string producer_name; @@ -282,11 +290,9 @@ std::vector> InputModel::InputModelTFImpl::determine_cu // in case presence of NextIteration in the graph (or cycle created by other operation), // we break the cycle by outputs from the NextIteration operation // otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely - if (!is_input && ops_done.count(producer_operation_place) == 0 && - ops_set_to_do.count(producer_operation_place) == 0) { + if (!is_input && ops_done.count(producer_operation_place) == 0) { can_add = false; ops_to_do.push(producer_operation_place); - ops_set_to_do.insert(producer_operation_place); } } diff --git a/src/frontends/tensorflow/tests/convert_model.cpp b/src/frontends/tensorflow/tests/convert_model.cpp index a3763759da5..e88df1dc867 100644 --- a/src/frontends/tensorflow/tests/convert_model.cpp +++ b/src/frontends/tensorflow/tests/convert_model.cpp @@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest; static const std::vector models{ std::string("2in_2out/2in_2out.pb"), + std::string("forward_edge_model/forward_edge_model.pb"), + std::string("forward_edge_model2/forward_edge_model2.pb"), }; INSTANTIATE_TEST_SUITE_P(TFConvertModelTest, diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.pbtxt b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.pbtxt new file mode 100644 index 00000000000..a9c3941e3bc --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.pbtxt @@ -0,0 +1,79 @@ +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.0 + } + } + } +} +node { + name: "x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "x" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "add" + op: "AddV2" + input: "Relu" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Mul" + op: "Mul" + input: "add" + input: "Relu" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.py b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.py new file mode 100644 index 00000000000..1f0e9d82f2c --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model.py @@ -0,0 +1,19 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import tensorflow.compat.v1 as tf + +tf.reset_default_graph() + +with tf.Session() as sess: + const2 = tf.constant(2.0, dtype=tf.float32) + x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x') + relu = tf.nn.relu(x) + add = tf.add(relu, const2, name="add") + # it has forward-edge from relu to multiply + # i.e. edge skipping direct child + tf.multiply(add, relu) + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'forward_edge_model.pb', as_text=False) diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.pbtxt b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.pbtxt new file mode 100644 index 00000000000..a0afd76c7a8 --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.pbtxt @@ -0,0 +1,79 @@ +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.0 + } + } + } +} +node { + name: "x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "x" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "mul" + op: "Mul" + input: "Relu" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Add" + op: "AddV2" + input: "Relu" + input: "mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.py b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.py new file mode 100644 index 00000000000..c9b5c1041d9 --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/forward_edge_model2.py @@ -0,0 +1,19 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import tensorflow.compat.v1 as tf + +tf.reset_default_graph() + +with tf.Session() as sess: + const2 = tf.constant(2.0, dtype=tf.float32) + x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x') + relu = tf.nn.relu(x) + mul = tf.multiply(relu, const2, name="mul") + # it has forward-edge from relu to multiply + # i.e. edge skipping direct child + tf.add(relu, mul) + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'forward_edge_model2.pb', as_text=False)