[TF FE] Support If and PartitionedCall operations (#14910)

* [TF FE] Support If and PartitionedCall operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix build

* Fix frontend wrapper for tests

* Erase tensor names in body graph before caching

* Apply code-review feedback: recover m_op_translators in Frontend

* Rename test models

* Rename unit-tests

* Correct scripts for test model generation

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-01-05 16:34:15 +04:00 committed by GitHub
parent efb602e13b
commit f13e7e1352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1683 additions and 327 deletions

View File

@ -21,8 +21,6 @@
namespace ov {
namespace frontend {
namespace tensorflow {
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
class TENSORFLOW_API FrontEnd : public ov::frontend::FrontEnd {
public:
using Ptr = std::shared_ptr<FrontEnd>;
@ -68,13 +66,6 @@ protected:
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
void translate_graph(const ov::frontend::InputModel::Ptr& model,
const std::string& model_name,
bool fail_fast,
bool no_conversion,
std::shared_ptr<ov::Model>& ov_model,
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const;
TelemetryExtension::Ptr m_telemetry;
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;
std::vector<ConversionExtensionBase::Ptr> m_conversion_extensions;

View File

@ -12,15 +12,19 @@
namespace ov {
namespace frontend {
namespace tensorflow {
class TranslateSession;
/// Keep necessary data for a single node in the original FW graph to facilitate
/// conversion process in the rules code.
class NodeContext : public ov::frontend::NodeContext {
public:
using Ptr = std::shared_ptr<NodeContext>;
NodeContext(const std::shared_ptr<DecoderBase>& decoder, const OutputVector& inputs)
NodeContext(const std::shared_ptr<DecoderBase>& decoder,
const OutputVector& inputs,
TranslateSession* translate_session = nullptr)
: ov::frontend::NodeContext(decoder->get_op_type()),
m_decoder(decoder),
m_translate_session(translate_session),
m_inputs(inputs) {}
/// Detects if there is at least one input attached with a given name
@ -51,10 +55,16 @@ public:
return res;
}
/// \brief Get a pointer to TranslateSession object
TranslateSession* get_translate_session() const {
return m_translate_session;
}
private:
ov::Any apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const override;
std::shared_ptr<DecoderBase> m_decoder;
TranslateSession* m_translate_session;
const OutputVector& m_inputs;
};

View File

@ -263,10 +263,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
name,
"' attribute is not supported.");
case ::tensorflow::AttrValue::ValueCase::kFunc:
FRONT_END_GENERAL_CHECK(false,
"Conversion from Tensorflow to OpenVINO data type failed: Function type for '",
name,
"' attribute is not supported.");
// attrs[0].func() returns NameAttrList object from which
// we retrieve the function name
// Further, InputModel object is created for FunctionDef with this name
// and is converted to ov::Model object.
return attrs[0].func().name();
default:
FRONT_END_GENERAL_CHECK(false, "Conversion from Tensorflow to OpenVINO data type failed.");
}

View File

@ -19,6 +19,7 @@
#include "so_extension.hpp"
#include "tf_framework_node.hpp"
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "translate_session.hpp"
#include "utils.hpp"
using namespace ov;
@ -45,314 +46,10 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
old_output->replace(*new_output);
}
}
void inject_body_model(std::shared_ptr<ov::Model> body_model,
const std::string& operation_type,
const ov::OutputVector& ov_inputs,
ov::OutputVector& ov_outputs) {
ov_outputs.clear();
auto body_parameters = body_model->get_parameters();
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
"[TensorFlow Error] Internal error or incorrect input models: number of "
"inputs and arguments to the function " +
operation_type + " do not match.");
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
}
for (const auto& result_node : body_model->get_results()) {
ov_outputs.push_back(result_node->input_value(0));
}
}
} // namespace
FrontEnd::FrontEnd() : m_op_translators(tensorflow::op::get_supported_ops()) {}
void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
const std::string& model_name,
bool fail_fast,
bool no_conversion,
std::shared_ptr<ov::Model>& ov_model,
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const {
// a map from operation names to generated OV Output<TFNodeDecoder>
tensorflow::OpMap ng_op_map;
ov::ParameterVector params;
ov::ResultVector results;
const auto& model_tf = std::dynamic_pointer_cast<InputModel>(model);
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV Model");
const auto& operation_places = model_tf->get_op_places();
const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs();
const auto& model_frozen_inputs = model_tf->get_tensor_values();
TranslatorDictionaryType translate_map;
const auto& TRANSLATE_OP_MAP = m_op_translators;
if (no_conversion) {
const std::set<std::string> required_types{"Placeholder", "NoOp"};
for (const auto& name : required_types) {
translate_map.emplace(name, TRANSLATE_OP_MAP.at(name));
}
} else {
translate_map.insert(TRANSLATE_OP_MAP.begin(), TRANSLATE_OP_MAP.end());
}
// fill ng_op_map with Constant outputs for frozen inputs
for (const auto& frozen_input : model_frozen_inputs) {
const auto& frozen_input_name = frozen_input.first;
const auto& frozen_input_value = frozen_input.second;
FRONT_END_GENERAL_CHECK(ng_op_map.count(frozen_input_name) == 0,
"Input with frozen value has been already met: " + frozen_input_name);
ng_op_map[frozen_input_name] = {frozen_input_value};
}
// create parameter nodes for all tensor places corresponding to inputs
for (const auto& input_place : model_inputs) {
FRONT_END_GENERAL_CHECK(input_place->get_names().size() == 1, "Input place must have one name.");
auto input_name = input_place->get_names()[0];
if (ng_op_map.count(input_name)) {
// probably this input is frozen
continue;
}
const auto& input_tensor_place = std::dynamic_pointer_cast<TensorPlace>(input_place);
auto input_shape = input_tensor_place->get_partial_shape();
auto input_type = input_tensor_place->get_element_type();
// in case of cutting graph, types of custom inputs can be undefined,
// according to MO help, fp32 is used by default in such cases
if (input_type == element::undefined) {
input_type = element::f32;
}
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
set_node_name(input_name, param);
params.push_back(param);
ng_op_map[input_name] = {param};
}
// create the OV ops from TensorFlow ops
for (const auto& operation_place : operation_places) {
auto operation_decoder = operation_place->get_decoder();
auto operation_name = operation_place->get_names()[0];
// output for parameter nodes has been already generated
if (ng_op_map.count(operation_name)) {
continue;
}
// prepare a list of OV node inputs for each node
ov::OutputVector ov_inputs;
size_t operation_input_size = operation_decoder->get_input_size();
if (operation_decoder->get_op_type() == "NextIteration") {
// we expect no inputs for NextIteration because we break-up the cycle in InputModel
operation_input_size = 0;
}
for (size_t input_port_idx = 0; input_port_idx < operation_input_size; ++input_port_idx) {
// TODO: Implement more general approach. Skipping Constants that have input edges
if (operation_decoder->get_op_type() == "Const") {
break;
}
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(input_port_idx) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
'\n');
}
// skip conditional edges that must be resolved before operation translation
// now we can meet them because we still work with TensorFlow protobuf
if (is_conditional_edge(producer_name)) {
continue;
}
// TODO: re-implement the logic below once Place graph structure is implemented
// Using Place graph structure (OpPlace, In/OutPortPlace places and their connections) can give
// names of ports and operations that can be used for further check about existence in ng_op_map
// check if output vector for places have been already defined and the order of this check is important
// it moves from places corresponding to input port of the current operation node to output port of original
// producers
if (ng_op_map.count(std::to_string(input_port_idx) + ":" + operation_name)) {
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
const auto& input_outputs_vector =
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name)) {
const auto& input_outputs_vector = ng_op_map.at(producer_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
} else {
FRONT_END_GENERAL_CHECK(false,
"No input is found for node \"" + operation_name + "\" by port " +
std::to_string(producer_port_idx));
}
}
// generate OV node output vector for the current operation node
ov::OutputVector ov_outputs;
bool is_converted = false;
auto operation_type = operation_decoder->get_op_type();
try {
if (translate_map.count(operation_type)) {
auto translator = translate_map[operation_decoder->get_op_type()];
NodeContext node_context(operation_decoder, ov_inputs);
ov_outputs = translator(node_context);
is_converted = true;
} else if (cached_body_models->count(operation_type)) {
// check if such body graph has been converted before
// re-use it from the cache for further injection
// create new instance of the required body model
// since it will be modified by injection
auto cached_body_model = cached_body_models->at(operation_type);
auto body_model = cached_body_model->clone();
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
is_converted = true;
} else if (auto body_input_model = model_tf->get_body_input_model(operation_type)) {
// try to find a function by name in the model library
std::shared_ptr<ov::Model> body_model;
translate_graph(body_input_model,
operation_decoder->get_op_type(),
fail_fast,
no_conversion,
body_model,
cached_body_models);
// save new instance of body_model in the cache of body models
// before its injection into the parent graph
auto cached_body_model = body_model->clone();
cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
is_converted = true;
}
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
} catch (...) {
if (fail_fast) {
// in case of decode, unsupported operation will be converted to FrameworkNode
if (m_telemetry && translate_map.count(operation_decoder->get_op_type()) == 0) {
// send event about which operation is not supported for conversion
m_telemetry->send_event("error_cause", "tf_" + operation_decoder->get_op_type());
}
// re-throw any exception
throw;
} else {
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
ov_inputs,
operation_place->get_output_ports().size());
set_node_name(operation_name, ng_node);
ov_outputs = ng_node->outputs();
}
}
// register OV node outputs in the map for new operation node
for (const auto& output : ov_outputs) {
if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
// do not add RetVal type operation to ng_op_map
results.push_back(result);
} else {
auto param = std::dynamic_pointer_cast<ov::opset8::Parameter>(output.get_node_shared_ptr());
// avoid duplicating Parameter nodes if they are already in the Parameters vector
if (param && operation_decoder->get_op_type() != "Identity" &&
std::find(params.begin(), params.end(), param) == params.end()) {
params.push_back(param);
}
ng_op_map[operation_name].push_back(output);
}
}
}
// create Result nodes for all model outputs
for (const auto& model_output : model_outputs) {
auto model_output_tensor_place = std::dynamic_pointer_cast<TensorPlace>(model_output);
auto model_output_name = model_output_tensor_place->get_names()[0];
std::string operation_name;
std::string port_type;
size_t port_index;
ov::frontend::tensorflow::extract_operation_name_and_port(model_output_name,
operation_name,
port_index,
port_type);
if (port_type == "none") {
for (const auto& node_output : ng_op_map[operation_name]) {
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
} else if (port_type == "out") {
const auto& node_outputs = ng_op_map[operation_name];
FRONT_END_GENERAL_CHECK(node_outputs.size() > port_index,
"Output port with index " + std::to_string(port_index) + " of " + operation_name +
"node specified as custom output does not exist");
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[port_index]);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
} else if (port_type == "in") {
// TODO: avoid this traversing by having a map for OpPlace objects, for example
std::shared_ptr<OpPlace> operation_place = nullptr;
for (const auto& op_place : operation_places) {
FRONT_END_GENERAL_CHECK(!op_place->get_names().empty(), "No names for OpPlace found.");
if (op_place->get_names()[0] == operation_name) {
operation_place = op_place;
}
}
FRONT_END_GENERAL_CHECK(operation_place, "There is no operation place with a name: " + operation_name);
auto operation_decoder = operation_place->get_decoder();
// get to know a producer node and by which its output port data is generated
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(port_index) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
'\n');
}
// add Result node for this producer output port
const auto& node_outputs = ng_op_map[producer_name];
FRONT_END_GENERAL_CHECK(node_outputs.size() > producer_port_idx,
"Output port with index " + std::to_string(producer_port_idx) + " of " +
producer_name + "node specified as custom output does not exist");
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx]);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
}
// find all terminal nodes in OV graph to complete list of results
if (results.empty()) {
for (const auto& node_output_vector : ng_op_map) {
for (size_t output_ind = 0; output_ind < node_output_vector.second.size(); ++output_ind) {
auto output = node_output_vector.second[output_ind];
if (output.get_target_inputs().empty() &&
!std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
auto model_output_name =
output.get_node_shared_ptr()->get_friendly_name() + ":" + std::to_string(output_ind);
auto result_node = std::make_shared<ov::opset8::Result>(output);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
}
}
}
// TODO: reorder results and params according to indices given in RT info (if any)
// create the OV Model
ov_model = std::make_shared<ov::Model>(results, params, model_name);
}
/// \brief Check if FrontEndTensorflow can recognize model from given parts
bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
// TODO: Support other TensorFlow formats: SavedModel, .meta, checkpoint, pbtxt
@ -430,9 +127,25 @@ std::shared_ptr<ov::Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr
return function;
}
// create a shared pointer to the cloned dictionary of translators
auto translator_map = std::make_shared<TranslatorDictionaryType>(m_op_translators);
std::shared_ptr<ov::Model> f;
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f, cached_body_models);
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", true, m_telemetry != nullptr);
try {
f = translate_session.get_converted_model();
} catch (const std::exception&) {
if (m_telemetry) {
auto telemetry_data = translate_session.get_telemetry_data();
if (telemetry_data) {
// send event about which operation is not supported for conversion
for (const auto& telemetry_item : *telemetry_data.get()) {
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
}
}
}
throw;
}
normalize(f);
for (const auto& node : f->get_ordered_ops()) {
@ -464,18 +177,55 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
return function;
}
// create a shared pointer to the cloned dictionary of translators
auto translator_map = std::make_shared<TranslatorDictionaryType>(m_op_translators);
std::shared_ptr<ov::Model> f;
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f, cached_body_models);
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", false, m_telemetry != nullptr);
try {
f = translate_session.get_converted_model();
} catch (const std::exception&) {
if (m_telemetry) {
auto telemetry_data = translate_session.get_telemetry_data();
if (telemetry_data) {
// send event about which operation is not supported for conversion
for (const auto& telemetry_item : *telemetry_data.get()) {
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
}
}
}
throw;
}
normalize(f);
return f;
}
std::shared_ptr<ov::Model> FrontEnd::decode(const ov::frontend::InputModel::Ptr& model) const {
auto model_tf = std::dynamic_pointer_cast<InputModel>(model);
auto translator_map = std::make_shared<TranslatorDictionaryType>();
const std::set<std::string> required_types{"Placeholder", "NoOp"};
for (const auto& name : required_types) {
translator_map->emplace(name, m_op_translators.at(name));
}
std::shared_ptr<ov::Model> f;
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f, cached_body_models);
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", false, m_telemetry != nullptr);
try {
f = translate_session.get_converted_model();
} catch (const std::exception&) {
if (m_telemetry) {
auto telemetry_data = translate_session.get_telemetry_data();
if (telemetry_data) {
// send event about which operation is not supported for conversion
for (const auto& telemetry_item : *telemetry_data.get()) {
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
}
}
}
throw;
}
return f;
}

