[TF FE] Report a reason of no conversion of internal operations (#17534)

* [TF FE] Report a reason of no conversion of internal operations

Some operations during translations can be temporarily converted to InternalOperation
such as Const operation of string type for which we need to define more elaborated reason
why it is represented as InternalOperation.
Also, restrict instantiation of InternalOperation because instead user should use FrameworkNode.
InternalOperation is a base class for internal operation types of TF FE that have
extended API compare to FrameWorkNode.
For all internal operation we defined a reason why it is not converted to OpenVINO opset
that will be reported in TF FE if they are not gone finally.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update src/frontends/tensorflow/tests/convert_unsupported.cpp

* Correct a script for generation of the test model

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-05-15 15:17:04 +03:00 committed by GitHub
parent aa932d341a
commit 1ded4ede41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 94 additions and 21 deletions

View File

@ -8,6 +8,7 @@
#include "graph_iterator_proto.hpp" #include "graph_iterator_proto.hpp"
#include "graph_iterator_proto_txt.hpp" #include "graph_iterator_proto_txt.hpp"
#include "graph_iterator_saved_model.hpp" #include "graph_iterator_saved_model.hpp"
#include "helper_ops/internal_operation.hpp"
#include "helper_transforms/block_lstm_replacer.hpp" #include "helper_transforms/block_lstm_replacer.hpp"
#include "helper_transforms/const_to_result_remover.hpp" #include "helper_transforms/const_to_result_remover.hpp"
#include "helper_transforms/embedding_segments_feature_fusing.hpp" #include "helper_transforms/embedding_segments_feature_fusing.hpp"
@ -37,7 +38,16 @@ void get_unsupported_operations_and_failures(const std::shared_ptr<Model>& model
std::set<std::string>& unsupported_operations, std::set<std::string>& unsupported_operations,
std::unordered_map<std::string, std::string>& failures) { std::unordered_map<std::string, std::string>& failures) {
for (const auto& node : model->get_ordered_ops()) { for (const auto& node : model->get_ordered_ops()) {
if (const auto& fw_node = ov::as_type_ptr<FrameworkNode>(node)) { if (const auto& internal_op = std::dynamic_pointer_cast<InternalOperation>(node)) {
// handle internal operations separately
// which can have elaborated reason of unconverted operation
// like Const of string type
auto op_type = internal_op->get_no_conversion_reason();
if (unsupported_operations.count(op_type) > 0) {
continue;
}
unsupported_operations.insert(op_type);
} else if (const auto& fw_node = ov::as_type_ptr<FrameworkNode>(node)) {
auto op_type = fw_node->get_decoder()->get_op_type(); auto op_type = fw_node->get_decoder()->get_op_type();
// if this operation is encountered among unsupported operations // if this operation is encountered among unsupported operations
// or conversion failures, skip it // or conversion failures, skip it

View File

@ -225,3 +225,19 @@ TEST(FrontEndConvertModelTest, test_unsupported_resource_gather_translator) {
FAIL() << "Conversion of the model with ResourceGather failed by wrong reason."; FAIL() << "Conversion of the model with ResourceGather failed by wrong reason.";
} }
} }
TEST(FrontEndConvertModelTest, test_unsupported_operation_conversion_with_reason) {
shared_ptr<Model> model = nullptr;
try {
model = convert_model("gather_with_string_table/gather_with_string_table.pb");
FAIL() << "The model with Const of string type must not be converted.";
} catch (const OpConversionFailure& error) {
string error_message = error.what();
string ref_message =
"[TensorFlow Frontend] Internal error, no translator found for operation(s): Const of string type";
ASSERT_TRUE(error_message.find(ref_message) != string::npos);
ASSERT_EQ(model, nullptr);
} catch (...) {
FAIL() << "Conversion of the model with Const of string type failed by wrong reason.";
}
}

View File

@ -0,0 +1,18 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import sys
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
params = tf.constant(["First sentence", "Second sentence sentence", "Third"], dtype=tf.string)
indices = tf.compat.v1.placeholder(tf.int32, [2, 3, 5], name='data')
axes = tf.constant([0], dtype=tf.int32)
gather = tf.raw_ops.GatherV2(params=params, indices=indices, axis=0)
tf.compat.v1.global_variables_initializer()
tf.io.write_graph(sess.graph, os.path.join(sys.argv[1], "gather_with_string_table"),
"gather_with_string_table.pb", False)

View File

@ -30,7 +30,10 @@ public:
float cell_clip, float cell_clip,
bool use_peephole, bool use_peephole,
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, OutputVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b}, 7), : InternalOperation(decoder,
OutputVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b},
7,
"BlockLSTM"),
m_hidden_size(ov::Dimension::dynamic()), m_hidden_size(ov::Dimension::dynamic()),
m_forget_bias(forget_bias), m_forget_bias(forget_bias),
m_cell_clip(cell_clip), m_cell_clip(cell_clip),

