[TF FE] Support multioutput body graph nodes (#16142)

This is a corner case because body graph nodes have named output ports.
This allows to support custom RetinaNet model.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2023-03-08 17:29:42 +04:00
committed by GitHub
parent 3dbea43ef1
commit f3e7e55968
13 changed files with 456 additions and 23 deletions

View File

@@ -50,6 +50,7 @@ public:
class FRONTEND_API DecoderBase {
public:
using OpTypeByName = std::unordered_map<std::string, std::string>;
/// \brief Get attribute value by name
///
/// \param name Attribute name
@@ -68,6 +69,17 @@ public:
std::string& producer_name,
size_t& producer_output_port_index) const = 0;
/// \brief Get a producer name and its output port index
///
/// \param input_port_idx Input port index by which data is consumed
/// \param producer_name A producer name
/// \param producer_output_port_index Output port index from which data is generated
/// \param op_type_by_name Map of operation name to their types
virtual void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const = 0;
/// \brief Get operation type
virtual const std::string& get_op_type() const = 0;

View File

@@ -4,6 +4,7 @@
#include "decoder_argdef.hpp"
#include "decoder_proto.hpp"
#include "op_def.pb.h"
#include "openvino/frontend/tensorflow/node_context.hpp"
#include "openvino/frontend/tensorflow/special_types.hpp"
@@ -58,19 +59,16 @@ void DecoderArgDef::get_input_node(size_t input_port_idx,
// and output port is 2
FRONT_END_GENERAL_CHECK(m_op_type == "output_arg",
"[TensorFlow Frontend] Internal error: get_input_node is supported only for output_arg.");
auto first_colon = m_producer_name.find_first_of(":");
auto last_colon = m_producer_name.find_last_of(":");
if (first_colon != std::string::npos && last_colon != std::string::npos) {
producer_name = m_producer_name.substr(0, first_colon);
auto port_id = m_producer_name.substr(last_colon + 1);
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
"Port id is not specified or not a number. Value: ",
port_id);
producer_output_port_index = std::stoi(port_id);
return;
}
producer_name = m_producer_name;
producer_output_port_index = 0;
parse_producer_name(m_producer_name, producer_name, producer_output_port_index, {});
}
void DecoderArgDef::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const {
FRONT_END_GENERAL_CHECK(m_op_type == "output_arg",
"[TensorFlow Frontend] Internal error: get_input_node is supported only for output_arg.");
parse_producer_name(m_producer_name, producer_name, producer_output_port_index, op_type_by_name);
}
ov::Any DecoderArgDef::get_attribute(const std::string& name) const {

View File

@@ -38,6 +38,11 @@ public:
std::string& producer_name,
size_t& producer_output_port_index) const override;
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const override;
const std::string& get_op_type() const override;
const std::string& get_op_name() const override;

View File