View File

@ -8,6 +8,7 @@
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
#include "place.hpp"
#include "translate_session.hpp"
namespace ov {
namespace frontend {
@ -17,7 +18,7 @@ class OpPlace;
class TensorPlace;
class InputModel : public ov::frontend::InputModel {
friend class FrontEnd;
friend class TranslateSession;
class InputModelTFImpl;
std::shared_ptr<InputModelTFImpl> _impl;

View File

@ -23,7 +23,7 @@ OutputVector translate_input_arg_op(const NodeContext& node) {
OutputVector translate_output_arg_op(const NodeContext& node) {
default_op_checks(node, 1, {"output_arg"});
auto result = std::make_shared<Result>();
auto result = std::make_shared<Result>(node.get_input(0));
set_node_name(node.get_name(), result);
return result->outputs();
}

View File

@ -0,0 +1,87 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/opsets/opset10.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset10;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_if_op(const NodeContext& node) {
default_op_checks(node, 1, {"If", "StatelessIf"});
auto node_name = node.get_name();
auto translate_session = node.get_translate_session();
FRONT_END_GENERAL_CHECK(translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr.");
// retrieve body ov::Model for then and else branches
auto then_branch_type = node.get_attribute<std::string>("then_branch");
auto else_branch_type = node.get_attribute<std::string>("else_branch");
auto then_branch_body = translate_session->get_body_ov_model(then_branch_type);
FRONT_END_GENERAL_CHECK(
then_branch_body,
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find branch body graph with name " +
then_branch_type);
auto else_branch_body = translate_session->get_body_ov_model(else_branch_type);
FRONT_END_GENERAL_CHECK(
else_branch_body,
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find branch body graph with name " +
else_branch_type);
// get condition output
auto cond = node.get_input(0);
size_t input_size_t = node.get_input_size() - 1;
auto then_params = then_branch_body->get_parameters();
auto else_params = else_branch_body->get_parameters();
FRONT_END_GENERAL_CHECK(input_size_t == then_params.size(),
"[TensorFlow Frontend] Internal error or incorrect input model: number of inputs to If and "
"number of inputs in then branch do not match.");
FRONT_END_GENERAL_CHECK(input_size_t == else_params.size(),
"[TensorFlow Frontend] Internal error or incorrect input model: number of inputs to If and "
"number of inputs in else branch do not match.");
// create If operation and set body graphs
auto if_op = std::make_shared<If>(cond);
if_op->set_then_body(then_branch_body);
if_op->set_else_body(else_branch_body);
// set inputs
int input_size = static_cast<int>(input_size_t);
for (int ind = 0; ind < input_size; ++ind) {
auto curr_input = node.get_input(ind + 1);
auto then_param = then_params[ind];
auto else_param = else_params[ind];
if_op->set_input(curr_input, then_param, else_param);
}
// set outputs
auto then_results = then_branch_body->get_results();
auto else_results = else_branch_body->get_results();
FRONT_END_GENERAL_CHECK(then_results.size() == else_results.size(),
"[TensorFlow Frontend] Internal error or incorrect input model: number of result nodes in "
"then and else branches do not match.");
int output_size = static_cast<int>(then_results.size());
for (int ind = 0; ind < output_size; ++ind) {
if_op->set_output(then_results[ind], else_results[ind]);
}
auto ov_outputs = if_op->outputs();
// set output tensor names
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
ov_outputs[idx].get_tensor().set_names({node_name + ":" + std::to_string(idx)});
}
return ov_outputs;
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,52 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/opsets/opset10.hpp"
using namespace std;
using namespace ov;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_partitioned_call_op(const NodeContext& node) {
default_op_checks(node, 0, {"PartitionedCall", "StatefulPartitionedCall"});
auto node_name = node.get_name();
auto translate_session = node.get_translate_session();
FRONT_END_GENERAL_CHECK(translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr.");
auto operation_type = node.get_attribute<std::string>("f");
// prepare a vector of inputs
OutputVector ov_inputs;
int input_size = static_cast<int>(node.get_input_size());
for (int ind = 0; ind < input_size; ++ind) {
ov_inputs.push_back(node.get_input(ind));
}
// try to retrieve ov::Model for body graph
auto body_model = translate_session->get_body_ov_model(operation_type);
FRONT_END_OP_CONVERSION_CHECK(
body_model,
"[TensorFlow Frontend] Internal error or incorrect input model: body graph is not found for " + operation_type +
".");
// inject the body graph into the parent graph
OutputVector ov_outputs;
translate_session->inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
// set output tensor names
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
set_out_name({node_name + ":" + std::to_string(idx)}, ov_outputs[idx]);
}
return ov_outputs;
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -66,6 +66,7 @@ OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_input_arg_op);
OP_CONVERTER(translate_output_arg_op);
OP_CONVERTER(translate_if_op);
OP_CONVERTER(translate_interpolate_op);
OP_CONVERTER(translate_is_finite_op);
OP_CONVERTER(translate_is_inf_op);
@ -84,6 +85,7 @@ OP_CONVERTER(translate_mirror_pad_op);
OP_CONVERTER(translate_non_max_suppression_op);
OP_CONVERTER(translate_normalize_l2_op);
OP_CONVERTER(translate_parallel_dynamic_stitch_op);
OP_CONVERTER(translate_partitioned_call_op);
OP_CONVERTER(translate_placeholder_op);
OP_CONVERTER(translate_placeholder_with_default_op);
OP_CONVERTER(translate_no_op);
@ -248,6 +250,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"GatherNd", translate_gather_nd_op},
{"Identity", translate_identity_op},
{"IdentityN", translate_identity_n_op},
{"If", translate_if_op},
{"input_arg", translate_input_arg_op},
{"output_arg", translate_output_arg_op},
{"L2Loss", translate_l2_loss_op},
@ -276,6 +279,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"PadV2", translate_padv2_op},
{"DynamicStitch", translate_parallel_dynamic_stitch_op},
{"ParallelDynamicStitch", translate_parallel_dynamic_stitch_op},
{"PartitionedCall", translate_partitioned_call_op},
{"Placeholder", translate_placeholder_op},
{"PlaceholderWithDefault", translate_placeholder_with_default_op},
{"PreventGradient", translate_identity_op},
@ -314,6 +318,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Square", translate_square_op},
{"Squeeze", translate_squeeze_op},
{"SpaceToBatchND", translate_space_to_batch_nd_op},
{"StatefulPartitionedCall", translate_partitioned_call_op},
{"StatelessIf", translate_if_op},
{"StridedSlice", translate_strided_slice_op},
{"Tile", translate_tile_op},
{"TopK", translate_top_k_op},

View File

@ -0,0 +1,343 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "translate_session.hpp"
#include "input_model.hpp"
#include "openvino/opsets/opset10.hpp"
#include "tf_framework_node.hpp"
#include "utils.hpp"
using namespace ov::frontend::tensorflow;
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
const std::shared_ptr<TranslatorDictionaryType>& translator_map,
const std::string& model_name,
bool fail_fast,
bool telemetry)
: m_fail_fast(fail_fast),
m_telemetry(telemetry),
m_input_model(input_model),
m_translator_map(translator_map),
m_model_name(model_name),
m_ov_model(nullptr),
m_cached_body_models(std::make_shared<CachedBodyModelsType>()),
m_telemetry_data(std::make_shared<TelemetryDataType>()) {}
std::shared_ptr<ov::Model> TranslateSession::get_converted_model() {
if (m_ov_model) {
return m_ov_model;
}
translate_graph(m_input_model, m_ov_model);
return m_ov_model;
}
std::shared_ptr<TelemetryDataType> TranslateSession::get_telemetry_data() const {
return m_telemetry_data;
}
void TranslateSession::inject_body_model(std::shared_ptr<ov::Model> body_model,
const std::string& operation_type,
const ov::OutputVector& ov_inputs,
ov::OutputVector& ov_outputs) {
ov_outputs.clear();
auto body_parameters = body_model->get_parameters();
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
"[TensorFlow Error] Internal error or incorrect input models: number of "
"inputs and arguments to the function " +
operation_type + " do not match.");
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
}
for (const auto& result_node : body_model->get_results()) {
ov_outputs.push_back(result_node->input_value(0));
}
}
void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model,
std::shared_ptr<ov::Model>& ov_model) {
OpMap ng_op_map;
ov::ParameterVector params;
ov::ResultVector results;
const auto& model_tf = std::dynamic_pointer_cast<InputModel>(input_model);
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV Model");
const auto& operation_places = model_tf->get_op_places();
const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs();
const auto& model_frozen_inputs = model_tf->get_tensor_values();
// fill ng_op_map with Constant outputs for frozen inputs
for (const auto& frozen_input : model_frozen_inputs) {
const auto& frozen_input_name = frozen_input.first;
const auto& frozen_input_value = frozen_input.second;
FRONT_END_GENERAL_CHECK(ng_op_map.count(frozen_input_name) == 0,
"Input with frozen value has been already met: " + frozen_input_name);
ng_op_map[frozen_input_name] = {frozen_input_value};
}
// create parameter nodes for all tensor places corresponding to inputs
for (const auto& input_place : model_inputs) {
FRONT_END_GENERAL_CHECK(input_place->get_names().size() == 1, "Input place must have one name.");
auto input_name = input_place->get_names()[0];
if (ng_op_map.count(input_name)) {
// probably this input is frozen
continue;
}
const auto& input_tensor_place = std::dynamic_pointer_cast<TensorPlace>(input_place);
auto input_shape = input_tensor_place->get_partial_shape();
auto input_type = input_tensor_place->get_element_type();
// in case of cutting graph, types of custom inputs can be undefined,
// according to MO help, fp32 is used by default in such cases
if (input_type == element::undefined) {
input_type = element::f32;
}
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
set_node_name(input_name, param);
params.push_back(param);
ng_op_map[input_name] = {param};
}
// create the OV ops from TensorFlow ops
for (const auto& operation_place : operation_places) {
auto operation_decoder = operation_place->get_decoder();
auto operation_name = operation_place->get_names()[0];
// output for parameter nodes has been already generated
if (ng_op_map.count(operation_name)) {
continue;
}
// prepare a list of OV node inputs for each node
ov::OutputVector ov_inputs;
size_t operation_input_size = operation_decoder->get_input_size();
if (operation_decoder->get_op_type() == "NextIteration") {
// we expect no inputs for NextIteration because we break-up the cycle in InputModel
operation_input_size = 0;
}
for (size_t input_port_idx = 0; input_port_idx < operation_input_size; ++input_port_idx) {
// TODO: Implement more general approach. Skipping Constants that have input edges
if (operation_decoder->get_op_type() == "Const") {
break;
}
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(input_port_idx) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
'\n');
}
// skip conditional edges that must be resolved before operation translation
// now we can meet them because we still work with TensorFlow protobuf
if (is_conditional_edge(producer_name)) {
continue;
}
// TODO: re-implement the logic below once Place graph structure is implemented
// Using Place graph structure (OpPlace, In/OutPortPlace places and their connections) can give
// names of ports and operations that can be used for further check about existence in ng_op_map
// check if output vector for places have been already defined and the order of this check is important
// it moves from places corresponding to input port of the current operation node to output port of original
// producers
if (ng_op_map.count(std::to_string(input_port_idx) + ":" + operation_name)) {
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
const auto& input_outputs_vector =
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(0));
} else if (ng_op_map.count(producer_name)) {
const auto& input_outputs_vector = ng_op_map.at(producer_name);
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
"Input created with pruning must have one output");
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
} else {
FRONT_END_GENERAL_CHECK(false,
"No input is found for node \"" + operation_name + "\" by port " +
std::to_string(producer_port_idx));
}
}
// generate OV node output vector for the current operation node
ov::OutputVector ov_outputs;
bool is_converted = false;
auto operation_type = operation_decoder->get_op_type();
try {
if (m_translator_map->count(operation_type)) {
auto translator = m_translator_map->at(operation_decoder->get_op_type());
NodeContext node_context(operation_decoder, ov_inputs, this);
ov_outputs = translator(node_context);
is_converted = true;
} else if (auto body_ov_model = get_body_ov_model(operation_type)) {
inject_body_model(body_ov_model, operation_type, ov_inputs, ov_outputs);
// set output tensor names
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
ov_outputs[idx].get_tensor().set_names({operation_name + ":" + std::to_string(idx)});
}
is_converted = true;
}
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
} catch (...) {
if (m_fail_fast) {
// in case of decode, unsupported operation will be converted to FrameworkNode
if (m_telemetry && !is_converted) {
// send event about which operation is not supported for conversion
m_telemetry_data->push_back(
std::make_pair<std::string, std::string>("error_cause", "tf_" + operation_type));
}
// re-throw any exception
throw;
} else {
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
ov_inputs,
operation_place->get_output_ports().size());
set_node_name(operation_name, ng_node);
ov_outputs = ng_node->outputs();
}
}
// register OV node outputs in the map for new operation node
for (const auto& output : ov_outputs) {
if (auto result = std::dynamic_pointer_cast<ov::opset10::Result>(output.get_node_shared_ptr())) {
// do not add RetVal type operation to ng_op_map
results.push_back(result);
} else {
auto param = std::dynamic_pointer_cast<ov::opset8::Parameter>(output.get_node_shared_ptr());
// avoid duplicating Parameter nodes if they are already in the Parameters vector
if (param && operation_decoder->get_op_type() != "Identity" &&
std::find(params.begin(), params.end(), param) == params.end()) {
params.push_back(param);
}
ng_op_map[operation_name].push_back(output);
}
}
}
// create Result nodes for all model outputs
if (results.empty()) {
for (const auto& model_output : model_outputs) {
auto model_output_tensor_place = std::dynamic_pointer_cast<TensorPlace>(model_output);
auto model_output_name = model_output_tensor_place->get_names()[0];
std::string operation_name;
std::string port_type;
size_t port_index;
ov::frontend::tensorflow::extract_operation_name_and_port(model_output_name,
operation_name,
port_index,
port_type);
if (port_type == "none") {
for (const auto& node_output : ng_op_map[operation_name]) {
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
} else if (port_type == "out") {
const auto& node_outputs = ng_op_map[operation_name];
FRONT_END_GENERAL_CHECK(node_outputs.size() > port_index,
"Output port with index " + std::to_string(port_index) + " of " +
operation_name + "node specified as custom output does not exist");
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[port_index]);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
} else if (port_type == "in") {
// TODO: avoid this traversing by having a map for OpPlace objects, for example
std::shared_ptr<OpPlace> operation_place = nullptr;
for (const auto& op_place : operation_places) {
FRONT_END_GENERAL_CHECK(!op_place->get_names().empty(), "No names for OpPlace found.");
if (op_place->get_names()[0] == operation_name) {
operation_place = op_place;
}
}
FRONT_END_GENERAL_CHECK(operation_place, "There is no operation place with a name: " + operation_name);
auto operation_decoder = operation_place->get_decoder();
// get to know a producer node and by which its output port data is generated
std::string producer_name;
size_t producer_port_idx;
try {
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx);
} catch (const std::exception&) {
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(port_index) +
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
producer_name +
"', expected input port index: " + std::to_string(producer_port_idx) + '\n');
}
// add Result node for this producer output port
const auto& node_outputs = ng_op_map[producer_name];
FRONT_END_GENERAL_CHECK(node_outputs.size() > producer_port_idx,
"Output port with index " + std::to_string(producer_port_idx) + " of " +
producer_name + "node specified as custom output does not exist");
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx]);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
}
}
// TODO: it may be redundant step since models_output is filled in InputModel constructor
// find all terminal nodes in OV graph to complete list of results
if (results.empty()) {
for (const auto& node_output_vector : ng_op_map) {
for (size_t output_ind = 0; output_ind < node_output_vector.second.size(); ++output_ind) {
auto output = node_output_vector.second[output_ind];
if (output.get_target_inputs().empty() &&
!std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
auto model_output_name =
output.get_node_shared_ptr()->get_friendly_name() + ":" + std::to_string(output_ind);
auto result_node = std::make_shared<ov::opset8::Result>(output);
result_node->set_friendly_name(model_output_name);
results.push_back(result_node);
}
}
}
}
// TODO: reorder results and params according to indices given in RT info (if any)
// create the OV Model
ov_model = std::make_shared<ov::Model>(results, params, m_model_name);
}
std::shared_ptr<ov::Model> TranslateSession::get_body_ov_model(const std::string& body_graph_name) {
std::shared_ptr<ov::Model> body_model = nullptr;
auto input_model = std::dynamic_pointer_cast<InputModel>(m_input_model);
if (m_cached_body_models->count(body_graph_name)) {
// check if such body graph has been converted before
// re-use it from the cache for further injection
// create new instance of the required body model
// since it will be modified by injection
auto cached_body_model = m_cached_body_models->at(body_graph_name);
body_model = cached_body_model->clone();
} else if (auto body_input_model = input_model->get_body_input_model(body_graph_name)) {
// try to find a function by name in the model library
translate_graph(body_input_model, body_model);
// save new instance of body_model in the cache of body models
// before its injection into the parent graph
// before caching, erase tensor names from the body graph
// otherwise, it can lead tensor names conflicts
for (const auto& op : body_model->get_ordered_ops()) {
for (size_t ind = 0; ind < op->get_output_size(); ++ind) {
op->get_output_tensor(ind).set_names({});
}
}
auto cached_body_model = body_model->clone();
update_cached_body_models(body_graph_name, cached_body_model);
}
return body_model;
}

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/frontend/input_model.hpp"
#include "openvino/frontend/tensorflow/node_context.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
using TelemetryDataType = std::vector<std::pair<std::string, std::string>>;
/// For one call of convert and decode method of Frontend, it creates one TranslateSession object to save data for the
/// translation session: telemetry statistics, cache of convrted body graph models, operation translators (including
/// extensions) registered for this translation session.
class TranslateSession {
public:
TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
const std::shared_ptr<TranslatorDictionaryType>& translator_map,
const std::string& model_name,
bool fail_fast,
bool telemetry);
std::shared_ptr<ov::Model> get_converted_model();
std::shared_ptr<TelemetryDataType> get_telemetry_data() const;
void translate_graph(const ov::frontend::InputModel::Ptr& input_model, std::shared_ptr<ov::Model>& ov_model);
void inject_body_model(std::shared_ptr<ov::Model> body_model,
const std::string& operation_type,
const ov::OutputVector& ov_inputs,
ov::OutputVector& ov_outputs);
std::shared_ptr<ov::Model> get_body_ov_model(const std::string& body_graph_name);
private:
const ov::frontend::InputModel::Ptr m_input_model;
const bool m_fail_fast;
const bool m_telemetry;
const std::shared_ptr<TranslatorDictionaryType> m_translator_map;
const std::string m_model_name;
std::shared_ptr<CachedBodyModelsType> m_cached_body_models;
std::shared_ptr<TelemetryDataType> m_telemetry_data;
std::shared_ptr<ov::Model> m_ov_model;
void update_cached_body_models(const std::string& operation_type,
const std::shared_ptr<const ov::Model>& cached_body_model) {
m_cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
}
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -114,3 +114,92 @@ TEST_F(TransformationTestsF, ModelWithSwishF32BodyGraph) {
model_ref = make_shared<Model>(OutputVector{sigmoid2}, ParameterVector{x});
}
}
TEST_F(TransformationTestsF, PartitionedCall) {
{
model = convert_model("partitioned_call/partitioned_call.pb");
// need to call shape inference since body graphs can be injected with undefined shapes
model->validate_nodes_and_infer_types();
}
{
auto x = make_shared<Parameter>(i32, Shape{2});
auto y = make_shared<Parameter>(i32, Shape{1});
auto sub = make_shared<Subtract>(x, y);
auto const_pow = make_shared<Constant>(i32, Shape{}, 2);
auto pow = make_shared<Power>(sub, const_pow);
model_ref = make_shared<Model>(OutputVector{pow}, ParameterVector{x, y});
}
}
TEST_F(TransformationTestsF, ModelWithIf) {
{ model = convert_model("model_with_if/model_with_if.pb"); }
{
// create then branch body graph
auto then_x = make_shared<Parameter>(i32, Shape{2});
auto then_y = make_shared<Parameter>(i32, Shape{1});
auto add = make_shared<Add>(then_x, then_y);
auto then_result = make_shared<Result>(add);
auto then_model = make_shared<Model>(OutputVector{then_result}, ParameterVector{then_x, then_y});
// create else branch body graph
auto else_x = make_shared<Parameter>(i32, Shape{2});
auto else_y = make_shared<Parameter>(i32, Shape{1});
auto sub = make_shared<Subtract>(else_x, else_y);
auto else_result = make_shared<Result>(sub);
auto else_model = make_shared<Model>(OutputVector{else_result}, ParameterVector{else_x, else_y});
// create the main graph
auto x = make_shared<Parameter>(i32, Shape{2});
auto y = make_shared<Parameter>(i32, Shape{1});
auto cond_const = make_shared<Constant>(i32, Shape{}, 10);
auto cond = make_shared<Greater>(x, cond_const);
auto if_op = make_shared<If>(cond);
if_op->set_then_body(then_model);
if_op->set_else_body(else_model);
if_op->set_input(x, then_x, else_x);
if_op->set_input(y, then_y, else_y);
if_op->set_output(then_result, else_result);
model_ref = make_shared<Model>(OutputVector{if_op}, ParameterVector{x, y});
}
}
TEST_F(TransformationTestsF, InjectedBodyAndIf) {
{
model = convert_model("injected_body_and_if/injected_body_and_if.pb");
// need to call shape inference since body graphs can be injected with undefined shapes
model->validate_nodes_and_infer_types();
}
{
// create then branch body graph
auto then_x = make_shared<Parameter>(i32, Shape{2});
auto then_y = make_shared<Parameter>(i32, Shape{1});
auto add = make_shared<Add>(then_x, then_y);
auto then_result = make_shared<Result>(add);
auto then_model = make_shared<Model>(OutputVector{then_result}, ParameterVector{then_x, then_y});
// create else branch body graph
auto else_x = make_shared<Parameter>(i32, Shape{2});
auto else_y = make_shared<Parameter>(i32, Shape{1});
auto sub = make_shared<Subtract>(else_x, else_y);
auto pow_const = make_shared<Constant>(i32, Shape{}, 2);
auto pow = make_shared<Power>(sub, pow_const);
auto else_result = make_shared<Result>(pow);
auto else_model = make_shared<Model>(OutputVector{else_result}, ParameterVector{else_x, else_y});
// create the main graph
auto x = make_shared<Parameter>(i32, Shape{2});
auto y = make_shared<Parameter>(i32, Shape{1});
auto cond_const = make_shared<Constant>(i32, Shape{}, 10);
auto cond = make_shared<Greater>(x, cond_const);
auto if_op = make_shared<If>(cond);
if_op->set_then_body(then_model);
if_op->set_else_body(else_model);
if_op->set_input(x, then_x, else_x);
if_op->set_input(y, then_y, else_y);
if_op->set_output(then_result, else_result);
model_ref = make_shared<Model>(OutputVector{if_op}, ParameterVector{x, y});
}
}

