[TF FE] Support Assert with string tensors (#14640)

* [TF FE] Support Assert with string tensors

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update src/frontends/tensorflow/src/op/const.cpp

* Update src/frontends/tensorflow/src/op/const.cpp

* Apply code-review feedback: better to use UnsupportedConstant

* Correct unit-test

* Replace Op with Node

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-12-15 19:28:34 +04:00 committed by GitHub
parent 9c0ec2c9b4
commit e44a4fc6d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 291 additions and 5 deletions

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "internal_operation.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
class UnsupportedConstant : public InternalOperation {
public:
OPENVINO_OP("UnsupportedConstant", "ov::frontend::tensorflow::util", InternalOperation);
UnsupportedConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_type(0, ov::element::undefined, ov::PartialShape::dynamic());
}
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <numeric>
#include "op_table.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_assert_op(const NodeContext& node) {
default_op_checks(node, 1, {"Assert"});
auto cond = node.get_input(0);
auto cond_const = get_constant_from_source(cond);
TENSORFLOW_OP_VALIDATION(node,
cond_const,
"[TensorFlow Frontend] The condition must be constant for further model conversion.");
auto cond_values = cond_const->cast_vector<bool>();
TENSORFLOW_OP_VALIDATION(node,
cond_values.size() == 1,
"[TensorFlow Frontend] Incorrect model - the condition must have one element.");
TENSORFLOW_OP_VALIDATION(node,
cond_values[0],
"[TensorFlow Frontend] The condition must be true for further model conversion.");
return {};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -2,11 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_ops/unsupported_constant.hpp"
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
using namespace ov;
namespace ov {
namespace frontend {
@ -14,12 +16,18 @@ namespace tensorflow {
namespace op {
OutputVector translate_const_op(const NodeContext& node) {
auto tensor = node.get_attribute<ov::Tensor>("value");
auto res = std::make_shared<ov::opset8::Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());
set_node_name(node.get_name(), res);
return {res};
auto ov_type = node.get_attribute<element::Type>("dtype");
std::shared_ptr<Node> const_node;
if (ov_type == element::undefined) {
const_node = std::make_shared<UnsupportedConstant>();
} else {
auto tensor = node.get_attribute<Tensor>("value");
const_node = std::make_shared<Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());
}
set_node_name(node.get_name(), const_node);
return {const_node};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@ -27,6 +27,7 @@ OP_T_CONVERTER(translate_direct_reduce_op);
OP_CONVERTER(translate_add_n_op);
OP_CONVERTER(translate_arg_max_op);
OP_CONVERTER(translate_arg_min_op);
OP_CONVERTER(translate_assert_op);
OP_CONVERTER(translate_avg_pool_op);
OP_CONVERTER(translate_batch_mat_mul_op);
OP_CONVERTER(translate_batch_to_space_nd_op);
@ -204,6 +205,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"AddN", translate_add_n_op},
{"ArgMax", translate_arg_max_op},
{"ArgMin", translate_arg_min_op},
{"Assert", translate_assert_op},
{"AvgPool", translate_avg_pool_op},
{"AvgPool3D", translate_avg_pool_op},
{"BatchMatMul", translate_batch_mat_mul_op},

View File

@ -4,13 +4,19 @@
#include <openvino/frontend/exception.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <transformations/common_optimizations/moc_transformations.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "test_common.hpp"
#include "tf_utils.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::element;
using namespace ov::opset10;
using namespace ov::frontend;
namespace {
@ -69,3 +75,19 @@ TEST(FrontEndConvertTrickyModels, model_with_output_shapes) {
}
}
}
TEST_F(TransformationTestsF, AssertAndStringTensors) {
{
model = convert_model("string_tensors_model/string_tensors_model.pb");
// TODO: investigate - why we have redundant nodes after the conversion
manager.register_pass<pass::MOCTransformations>(false);
}
{
auto x = make_shared<Parameter>(f32, Shape{2, 3});
auto y = make_shared<Parameter>(f32, Shape{2, 3});
auto cond = make_shared<Constant>(boolean, Shape{1, 1}, std::vector<bool>{true});
auto select = make_shared<Select>(cond, x, y);
model_ref = make_shared<Model>(OutputVector{select}, ParameterVector{x, y});
}
}

View File

@ -0,0 +1,167 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node {
name: "Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "TensorFlow Frontend"
}
}
}
}
node {
name: "Const_2"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "TensorFlow Frontend, ONNX Frontend"
}
}
}
}
node {
name: "Const_3"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "TensorFlow Frontend, ONNX Frontend, PDPD Frontend"
}
}
}
}
node {
name: "Select"
op: "Select"
input: "Const"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Assert/Assert"
op: "Assert"
input: "Const"
input: "Const_1"
input: "Const_2"
input: "Const_3"
attr {
key: "T"
value {
list {
type: DT_STRING
type: DT_STRING
type: DT_STRING
}
}
}
attr {
key: "summarize"
value {
i: 3
}
}
}

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
# Create the graph and model
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 3], 'x')
y = tf.placeholder(tf.float32, [2, 3], 'y')
cond = tf.constant(True, dtype=tf.bool)
message1 = tf.constant("TensorFlow Frontend", dtype=tf.string)
message2 = tf.constant("TensorFlow Frontend, ONNX Frontend", dtype=tf.string)
message3 = tf.constant("TensorFlow Frontend, ONNX Frontend, PDPD Frontend", dtype=tf.string)
select = tf.where(cond, x, y)
assert_op = tf.Assert(cond, [message1, message2, message3])
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'string_tensors_model.pbtxt', as_text=True)