[TF FE] Break the cycle in the different way (#14480)
* [TF FE] Break the cycle in the different way Earlier solution was incorrect due to inproper handling of forward edges cases (edges going from parent to grand-child) for which topological sorting of nodes can be interrupted. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Break the cycle by NextIteration inputs Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
b1700d97f1
commit
32ae862f99
@ -207,9 +207,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_place
|
||||
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
|
||||
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
|
||||
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
|
||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_set_to_do;
|
||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
|
||||
|
||||
// TODO: implement logic to check direct cycles in the graph
|
||||
// and break them
|
||||
// probably not only NextIteration can generate cycles
|
||||
|
||||
for (const auto& output_place : m_outputs) {
|
||||
FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name.");
|
||||
auto output_place_name = output_place->get_names()[0];
|
||||
@ -221,7 +224,6 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
"Custom specified output is incorrect: " + output_place_name);
|
||||
auto output_operation_place = m_op_places_map.at(operation_name);
|
||||
ops_to_do.push(output_operation_place);
|
||||
ops_set_to_do.insert(output_operation_place);
|
||||
}
|
||||
|
||||
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
|
||||
@ -233,6 +235,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
if (ops_done.count(current_operation_place) == 0) {
|
||||
bool can_add = true;
|
||||
auto input_count = current_operation_decoder->get_input_size();
|
||||
auto current_operation_type = current_operation_decoder->get_op_type();
|
||||
|
||||
if (current_operation_type == "NextIteration") {
|
||||
// break the cycle created by NextIteration
|
||||
input_count = 0;
|
||||
}
|
||||
|
||||
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
|
||||
std::string producer_name;
|
||||
@ -282,11 +290,9 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
// in case presence of NextIteration in the graph (or cycle created by other operation),
|
||||
// we break the cycle by outputs from the NextIteration operation
|
||||
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
|
||||
if (!is_input && ops_done.count(producer_operation_place) == 0 &&
|
||||
ops_set_to_do.count(producer_operation_place) == 0) {
|
||||
if (!is_input && ops_done.count(producer_operation_place) == 0) {
|
||||
can_add = false;
|
||||
ops_to_do.push(producer_operation_place);
|
||||
ops_set_to_do.insert(producer_operation_place);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest;
|
||||
|
||||
static const std::vector<std::string> models{
|
||||
std::string("2in_2out/2in_2out.pb"),
|
||||
std::string("forward_edge_model/forward_edge_model.pb"),
|
||||
std::string("forward_edge_model2/forward_edge_model2.pb"),
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
||||
|
@ -0,0 +1,79 @@
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu"
|
||||
op: "Relu"
|
||||
input: "x"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "add"
|
||||
op: "AddV2"
|
||||
input: "Relu"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "add"
|
||||
input: "Relu"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
const2 = tf.constant(2.0, dtype=tf.float32)
|
||||
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
|
||||
relu = tf.nn.relu(x)
|
||||
add = tf.add(relu, const2, name="add")
|
||||
# it has forward-edge from relu to multiply
|
||||
# i.e. edge skipping direct child
|
||||
tf.multiply(add, relu)
|
||||
|
||||
tf.global_variables_initializer()
|
||||
tf.io.write_graph(sess.graph, '.', 'forward_edge_model.pb', as_text=False)
|
@ -0,0 +1,79 @@
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu"
|
||||
op: "Relu"
|
||||
input: "x"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "mul"
|
||||
op: "Mul"
|
||||
input: "Relu"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "AddV2"
|
||||
input: "Relu"
|
||||
input: "mul"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
const2 = tf.constant(2.0, dtype=tf.float32)
|
||||
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
|
||||
relu = tf.nn.relu(x)
|
||||
mul = tf.multiply(relu, const2, name="mul")
|
||||
# it has forward-edge from relu to multiply
|
||||
# i.e. edge skipping direct child
|
||||
tf.add(relu, mul)
|
||||
|
||||
tf.global_variables_initializer()
|
||||
tf.io.write_graph(sess.graph, '.', 'forward_edge_model2.pb', as_text=False)
|
Loading…
Reference in New Issue
Block a user