@@ -285,16 +285,53 @@ size_t DecoderProto::get_input_size() const {
return m_node_def->input_size();
}
void DecoderProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
// Body graph nodes may have two colons `:`, for example,
// producer_name:z:2 means that producer operation name is `producer_name`
// and output port is 2
std::string producer_port_name = m_node_def->input(static_cast<int>(input_port_idx));
void parse_producer_name(const std::string& producer_port_name,
std::string& producer_name,
size_t& producer_output_port_index,
const DecoderBase::OpTypeByName& op_type_by_name) {
using OutputPortIdxMax = std::unordered_map<std::string, int>;
// create a table of operation type and its output ports
// for which we specify output port indices manually
// it is mainly affects multiple output operations
// extract this information from tensorflow/core/ops/*.cc files
const OutputPortIdxMax output_port_idx_map = {
{"TopK:indices", 1},
{"TopKV2:indices", 1},
{"CTCGreedyDecoder:decoded_values", 1},
{"CTCGreedyDecoder:decoded_shape", 2},
{"CTCGreedyDecoder:log_probability", 3},
{"CTCGreedyDecoder:log_probability", 3},
{"FusedBatchNorm:batch_mean", 1},
{"FusedBatchNorm:batch_variance", 2},
{"FusedBatchNormV2:batch_mean", 1},
{"FusedBatchNormV2:batch_variance", 2},
{"FusedBatchNormV3:batch_mean", 1},
{"FusedBatchNormV3:batch_variance", 2},
};
// Body graph nodes may have two colons `:` input names, for example,
// `TopKV2Name:indices:0` means that producer operation name is `TopKV2Name`
// the middle name is output port name of the producer `indices` that means
// the second output port of TopKV2 is used.
// The first output port of TopKV2 is described as `TopKV2Name:values:0`
auto first_colon = producer_port_name.find_first_of(":");
auto last_colon = producer_port_name.find_last_of(":");
if (first_colon != std::string::npos && last_colon != std::string::npos) {
if (first_colon != std::string::npos && first_colon < last_colon) {
// we have at least two colons producer_name:output_port_name:port_idx
producer_name = producer_port_name.substr(0, first_colon);
auto port_id = producer_port_name.substr(last_colon + 1);
auto port_name = producer_port_name.substr(first_colon + 1, last_colon - first_colon - 1);
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
"Port id is not specified or not a number. Value: ",
port_id);
producer_output_port_index = std::stoi(port_id);
auto producer_op_type =
(op_type_by_name.count(producer_name) > 0) ? op_type_by_name.at(producer_name) : "Unknown";
auto producer_key = producer_op_type + ":" + port_name;
producer_output_port_index = output_port_idx_map.count(producer_key) > 0 ? output_port_idx_map.at(producer_key)
: producer_output_port_index;
return;
} else if (first_colon != std::string::npos) {
// just one colon case
producer_name = producer_port_name.substr(0, first_colon);
auto port_id = producer_port_name.substr(last_colon + 1);
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
@@ -307,6 +344,21 @@ void DecoderProto::get_input_node(size_t input_port_idx,
producer_output_port_index = 0;
}
void DecoderProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
const std::string producer_port_name = m_node_def->input(static_cast<int>(input_port_idx));
parse_producer_name(producer_port_name, producer_name, producer_output_port_index, {});
}
void DecoderProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const {
const std::string producer_port_name = m_node_def->input(static_cast<int>(input_port_idx));
parse_producer_name(producer_port_name, producer_name, producer_output_port_index, op_type_by_name);
}
const std::string& DecoderProto::get_op_type() const {
return m_node_def->op();
}

View File

@@ -18,6 +18,11 @@ namespace ov {
namespace frontend {
namespace tensorflow {
void parse_producer_name(const std::string& producer_port_name,
std::string& producer_name,
size_t& producer_output_port_index,
const DecoderBase::OpTypeByName& op_type_by_name);
class DecoderProto : public ov::frontend::tensorflow::DecoderBase {
public:
explicit DecoderProto(const ::tensorflow::NodeDef* node_def) : m_node_def(node_def) {}
@@ -30,6 +35,11 @@ public:
std::string& producer_name,
size_t& producer_output_port_index) const override;
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const override;
const std::string& get_op_type() const override;
const std::string& get_op_name() const override;

View File

@@ -84,6 +84,7 @@ void TranslateSession::inject_body_model(std::shared_ptr<ov::Model> body_model,
void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model,
std::shared_ptr<ov::Model>& ov_model) {
DecoderBase::OpTypeByName op_type_by_name;
OpMap ng_op_map;
ov::ParameterVector params;
ov::ResultVector results;
@@ -130,6 +131,7 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
for (const auto& operation_place : operation_places) {
auto operation_decoder = operation_place->get_decoder();
auto operation_name = operation_place->get_names()[0];
op_type_by_name[operation_name] = operation_decoder->get_op_type();
// output for parameter nodes has been already generated
if (ng_op_map.count(operation_name)) {
continue;
@@ -151,7 +153,7 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx);
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx, op_type_by_name);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(input_port_idx) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
@@ -297,7 +299,7 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx);
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx, op_type_by_name);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(port_index) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +

View File

@@ -335,3 +335,18 @@ TEST_F(TransformationTestsF, ModelWithLookupTableOperations) {
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x});
}
}
TEST_F(TransformationTestsF, ModelWithMultioutputBodyGraphNode) {
{ model = convert_model("partitioned_call2/partitioned_call2.pb"); }
{
auto x = make_shared<Parameter>(i32, Shape{5});
auto y = make_shared<Parameter>(i32, Shape{5});
auto sub = make_shared<Subtract>(x, y);
auto const_three = make_shared<Constant>(i32, Shape{}, 3);
auto const_ten = make_shared<Constant>(i32, Shape{}, 10);
auto topk =
make_shared<TopK>(sub, const_three, -1, op::v1::TopK::Mode::MAX, op::v1::TopK::SortType::SORT_VALUES, i32);
auto add = make_shared<Add>(topk->output(1), const_ten);
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x, y});
}
}