View File

@ -0,0 +1,368 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node {
name: "Greater"
op: "Greater"
input: "x"
input: "Const"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "If"
op: "If"
input: "Greater"
input: "x"
input: "y"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_INT32
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "else_branch"
value {
func {
name: "else_branch_func_AthmcAbLnco"
}
}
}
attr {
key: "output_shapes"
value {
list {
}
}
}
attr {
key: "then_branch"
value {
func {
name: "then_branch_func_mdn8Hcdd6RQ"
}
}
}
}
node {
name: "init"
op: "NoOp"
}
library {
function {
signature {
name: "then_branch_func_mdn8Hcdd6RQ"
input_arg {
name: "x"
type: DT_INT32
}
input_arg {
name: "y"
type: DT_INT32
}
output_arg {
name: "add"
type: DT_INT32
}
}
node_def {
name: "add_0"
op: "AddV2"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "add"
}
}
ret {
key: "add"
value: "add_0:z:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
}
function {
signature {
name: "else_branch_func_AthmcAbLnco"
input_arg {
name: "x"
type: DT_INT32
}
input_arg {
name: "y"
type: DT_INT32
}
output_arg {
name: "aux_func_gmpkxbsu4wi"
type: DT_INT32
}
}
node_def {
name: "sub"
op: "Sub"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "sub"
}
}
node_def {
name: "aux_func_GmpkxbsU4WI"
op: "aux_func_GmpkxbsU4WI"
input: "sub:z:0"
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
experimental_debug_info {
original_node_names: "aux_func_GmpkxbsU4WI"
}
}
ret {
key: "aux_func_gmpkxbsu4wi"
value: "aux_func_GmpkxbsU4WI:pow:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
}
function {
signature {
name: "aux_func_GmpkxbsU4WI"
input_arg {
name: "x"
type: DT_INT32
}
output_arg {
name: "pow"
type: DT_INT32
}
}
node_def {
name: "pow/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
experimental_debug_info {
original_node_names: "pow/y"
}
}
node_def {
name: "pow_0"
op: "Pow"
input: "x"
input: "pow/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "pow"
}
}
ret {
key: "pow"
value: "pow_0:z:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
}
}
versions {
producer: 808
min_consumer: 12
}

