[TF FE] Break the cycle for NextIteration (#14063)
* [TF FE] Break the cycle for NextIteration Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix code-style * Better to cut by outputs of operation creating the cycle * Remove extra thing * Add test model * Add test for conversion of TF1 While * Apply code-review feedback: use pointers and correct error message * Remove extra check Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
9d8a03f90c
commit
384a961793
@ -171,6 +171,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::get_op_place
|
||||
std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cut_nodes() const {
|
||||
std::vector<std::shared_ptr<OpPlace>> topologically_sorted_ops;
|
||||
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
|
||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_set_to_do;
|
||||
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
|
||||
|
||||
for (const auto& output_place : m_outputs) {
|
||||
@ -184,6 +185,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
"Custom specified output is incorrect: " + output_place_name);
|
||||
auto output_operation_place = m_op_places_map.at(operation_name);
|
||||
ops_to_do.push(output_operation_place);
|
||||
ops_set_to_do.insert(output_operation_place);
|
||||
}
|
||||
|
||||
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
|
||||
@ -195,6 +197,7 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
if (ops_done.count(current_operation_place) == 0) {
|
||||
bool can_add = true;
|
||||
auto input_count = current_operation_decoder->get_input_size();
|
||||
|
||||
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
|
||||
std::string producer_name;
|
||||
size_t producer_output_port_idx;
|
||||
@ -240,9 +243,14 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::determine_cu
|
||||
is_input |= tensor_place->is_input();
|
||||
}
|
||||
|
||||
if (!is_input && ops_done.count(producer_operation_place) == 0) {
|
||||
// in case presence of NextIteration in the graph (or cycle created by other operation),
|
||||
// we break the cycle by outputs from the NextIteration operation
|
||||
// otherwise, the operations nodes in the cycle will be added to ops_to_do infinitely
|
||||
if (!is_input && ops_done.count(producer_operation_place) == 0 &&
|
||||
ops_set_to_do.count(producer_operation_place) == 0) {
|
||||
can_add = false;
|
||||
ops_to_do.push(producer_operation_place);
|
||||
ops_set_to_do.insert(producer_operation_place);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,8 @@ if (tensorflow_FOUND)
|
||||
set(TEST_TENSORFLOW_MODELS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TEST_TENSORFLOW_MODELS_DIRNAME}/)
|
||||
|
||||
file(GLOB_RECURSE TENSORFLOW_GEN_SCRIPTS ${CMAKE_CURRENT_SOURCE_DIR}/test_models/gen_scripts/generate_*.py)
|
||||
file(GLOB_RECURSE TENSORFLOW_MODELS_PBTXT ${CMAKE_CURRENT_SOURCE_DIR}/test_models/models_pbtxt/*.pbtxt)
|
||||
list (APPEND TENSORFLOW_GEN_SCRIPTS ${TENSORFLOW_MODELS_PBTXT})
|
||||
file(GLOB_RECURSE TENSORFLOW_ALL_SCRIPTS ${CMAKE_CURRENT_SOURCE_DIR}/*.py)
|
||||
set(OUT_FILES "")
|
||||
foreach(GEN_SCRIPT ${TENSORFLOW_GEN_SCRIPTS})
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "tf_utils.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace ov::frontend;
|
||||
|
||||
@ -19,11 +20,11 @@ TEST(FrontEndConvertModelTest, test_unsupported_op) {
|
||||
InputModel::Ptr inputModel;
|
||||
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
|
||||
ASSERT_NE(frontEnd, nullptr);
|
||||
auto model_filename = FrontEndTestUtils::make_model_path(std::string(TEST_TENSORFLOW_MODELS_DIRNAME) +
|
||||
std::string("relu_unsupported/relu_unsupported.pb"));
|
||||
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
|
||||
string("relu_unsupported/relu_unsupported.pb"));
|
||||
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
|
||||
ASSERT_NE(inputModel, nullptr);
|
||||
std::shared_ptr<ngraph::Function> function;
|
||||
shared_ptr<ngraph::Function> function;
|
||||
ASSERT_THROW(function = frontEnd->convert(inputModel), OpConversionFailure);
|
||||
ASSERT_EQ(function, nullptr);
|
||||
ASSERT_NO_THROW(function = frontEnd->decode(inputModel));
|
||||
@ -32,9 +33,35 @@ TEST(FrontEndConvertModelTest, test_unsupported_op) {
|
||||
ASSERT_THROW(frontEnd->convert(function), OpConversionFailure);
|
||||
|
||||
for (auto& node : function->get_ordered_ops()) {
|
||||
if (node->get_friendly_name() == "relu_0" && std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(node)) {
|
||||
function->replace_node(node, std::make_shared<opset6::Relu>(node->input(0).get_source_output()));
|
||||
if (node->get_friendly_name() == "relu_0" && dynamic_pointer_cast<ov::op::util::FrameworkNode>(node)) {
|
||||
function->replace_node(node, make_shared<opset6::Relu>(node->input(0).get_source_output()));
|
||||
}
|
||||
}
|
||||
ASSERT_NO_THROW(frontEnd->convert(function));
|
||||
}
|
||||
|
||||
TEST(FrontEndConvertModelTest, test_unsupported_tf1_while) {
|
||||
FrontEndManager fem;
|
||||
FrontEnd::Ptr frontEnd;
|
||||
InputModel::Ptr inputModel;
|
||||
ASSERT_NO_THROW(frontEnd = fem.load_by_framework(TF_FE));
|
||||
ASSERT_NE(frontEnd, nullptr);
|
||||
auto model_filename = FrontEndTestUtils::make_model_path(string(TEST_TENSORFLOW_MODELS_DIRNAME) +
|
||||
string("model_tf1_while/model_tf1_while.pb"));
|
||||
ASSERT_NO_THROW(inputModel = frontEnd->load(model_filename));
|
||||
ASSERT_NE(inputModel, nullptr);
|
||||
shared_ptr<ngraph::Function> function;
|
||||
|
||||
try {
|
||||
function = frontEnd->convert(inputModel);
|
||||
FAIL() << "TensorFlow 1 While is not supported in TF FE but conversion passed without errors. "
|
||||
"OpConversionFailure is expected.";
|
||||
} catch (const OpConversionFailure& error) {
|
||||
string error_message = error.what();
|
||||
string ref_message = "No translator found for Enter node.";
|
||||
ASSERT_TRUE(error_message.find(ref_message) != string::npos);
|
||||
ASSERT_EQ(function, nullptr);
|
||||
} catch (...) {
|
||||
FAIL() << "Conversion of TensorFlow 1 While failed by wrong reason.";
|
||||
}
|
||||
}
|
||||
|
@ -3,12 +3,11 @@
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
|
||||
print(sys.argv)
|
||||
if len(sys.argv) < 4:
|
||||
print("Script, output folder and mark file must be specified as arguments")
|
||||
print("Script[model in pbtxt format], output folder and mark file must be specified as arguments")
|
||||
exit(1)
|
||||
|
||||
gen_script = sys.argv[1]
|
||||
@ -16,7 +15,20 @@ out_folder = sys.argv[2]
|
||||
mark_file = sys.argv[3]
|
||||
|
||||
print("Processing: {} ".format(gen_script))
|
||||
subprocess.run([sys.executable, gen_script, out_folder], env=os.environ)
|
||||
|
||||
if gen_script.endswith('.py'):
|
||||
subprocess.run([sys.executable, gen_script, out_folder], env=os.environ)
|
||||
elif gen_script.endswith('.pbtxt'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
from google.protobuf import text_format
|
||||
|
||||
model_pbtxt = gen_script
|
||||
with open(model_pbtxt, "r") as f:
|
||||
model_name = os.path.basename(model_pbtxt).split('.')[0]
|
||||
graph_def = tf.GraphDef()
|
||||
text_format.Merge(f.read(), graph_def)
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
tf.io.write_graph(graph_def, os.path.join(sys.argv[2], model_name), model_name + '.pb', False)
|
||||
|
||||
# Create mark file indicating that script was executed
|
||||
with open(mark_file, "w") as fp:
|
||||
|
@ -0,0 +1,219 @@
|
||||
node {
|
||||
name: "i"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "j"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Enter"
|
||||
op: "Enter"
|
||||
input: "i"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "frame_name"
|
||||
value {
|
||||
s: "while/while_context"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "is_constant"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "parallel_iterations"
|
||||
value {
|
||||
i: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Merge"
|
||||
op: "Merge"
|
||||
input: "while/Enter"
|
||||
input: "while/NextIteration"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less/y"
|
||||
op: "Const"
|
||||
input: "^while/Merge"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Less"
|
||||
op: "Less"
|
||||
input: "while/Merge"
|
||||
input: "while/Less/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/LoopCond"
|
||||
op: "LoopCond"
|
||||
input: "while/Less"
|
||||
}
|
||||
node {
|
||||
name: "while/Switch"
|
||||
op: "Switch"
|
||||
input: "while/Merge"
|
||||
input: "while/LoopCond"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@while/Merge"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity"
|
||||
op: "Identity"
|
||||
input: "while/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add/y"
|
||||
op: "Const"
|
||||
input: "^while/Identity"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Add"
|
||||
op: "Add"
|
||||
input: "while/Identity"
|
||||
input: "while/Add/y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/NextIteration"
|
||||
op: "NextIteration"
|
||||
input: "while/Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Exit"
|
||||
op: "Exit"
|
||||
input: "while/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "while/Exit"
|
||||
input: "j"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
# Note: run this script in TensorFlow 1 environment to generate model_tf1_while.pbtxt
|
||||
# The model with Switch, NextIteration and other TF1 While stuff cannot be generated in TF2 environment
|
||||
with tf.Session() as sess:
|
||||
i = tf.placeholder(tf.int32, [], 'i')
|
||||
j = tf.placeholder(tf.int32, [], 'j')
|
||||
|
||||
r = tf.while_loop(lambda i: tf.less(i, 10), lambda i: (tf.add(i, 1),), [i])
|
||||
tf.add(r, j)
|
||||
tf.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
tf.io.write_graph(tf_net, './', 'model_tf1_while.pbtxt', True)
|
Loading…
Reference in New Issue
Block a user