[TF FE] Break the cycle in the different way (#14480)

* [TF FE] Break the cycle in the different way

Earlier solution was incorrect due to inproper handling of forward edges cases
(edges going from parent to grand-child) for which topological sorting
of nodes can be interrupted.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Break the cycle by NextIteration inputs

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-12-08 18:23:49 +04:00 committed by GitHub
parent b1700d97f1
commit 32ae862f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 5 deletions

View File

@ -207,9 +207,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_place
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
std::unordered_set<std::shared_ptr<OpPlace>> ops_set_to_do;
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
// TODO: implement logic to check direct cycles in the graph
// and break them
// probably not only NextIteration can generate cycles
for (const auto& output_place : m_outputs) {
FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name.");
auto output_place_name = output_place->get_names()[0];
@ -221,7 +224,6 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
"Custom specified output is incorrect: " + output_place_name);
auto output_operation_place = m_op_places_map.at(operation_name);
ops_to_do.push(output_operation_place);
ops_set_to_do.insert(output_operation_place);
}
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
@ -233,6 +235,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
if (ops_done.count(current_operation_place) == 0) {
bool can_add = true;
auto input_count = current_operation_decoder->get_input_size();
auto current_operation_type = current_operation_decoder->get_op_type();
if (current_operation_type == "NextIteration") {
// break the cycle created by NextIteration
input_count = 0;
}
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
std::string producer_name;
@ -282,11 +290,9 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
// in case presence of NextIteration in the graph (or cycle created by other operation),
// we break the cycle by outputs from the NextIteration operation
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
if (!is_input && ops_done.count(producer_operation_place) == 0 &&
ops_set_to_do.count(producer_operation_place) == 0) {
if (!is_input && ops_done.count(producer_operation_place) == 0) {
can_add = false;
ops_to_do.push(producer_operation_place);
ops_set_to_do.insert(producer_operation_place);
}
}

View File

@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest;
static const std::vector<std::string> models{
std::string("2in_2out/2in_2out.pb"),
std::string("forward_edge_model/forward_edge_model.pb"),
std::string("forward_edge_model2/forward_edge_model2.pb"),
};
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,

View File

@ -0,0 +1,79 @@
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "Relu"
op: "Relu"
input: "x"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "add"
op: "AddV2"
input: "Relu"
input: "Const"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Mul"
op: "Mul"
input: "add"
input: "Relu"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
const2 = tf.constant(2.0, dtype=tf.float32)
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
relu = tf.nn.relu(x)
add = tf.add(relu, const2, name="add")
# it has forward-edge from relu to multiply
# i.e. edge skipping direct child
tf.multiply(add, relu)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'forward_edge_model.pb', as_text=False)

View File

@ -0,0 +1,79 @@
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "Relu"
op: "Relu"
input: "x"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "mul"
op: "Mul"
input: "Relu"
input: "Const"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "AddV2"
input: "Relu"
input: "mul"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
const2 = tf.constant(2.0, dtype=tf.float32)
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
relu = tf.nn.relu(x)
mul = tf.multiply(relu, const2, name="mul")
# it has forward-edge from relu to multiply
# i.e. edge skipping direct child
tf.add(relu, mul)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'forward_edge_model2.pb', as_text=False)