View File

@ -0,0 +1,35 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
from tensorflow.python.framework import function
tf.reset_default_graph()
@function.Defun(tf.int32)
def aux_func(x):
return x ** 2
@function.Defun(tf.int32, tf.int32)
def then_branch_func(x, y):
return x + y
@function.Defun(tf.int32, tf.int32)
def else_branch_func(x, y):
return aux_func(x - y)
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [2], 'x')
y = tf.placeholder(tf.int32, [1], 'y')
const_cond = tf.constant(10, dtype=tf.int32)
cond = tf.raw_ops.Greater(x=x, y=const_cond)
if_op = tf.raw_ops.If(cond=cond, input=[x, y], Tout=[tf.int32], then_branch=then_branch_func,
else_branch=else_branch_func)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'injected_body_and_if.pbtxt', as_text=True)

View File

@ -0,0 +1,278 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 10
}
}
}
}
node {
name: "Greater"
op: "Greater"
input: "x"
input: "Const"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "If"
op: "If"
input: "Greater"
input: "x"
input: "y"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_INT32
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "else_branch"
value {
func {
name: "else_branch_func_Fw4jHLGozIk"
}
}
}
attr {
key: "output_shapes"
value {
list {
}
}
}
attr {
key: "then_branch"
value {
func {
name: "then_branch_func_mdn8Hcdd6RQ"
}
}
}
}
node {
name: "init"
op: "NoOp"
}
library {
function {
signature {
name: "then_branch_func_mdn8Hcdd6RQ"
input_arg {
name: "x"
type: DT_INT32
}
input_arg {
name: "y"
type: DT_INT32
}
output_arg {
name: "add"
type: DT_INT32
}
}
node_def {
name: "add_0"
op: "AddV2"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "add"
}
}
ret {
key: "add"
value: "add_0:z:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
}
function {
signature {
name: "else_branch_func_Fw4jHLGozIk"
input_arg {
name: "x"
type: DT_INT32
}
input_arg {
name: "y"
type: DT_INT32
}
output_arg {
name: "sub"
type: DT_INT32
}
}
node_def {
name: "sub_0"
op: "Sub"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "sub"
}
}
ret {
key: "sub"
value: "sub_0:z:0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
}
}
}
}
versions {
producer: 808
min_consumer: 12
}

