[TF FE] Support body graph conversion and its injection (#14841)

* [TF FE] Support body graph conversion and injection

Now this is a base for further implementation support for StatefulPartitionedOp, While, and If operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Remove unused variable

* Remove artifacts of serialization experiment

* Apply code-review feedback: comments for decoder_argdef

* Create a mat to cache function indices by names

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-12-29 00:08:05 +04:00 committed by GitHub
parent 0be53a66d2
commit 238c3234a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 522 additions and 42 deletions

View File

@ -21,6 +21,7 @@
namespace ov {
namespace frontend {
namespace tensorflow {
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
class TENSORFLOW_API FrontEnd : public ov::frontend::FrontEnd {
public:
@ -71,7 +72,8 @@ protected:
const std::string& model_name,
bool fail_fast,
bool no_conversion,
std::shared_ptr<ov::Model>& ng_function) const;
std::shared_ptr<ov::Model>& ov_model,
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const;
TelemetryExtension::Ptr m_telemetry;
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;

View File

@ -36,6 +36,10 @@ public:
/// \brief Destructor
virtual ~GraphIterator() = default;
/// \brief Checks if the main model graph contains a function of the requested name in the library
/// Returns GraphIterator to this function and nullptr, if it does not exist
virtual std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const = 0;
};
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "decoder_argdef.hpp"
#include "op_def.pb.h"
#include "openvino/frontend/tensorflow/node_context.hpp"
#include "openvino/frontend/tensorflow/special_types.hpp"
#include "types.pb.h"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace {
const std::map<::tensorflow::DataType, ov::element::Type>& TYPE_MAP() {
static const std::map<::tensorflow::DataType, ov::element::Type> type_map{
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
{::tensorflow::DataType::DT_INT16, ov::element::i16},
{::tensorflow::DataType::DT_INT32, ov::element::i32},
{::tensorflow::DataType::DT_INT64, ov::element::i64},
{::tensorflow::DataType::DT_HALF, ov::element::f16},
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
{::tensorflow::DataType::DT_INT8, ov::element::i8},
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16}};
return type_map;
}
} // namespace
size_t DecoderArgDef::get_input_size() const {
FRONT_END_GENERAL_CHECK(m_op_type == "input_arg" || m_op_type == "output_arg",
"[TensorFlow Frontend] Internal error: Incorrect use of DecoderArgDef class.");
if (m_op_type == "input_arg") {
return 0;
} else {
return 1;
}
}
const std::string& DecoderArgDef::get_op_type() const {
FRONT_END_GENERAL_CHECK(m_op_type == "input_arg" || m_op_type == "output_arg",
"[TensorFlow Frontend] Internal error: Incorrect use of DecoderArgDef class.");
return m_op_type;
}
const std::string& DecoderArgDef::get_op_name() const {
return m_arg_def->name();
}
void DecoderArgDef::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
// Body graph nodes may have two colons `:`, for example,
// producer_name:z:2 means that producer operation name is `producer_name`
// and output port is 2
FRONT_END_GENERAL_CHECK(m_op_type == "output_arg",
"[TensorFlow Frontend] Internal error: get_input_node is supported only for output_arg.");
auto first_colon = m_producer_name.find_first_of(":");
auto last_colon = m_producer_name.find_last_of(":");
if (first_colon != std::string::npos && last_colon != std::string::npos) {
producer_name = m_producer_name.substr(0, first_colon);
auto port_id = m_producer_name.substr(last_colon + 1);
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
"Port id is not specified or not a number. Value: ",
port_id);
producer_output_port_index = std::stoi(port_id);
return;
}
producer_name = m_producer_name;
producer_output_port_index = 0;
}
ov::Any DecoderArgDef::get_attribute(const std::string& name) const {
FRONT_END_GENERAL_CHECK(name == "type",
"[TensorFlow Frontend] Internal error: DecoderArgDef supports only `type` attribute.");
if (TYPE_MAP().count(m_arg_def->type())) {
return TYPE_MAP().at(m_arg_def->type());
} else {
// for all unsupported types return undefined type
return ov::element::undefined;
}
}
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,53 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "openvino/frontend/tensorflow/decoder.hpp"
namespace tensorflow {
class OpDef_ArgDef;
} // namespace tensorflow
namespace ov {
namespace frontend {
namespace tensorflow {
class DecoderArgDef : public ov::frontend::tensorflow::DecoderBase {
public:
explicit DecoderArgDef(const ::tensorflow::OpDef_ArgDef* arg_def, const std::string& op_type)
: m_arg_def(arg_def),
m_op_type(op_type) {}
explicit DecoderArgDef(const ::tensorflow::OpDef_ArgDef* arg_def,
const std::string& op_type,
const std::string& producer_name)
: m_arg_def(arg_def),
m_op_type(op_type),
m_producer_name(producer_name) {}
ov::Any get_attribute(const std::string& name) const override;
size_t get_input_size() const override;
void get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const override;
const std::string& get_op_type() const override;
const std::string& get_op_name() const override;
private:
const ::tensorflow::OpDef_ArgDef* m_arg_def;
const std::string m_op_type;
const std::string m_producer_name;
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -279,12 +279,15 @@ size_t DecoderProto::get_input_size() const {
void DecoderProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
// TODO: handle body graph nodes with a couple of columns
// Body graph nodes may have two colons `:`, for example,
// producer_name:z:2 means that producer operation name is `producer_name`
// and output port is 2
std::string producer_port_name = m_node_def->input(static_cast<int>(input_port_idx));
auto delim_pos = producer_port_name.find(':');
if (delim_pos != std::string::npos) {
producer_name = producer_port_name.substr(0, delim_pos);
auto port_id = producer_port_name.substr(delim_pos + 1);
auto first_colon = producer_port_name.find_first_of(":");
auto last_colon = producer_port_name.find_last_of(":");
if (first_colon != std::string::npos && last_colon != std::string::npos) {
producer_name = producer_port_name.substr(0, first_colon);
auto port_id = producer_port_name.substr(last_colon + 1);
FRONT_END_GENERAL_CHECK(!port_id.empty() && std::all_of(port_id.begin(), port_id.end(), ::isdigit),
"Port id is not specified or not a number. Value: ",
port_id);

View File

@ -33,8 +33,8 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
auto translator_it = TRANSLATE_OP_MAP.find(type);
FRONT_END_OP_CONVERSION_CHECK(translator_it != TRANSLATE_OP_MAP.end(), "No translator found for ", type, " node.");
ov::OutputVector ng_inputs = node->input_values();
NodeContext node_ctx(node->get_decoder(), ng_inputs);
ov::OutputVector ov_inputs = node->input_values();
NodeContext node_ctx(node->get_decoder(), ov_inputs);
auto new_node_outputs = translator_it->second(node_ctx);
auto new_output = new_node_outputs.begin();
@ -45,6 +45,24 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
old_output->replace(*new_output);
}
}
void inject_body_model(std::shared_ptr<ov::Model> body_model,
const std::string& operation_type,
const ov::OutputVector& ov_inputs,
ov::OutputVector& ov_outputs) {
ov_outputs.clear();
auto body_parameters = body_model->get_parameters();
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
"[TensorFlow Error] Internal error or incorrect input models: number of "
"inputs and arguments to the function " +
operation_type + " do not match.");
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
}
for (const auto& result_node : body_model->get_results()) {
ov_outputs.push_back(result_node->input_value(0));
}
}
} // namespace
FrontEnd::FrontEnd() : m_op_translators(tensorflow::op::get_supported_ops()) {}
@ -53,7 +71,8 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
const std::string& model_name,
bool fail_fast,
bool no_conversion,
std::shared_ptr<ov::Model>& ng_function) const {
std::shared_ptr<ov::Model>& ov_model,
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const {
// a map from operation names to generated OV Output<TFNodeDecoder>
tensorflow::OpMap ng_op_map;
@ -65,7 +84,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs();
const auto& model_frozen_inputs = model_tf->get_tensor_values();
std::map<const std::string, const std::function<ov::OutputVector(const NodeContext&)>> translate_map;
TranslatorDictionaryType translate_map;
const auto& TRANSLATE_OP_MAP = m_op_translators;
if (no_conversion) {
@ -119,7 +138,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
}
// prepare a list of OV node inputs for each node
ov::OutputVector ng_inputs;
ov::OutputVector ov_inputs;
size_t operation_input_size = operation_decoder->get_input_size();
if (operation_decoder->get_op_type() == "NextIteration") {
@ -159,18 +178,18 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ng_inputs.push_back(input_outputs_vector.at(0));
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
const auto& input_outputs_vector =
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ng_inputs.push_back(input_outputs_vector.at(0));
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name)) {
const auto& input_outputs_vector = ng_op_map.at(producer_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
"Input created with pruning must have one output");
ng_inputs.push_back(input_outputs_vector.at(producer_port_idx));
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
} else {
FRONT_END_GENERAL_CHECK(false,
"No input is found for node \"" + operation_name + "\" by port " +
@ -179,14 +198,43 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
}
// generate OV node output vector for the current operation node
ov::OutputVector ng_outputs;
ov::OutputVector ov_outputs;
bool is_converted = false;
auto operation_type = operation_decoder->get_op_type();
try {
FRONT_END_OP_CONVERSION_CHECK(translate_map.count(operation_decoder->get_op_type()),
"No translator found for " + operation_decoder->get_op_type() + " node.");
auto op_fun = &(translate_map[operation_decoder->get_op_type()]);
NodeContext node_context(operation_decoder, ng_inputs);
// generate OV node output vector using translator for given operation type
ng_outputs = (*op_fun)(node_context);
if (translate_map.count(operation_type)) {
auto translator = translate_map[operation_decoder->get_op_type()];
NodeContext node_context(operation_decoder, ov_inputs);
ov_outputs = translator(node_context);
is_converted = true;
} else if (cached_body_models->count(operation_type)) {
// check if such body graph has been converted before
// re-use it from the cache for further injection
// create new instance of the required body model
// since it will be modified by injection
auto cached_body_model = cached_body_models->at(operation_type);
auto body_model = cached_body_model->clone();
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
is_converted = true;
} else if (auto body_input_model = model_tf->get_body_input_model(operation_type)) {
// try to find a function by name in the model library
std::shared_ptr<ov::Model> body_model;
translate_graph(body_input_model,
operation_decoder->get_op_type(),
fail_fast,
no_conversion,
body_model,
cached_body_models);
// save new instance of body_model in the cache of body models
// before its injection into the parent graph
auto cached_body_model = body_model->clone();
cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
is_converted = true;
}
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
} catch (...) {
if (fail_fast) {
// in case of decode, unsupported operation will be converted to FrameworkNode
@ -198,15 +246,15 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
throw;
} else {
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
ng_inputs,
ov_inputs,
operation_place->get_output_ports().size());
set_node_name(operation_name, ng_node);
ng_outputs = ng_node->outputs();
ov_outputs = ng_node->outputs();
}
}
// register OV node outputs in the map for new operation node
for (const auto& output : ng_outputs) {
for (const auto& output : ov_outputs) {
if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
// do not add RetVal type operation to ng_op_map
results.push_back(result);
@ -302,7 +350,7 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
// TODO: reorder results and params according to indices given in RT info (if any)
// create the OV Model
ng_function = std::make_shared<ov::Model>(results, params, model_name);
ov_model = std::make_shared<ov::Model>(results, params, model_name);
}
/// \brief Check if FrontEndTensorflow can recognize model from given parts
@ -383,7 +431,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr
}
std::shared_ptr<ov::Model> f;
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f);
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f, cached_body_models);
normalize(f);
for (const auto& node : f->get_ordered_ops()) {
@ -416,7 +465,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
}
std::shared_ptr<ov::Model> f;
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f);
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f, cached_body_models);
normalize(f);
return f;
}
@ -424,7 +474,8 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
std::shared_ptr<ov::Model> FrontEnd::decode(const ov::frontend::InputModel::Ptr& model) const {
auto model_tf = std::dynamic_pointer_cast<InputModel>(model);
std::shared_ptr<ov::Model> f;
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f);
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f, cached_body_models);
return f;
}

