[TF FE] Break the cycle in the different way (#14480)
* [TF FE] Break the cycle in the different way Earlier solution was incorrect due to inproper handling of forward edges cases (edges going from parent to grand-child) for which topological sorting of nodes can be interrupted. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Break the cycle by NextIteration inputs Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
b1700d97f1
commit
32ae862f99
@ -207,9 +207,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_place
|
|||||||
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
|
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
|
||||||
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
|
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
|
||||||
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
|
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
|
||||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_set_to_do;
|
|
||||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
|
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
|
||||||
|
|
||||||
|
// TODO: implement logic to check direct cycles in the graph
|
||||||
|
// and break them
|
||||||
|
// probably not only NextIteration can generate cycles
|
||||||
|
|
||||||
for (const auto& output_place : m_outputs) {
|
for (const auto& output_place : m_outputs) {
|
||||||
FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name.");
|
FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name.");
|
||||||
auto output_place_name = output_place->get_names()[0];
|
auto output_place_name = output_place->get_names()[0];
|
||||||
@ -221,7 +224,6 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
|||||||
"Custom specified output is incorrect: " + output_place_name);
|
"Custom specified output is incorrect: " + output_place_name);
|
||||||
auto output_operation_place = m_op_places_map.at(operation_name);
|
auto output_operation_place = m_op_places_map.at(operation_name);
|
||||||
ops_to_do.push(output_operation_place);
|
ops_to_do.push(output_operation_place);
|
||||||
ops_set_to_do.insert(output_operation_place);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
|
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
|
||||||
@ -233,6 +235,12 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
|||||||
if (ops_done.count(current_operation_place) == 0) {
|
if (ops_done.count(current_operation_place) == 0) {
|
||||||
bool can_add = true;
|
bool can_add = true;
|
||||||
auto input_count = current_operation_decoder->get_input_size();
|
auto input_count = current_operation_decoder->get_input_size();
|
||||||
|
auto current_operation_type = current_operation_decoder->get_op_type();
|
||||||
|
|
||||||
|
if (current_operation_type == "NextIteration") {
|
||||||
|
// break the cycle created by NextIteration
|
||||||
|
input_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
|
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
|
||||||
std::string producer_name;
|
std::string producer_name;
|
||||||
@ -282,11 +290,9 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
|||||||
// in case presence of NextIteration in the graph (or cycle created by other operation),
|
// in case presence of NextIteration in the graph (or cycle created by other operation),
|
||||||
// we break the cycle by outputs from the NextIteration operation
|
// we break the cycle by outputs from the NextIteration operation
|
||||||
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
|
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
|
||||||
if (!is_input && ops_done.count(producer_operation_place) == 0 &&
|
if (!is_input && ops_done.count(producer_operation_place) == 0) {
|
||||||
ops_set_to_do.count(producer_operation_place) == 0) {
|
|
||||||
can_add = false;
|
can_add = false;
|
||||||
ops_to_do.push(producer_operation_place);
|
ops_to_do.push(producer_operation_place);
|
||||||
ops_set_to_do.insert(producer_operation_place);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest;
|
|||||||
|
|
||||||
static const std::vector<std::string> models{
|
static const std::vector<std::string> models{
|
||||||
std::string("2in_2out/2in_2out.pb"),
|
std::string("2in_2out/2in_2out.pb"),
|
||||||
|
std::string("forward_edge_model/forward_edge_model.pb"),
|
||||||
|
std::string("forward_edge_model2/forward_edge_model2.pb"),
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
||||||
|
@ -0,0 +1,79 @@
|
|||||||
|
node {
|
||||||
|
name: "Const"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
float_val: 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Relu"
|
||||||
|
op: "Relu"
|
||||||
|
input: "x"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "add"
|
||||||
|
op: "AddV2"
|
||||||
|
input: "Relu"
|
||||||
|
input: "Const"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "add"
|
||||||
|
input: "Relu"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
const2 = tf.constant(2.0, dtype=tf.float32)
|
||||||
|
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
|
||||||
|
relu = tf.nn.relu(x)
|
||||||
|
add = tf.add(relu, const2, name="add")
|
||||||
|
# it has forward-edge from relu to multiply
|
||||||
|
# i.e. edge skipping direct child
|
||||||
|
tf.multiply(add, relu)
|
||||||
|
|
||||||
|
tf.global_variables_initializer()
|
||||||
|
tf.io.write_graph(sess.graph, '.', 'forward_edge_model.pb', as_text=False)
|
@ -0,0 +1,79 @@
|
|||||||
|
node {
|
||||||
|
name: "Const"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
float_val: 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Relu"
|
||||||
|
op: "Relu"
|
||||||
|
input: "x"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "Relu"
|
||||||
|
input: "Const"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Add"
|
||||||
|
op: "AddV2"
|
||||||
|
input: "Relu"
|
||||||
|
input: "mul"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
const2 = tf.constant(2.0, dtype=tf.float32)
|
||||||
|
x = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='x')
|
||||||
|
relu = tf.nn.relu(x)
|
||||||
|
mul = tf.multiply(relu, const2, name="mul")
|
||||||
|
# it has forward-edge from relu to multiply
|
||||||
|
# i.e. edge skipping direct child
|
||||||
|
tf.add(relu, mul)
|
||||||
|
|
||||||
|
tf.global_variables_initializer()
|
||||||
|
tf.io.write_graph(sess.graph, '.', 'forward_edge_model2.pb', as_text=False)
|
Loading…
Reference in New Issue
Block a user