[TF FE] Support body graph conversion and its injection (#14841)
* [TF FE] Support body graph conversion and injection Now this is a base for further implementation support for StatefulPartitionedOp, While, and If operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Remove unused variable * Remove artifacts of serialization experiment * Apply code-review feedback: comments for decoder_argdef * Create a mat to cache function indices by names Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
0be53a66d2
commit
238c3234a9
@ -21,6 +21,7 @@
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
|
||||
|
||||
class TENSORFLOW_API FrontEnd : public ov::frontend::FrontEnd {
|
||||
public:
|
||||
@ -71,7 +72,8 @@ protected:
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool no_conversion,
|
||||
std::shared_ptr<ov::Model>& ng_function) const;
|
||||
std::shared_ptr<ov::Model>& ov_model,
|
||||
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const;
|
||||
|
||||
TelemetryExtension::Ptr m_telemetry;
|
||||
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;
|
||||
|
@ -36,6 +36,10 @@ public:
|
||||
|
||||
/// \brief Destructor
|
||||
virtual ~GraphIterator() = default;
|
||||
|
||||
/// \brief Checks if the main model graph contains a function of the requested name in the library
|
||||
/// Returns GraphIterator to this function and nullptr, if it does not exist
|
||||
virtual std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
89
src/frontends/tensorflow/src/decoder_argdef.cpp
Normal file
89
src/frontends/tensorflow/src/decoder_argdef.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "decoder_argdef.hpp"
|
||||
|
||||
#include "op_def.pb.h"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "openvino/frontend/tensorflow/special_types.hpp"
|
||||
#include "types.pb.h"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
const std::map<::tensorflow::DataType, ov::element::Type>& TYPE_MAP() {
|
||||
static const std::map<::tensorflow::DataType, ov::element::Type> type_map{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}};
|
||||
return type_map;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
size_t DecoderArgDef::get_input_size() const {
|
||||
FRONT_END_GENERAL_CHECK(m_op_type == "input_arg" || m_op_type == "output_arg",
|
||||
"[TensorFlow Frontend] Internal error: Incorrect use of DecoderArgDef class.");
|
||||
if (m_op_type == "input_arg") {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& DecoderArgDef::get_op_type() const {
|
||||
FRONT_END_GENERAL_CHECK(m_op_type == "input_arg" || m_op_type == "output_arg",
|
||||
"[TensorFlow Frontend] Internal error: Incorrect use of DecoderArgDef class.");
|
||||
return m_op_type;
|
||||
}
|
||||
|
||||
const std::string& DecoderArgDef::get_op_name() const {
|
||||
return m_arg_def->name();
|
||||
}
|
||||
|
||||
void DecoderArgDef::get_input_node(size_t input_port_idx,
|
||||
std::string& producer_name,
|
||||
size_t& producer_output_port_index) const {
|
||||
// Body graph nodes may have two colons `:`, for example,
|
||||
// producer_name:z:2 means that producer operation name is `producer_name`
|
||||
// and output port is 2
|
||||
FRONT_END_GENERAL_CHECK(m_op_type == "output_arg",
|
||||
"[TensorFlow Frontend] Internal error: get_input_node is supported only for output_arg.");
|
||||
auto first_colon = m_producer_name.find_first_of(":");
|
||||
auto last_colon = m_producer_name.find_last_of(":");
|
||||
if (first_colon != std::string::npos && last_colon != std::string::npos) {
|
||||
producer_name = m_producer_name.substr(0, first_colon);
|
||||
auto port_id = m_producer_name.substr(last_colon + 1);
|
||||
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
|
||||
"Port id is not specified or not a number. Value: ",
|
||||
port_id);
|
||||
producer_output_port_index = std::stoi(port_id);
|
||||
return;
|
||||
}
|
||||
producer_name = m_producer_name;
|
||||
producer_output_port_index = 0;
|
||||
}
|
||||
|
||||
ov::Any DecoderArgDef::get_attribute(const std::string& name) const {
|
||||
FRONT_END_GENERAL_CHECK(name == "type",
|
||||
"[TensorFlow Frontend] Internal error: DecoderArgDef supports only `type` attribute.");
|
||||
if (TYPE_MAP().count(m_arg_def->type())) {
|
||||
return TYPE_MAP().at(m_arg_def->type());
|
||||
} else {
|
||||
// for all unsupported types return undefined type
|
||||
return ov::element::undefined;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
53
src/frontends/tensorflow/src/decoder_argdef.hpp
Normal file
53
src/frontends/tensorflow/src/decoder_argdef.hpp
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/frontend/tensorflow/decoder.hpp"
|
||||
|
||||
namespace tensorflow {
|
||||
class OpDef_ArgDef;
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
class DecoderArgDef : public ov::frontend::tensorflow::DecoderBase {
|
||||
public:
|
||||
explicit DecoderArgDef(const ::tensorflow::OpDef_ArgDef* arg_def, const std::string& op_type)
|
||||
: m_arg_def(arg_def),
|
||||
m_op_type(op_type) {}
|
||||
|
||||
explicit DecoderArgDef(const ::tensorflow::OpDef_ArgDef* arg_def,
|
||||
const std::string& op_type,
|
||||
const std::string& producer_name)
|
||||
: m_arg_def(arg_def),
|
||||
m_op_type(op_type),
|
||||
m_producer_name(producer_name) {}
|
||||
|
||||
ov::Any get_attribute(const std::string& name) const override;
|
||||
|
||||
size_t get_input_size() const override;
|
||||
|
||||
void get_input_node(size_t input_port_idx,
|
||||
std::string& producer_name,
|
||||
size_t& producer_output_port_index) const override;
|
||||
|
||||
const std::string& get_op_type() const override;
|
||||
|
||||
const std::string& get_op_name() const override;
|
||||
|
||||
private:
|
||||
const ::tensorflow::OpDef_ArgDef* m_arg_def;
|
||||
const std::string m_op_type;
|
||||
const std::string m_producer_name;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -279,12 +279,15 @@ size_t DecoderProto::get_input_size() const {
|
||||
void DecoderProto::get_input_node(size_t input_port_idx,
|
||||
std::string& producer_name,
|
||||
size_t& producer_output_port_index) const {
|
||||
// TODO: handle body graph nodes with a couple of columns
|
||||
// Body graph nodes may have two colons `:`, for example,
|
||||
// producer_name:z:2 means that producer operation name is `producer_name`
|
||||
// and output port is 2
|
||||
std::string producer_port_name = m_node_def->input(static_cast<int>(input_port_idx));
|
||||
auto delim_pos = producer_port_name.find(':');
|
||||
if (delim_pos != std::string::npos) {
|
||||
producer_name = producer_port_name.substr(0, delim_pos);
|
||||
auto port_id = producer_port_name.substr(delim_pos + 1);
|
||||
auto first_colon = producer_port_name.find_first_of(":");
|
||||
auto last_colon = producer_port_name.find_last_of(":");
|
||||
if (first_colon != std::string::npos && last_colon != std::string::npos) {
|
||||
producer_name = producer_port_name.substr(0, first_colon);
|
||||
auto port_id = producer_port_name.substr(last_colon + 1);
|
||||
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
|
||||
"Port id is not specified or not a number. Value: ",
|
||||
port_id);
|
||||
|
@ -33,8 +33,8 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
|
||||
auto translator_it = TRANSLATE_OP_MAP.find(type);
|
||||
FRONT_END_OP_CONVERSION_CHECK(translator_it != TRANSLATE_OP_MAP.end(), "No translator found for ", type, " node.");
|
||||
|
||||
ov::OutputVector ng_inputs = node->input_values();
|
||||
NodeContext node_ctx(node->get_decoder(), ng_inputs);
|
||||
ov::OutputVector ov_inputs = node->input_values();
|
||||
NodeContext node_ctx(node->get_decoder(), ov_inputs);
|
||||
auto new_node_outputs = translator_it->second(node_ctx);
|
||||
|
||||
auto new_output = new_node_outputs.begin();
|
||||
@ -45,6 +45,24 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
|
||||
old_output->replace(*new_output);
|
||||
}
|
||||
}
|
||||
|
||||
void inject_body_model(std::shared_ptr<ov::Model> body_model,
|
||||
const std::string& operation_type,
|
||||
const ov::OutputVector& ov_inputs,
|
||||
ov::OutputVector& ov_outputs) {
|
||||
ov_outputs.clear();
|
||||
auto body_parameters = body_model->get_parameters();
|
||||
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
|
||||
"[TensorFlow Error] Internal error or incorrect input models: number of "
|
||||
"inputs and arguments to the function " +
|
||||
operation_type + " do not match.");
|
||||
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
|
||||
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
|
||||
}
|
||||
for (const auto& result_node : body_model->get_results()) {
|
||||
ov_outputs.push_back(result_node->input_value(0));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FrontEnd::FrontEnd() : m_op_translators(tensorflow::op::get_supported_ops()) {}
|
||||
@ -53,7 +71,8 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool no_conversion,
|
||||
std::shared_ptr<ov::Model>& ng_function) const {
|
||||
std::shared_ptr<ov::Model>& ov_model,
|
||||
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const {
|
||||
// a map from operation names to generated OV Output<TFNodeDecoder>
|
||||
tensorflow::OpMap ng_op_map;
|
||||
|
||||
@ -65,7 +84,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
const auto& model_inputs = model_tf->get_inputs();
|
||||
const auto& model_outputs = model_tf->get_outputs();
|
||||
const auto& model_frozen_inputs = model_tf->get_tensor_values();
|
||||
std::map<const std::string, const std::function<ov::OutputVector(const NodeContext&)>> translate_map;
|
||||
TranslatorDictionaryType translate_map;
|
||||
|
||||
const auto& TRANSLATE_OP_MAP = m_op_translators;
|
||||
if (no_conversion) {
|
||||
@ -119,7 +138,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
}
|
||||
|
||||
// prepare a list of OV node inputs for each node
|
||||
ov::OutputVector ng_inputs;
|
||||
ov::OutputVector ov_inputs;
|
||||
size_t operation_input_size = operation_decoder->get_input_size();
|
||||
|
||||
if (operation_decoder->get_op_type() == "NextIteration") {
|
||||
@ -159,18 +178,18 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ng_inputs.push_back(input_outputs_vector.at(0));
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
|
||||
const auto& input_outputs_vector =
|
||||
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ng_inputs.push_back(input_outputs_vector.at(0));
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name)) {
|
||||
const auto& input_outputs_vector = ng_op_map.at(producer_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
|
||||
"Input created with pruning must have one output");
|
||||
ng_inputs.push_back(input_outputs_vector.at(producer_port_idx));
|
||||
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"No input is found for node \"" + operation_name + "\" by port " +
|
||||
@ -179,14 +198,43 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
}
|
||||
|
||||
// generate OV node output vector for the current operation node
|
||||
ov::OutputVector ng_outputs;
|
||||
ov::OutputVector ov_outputs;
|
||||
bool is_converted = false;
|
||||
auto operation_type = operation_decoder->get_op_type();
|
||||
try {
|
||||
FRONT_END_OP_CONVERSION_CHECK(translate_map.count(operation_decoder->get_op_type()),
|
||||
"No translator found for " + operation_decoder->get_op_type() + " node.");
|
||||
auto op_fun = &(translate_map[operation_decoder->get_op_type()]);
|
||||
NodeContext node_context(operation_decoder, ng_inputs);
|
||||
// generate OV node output vector using translator for given operation type
|
||||
ng_outputs = (*op_fun)(node_context);
|
||||
if (translate_map.count(operation_type)) {
|
||||
auto translator = translate_map[operation_decoder->get_op_type()];
|
||||
NodeContext node_context(operation_decoder, ov_inputs);
|
||||
ov_outputs = translator(node_context);
|
||||
is_converted = true;
|
||||
} else if (cached_body_models->count(operation_type)) {
|
||||
// check if such body graph has been converted before
|
||||
// re-use it from the cache for further injection
|
||||
|
||||
// create new instance of the required body model
|
||||
// since it will be modified by injection
|
||||
auto cached_body_model = cached_body_models->at(operation_type);
|
||||
auto body_model = cached_body_model->clone();
|
||||
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
|
||||
is_converted = true;
|
||||
} else if (auto body_input_model = model_tf->get_body_input_model(operation_type)) {
|
||||
// try to find a function by name in the model library
|
||||
std::shared_ptr<ov::Model> body_model;
|
||||
translate_graph(body_input_model,
|
||||
operation_decoder->get_op_type(),
|
||||
fail_fast,
|
||||
no_conversion,
|
||||
body_model,
|
||||
cached_body_models);
|
||||
// save new instance of body_model in the cache of body models
|
||||
// before its injection into the parent graph
|
||||
auto cached_body_model = body_model->clone();
|
||||
cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
|
||||
|
||||
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
|
||||
is_converted = true;
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
|
||||
} catch (...) {
|
||||
if (fail_fast) {
|
||||
// in case of decode, unsupported operation will be converted to FrameworkNode
|
||||
@ -198,15 +246,15 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
throw;
|
||||
} else {
|
||||
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
|
||||
ng_inputs,
|
||||
ov_inputs,
|
||||
operation_place->get_output_ports().size());
|
||||
set_node_name(operation_name, ng_node);
|
||||
ng_outputs = ng_node->outputs();
|
||||
ov_outputs = ng_node->outputs();
|
||||
}
|
||||
}
|
||||
|
||||
// register OV node outputs in the map for new operation node
|
||||
for (const auto& output : ng_outputs) {
|
||||
for (const auto& output : ov_outputs) {
|
||||
if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
|
||||
// do not add RetVal type operation to ng_op_map
|
||||
results.push_back(result);
|
||||
@ -302,7 +350,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
// TODO: reorder results and params according to indices given in RT info (if any)
|
||||
|
||||
// create the OV Model
|
||||
ng_function = std::make_shared<ov::Model>(results, params, model_name);
|
||||
ov_model = std::make_shared<ov::Model>(results, params, model_name);
|
||||
}
|
||||
|
||||
/// \brief Check if FrontEndTensorflow can recognize model from given parts
|
||||
@ -383,7 +431,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f);
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f, cached_body_models);
|
||||
normalize(f);
|
||||
|
||||
for (const auto& node : f->get_ordered_ops()) {
|
||||
@ -416,7 +465,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f);
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f, cached_body_models);
|
||||
normalize(f);
|
||||
return f;
|
||||
}
|
||||
@ -424,7 +474,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
|
||||
std::shared_ptr<ov::Model> FrontEnd::decode(const ov::frontend::InputModel::Ptr& model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModel>(model);
|
||||
std::shared_ptr<ov::Model> f;
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f);
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f, cached_body_models);
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -6,9 +6,9 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "decoder_argdef.hpp"
|
||||
#include "decoder_proto.hpp"
|
||||
#include "graph.pb.h"
|
||||
#include "node_def.pb.h"
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "openvino/frontend/tensorflow/decoder.hpp"
|
||||
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
|
||||
@ -18,21 +18,70 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
class GraphIteratorProto : public GraphIterator {
|
||||
std::vector<const ::tensorflow::NodeDef*> m_nodes;
|
||||
size_t node_index = 0;
|
||||
std::shared_ptr<::tensorflow::GraphDef> m_graph_def;
|
||||
std::shared_ptr<::tensorflow::FunctionDef> m_func_def;
|
||||
|
||||
size_t node_index = 0;
|
||||
std::vector<std::shared_ptr<DecoderBase>> m_decoders;
|
||||
std::unordered_map<std::string, int> m_library_map;
|
||||
|
||||
public:
|
||||
GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def,
|
||||
const std::shared_ptr<::tensorflow::FunctionDef>& func_def,
|
||||
const std::unordered_map<std::string, int>& library_map)
|
||||
: m_graph_def(graph_def),
|
||||
m_func_def(func_def),
|
||||
m_library_map(library_map) {
|
||||
auto nodes_size = m_func_def->node_def_size();
|
||||
auto input_size = m_func_def->signature().input_arg_size();
|
||||
auto output_size = m_func_def->signature().output_arg_size();
|
||||
auto ret_map = m_func_def->ret();
|
||||
|
||||
// fill all inputs from library functions
|
||||
// these input_arg objects are of different type OpDef_ArgDef
|
||||
// they are not NodeDef so we use separate Decoder class
|
||||
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
|
||||
auto input_arg = &m_func_def->signature().input_arg(input_ind);
|
||||
m_decoders.push_back(std::make_shared<DecoderArgDef>(input_arg, "input_arg"));
|
||||
}
|
||||
|
||||
// fill all node defs from library functions
|
||||
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
|
||||
m_decoders.push_back(std::make_shared<DecoderProto>(&(m_func_def->node_def(node_ind))));
|
||||
}
|
||||
|
||||
// fill all outputs from library functions
|
||||
// these output_arg objects are of different type OpDef_ArgDef
|
||||
// they are not NodeDef so we use separate Decoder class
|
||||
for (int output_ind = 0; output_ind < output_size; ++output_ind) {
|
||||
auto output_arg = &m_func_def->signature().output_arg(output_ind);
|
||||
auto producer_name = ret_map.at(output_arg->name());
|
||||
m_decoders.push_back(std::make_shared<DecoderArgDef>(output_arg, "output_arg", producer_name));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
GraphIteratorProto(const std::basic_string<T>& path) : m_graph_def(std::make_shared<::tensorflow::GraphDef>()) {
|
||||
GraphIteratorProto(const std::basic_string<T>& path)
|
||||
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
|
||||
m_func_def(nullptr) {
|
||||
std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary);
|
||||
|
||||
FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist");
|
||||
FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed");
|
||||
|
||||
m_nodes.resize(m_graph_def->node_size());
|
||||
for (size_t i = 0; i < m_nodes.size(); ++i)
|
||||
m_nodes[i] = &m_graph_def->node(static_cast<int>(i));
|
||||
auto nodes_size = m_graph_def->node_size();
|
||||
m_decoders.resize(static_cast<size_t>(nodes_size));
|
||||
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
|
||||
m_decoders[node_ind] = std::make_shared<DecoderProto>(&m_graph_def->node(node_ind));
|
||||
}
|
||||
|
||||
// initialize a library map
|
||||
auto num_funcs = m_graph_def->library().function_size();
|
||||
for (int func_ind = 0; func_ind < num_funcs; ++func_ind) {
|
||||
auto func = m_graph_def->library().function(func_ind);
|
||||
auto func_name = func.signature().name();
|
||||
m_library_map.insert(std::pair<std::string, int>(func_name, func_ind));
|
||||
}
|
||||
}
|
||||
|
||||
/// Set iterator to the start position
|
||||
@ -41,7 +90,7 @@ public:
|
||||
}
|
||||
|
||||
size_t size() const override {
|
||||
return m_nodes.size();
|
||||
return m_decoders.size();
|
||||
}
|
||||
|
||||
/// Moves to the next node in the graph
|
||||
@ -50,15 +99,30 @@ public:
|
||||
}
|
||||
|
||||
bool is_end() const override {
|
||||
return node_index >= m_nodes.size();
|
||||
return node_index >= m_decoders.size();
|
||||
}
|
||||
|
||||
/// Return NodeContext for the current node that iterator points to
|
||||
std::shared_ptr<DecoderBase> get_decoder() const override {
|
||||
return std::make_shared<DecoderProto>(m_nodes[node_index]);
|
||||
return m_decoders[node_index];
|
||||
}
|
||||
|
||||
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
|
||||
if (m_library_map.count(func_name)) {
|
||||
auto func_ind = m_library_map.at(func_name);
|
||||
auto func_size = m_graph_def->library().function_size();
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
0 <= func_ind && func_ind < func_size,
|
||||
"[TensorFlow Error] Internal Error: incorrect library map to cache function indices by names.");
|
||||
|
||||
auto func = m_graph_def->library().function(func_ind);
|
||||
auto func_ptr = std::make_shared<::tensorflow::FunctionDef>(func);
|
||||
return std::make_shared<GraphIteratorProto>(m_graph_def, func_ptr, m_library_map);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -76,6 +76,7 @@ public:
|
||||
std::map<std::string, Output<Node>> get_tensor_values() const {
|
||||
return m_tensor_values;
|
||||
};
|
||||
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_model_name) const;
|
||||
|
||||
private:
|
||||
void loadPlaces();
|
||||
@ -314,6 +315,15 @@ InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_i
|
||||
loadPlaces();
|
||||
}
|
||||
|
||||
std::shared_ptr<InputModel> InputModel::InputModelTFImpl::get_body_input_model(
|
||||
const std::string& body_model_name) const {
|
||||
auto body_graph_iterator = m_graph_iterator->get_body_graph_iterator(body_model_name);
|
||||
if (!body_graph_iterator) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<InputModel>(body_graph_iterator, m_telemetry);
|
||||
}
|
||||
|
||||
InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
|
||||
const ov::frontend::InputModel& input_model,
|
||||
const std::shared_ptr<TelemetryExtension>& telemetry)
|
||||
@ -427,6 +437,10 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::get_op_places() const {
|
||||
return _impl->get_op_places();
|
||||
}
|
||||
|
||||
std::shared_ptr<InputModel> InputModel::get_body_input_model(const std::string& body_model_name) const {
|
||||
return _impl->get_body_input_model(body_model_name);
|
||||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<TensorPlace>> InputModel::get_tensor_places() const {
|
||||
return _impl->get_tensor_places();
|
||||
}
|
||||
|
@ -4,8 +4,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/frontend/extension/telemetry.hpp"
|
||||
#include "openvino/frontend/input_model.hpp"
|
||||
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
|
||||
#include "place.hpp"
|
||||
|
||||
@ -24,6 +24,7 @@ class InputModel : public ov::frontend::InputModel {
|
||||
std::vector<std::shared_ptr<OpPlace>> get_op_places() const;
|
||||
std::map<std::string, std::shared_ptr<TensorPlace>> get_tensor_places() const;
|
||||
std::map<std::string, Output<Node>> get_tensor_values() const;
|
||||
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_input_model_name) const;
|
||||
|
||||
public:
|
||||
explicit InputModel(const GraphIterator::Ptr& graph_iterator,
|
||||
|
34
src/frontends/tensorflow/src/op/argdef_ops.cpp
Normal file
34
src/frontends/tensorflow/src/op/argdef_ops.cpp
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_input_arg_op(const NodeContext& node) {
|
||||
default_op_checks(node, 0, {"input_arg"});
|
||||
auto param_type = node.get_attribute<ov::element::Type>("type");
|
||||
|
||||
auto param = std::make_shared<Parameter>(param_type, ov::PartialShape::dynamic());
|
||||
set_node_name(node.get_name(), param);
|
||||
return param->outputs();
|
||||
}
|
||||
|
||||
OutputVector translate_output_arg_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"output_arg"});
|
||||
auto result = std::make_shared<Result>();
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -64,6 +64,8 @@ OP_CONVERTER(translate_gather_nd_op);
|
||||
OP_CONVERTER(translate_gru_block_cell_op);
|
||||
OP_CONVERTER(translate_identity_op);
|
||||
OP_CONVERTER(translate_identity_n_op);
|
||||
OP_CONVERTER(translate_input_arg_op);
|
||||
OP_CONVERTER(translate_output_arg_op);
|
||||
OP_CONVERTER(translate_interpolate_op);
|
||||
OP_CONVERTER(translate_is_finite_op);
|
||||
OP_CONVERTER(translate_is_inf_op);
|
||||
@ -246,6 +248,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"GatherNd", translate_gather_nd_op},
|
||||
{"Identity", translate_identity_op},
|
||||
{"IdentityN", translate_identity_n_op},
|
||||
{"input_arg", translate_input_arg_op},
|
||||
{"output_arg", translate_output_arg_op},
|
||||
{"L2Loss", translate_l2_loss_op},
|
||||
{"LeakyRelu", translate_leaky_relu_op},
|
||||
{"LinSpace", translate_linspace_op},
|
||||
|
@ -93,6 +93,24 @@ TEST_F(TransformationTestsF, AssertAndStringTensors) {
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, UnsortedNodes) {
|
||||
{ function = convert_model("forward_edge_model_unsorted/forward_edge_model_unsorted.pb"); }
|
||||
{ function_ref = convert_model("forward_edge_model/forward_edge_model.pb"); }
|
||||
{ model = convert_model("forward_edge_model_unsorted/forward_edge_model_unsorted.pb"); }
|
||||
{ model_ref = convert_model("forward_edge_model/forward_edge_model.pb"); }
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ModelWithSwishF32BodyGraph) {
|
||||
{
|
||||
model = convert_model("swish_f32/swish_f32.pb");
|
||||
// need to call shape inference since body graphs can be injected with undefined shapes
|
||||
model->validate_nodes_and_infer_types();
|
||||
}
|
||||
{
|
||||
auto x = make_shared<Parameter>(f32, Shape{1, 112, 112, 32});
|
||||
auto const_add = make_shared<Constant>(f32, Shape{}, std::vector<float>{2});
|
||||
auto add = make_shared<Add>(x, const_add);
|
||||
auto sigmoid = make_shared<Sigmoid>(add);
|
||||
auto mul = make_shared<Multiply>(add, sigmoid);
|
||||
auto sigmoid2 = make_shared<Sigmoid>(mul);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{sigmoid2}, ParameterVector{x});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,143 @@
|
||||
node {
|
||||
name: "placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 112
|
||||
}
|
||||
dim {
|
||||
size: 112
|
||||
}
|
||||
dim {
|
||||
size: 32
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "const_add"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "add"
|
||||
op: "AddV2"
|
||||
input: "placeholder"
|
||||
input: "const_add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "swish_f32"
|
||||
op: "swish_f32"
|
||||
input: "add"
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main/Sigmoid"
|
||||
op: "Sigmoid"
|
||||
input: "swish_f32"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "swish_f32"
|
||||
input_arg {
|
||||
name: "features"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "mul"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
description: "Computes the Swish activation function: `x * sigmoid(x)`.\n\n Source: \"Searching for Activation Functions\" (Ramachandran et al. 2017)\n https://arxiv.org/abs/1710.05941\n\n Args:\n features: A `Tensor` representing preactivation values.\n name: A name for the operation (optional).\n\n Returns:\n The activation value.\n "
|
||||
}
|
||||
node_def {
|
||||
name: "Sigmoid"
|
||||
op: "Sigmoid"
|
||||
input: "features"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "mul_0"
|
||||
op: "Mul"
|
||||
input: "features"
|
||||
input: "Sigmoid:y:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "mul"
|
||||
value: "mul_0:z:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_noinline"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user