View File

@ -23,7 +23,7 @@ public:
const std::string& container, const std::string& container,
const std::string& shared_name, const std::string& shared_name,
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: InternalOperation(decoder, OutputVector{}, 1), : InternalOperation(decoder, OutputVector{}, 1, "FIFOQueue"),
m_component_types(component_types), m_component_types(component_types),
m_shapes(shapes), m_shapes(shapes),
m_container(container), m_container(container),

View File

@ -24,7 +24,7 @@ public:
const Output<Node>& b_ru, const Output<Node>& b_ru,
const Output<Node>& b_c, const Output<Node>& b_c,
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4), : InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4, "GRUBlockCell"),
m_hidden_size(ov::Dimension::dynamic()) { m_hidden_size(ov::Dimension::dynamic()) {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -14,7 +14,8 @@ class HashTable : public InternalOperation {
public: public:
OPENVINO_OP("HashTable", "ov::frontend::tensorflow", InternalOperation); OPENVINO_OP("HashTable", "ov::frontend::tensorflow", InternalOperation);
HashTable(const std::shared_ptr<DecoderBase>& decoder = nullptr) : InternalOperation(decoder, OutputVector{}, 1) { HashTable(const std::shared_ptr<DecoderBase>& decoder = nullptr)
: InternalOperation(decoder, OutputVector{}, 1, "HashTable") {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -51,11 +51,26 @@ private:
}; };
class InternalOperation : public ov::frontend::tensorflow::FrameworkNode { class InternalOperation : public ov::frontend::tensorflow::FrameworkNode {
public: protected:
InternalOperation(const std::shared_ptr<DecoderBase>& decoder, const OutputVector& inputs, size_t num_outputs) InternalOperation(const std::shared_ptr<DecoderBase>& decoder,
const OutputVector& inputs,
size_t num_outputs,
const std::string& no_conversion_reason)
: ov::frontend::tensorflow::FrameworkNode(decoder != nullptr ? decoder : std::make_shared<DecoderFake>(), : ov::frontend::tensorflow::FrameworkNode(decoder != nullptr ? decoder : std::make_shared<DecoderFake>(),
inputs, inputs,
num_outputs) {} num_outputs),
m_no_conversion_reason(no_conversion_reason) {}
public:
// get a reason why some operation is unable to convert to OpenVINO opset
// we store this information for InternalOperation to elaborate the reason
// for cases such as Constant node of string type
std::string get_no_conversion_reason() const {
return m_no_conversion_reason;
}
private:
std::string m_no_conversion_reason;
}; };
} // namespace tensorflow } // namespace tensorflow
} // namespace frontend } // namespace frontend

View File

@ -22,7 +22,7 @@ public:
const std::vector<ov::element::Type>& output_types, const std::vector<ov::element::Type>& output_types,
const std::vector<ov::PartialShape>& output_shapes, const std::vector<ov::PartialShape>& output_shapes,
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: InternalOperation(decoder, OutputVector{}, 1), : InternalOperation(decoder, OutputVector{}, 1, "Iterator"),
m_shared_name(shared_name), m_shared_name(shared_name),
m_container(container), m_container(container),
m_output_types(output_types), m_output_types(output_types),

View File

@ -25,7 +25,8 @@ public:
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: ov::frontend::tensorflow::InternalOperation(decoder, : ov::frontend::tensorflow::InternalOperation(decoder,
OutputVector{indices, values, dense_shape, default_value}, OutputVector{indices, values, dense_shape, default_value},
4) { 4,
"SparseFillEmptyRows") {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -22,7 +22,10 @@ public:
const Output<Node>& indices, const Output<Node>& indices,
const Output<Node>& segment_ids, const Output<Node>& segment_ids,
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data, indices, segment_ids}, 1) { : ov::frontend::tensorflow::InternalOperation(decoder,
OutputVector{data, indices, segment_ids},
1,
"SparseSegmentSum") {
validate_and_infer_types(); validate_and_infer_types();
} }
@ -33,7 +36,8 @@ public:
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: ov::frontend::tensorflow::InternalOperation(decoder, : ov::frontend::tensorflow::InternalOperation(decoder,
OutputVector{data, indices, segment_ids, num_segments}, OutputVector{data, indices, segment_ids, num_segments},
1) { 1,
"SparseSegmentSum") {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -19,19 +19,19 @@ public:
OPENVINO_OP("StringConstant", "ov::frontend::tensorflow::util", UnsupportedConstant); OPENVINO_OP("StringConstant", "ov::frontend::tensorflow::util", UnsupportedConstant);
StringConstant(ov::Any data, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) StringConstant(ov::Any data, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: UnsupportedConstant(decoder), : UnsupportedConstant("Const of string type", decoder),
m_data(data) { m_data(data) {
validate_and_infer_types(); validate_and_infer_types();
} }
StringConstant(std::string& str, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) StringConstant(std::string& str, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: UnsupportedConstant(decoder), : UnsupportedConstant("Const of string type", decoder),
m_data({str}) { m_data({str}) {
validate_and_infer_types(); validate_and_infer_types();
} }
StringConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) StringConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: UnsupportedConstant(decoder) { : UnsupportedConstant("Const of string type", decoder) {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -18,7 +18,13 @@ public:
OPENVINO_OP("UnsupportedConstant", "ov::frontend::tensorflow::util", InternalOperation); OPENVINO_OP("UnsupportedConstant", "ov::frontend::tensorflow::util", InternalOperation);
UnsupportedConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>()) UnsupportedConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1) { : InternalOperation(decoder, {}, 1, "Const of unknown type") {
validate_and_infer_types();
}
UnsupportedConstant(const std::string& no_conversion_reason,
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, {}, 1, no_conversion_reason) {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -18,9 +18,8 @@ class ComplexAbs : public ov::frontend::tensorflow::InternalOperation {
public: public:
OPENVINO_OP("ComplexAbs", "ov::frontend::tensorflow_lite::util", ov::frontend::tensorflow::InternalOperation); OPENVINO_OP("ComplexAbs", "ov::frontend::tensorflow_lite::util", ov::frontend::tensorflow::InternalOperation);
ComplexAbs(const Output<Node>& data, ComplexAbs(const Output<Node>& data, const std::shared_ptr<DecoderBase>& decoder = nullptr)
const std::shared_ptr<DecoderBase>& decoder = nullptr) : ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1, "ComplexAbs") {
: ov::frontend::tensorflow::InternalOperation(decoder,OutputVector{data},1) {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -21,7 +21,7 @@ public:
Rfft2d(const Output<Node>& data, Rfft2d(const Output<Node>& data,
const Output<Node>& fft_length, const Output<Node>& fft_length,
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: ov::frontend::tensorflow::InternalOperation(decoder,OutputVector{data, fft_length},1) { : ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data, fft_length}, 1, "Rfft2d") {
validate_and_infer_types(); validate_and_infer_types();
} }

View File

@ -23,7 +23,7 @@ public:
std::shared_ptr<ov::frontend::tensorflow_lite::QuantizationInfo> info, std::shared_ptr<ov::frontend::tensorflow_lite::QuantizationInfo> info,
const element::Type& type, const element::Type& type,
const std::shared_ptr<DecoderBase>& decoder = nullptr) const std::shared_ptr<DecoderBase>& decoder = nullptr)
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1), : ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1, "TFLQuantize"),
m_info(info), m_info(info),
m_type(type), m_type(type),
m_original_type(type) { m_original_type(type) {