View File

@@ -0,0 +1,285 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 5
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "y"
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 5
}
}
}
}
}
node {
name: "sub"
op: "Sub"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "PartitionedCall"
op: "PartitionedCall"
input: "sub"
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "_collective_manager_ids"
value {
list {
}
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "config"
value {
s: ""
}
}
attr {
key: "config_proto"
value {
s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000"
}
}
attr {
key: "executor_type"
value {
s: ""
}
}
attr {
key: "f"
value {
func {
name: "__inference_second_func_17"
}
}
}
}
node {
name: "Identity"
op: "Identity"
input: "PartitionedCall"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
library {
function {
signature {
name: "__inference_second_func_17"
input_arg {
name: "x"
type: DT_INT32
}
output_arg {
name: "identity"
type: DT_INT32
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 3
}
}
}
experimental_debug_info {
original_node_names: "Const"
}
}
node_def {
name: "TopKV2"
op: "TopKV2"
input: "x"
input: "Const:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "sorted"
value {
b: true
}
}
experimental_debug_info {
original_node_names: "TopKV2"
}
}
node_def {
name: "Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
experimental_debug_info {
original_node_names: "Const_1"
}
}
node_def {
name: "Add"
op: "AddV2"
input: "TopKV2:indices:0"
input: "Const_1:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "Add"
}
}
node_def {
name: "Identity"
op: "Identity"
input: "Add:z:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "Identity"
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
attr {
key: "_construction_context"
value {
s: "kEagerRuntime"
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
}
}
}
}
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
}
}
}
}
versions {
producer: 808
min_consumer: 12
}

View File

@@ -0,0 +1,21 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
@tf.function
def second_func(x):
x = tf.raw_ops.TopKV2(input=x, k=tf.constant(3, tf.int32))[1]
x = tf.add(x, tf.constant(10, tf.int32))
return x
@tf.function
def first_func(x, y):
return second_func(x - y)
graph_def = first_func.get_concrete_function(tf.constant([1, 2, 3, 4, 5], dtype=tf.int32),
tf.constant([0, 1, 1, 1, 1], dtype=tf.int32)).graph.as_graph_def()
tf.io.write_graph(graph_def, '.', 'partitioned_call2.pbtxt', as_text=True)

View File

@@ -35,6 +35,14 @@ public:
"Internal error: the get_input_node method of the fake node decoder is invoked.");
}
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const override {
FRONT_END_OP_CONVERSION_CHECK(false,
"Internal error: the get_input_node method of the fake node decoder is invoked.");
}
const std::string& get_op_type() const override {
// this method must not throw an exception since it is used by TF FE FrameworkNode constructor
return op_type;

View File

@@ -32,6 +32,13 @@ void DecoderFlatBuffer::get_input_node(size_t input_port_idx,
producer_output_port_index = input_tensor_idx;
}
void DecoderFlatBuffer::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const {
FRONT_END_NOT_IMPLEMENTED("get_input_node method with op_type_by_name map is not implemented for TFL FE.");
}
const std::string& DecoderFlatBuffer::get_op_type() const {
return m_type;
}

View File

@@ -55,6 +55,11 @@ public:
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const override;
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const override;
std::string get_output_tensor_name(size_t idx) const;
element::Type get_output_tensor_type(size_t idx) const;
std::string get_input_tensor_name(size_t idx) const;

View File

@@ -61,6 +61,19 @@ public:
m_decoder->get_input_node(input_port_idx, producer_name, producer_output_port_index);
}
/// \brief Get a producer name and its output port index
///
/// \param input_port_idx Input port index by which data is consumed
/// \param producer_name A producer name
/// \param producer_output_port_index Output port index from which data is generated
/// \param op_type_by_name Map of operation name to their types
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index,
const OpTypeByName& op_type_by_name) const override {
FRONT_END_NOT_IMPLEMENTED("get_input_node method with op_type_by_name map is not implemented for TFL FE.");
}
/// \brief Get operation type
const std::string& get_op_type() const override {
if (m_type.empty())