View File

@ -0,0 +1,30 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
from tensorflow.python.framework import function
tf.reset_default_graph()
@function.Defun(tf.int32, tf.int32)
def then_branch_func(x, y):
return x + y
@function.Defun(tf.int32, tf.int32)
def else_branch_func(x, y):
return x - y
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [2], 'x')
y = tf.placeholder(tf.int32, [1], 'y')
const_cond = tf.constant(10, dtype=tf.int32)
cond = tf.raw_ops.Greater(x=x, y=const_cond)
if_op = tf.raw_ops.If(cond=cond, input=[x, y], Tout=[tf.int32], then_branch=then_branch_func,
else_branch=else_branch_func)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_with_if.pbtxt', as_text=True)

View File

@ -0,0 +1,240 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "y"
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "sub"
op: "Sub"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "PartitionedCall"
op: "PartitionedCall"
input: "sub"
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "_collective_manager_ids"
value {
list {
}
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "config"
value {
s: ""
}
}
attr {
key: "config_proto"
value {
s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000"
}
}
attr {
key: "executor_type"
value {
s: ""
}
}
attr {
key: "f"
value {
func {
name: "__inference_second_func_14"
}
}
}
}
node {
name: "Identity"
op: "Identity"
input: "PartitionedCall"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
library {
function {
signature {
name: "__inference_second_func_14"
input_arg {
name: "x"
type: DT_INT32
}
output_arg {
name: "identity"
type: DT_INT32
}
}
node_def {
name: "pow/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
experimental_debug_info {
original_node_names: "pow/y"
}
}
node_def {
name: "pow"
op: "Pow"
input: "x"
input: "pow/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "pow"
}
}
node_def {
name: "Identity"
op: "Identity"
input: "pow:z:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "Identity"
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
attr {
key: "_construction_context"
value {
s: "kEagerRuntime"
}
}
arg_attr {
value {
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 2
}
}
}
}
}
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
}
}
}
}
versions {
producer: 808
min_consumer: 12
}

View File

@ -0,0 +1,18 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
@tf.function
def second_func(x):
return x ** 2
@tf.function
def first_func(x, y):
return second_func(x - y)
graph_def = first_func.get_concrete_function(tf.constant([6, 3]), tf.constant([7])).graph.as_graph_def()
tf.io.write_graph(graph_def, '.', 'partitioned_call.pbtxt', as_text=True)