View File

@ -6,9 +6,9 @@
#include <fstream>
#include "decoder_argdef.hpp"
#include "decoder_proto.hpp"
#include "graph.pb.h"
#include "node_def.pb.h"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/tensorflow/decoder.hpp"
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
@ -18,21 +18,70 @@ namespace frontend {
namespace tensorflow {
class GraphIteratorProto : public GraphIterator {
std::vector<const ::tensorflow::NodeDef*> m_nodes;
size_t node_index = 0;
std::shared_ptr<::tensorflow::GraphDef> m_graph_def;
std::shared_ptr<::tensorflow::FunctionDef> m_func_def;
size_t node_index = 0;
std::vector<std::shared_ptr<DecoderBase>> m_decoders;
std::unordered_map<std::string, int> m_library_map;
public:
GraphIteratorProto(const std::shared_ptr<::tensorflow::GraphDef>& graph_def,
const std::shared_ptr<::tensorflow::FunctionDef>& func_def,
const std::unordered_map<std::string, int>& library_map)
: m_graph_def(graph_def),
m_func_def(func_def),
m_library_map(library_map) {
auto nodes_size = m_func_def->node_def_size();
auto input_size = m_func_def->signature().input_arg_size();
auto output_size = m_func_def->signature().output_arg_size();
auto ret_map = m_func_def->ret();
// fill all inputs from library functions
// these input_arg objects are of different type OpDef_ArgDef
// they are not NodeDef so we use separate Decoder class
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
auto input_arg = &m_func_def->signature().input_arg(input_ind);
m_decoders.push_back(std::make_shared<DecoderArgDef>(input_arg, "input_arg"));
}
// fill all node defs from library functions
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
m_decoders.push_back(std::make_shared<DecoderProto>(&(m_func_def->node_def(node_ind))));
}
// fill all outputs from library functions
// these output_arg objects are of different type OpDef_ArgDef
// they are not NodeDef so we use separate Decoder class
for (int output_ind = 0; output_ind < output_size; ++output_ind) {
auto output_arg = &m_func_def->signature().output_arg(output_ind);
auto producer_name = ret_map.at(output_arg->name());
m_decoders.push_back(std::make_shared<DecoderArgDef>(output_arg, "output_arg", producer_name));
}
}
template <typename T>
GraphIteratorProto(const std::basic_string<T>& path) : m_graph_def(std::make_shared<::tensorflow::GraphDef>()) {
GraphIteratorProto(const std::basic_string<T>& path)
: m_graph_def(std::make_shared<::tensorflow::GraphDef>()),
m_func_def(nullptr) {
std::ifstream pb_stream(path, std::ios::in | std::ifstream::binary);
FRONT_END_GENERAL_CHECK(pb_stream && pb_stream.is_open(), "Model file does not exist");
FRONT_END_GENERAL_CHECK(m_graph_def->ParseFromIstream(&pb_stream), "Model cannot be parsed");
m_nodes.resize(m_graph_def->node_size());
for (size_t i = 0; i < m_nodes.size(); ++i)
m_nodes[i] = &m_graph_def->node(static_cast<int>(i));
auto nodes_size = m_graph_def->node_size();
m_decoders.resize(static_cast<size_t>(nodes_size));
for (int node_ind = 0; node_ind < nodes_size; ++node_ind) {
m_decoders[node_ind] = std::make_shared<DecoderProto>(&m_graph_def->node(node_ind));
}
// initialize a library map
auto num_funcs = m_graph_def->library().function_size();
for (int func_ind = 0; func_ind < num_funcs; ++func_ind) {
auto func = m_graph_def->library().function(func_ind);
auto func_name = func.signature().name();
m_library_map.insert(std::pair<std::string, int>(func_name, func_ind));
}
}
/// Set iterator to the start position
@ -41,7 +90,7 @@ public:
}
size_t size() const override {
return m_nodes.size();
return m_decoders.size();
}
/// Moves to the next node in the graph
@ -50,15 +99,30 @@ public:
}
bool is_end() const override {
return node_index >= m_nodes.size();
return node_index >= m_decoders.size();
}
/// Return NodeContext for the current node that iterator points to
std::shared_ptr<DecoderBase> get_decoder() const override {
return std::make_shared<DecoderProto>(m_nodes[node_index]);
return m_decoders[node_index];
}
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
if (m_library_map.count(func_name)) {
auto func_ind = m_library_map.at(func_name);
auto func_size = m_graph_def->library().function_size();
FRONT_END_GENERAL_CHECK(
0 <= func_ind && func_ind < func_size,
"[TensorFlow Error] Internal Error: incorrect library map to cache function indices by names.");
auto func = m_graph_def->library().function(func_ind);
auto func_ptr = std::make_shared<::tensorflow::FunctionDef>(func);
return std::make_shared<GraphIteratorProto>(m_graph_def, func_ptr, m_library_map);
}
return nullptr;
}
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -76,6 +76,7 @@ public:
std::map<std::string, Output<Node>> get_tensor_values() const {
return m_tensor_values;
};
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_model_name) const;
private:
void loadPlaces();
@ -314,6 +315,15 @@ InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_i
loadPlaces();
}
std::shared_ptr<InputModel> InputModel::InputModelTFImpl::get_body_input_model(
const std::string& body_model_name) const {
auto body_graph_iterator = m_graph_iterator->get_body_graph_iterator(body_model_name);
if (!body_graph_iterator) {
return nullptr;
}
return std::make_shared<InputModel>(body_graph_iterator, m_telemetry);
}
InputModel::InputModelTFImpl::InputModelTFImpl(const GraphIterator::Ptr& graph_iterator,
const ov::frontend::InputModel& input_model,
const std::shared_ptr<TelemetryExtension>& telemetry)
@ -427,6 +437,10 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::get_op_places() const {
return _impl->get_op_places();
}
std::shared_ptr<InputModel> InputModel::get_body_input_model(const std::string& body_model_name) const {
return _impl->get_body_input_model(body_model_name);
}
std::map<std::string, std::shared_ptr<TensorPlace>> InputModel::get_tensor_places() const {
return _impl->get_tensor_places();
}

