[TF FE] Skip Assert operation and add test (#16484)
At the conversion stage we can't resolve Assert node because the condition is computed only during inference time. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
17174a3839
commit
aaa4a4c210
@ -1,24 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_lookup_table_insert_op(const ov::frontend::tensorflow::NodeContext& node) {
|
||||
// auto-pruning of unsupported sub-graphs that contain
|
||||
// operations working with dictionaries
|
||||
default_op_checks(node, 3, {"LookupTableInsert", "LookupTableInsertV2"});
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -26,7 +26,6 @@ TF_OP_CONVERTER(translate_gru_block_cell_op);
|
||||
TF_OP_CONVERTER(translate_hash_table_op);
|
||||
TF_OP_CONVERTER(translate_iterator_get_next_op);
|
||||
TF_OP_CONVERTER(translate_iterator_op);
|
||||
TF_OP_CONVERTER(translate_lookup_table_insert_op);
|
||||
TF_OP_CONVERTER(translate_partitioned_call_op);
|
||||
TF_OP_CONVERTER(translate_queue_dequeue_op);
|
||||
TF_OP_CONVERTER(translate_queue_dequeue_many_op);
|
||||
@ -105,7 +104,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"AddN", translate_add_n_op},
|
||||
{"ArgMax", translate_arg_max_op},
|
||||
{"ArgMin", translate_arg_min_op},
|
||||
{"Assert", translate_assert_op},
|
||||
{"Assert", translate_no_op},
|
||||
{"AvgPool", translate_avg_pool_op},
|
||||
{"AvgPool3D", translate_avg_pool_op},
|
||||
{"BatchMatMul", translate_batch_mat_mul_op},
|
||||
@ -164,8 +163,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"ListDiff", translate_list_diff_op},
|
||||
{"LogSoftmax", translate_log_softmax_op},
|
||||
{"Log1p", translate_log_1p_op},
|
||||
{"LookupTableInsert", translate_lookup_table_insert_op},
|
||||
{"LookupTableInsertV2", translate_lookup_table_insert_op},
|
||||
{"LookupTableInsert", translate_no_op},
|
||||
{"LookupTableInsertV2", translate_no_op},
|
||||
{"LRN", translate_lrn_op},
|
||||
{"MatMul", translate_mat_mul_op},
|
||||
{"MatrixDiag", translate_matrix_diag_op},
|
||||
|
@ -346,6 +346,7 @@ TEST_F(TransformationTestsF, ModelWithIteratorGetNextAndUnsupportedOp) {
|
||||
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ModelWithMultioutputBodyGraphNode) {
|
||||
{ model = convert_model("partitioned_call2/partitioned_call2.pb"); }
|
||||
{
|
||||
@ -376,3 +377,13 @@ TEST_F(TransformationTestsF, ModelWithEmptyTensorListAndPushBack) {
|
||||
model_ref = make_shared<Model>(OutputVector{recover_item}, ParameterVector{x});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ModelWithAssertNode) {
|
||||
{ model = convert_model("model_with_assert/model_with_assert.pb"); }
|
||||
{
|
||||
auto x = make_shared<Parameter>(i32, PartialShape{Dimension::dynamic()});
|
||||
auto y = make_shared<Parameter>(i32, PartialShape{Dimension::dynamic()});
|
||||
auto add = make_shared<Add>(x, y);
|
||||
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,38 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
# model with Assert node generator
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def main():
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(dtype=tf.int32, shape=[None], name='x')
|
||||
y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[None], name='y')
|
||||
tf.raw_ops.AddV2(x=x, y=y)
|
||||
shape1 = tf.raw_ops.Shape(input=x)
|
||||
shape2 = tf.raw_ops.Shape(input=y)
|
||||
equal = tf.raw_ops.Equal(x=shape1, y=shape2)
|
||||
axis = tf.constant([0], dtype=tf.int32)
|
||||
all_equal = tf.raw_ops.All(input=equal, axis=axis)
|
||||
message = tf.constant("Shapes of operands are incompatible", dtype=tf.string)
|
||||
tf.raw_ops.Assert(condition=all_equal, data=[message])
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "model_with_assert"), "model_with_assert.pb", False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -34,7 +34,6 @@ OP_T_CONVERTER(translate_direct_reduce_op);
|
||||
OP_CONVERTER(translate_add_n_op);
|
||||
OP_CONVERTER(translate_arg_max_op);
|
||||
OP_CONVERTER(translate_arg_min_op);
|
||||
OP_CONVERTER(translate_assert_op);
|
||||
OP_CONVERTER(translate_avg_pool_op);
|
||||
OP_CONVERTER(translate_batch_mat_mul_op);
|
||||
OP_CONVERTER(translate_batch_to_space_nd_op);
|
||||
|
@ -1,36 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_assert_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"Assert"});
|
||||
auto cond = node.get_input(0);
|
||||
auto cond_const = get_constant_from_source(cond);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_const,
|
||||
"[TensorFlow Frontend] The condition must be constant for further model conversion.");
|
||||
auto cond_values = cond_const->cast_vector<bool>();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_values.size() == 1,
|
||||
"[TensorFlow Frontend] Incorrect model - the condition must have one element.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_values[0],
|
||||
"[TensorFlow Frontend] The condition must be true for further model conversion.");
|
||||
return {};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
OutputVector translate_no_op(const NodeContext& node) {
|
||||
// the operation does nothing in terms of data generation
|
||||
default_op_checks(node, 0, {"NoOp", "SaveV2"});
|
||||
default_op_checks(node, 0, {"NoOp", "SaveV2", "Assert", "LookupTableInsert", "LookupTableInsertV2"});
|
||||
return {};
|
||||
}
|
||||
} // namespace op
|
||||
|
Loading…
Reference in New Issue
Block a user