[TF FE] Break the cycle for NextIteration (#14063)

* [TF FE] Break the cycle for NextIteration

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix code-style

* Better to cut by outputs of operation creating the cycle

* Remove extra thing

* Add test model

* Add test for conversion of TF1 While

* Apply code-review feedback: use pointers and correct error message

* Remove extra check

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-11-19 03:31:11 +03:00 committed by GitHub
parent 9d8a03f90c
commit 384a961793
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 296 additions and 9 deletions

View File

@ -171,6 +171,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_place
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
std::unordered_set<std::shared_ptr<OpPlace>> ops_set_to_do;
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
for (const auto& output_place : m_outputs) {
@ -184,6 +185,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
"Custom specified output is incorrect: " + output_place_name);
auto output_operation_place = m_op_places_map.at(operation_name);
ops_to_do.push(output_operation_place);
ops_set_to_do.insert(output_operation_place);
}
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
@ -195,6 +197,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
if (ops_done.count(current_operation_place) == 0) {
bool can_add = true;
auto input_count = current_operation_decoder->get_input_size();
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
std::string producer_name;
size_t producer_output_port_idx;
@ -240,9 +243,14 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
is_input |= tensor_place->is_input();
}
if (!is_input && ops_done.count(producer_operation_place) == 0) {
// in case presence of NextIteration in the graph (or cycle created by other operation),
// we break the cycle by outputs from the NextIteration operation
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
if (!is_input && ops_done.count(producer_operation_place) == 0 &&
ops_set_to_do.count(producer_operation_place) == 0) {
can_add = false;
ops_to_do.push(producer_operation_place);
ops_set_to_do.insert(producer_operation_place);
}
}

View File

