[TF FE] Support Assert with string tensors (#14640)
* [TF FE] Support Assert with string tensors Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update src/frontends/tensorflow/src/op/const.cpp * Update src/frontends/tensorflow/src/op/const.cpp * Apply code-review feedback: better to use UnsupportedConstant * Correct unit-test * Replace Op with Node Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
9c0ec2c9b4
commit
e44a4fc6d2
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "internal_operation.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
class UnsupportedConstant : public InternalOperation {
|
||||
public:
|
||||
OPENVINO_OP("UnsupportedConstant", "ov::frontend::tensorflow::util", InternalOperation);
|
||||
|
||||
UnsupportedConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
void validate_and_infer_types() override {
|
||||
set_output_type(0, ov::element::undefined, ov::PartialShape::dynamic());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
36
src/frontends/tensorflow/src/op/assert.cpp
Normal file
36
src/frontends/tensorflow/src/op/assert.cpp
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_assert_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"Assert"});
|
||||
auto cond = node.get_input(0);
|
||||
auto cond_const = get_constant_from_source(cond);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_const,
|
||||
"[TensorFlow Frontend] The condition must be constant for further model conversion.");
|
||||
auto cond_values = cond_const->cast_vector<bool>();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_values.size() == 1,
|
||||
"[TensorFlow Frontend] Incorrect model - the condition must have one element.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
cond_values[0],
|
||||
"[TensorFlow Frontend] The condition must be true for further model conversion.");
|
||||
return {};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -2,11 +2,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
using namespace ov;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -14,12 +16,18 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_const_op(const NodeContext& node) {
|
||||
auto tensor = node.get_attribute<ov::Tensor>("value");
|
||||
auto res = std::make_shared<ov::opset8::Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());
|
||||
set_node_name(node.get_name(), res);
|
||||
return {res};
|
||||
auto ov_type = node.get_attribute<element::Type>("dtype");
|
||||
std::shared_ptr<Node> const_node;
|
||||
if (ov_type == element::undefined) {
|
||||
const_node = std::make_shared<UnsupportedConstant>();
|
||||
} else {
|
||||
auto tensor = node.get_attribute<Tensor>("value");
|
||||
const_node = std::make_shared<Constant>(tensor.get_element_type(), tensor.get_shape(), tensor.data());
|
||||
}
|
||||
set_node_name(node.get_name(), const_node);
|
||||
return {const_node};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -27,6 +27,7 @@ OP_T_CONVERTER(translate_direct_reduce_op);
|
||||
OP_CONVERTER(translate_add_n_op);
|
||||
OP_CONVERTER(translate_arg_max_op);
|
||||
OP_CONVERTER(translate_arg_min_op);
|
||||
OP_CONVERTER(translate_assert_op);
|
||||
OP_CONVERTER(translate_avg_pool_op);
|
||||
OP_CONVERTER(translate_batch_mat_mul_op);
|
||||
OP_CONVERTER(translate_batch_to_space_nd_op);
|
||||
@ -204,6 +205,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"AddN", translate_add_n_op},
|
||||
{"ArgMax", translate_arg_max_op},
|
||||
{"ArgMin", translate_arg_min_op},
|
||||
{"Assert", translate_assert_op},
|
||||
{"AvgPool", translate_avg_pool_op},
|
||||
{"AvgPool3D", translate_avg_pool_op},
|
||||
{"BatchMatMul", translate_batch_mat_mul_op},
|
||||
|
@ -4,13 +4,19 @@
|
||||
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <transformations/common_optimizations/moc_transformations.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_common.hpp"
|
||||
#include "tf_utils.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::element;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::frontend;
|
||||
|
||||
namespace {
|
||||
@ -69,3 +75,19 @@ TEST(FrontEndConvertTrickyModels, model_with_output_shapes) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, AssertAndStringTensors) {
|
||||
{
|
||||
model = convert_model("string_tensors_model/string_tensors_model.pb");
|
||||
// TODO: investigate - why we have redundant nodes after the conversion
|
||||
manager.register_pass<pass::MOCTransformations>(false);
|
||||
}
|
||||
{
|
||||
auto x = make_shared<Parameter>(f32, Shape{2, 3});
|
||||
auto y = make_shared<Parameter>(f32, Shape{2, 3});
|
||||
auto cond = make_shared<Constant>(boolean, Shape{1, 1}, std::vector<bool>{true});
|
||||
auto select = make_shared<Select>(cond, x, y);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{select}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,167 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_BOOL
|
||||
tensor_shape {
|
||||
}
|
||||
bool_val: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "TensorFlow Frontend"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_2"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "TensorFlow Frontend, ONNX Frontend"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_3"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "TensorFlow Frontend, ONNX Frontend, PDPD Frontend"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Select"
|
||||
op: "Select"
|
||||
input: "Const"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Assert/Assert"
|
||||
op: "Assert"
|
||||
input: "Const"
|
||||
input: "Const_1"
|
||||
input: "Const_2"
|
||||
input: "Const_3"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_STRING
|
||||
type: DT_STRING
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "summarize"
|
||||
value {
|
||||
i: 3
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.Session() as sess:
|
||||
x = tf.placeholder(tf.float32, [2, 3], 'x')
|
||||
y = tf.placeholder(tf.float32, [2, 3], 'y')
|
||||
cond = tf.constant(True, dtype=tf.bool)
|
||||
message1 = tf.constant("TensorFlow Frontend", dtype=tf.string)
|
||||
message2 = tf.constant("TensorFlow Frontend, ONNX Frontend", dtype=tf.string)
|
||||
message3 = tf.constant("TensorFlow Frontend, ONNX Frontend, PDPD Frontend", dtype=tf.string)
|
||||
select = tf.where(cond, x, y)
|
||||
assert_op = tf.Assert(cond, [message1, message2, message3])
|
||||
|
||||
tf.global_variables_initializer()
|
||||
tf.io.write_graph(sess.graph, '.', 'string_tensors_model.pbtxt', as_text=True)
|
Loading…
Reference in New Issue
Block a user