View File

@ -4,8 +4,8 @@
#pragma once
#include "input_model.hpp"
#include "openvino/frontend/extension/telemetry.hpp"
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
#include "place.hpp"
@ -24,6 +24,7 @@ class InputModel : public ov::frontend::InputModel {
std::vector<std::shared_ptr<OpPlace>> get_op_places() const;
std::map<std::string, std::shared_ptr<TensorPlace>> get_tensor_places() const;
std::map<std::string, Output<Node>> get_tensor_values() const;
std::shared_ptr<InputModel> get_body_input_model(const std::string& body_input_model_name) const;
public:
explicit InputModel(const GraphIterator::Ptr& graph_iterator,

View File

@ -0,0 +1,34 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset10.hpp"
using namespace std;
using namespace ov::opset10;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_input_arg_op(const NodeContext& node) {
default_op_checks(node, 0, {"input_arg"});
auto param_type = node.get_attribute<ov::element::Type>("type");
auto param = std::make_shared<Parameter>(param_type, ov::PartialShape::dynamic());
set_node_name(node.get_name(), param);
return param->outputs();
}
OutputVector translate_output_arg_op(const NodeContext& node) {
default_op_checks(node, 1, {"output_arg"});
auto result = std::make_shared<Result>();
set_node_name(node.get_name(), result);
return result->outputs();
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -64,6 +64,8 @@ OP_CONVERTER(translate_gather_nd_op);
OP_CONVERTER(translate_gru_block_cell_op);
OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_input_arg_op);
OP_CONVERTER(translate_output_arg_op);
OP_CONVERTER(translate_interpolate_op);
OP_CONVERTER(translate_is_finite_op);
OP_CONVERTER(translate_is_inf_op);
@ -246,6 +248,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"GatherNd", translate_gather_nd_op},
{"Identity", translate_identity_op},
{"IdentityN", translate_identity_n_op},
{"input_arg", translate_input_arg_op},
{"output_arg", translate_output_arg_op},
{"L2Loss", translate_l2_loss_op},
{"LeakyRelu", translate_leaky_relu_op},
{"LinSpace", translate_linspace_op},

View File

@ -93,6 +93,24 @@ TEST_F(TransformationTestsF, AssertAndStringTensors) {
}
TEST_F(TransformationTestsF, UnsortedNodes) {
{ function = convert_model("forward_edge_model_unsorted/forward_edge_model_unsorted.pb"); }
{ function_ref = convert_model("forward_edge_model/forward_edge_model.pb"); }
{ model = convert_model("forward_edge_model_unsorted/forward_edge_model_unsorted.pb"); }
{ model_ref = convert_model("forward_edge_model/forward_edge_model.pb"); }
}
TEST_F(TransformationTestsF, ModelWithSwishF32BodyGraph) {
{
model = convert_model("swish_f32/swish_f32.pb");
// need to call shape inference since body graphs can be injected with undefined shapes
model->validate_nodes_and_infer_types();
}
{
auto x = make_shared<Parameter>(f32, Shape{1, 112, 112, 32});
auto const_add = make_shared<Constant>(f32, Shape{}, std::vector<float>{2});
auto add = make_shared<Add>(x, const_add);
auto sigmoid = make_shared<Sigmoid>(add);
auto mul = make_shared<Multiply>(add, sigmoid);
auto sigmoid2 = make_shared<Sigmoid>(mul);
model_ref = make_shared<Model>(OutputVector{sigmoid2}, ParameterVector{x});
}
}

View File

@ -0,0 +1,143 @@
node {
name: "placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 112
}
dim {
size: 112
}
dim {
size: 32
}
}
}
}
}
node {
name: "const_add"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "add"
op: "AddV2"
input: "placeholder"
input: "const_add"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "swish_f32"
op: "swish_f32"
input: "add"
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
}
node {
name: "main/Sigmoid"
op: "Sigmoid"
input: "swish_f32"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
function {
signature {
name: "swish_f32"
input_arg {
name: "features"
type: DT_FLOAT
}
output_arg {
name: "mul"
type: DT_FLOAT
}
description: "Computes the Swish activation function: `x * sigmoid(x)`.\n\n Source: \"Searching for Activation Functions\" (Ramachandran et al. 2017)\n https://arxiv.org/abs/1710.05941\n\n Args:\n features: A `Tensor` representing preactivation values.\n name: A name for the operation (optional).\n\n Returns:\n The activation value.\n "
}
node_def {
name: "Sigmoid"
op: "Sigmoid"
input: "features"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node_def {
name: "mul_0"
op: "Mul"
input: "features"
input: "Sigmoid:y:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
ret {
key: "mul"
value: "mul_0:z:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
attr {
key: "_noinline"
value {
b: true
}
}
arg_attr {
value {
}
}
}
}