[TF FE] Report a reason of no conversion of internal operations (#17534)
* [TF FE] Report a reason of no conversion of internal operations Some operations during translations can be temporarily converted to InternalOperation such as Const operation of string type for which we need to define more elaborated reason why it is represented as InternalOperation. Also, restrict instantiation of InternalOperation because instead user should use FrameworkNode. InternalOperation is a base class for internal operation types of TF FE that have extended API compare to FrameWorkNode. For all internal operation we defined a reason why it is not converted to OpenVINO opset that will be reported in TF FE if they are not gone finally. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update src/frontends/tensorflow/tests/convert_unsupported.cpp * Correct a script for generation of the test model --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
aa932d341a
commit
1ded4ede41
@ -8,6 +8,7 @@
|
||||
#include "graph_iterator_proto.hpp"
|
||||
#include "graph_iterator_proto_txt.hpp"
|
||||
#include "graph_iterator_saved_model.hpp"
|
||||
#include "helper_ops/internal_operation.hpp"
|
||||
#include "helper_transforms/block_lstm_replacer.hpp"
|
||||
#include "helper_transforms/const_to_result_remover.hpp"
|
||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||
@ -37,7 +38,16 @@ void get_unsupported_operations_and_failures(const std::shared_ptr<Model>& model
|
||||
std::set<std::string>& unsupported_operations,
|
||||
std::unordered_map<std::string, std::string>& failures) {
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
if (const auto& fw_node = ov::as_type_ptr<FrameworkNode>(node)) {
|
||||
if (const auto& internal_op = std::dynamic_pointer_cast<InternalOperation>(node)) {
|
||||
// handle internal operations separately
|
||||
// which can have elaborated reason of unconverted operation
|
||||
// like Const of string type
|
||||
auto op_type = internal_op->get_no_conversion_reason();
|
||||
if (unsupported_operations.count(op_type) > 0) {
|
||||
continue;
|
||||
}
|
||||
unsupported_operations.insert(op_type);
|
||||
} else if (const auto& fw_node = ov::as_type_ptr<FrameworkNode>(node)) {
|
||||
auto op_type = fw_node->get_decoder()->get_op_type();
|
||||
// if this operation is encountered among unsupported operations
|
||||
// or conversion failures, skip it
|
||||
|
@ -225,3 +225,19 @@ TEST(FrontEndConvertModelTest, test_unsupported_resource_gather_translator) {
|
||||
FAIL() << "Conversion of the model with ResourceGather failed by wrong reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FrontEndConvertModelTest, test_unsupported_operation_conversion_with_reason) {
|
||||
shared_ptr<Model> model = nullptr;
|
||||
try {
|
||||
model = convert_model("gather_with_string_table/gather_with_string_table.pb");
|
||||
FAIL() << "The model with Const of string type must not be converted.";
|
||||
} catch (const OpConversionFailure& error) {
|
||||
string error_message = error.what();
|
||||
string ref_message =
|
||||
"[TensorFlow Frontend] Internal error, no translator found for operation(s): Const of string type";
|
||||
ASSERT_TRUE(error_message.find(ref_message) != string::npos);
|
||||
ASSERT_EQ(model, nullptr);
|
||||
} catch (...) {
|
||||
FAIL() << "Conversion of the model with Const of string type failed by wrong reason.";
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,18 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
params = tf.constant(["First sentence", "Second sentence sentence", "Third"], dtype=tf.string)
|
||||
indices = tf.compat.v1.placeholder(tf.int32, [2, 3, 5], name='data')
|
||||
axes = tf.constant([0], dtype=tf.int32)
|
||||
gather = tf.raw_ops.GatherV2(params=params, indices=indices, axis=0)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf.io.write_graph(sess.graph, os.path.join(sys.argv[1], "gather_with_string_table"),
|
||||
"gather_with_string_table.pb", False)
|
@ -30,7 +30,10 @@ public:
|
||||
float cell_clip,
|
||||
bool use_peephole,
|
||||
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, OutputVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b}, 7),
|
||||
: InternalOperation(decoder,
|
||||
OutputVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b},
|
||||
7,
|
||||
"BlockLSTM"),
|
||||
m_hidden_size(ov::Dimension::dynamic()),
|
||||
m_forget_bias(forget_bias),
|
||||
m_cell_clip(cell_clip),
|
||||
|
@ -23,7 +23,7 @@ public:
|
||||
const std::string& container,
|
||||
const std::string& shared_name,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: InternalOperation(decoder, OutputVector{}, 1),
|
||||
: InternalOperation(decoder, OutputVector{}, 1, "FIFOQueue"),
|
||||
m_component_types(component_types),
|
||||
m_shapes(shapes),
|
||||
m_container(container),
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
const Output<Node>& b_ru,
|
||||
const Output<Node>& b_c,
|
||||
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4),
|
||||
: InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4, "GRUBlockCell"),
|
||||
m_hidden_size(ov::Dimension::dynamic()) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ class HashTable : public InternalOperation {
|
||||
public:
|
||||
OPENVINO_OP("HashTable", "ov::frontend::tensorflow", InternalOperation);
|
||||
|
||||
HashTable(const std::shared_ptr<DecoderBase>& decoder = nullptr) : InternalOperation(decoder, OutputVector{}, 1) {
|
||||
HashTable(const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: InternalOperation(decoder, OutputVector{}, 1, "HashTable") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -51,11 +51,26 @@ private:
|
||||
};
|
||||
|
||||
class InternalOperation : public ov::frontend::tensorflow::FrameworkNode {
|
||||
public:
|
||||
InternalOperation(const std::shared_ptr<DecoderBase>& decoder, const OutputVector& inputs, size_t num_outputs)
|
||||
protected:
|
||||
InternalOperation(const std::shared_ptr<DecoderBase>& decoder,
|
||||
const OutputVector& inputs,
|
||||
size_t num_outputs,
|
||||
const std::string& no_conversion_reason)
|
||||
: ov::frontend::tensorflow::FrameworkNode(decoder != nullptr ? decoder : std::make_shared<DecoderFake>(),
|
||||
inputs,
|
||||
num_outputs) {}
|
||||
num_outputs),
|
||||
m_no_conversion_reason(no_conversion_reason) {}
|
||||
|
||||
public:
|
||||
// get a reason why some operation is unable to convert to OpenVINO opset
|
||||
// we store this information for InternalOperation to elaborate the reason
|
||||
// for cases such as Constant node of string type
|
||||
std::string get_no_conversion_reason() const {
|
||||
return m_no_conversion_reason;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string m_no_conversion_reason;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
@ -22,7 +22,7 @@ public:
|
||||
const std::vector<ov::element::Type>& output_types,
|
||||
const std::vector<ov::PartialShape>& output_shapes,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: InternalOperation(decoder, OutputVector{}, 1),
|
||||
: InternalOperation(decoder, OutputVector{}, 1, "Iterator"),
|
||||
m_shared_name(shared_name),
|
||||
m_container(container),
|
||||
m_output_types(output_types),
|
||||
|
@ -25,7 +25,8 @@ public:
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder,
|
||||
OutputVector{indices, values, dense_shape, default_value},
|
||||
4) {
|
||||
4,
|
||||
"SparseFillEmptyRows") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,10 @@ public:
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& segment_ids,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data, indices, segment_ids}, 1) {
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder,
|
||||
OutputVector{data, indices, segment_ids},
|
||||
1,
|
||||
"SparseSegmentSum") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -33,7 +36,8 @@ public:
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder,
|
||||
OutputVector{data, indices, segment_ids, num_segments},
|
||||
1) {
|
||||
1,
|
||||
"SparseSegmentSum") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -19,19 +19,19 @@ public:
|
||||
OPENVINO_OP("StringConstant", "ov::frontend::tensorflow::util", UnsupportedConstant);
|
||||
|
||||
StringConstant(ov::Any data, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: UnsupportedConstant(decoder),
|
||||
: UnsupportedConstant("Const of string type", decoder),
|
||||
m_data(data) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
StringConstant(std::string& str, const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: UnsupportedConstant(decoder),
|
||||
: UnsupportedConstant("Const of string type", decoder),
|
||||
m_data({str}) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
StringConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: UnsupportedConstant(decoder) {
|
||||
: UnsupportedConstant("Const of string type", decoder) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,13 @@ public:
|
||||
OPENVINO_OP("UnsupportedConstant", "ov::frontend::tensorflow::util", InternalOperation);
|
||||
|
||||
UnsupportedConstant(const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1) {
|
||||
: InternalOperation(decoder, {}, 1, "Const of unknown type") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
UnsupportedConstant(const std::string& no_conversion_reason,
|
||||
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||
: InternalOperation(decoder, {}, 1, no_conversion_reason) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -18,9 +18,8 @@ class ComplexAbs : public ov::frontend::tensorflow::InternalOperation {
|
||||
public:
|
||||
OPENVINO_OP("ComplexAbs", "ov::frontend::tensorflow_lite::util", ov::frontend::tensorflow::InternalOperation);
|
||||
|
||||
ComplexAbs(const Output<Node>& data,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder,OutputVector{data},1) {
|
||||
ComplexAbs(const Output<Node>& data, const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1, "ComplexAbs") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ public:
|
||||
Rfft2d(const Output<Node>& data,
|
||||
const Output<Node>& fft_length,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder,OutputVector{data, fft_length},1) {
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data, fft_length}, 1, "Rfft2d") {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ public:
|
||||
std::shared_ptr<ov::frontend::tensorflow_lite::QuantizationInfo> info,
|
||||
const element::Type& type,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1),
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{data}, 1, "TFLQuantize"),
|
||||
m_info(info),
|
||||
m_type(type),
|
||||
m_original_type(type) {
|
||||
|
Loading…
Reference in New Issue
Block a user