@ -38,6 +38,8 @@ if (tensorflow_FOUND)
set(TEST_TENSORFLOW_MODELS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TEST_TENSORFLOW_MODELS_DIRNAME}/)
file(GLOB_RECURSE TENSORFLOW_GEN_SCRIPTS ${CMAKE_CURRENT_SOURCE_DIR}/test_models/gen_scripts/generate_*.py)
file(GLOB_RECURSE TENSORFLOW_MODELS_PBTXT ${CMAKE_CURRENT_SOURCE_DIR}/test_models/models_pbtxt/*.pbtxt)
list (APPEND TENSORFLOW_GEN_SCRIPTS ${TENSORFLOW_MODELS_PBTXT})
file(GLOB_RECURSE TENSORFLOW_ALL_SCRIPTS ${CMAKE_CURRENT_SOURCE_DIR}/*.py)
set(OUT_FILES "")
foreach(GEN_SCRIPT ${TENSORFLOW_GEN_SCRIPTS})

View File

@ -10,6 +10,7 @@
#include "tf_utils.hpp"
#include "utils.hpp"
using namespace std;
using namespace ngraph;
using namespace ov::frontend;
@ -19,11 +20,11 @@ TEST(FrontEndConvertModelTest, test_unsupported_op) {
InputModel::Ptr inputModel;
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
ASSERT_NE(frontEnd, nullptr);
auto model_filename = FrontEndTestUtils::make_model_path(std::string(TEST_TENSORFLOW_MODELS_DIRNAME) +
std::string("relu_unsupported/relu_unsupported.pb"));
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
string("relu_unsupported/relu_unsupported.pb"));
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
ASSERT_NE(inputModel, nullptr);
std::shared_ptr<ngraph::Function> function;
shared_ptr<ngraph::Function> function;
ASSERT_THROW(function = frontEnd->convert(inputModel), OpConversionFailure);
ASSERT_EQ(function, nullptr);
ASSERT_NO_THROW(function = frontEnd->decode(inputModel));
@ -32,9 +33,35 @@ TEST(FrontEndConvertModelTest, test_unsupported_op) {
ASSERT_THROW(frontEnd->convert(function), OpConversionFailure);
for (auto& node : function->get_ordered_ops()) {
if (node->get_friendly_name() == "relu_0" && std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(node)) {
function->replace_node(node, std::make_shared<opset6::Relu>(node->input(0).get_source_output()));
if (node->get_friendly_name() == "relu_0" && dynamic_pointer_cast<ov::op::util::FrameworkNode>(node)) {
function->replace_node(node, make_shared<opset6::Relu>(node->input(0).get_source_output()));
}
}
ASSERT_NO_THROW(frontEnd->convert(function));
}
TEST(FrontEndConvertModelTest, test_unsupported_tf1_while) {
FrontEndManager fem;
FrontEnd::Ptr frontEnd;
InputModel::Ptr inputModel;
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
ASSERT_NE(frontEnd, nullptr);
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
string("model_tf1_while/model_tf1_while.pb"));
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
ASSERT_NE(inputModel, nullptr);
shared_ptr<ngraph::Function> function;
try {
function = frontEnd->convert(inputModel);
FAIL() << "TensorFlow 1 While is not supported in TF FE but conversion passed without errors. "
"OpConversionFailure is expected.";
} catch (const OpConversionFailure& error) {
string error_message = error.what();
string ref_message = "No translator found for Enter node.";
ASSERT_TRUE(error_message.find(ref_message) != string::npos);
ASSERT_EQ(function, nullptr);
} catch (...) {
FAIL() << "Conversion of TensorFlow 1 While failed by wrong reason.";
}
}

View File

@ -3,12 +3,11 @@
import os
import subprocess
import sys
print(sys.argv)
if len(sys.argv) < 4:
print("Script, output folder and mark file must be specified as arguments")
print("Script[model in pbtxt format], output folder and mark file must be specified as arguments")
exit(1)
gen_script = sys.argv[1]
@ -16,7 +15,20 @@ out_folder = sys.argv[2]
mark_file = sys.argv[3]
print("Processing: {} ".format(gen_script))
subprocess.run([sys.executable, gen_script, out_folder], env=os.environ)
if gen_script.endswith('.py'):
subprocess.run([sys.executable, gen_script, out_folder], env=os.environ)
elif gen_script.endswith('.pbtxt'):
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
model_pbtxt = gen_script
with open(model_pbtxt, "r") as f:
model_name = os.path.basename(model_pbtxt).split('.')[0]
graph_def = tf.GraphDef()
text_format.Merge(f.read(), graph_def)
tf.import_graph_def(graph_def, name='')
tf.io.write_graph(graph_def, os.path.join(sys.argv[2], model_name), model_name + '.pb', False)
# Create mark file indicating that script was executed
with open(mark_file, "w") as fp:

View File

@ -0,0 +1,219 @@
node {
name: "i"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "j"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "while/Enter"
op: "Enter"
input: "i"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "frame_name"
value {
s: "while/while_context"
}
}
attr {
key: "is_constant"
value {
b: false
}
}
attr {
key: "parallel_iterations"
value {
i: 10
}
}
}
node {
name: "while/Merge"
op: "Merge"
input: "while/Enter"
input: "while/NextIteration"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Less/y"
op: "Const"
input: "^while/Merge"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node {
name: "while/Less"
op: "Less"
input: "while/Merge"
input: "while/Less/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/LoopCond"
op: "LoopCond"
input: "while/Less"
}
node {
name: "while/Switch"
op: "Switch"
input: "while/Merge"
input: "while/LoopCond"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
s: "loc:@while/Merge"
}
}
}
}
node {
name: "while/Identity"
op: "Identity"
input: "while/Switch:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Add/y"
op: "Const"
input: "^while/Identity"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
}
node {
name: "while/Add"
op: "Add"
input: "while/Identity"
input: "while/Add/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/NextIteration"
op: "NextIteration"
input: "while/Add"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Exit"
op: "Exit"
input: "while/Switch"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "Add"
op: "Add"
input: "while/Exit"
input: "j"
attr {
key: "T"
value {
type: DT_INT32
}
}
}

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
# Note: run this script in TensorFlow 1 environment to generate model_tf1_while.pbtxt
# The model with Switch, NextIteration and other TF1 While stuff cannot be generated in TF2 environment
with tf.Session() as sess:
i = tf.placeholder(tf.int32, [], 'i')
j = tf.placeholder(tf.int32, [], 'j')
r = tf.while_loop(lambda i: tf.less(i, 10), lambda i: (tf.add(i, 1),), [i])
tf.add(r, j)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_tf1_while.pbtxt', True)