[TF FE] Skip Assert operation and add test (#16484)

At the conversion stage we can't resolve Assert node because the condition
is computed only during inference time.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-03-23 11:49:46 +04:00 committed by GitHub
parent 17174a3839
commit aaa4a4c210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 66 deletions

View File

@ -1,24 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_op_table.hpp"
#include "openvino/frontend/tensorflow/node_context.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_lookup_table_insert_op(const ov::frontend::tensorflow::NodeContext& node) {
// auto-pruning of unsupported sub-graphs that contain
// operations working with dictionaries
default_op_checks(node, 3, {"LookupTableInsert", "LookupTableInsertV2"});
return {};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -26,7 +26,6 @@ TF_OP_CONVERTER(translate_gru_block_cell_op);
TF_OP_CONVERTER(translate_hash_table_op);
TF_OP_CONVERTER(translate_iterator_get_next_op);
TF_OP_CONVERTER(translate_iterator_op);
TF_OP_CONVERTER(translate_lookup_table_insert_op);
TF_OP_CONVERTER(translate_partitioned_call_op);
TF_OP_CONVERTER(translate_queue_dequeue_op);
TF_OP_CONVERTER(translate_queue_dequeue_many_op);
@ -105,7 +104,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"AddN", translate_add_n_op},
{"ArgMax", translate_arg_max_op},
{"ArgMin", translate_arg_min_op},
{"Assert", translate_assert_op},
{"Assert", translate_no_op},
{"AvgPool", translate_avg_pool_op},
{"AvgPool3D", translate_avg_pool_op},
{"BatchMatMul", translate_batch_mat_mul_op},
@ -164,8 +163,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"ListDiff", translate_list_diff_op},
{"LogSoftmax", translate_log_softmax_op},
{"Log1p", translate_log_1p_op},
{"LookupTableInsert", translate_lookup_table_insert_op},
{"LookupTableInsertV2", translate_lookup_table_insert_op},
{"LookupTableInsert", translate_no_op},
{"LookupTableInsertV2", translate_no_op},
{"LRN", translate_lrn_op},
{"MatMul", translate_mat_mul_op},
{"MatrixDiag", translate_matrix_diag_op},

View File

@ -346,6 +346,7 @@ TEST_F(TransformationTestsF, ModelWithIteratorGetNextAndUnsupportedOp) {
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x, y});
}
}
TEST_F(TransformationTestsF, ModelWithMultioutputBodyGraphNode) {
{ model = convert_model("partitioned_call2/partitioned_call2.pb"); }
{
@ -376,3 +377,13 @@ TEST_F(TransformationTestsF, ModelWithEmptyTensorListAndPushBack) {
model_ref = make_shared<Model>(OutputVector{recover_item}, ParameterVector{x});
}
}
TEST_F(TransformationTestsF, ModelWithAssertNode) {
{ model = convert_model("model_with_assert/model_with_assert.pb"); }
{
auto x = make_shared<Parameter>(i32, PartialShape{Dimension::dynamic()});
auto y = make_shared<Parameter>(i32, PartialShape{Dimension::dynamic()});
auto add = make_shared<Add>(x, y);
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x, y});
}
}

View File

@ -0,0 +1,38 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# model with Assert node generator
#
import os
import sys
import numpy as np
import tensorflow as tf
def main():
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(dtype=tf.int32, shape=[None], name='x')
y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[None], name='y')
tf.raw_ops.AddV2(x=x, y=y)
shape1 = tf.raw_ops.Shape(input=x)
shape2 = tf.raw_ops.Shape(input=y)
equal = tf.raw_ops.Equal(x=shape1, y=shape2)
axis = tf.constant([0], dtype=tf.int32)
all_equal = tf.raw_ops.All(input=equal, axis=axis)
message = tf.constant("Shapes of operands are incompatible", dtype=tf.string)
tf.raw_ops.Assert(condition=all_equal, data=[message])
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "model_with_assert"), "model_with_assert.pb", False)
if __name__ == "__main__":
main()

View File

@ -34,7 +34,6 @@ OP_T_CONVERTER(translate_direct_reduce_op);
OP_CONVERTER(translate_add_n_op);
OP_CONVERTER(translate_arg_max_op);
OP_CONVERTER(translate_arg_min_op);
OP_CONVERTER(translate_assert_op);
OP_CONVERTER(translate_avg_pool_op);
OP_CONVERTER(translate_batch_mat_mul_op);
OP_CONVERTER(translate_batch_to_space_nd_op);

View File

@ -1,36 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <numeric>
#include "common_op_table.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_assert_op(const NodeContext& node) {
default_op_checks(node, 1, {"Assert"});
auto cond = node.get_input(0);
auto cond_const = get_constant_from_source(cond);
TENSORFLOW_OP_VALIDATION(node,
cond_const,
"[TensorFlow Frontend] The condition must be constant for further model conversion.");
auto cond_values = cond_const->cast_vector<bool>();
TENSORFLOW_OP_VALIDATION(node,
cond_values.size() == 1,
"[TensorFlow Frontend] Incorrect model - the condition must have one element.");
TENSORFLOW_OP_VALIDATION(node,
cond_values[0],
"[TensorFlow Frontend] The condition must be true for further model conversion.");
return {};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -15,7 +15,7 @@ namespace op {
OutputVector translate_no_op(const NodeContext& node) {
// the operation does nothing in terms of data generation
default_op_checks(node, 0, {"NoOp", "SaveV2"});
default_op_checks(node, 0, {"NoOp", "SaveV2", "Assert", "LookupTableInsert", "LookupTableInsertV2"});
return {};
}
} // namespace op