[TF FE] Support If and PartitionedCall operations (#14910)
* [TF FE] Support If and PartitionedCall operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build * Fix frontend wrapper for tests * Erase tensor names in body graph before caching * Apply code-review feedback: recover m_op_translators in Frontend * Rename test models * Rename unit-tests * Correct scripts for test model generation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
efb602e13b
commit
f13e7e1352
@ -21,8 +21,6 @@
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
|
||||
|
||||
class TENSORFLOW_API FrontEnd : public ov::frontend::FrontEnd {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<FrontEnd>;
|
||||
@ -68,13 +66,6 @@ protected:
|
||||
|
||||
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
|
||||
|
||||
void translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool no_conversion,
|
||||
std::shared_ptr<ov::Model>& ov_model,
|
||||
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const;
|
||||
|
||||
TelemetryExtension::Ptr m_telemetry;
|
||||
std::vector<DecoderTransformationExtension::Ptr> m_transformation_extensions;
|
||||
std::vector<ConversionExtensionBase::Ptr> m_conversion_extensions;
|
||||
|
@ -12,15 +12,19 @@
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
class TranslateSession;
|
||||
|
||||
/// Keep necessary data for a single node in the original FW graph to facilitate
|
||||
/// conversion process in the rules code.
|
||||
class NodeContext : public ov::frontend::NodeContext {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<NodeContext>;
|
||||
NodeContext(const std::shared_ptr<DecoderBase>& decoder, const OutputVector& inputs)
|
||||
NodeContext(const std::shared_ptr<DecoderBase>& decoder,
|
||||
const OutputVector& inputs,
|
||||
TranslateSession* translate_session = nullptr)
|
||||
: ov::frontend::NodeContext(decoder->get_op_type()),
|
||||
m_decoder(decoder),
|
||||
m_translate_session(translate_session),
|
||||
m_inputs(inputs) {}
|
||||
|
||||
/// Detects if there is at least one input attached with a given name
|
||||
@ -51,10 +55,16 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
/// \brief Get a pointer to TranslateSession object
|
||||
TranslateSession* get_translate_session() const {
|
||||
return m_translate_session;
|
||||
}
|
||||
|
||||
private:
|
||||
ov::Any apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const override;
|
||||
|
||||
std::shared_ptr<DecoderBase> m_decoder;
|
||||
TranslateSession* m_translate_session;
|
||||
const OutputVector& m_inputs;
|
||||
};
|
||||
|
||||
|
@ -263,10 +263,11 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
|
||||
name,
|
||||
"' attribute is not supported.");
|
||||
case ::tensorflow::AttrValue::ValueCase::kFunc:
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"Conversion from Tensorflow to OpenVINO data type failed: Function type for '",
|
||||
name,
|
||||
"' attribute is not supported.");
|
||||
// attrs[0].func() returns NameAttrList object from which
|
||||
// we retrieve the function name
|
||||
// Further, InputModel object is created for FunctionDef with this name
|
||||
// and is converted to ov::Model object.
|
||||
return attrs[0].func().name();
|
||||
default:
|
||||
FRONT_END_GENERAL_CHECK(false, "Conversion from Tensorflow to OpenVINO data type failed.");
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "so_extension.hpp"
|
||||
#include "tf_framework_node.hpp"
|
||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||
#include "translate_session.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@ -45,314 +46,10 @@ void translate_framework_node(const std::shared_ptr<FrameworkNode>& node,
|
||||
old_output->replace(*new_output);
|
||||
}
|
||||
}
|
||||
|
||||
void inject_body_model(std::shared_ptr<ov::Model> body_model,
|
||||
const std::string& operation_type,
|
||||
const ov::OutputVector& ov_inputs,
|
||||
ov::OutputVector& ov_outputs) {
|
||||
ov_outputs.clear();
|
||||
auto body_parameters = body_model->get_parameters();
|
||||
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
|
||||
"[TensorFlow Error] Internal error or incorrect input models: number of "
|
||||
"inputs and arguments to the function " +
|
||||
operation_type + " do not match.");
|
||||
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
|
||||
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
|
||||
}
|
||||
for (const auto& result_node : body_model->get_results()) {
|
||||
ov_outputs.push_back(result_node->input_value(0));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FrontEnd::FrontEnd() : m_op_translators(tensorflow::op::get_supported_ops()) {}
|
||||
|
||||
void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool no_conversion,
|
||||
std::shared_ptr<ov::Model>& ov_model,
|
||||
const std::shared_ptr<CachedBodyModelsType>& cached_body_models) const {
|
||||
// a map from operation names to generated OV Output<TFNodeDecoder>
|
||||
tensorflow::OpMap ng_op_map;
|
||||
|
||||
ov::ParameterVector params;
|
||||
ov::ResultVector results;
|
||||
const auto& model_tf = std::dynamic_pointer_cast<InputModel>(model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV Model");
|
||||
const auto& operation_places = model_tf->get_op_places();
|
||||
const auto& model_inputs = model_tf->get_inputs();
|
||||
const auto& model_outputs = model_tf->get_outputs();
|
||||
const auto& model_frozen_inputs = model_tf->get_tensor_values();
|
||||
TranslatorDictionaryType translate_map;
|
||||
|
||||
const auto& TRANSLATE_OP_MAP = m_op_translators;
|
||||
if (no_conversion) {
|
||||
const std::set<std::string> required_types{"Placeholder", "NoOp"};
|
||||
for (const auto& name : required_types) {
|
||||
translate_map.emplace(name, TRANSLATE_OP_MAP.at(name));
|
||||
}
|
||||
} else {
|
||||
translate_map.insert(TRANSLATE_OP_MAP.begin(), TRANSLATE_OP_MAP.end());
|
||||
}
|
||||
|
||||
// fill ng_op_map with Constant outputs for frozen inputs
|
||||
for (const auto& frozen_input : model_frozen_inputs) {
|
||||
const auto& frozen_input_name = frozen_input.first;
|
||||
const auto& frozen_input_value = frozen_input.second;
|
||||
FRONT_END_GENERAL_CHECK(ng_op_map.count(frozen_input_name) == 0,
|
||||
"Input with frozen value has been already met: " + frozen_input_name);
|
||||
ng_op_map[frozen_input_name] = {frozen_input_value};
|
||||
}
|
||||
// create parameter nodes for all tensor places corresponding to inputs
|
||||
for (const auto& input_place : model_inputs) {
|
||||
FRONT_END_GENERAL_CHECK(input_place->get_names().size() == 1, "Input place must have one name.");
|
||||
auto input_name = input_place->get_names()[0];
|
||||
if (ng_op_map.count(input_name)) {
|
||||
// probably this input is frozen
|
||||
continue;
|
||||
}
|
||||
const auto& input_tensor_place = std::dynamic_pointer_cast<TensorPlace>(input_place);
|
||||
auto input_shape = input_tensor_place->get_partial_shape();
|
||||
auto input_type = input_tensor_place->get_element_type();
|
||||
|
||||
// in case of cutting graph, types of custom inputs can be undefined,
|
||||
// according to MO help, fp32 is used by default in such cases
|
||||
if (input_type == element::undefined) {
|
||||
input_type = element::f32;
|
||||
}
|
||||
|
||||
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
|
||||
set_node_name(input_name, param);
|
||||
params.push_back(param);
|
||||
ng_op_map[input_name] = {param};
|
||||
}
|
||||
|
||||
// create the OV ops from TensorFlow ops
|
||||
for (const auto& operation_place : operation_places) {
|
||||
auto operation_decoder = operation_place->get_decoder();
|
||||
auto operation_name = operation_place->get_names()[0];
|
||||
// output for parameter nodes has been already generated
|
||||
if (ng_op_map.count(operation_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// prepare a list of OV node inputs for each node
|
||||
ov::OutputVector ov_inputs;
|
||||
size_t operation_input_size = operation_decoder->get_input_size();
|
||||
|
||||
if (operation_decoder->get_op_type() == "NextIteration") {
|
||||
// we expect no inputs for NextIteration because we break-up the cycle in InputModel
|
||||
operation_input_size = 0;
|
||||
}
|
||||
for (size_t input_port_idx = 0; input_port_idx < operation_input_size; ++input_port_idx) {
|
||||
// TODO: Implement more general approach. Skipping Constants that have input edges
|
||||
if (operation_decoder->get_op_type() == "Const") {
|
||||
break;
|
||||
}
|
||||
std::string producer_name;
|
||||
size_t producer_port_idx;
|
||||
try {
|
||||
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx);
|
||||
} catch (const std::exception&) {
|
||||
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(input_port_idx) +
|
||||
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
|
||||
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
|
||||
'\n');
|
||||
}
|
||||
|
||||
// skip conditional edges that must be resolved before operation translation
|
||||
// now we can meet them because we still work with TensorFlow protobuf
|
||||
if (is_conditional_edge(producer_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: re-implement the logic below once Place graph structure is implemented
|
||||
// Using Place graph structure (OpPlace, In/OutPortPlace places and their connections) can give
|
||||
// names of ports and operations that can be used for further check about existence in ng_op_map
|
||||
|
||||
// check if output vector for places have been already defined and the order of this check is important
|
||||
// it moves from places corresponding to input port of the current operation node to output port of original
|
||||
// producers
|
||||
if (ng_op_map.count(std::to_string(input_port_idx) + ":" + operation_name)) {
|
||||
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
|
||||
const auto& input_outputs_vector =
|
||||
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name)) {
|
||||
const auto& input_outputs_vector = ng_op_map.at(producer_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"No input is found for node \"" + operation_name + "\" by port " +
|
||||
std::to_string(producer_port_idx));
|
||||
}
|
||||
}
|
||||
|
||||
// generate OV node output vector for the current operation node
|
||||
ov::OutputVector ov_outputs;
|
||||
bool is_converted = false;
|
||||
auto operation_type = operation_decoder->get_op_type();
|
||||
try {
|
||||
if (translate_map.count(operation_type)) {
|
||||
auto translator = translate_map[operation_decoder->get_op_type()];
|
||||
NodeContext node_context(operation_decoder, ov_inputs);
|
||||
ov_outputs = translator(node_context);
|
||||
is_converted = true;
|
||||
} else if (cached_body_models->count(operation_type)) {
|
||||
// check if such body graph has been converted before
|
||||
// re-use it from the cache for further injection
|
||||
|
||||
// create new instance of the required body model
|
||||
// since it will be modified by injection
|
||||
auto cached_body_model = cached_body_models->at(operation_type);
|
||||
auto body_model = cached_body_model->clone();
|
||||
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
|
||||
is_converted = true;
|
||||
} else if (auto body_input_model = model_tf->get_body_input_model(operation_type)) {
|
||||
// try to find a function by name in the model library
|
||||
std::shared_ptr<ov::Model> body_model;
|
||||
translate_graph(body_input_model,
|
||||
operation_decoder->get_op_type(),
|
||||
fail_fast,
|
||||
no_conversion,
|
||||
body_model,
|
||||
cached_body_models);
|
||||
// save new instance of body_model in the cache of body models
|
||||
// before its injection into the parent graph
|
||||
auto cached_body_model = body_model->clone();
|
||||
cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
|
||||
|
||||
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
|
||||
is_converted = true;
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
|
||||
} catch (...) {
|
||||
if (fail_fast) {
|
||||
// in case of decode, unsupported operation will be converted to FrameworkNode
|
||||
if (m_telemetry && translate_map.count(operation_decoder->get_op_type()) == 0) {
|
||||
// send event about which operation is not supported for conversion
|
||||
m_telemetry->send_event("error_cause", "tf_" + operation_decoder->get_op_type());
|
||||
}
|
||||
// re-throw any exception
|
||||
throw;
|
||||
} else {
|
||||
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
|
||||
ov_inputs,
|
||||
operation_place->get_output_ports().size());
|
||||
set_node_name(operation_name, ng_node);
|
||||
ov_outputs = ng_node->outputs();
|
||||
}
|
||||
}
|
||||
|
||||
// register OV node outputs in the map for new operation node
|
||||
for (const auto& output : ov_outputs) {
|
||||
if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
|
||||
// do not add RetVal type operation to ng_op_map
|
||||
results.push_back(result);
|
||||
} else {
|
||||
auto param = std::dynamic_pointer_cast<ov::opset8::Parameter>(output.get_node_shared_ptr());
|
||||
// avoid duplicating Parameter nodes if they are already in the Parameters vector
|
||||
if (param && operation_decoder->get_op_type() != "Identity" &&
|
||||
std::find(params.begin(), params.end(), param) == params.end()) {
|
||||
params.push_back(param);
|
||||
}
|
||||
ng_op_map[operation_name].push_back(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create Result nodes for all model outputs
|
||||
for (const auto& model_output : model_outputs) {
|
||||
auto model_output_tensor_place = std::dynamic_pointer_cast<TensorPlace>(model_output);
|
||||
auto model_output_name = model_output_tensor_place->get_names()[0];
|
||||
std::string operation_name;
|
||||
std::string port_type;
|
||||
size_t port_index;
|
||||
ov::frontend::tensorflow::extract_operation_name_and_port(model_output_name,
|
||||
operation_name,
|
||||
port_index,
|
||||
port_type);
|
||||
|
||||
if (port_type == "none") {
|
||||
for (const auto& node_output : ng_op_map[operation_name]) {
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
} else if (port_type == "out") {
|
||||
const auto& node_outputs = ng_op_map[operation_name];
|
||||
FRONT_END_GENERAL_CHECK(node_outputs.size() > port_index,
|
||||
"Output port with index " + std::to_string(port_index) + " of " + operation_name +
|
||||
"node specified as custom output does not exist");
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[port_index]);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
} else if (port_type == "in") {
|
||||
// TODO: avoid this traversing by having a map for OpPlace objects, for example
|
||||
std::shared_ptr<OpPlace> operation_place = nullptr;
|
||||
for (const auto& op_place : operation_places) {
|
||||
FRONT_END_GENERAL_CHECK(!op_place->get_names().empty(), "No names for OpPlace found.");
|
||||
if (op_place->get_names()[0] == operation_name) {
|
||||
operation_place = op_place;
|
||||
}
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(operation_place, "There is no operation place with a name: " + operation_name);
|
||||
auto operation_decoder = operation_place->get_decoder();
|
||||
|
||||
// get to know a producer node and by which its output port data is generated
|
||||
std::string producer_name;
|
||||
size_t producer_port_idx;
|
||||
try {
|
||||
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx);
|
||||
} catch (const std::exception&) {
|
||||
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(port_index) +
|
||||
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
|
||||
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
|
||||
'\n');
|
||||
}
|
||||
|
||||
// add Result node for this producer output port
|
||||
const auto& node_outputs = ng_op_map[producer_name];
|
||||
FRONT_END_GENERAL_CHECK(node_outputs.size() > producer_port_idx,
|
||||
"Output port with index " + std::to_string(producer_port_idx) + " of " +
|
||||
producer_name + "node specified as custom output does not exist");
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx]);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
}
|
||||
// find all terminal nodes in OV graph to complete list of results
|
||||
if (results.empty()) {
|
||||
for (const auto& node_output_vector : ng_op_map) {
|
||||
for (size_t output_ind = 0; output_ind < node_output_vector.second.size(); ++output_ind) {
|
||||
auto output = node_output_vector.second[output_ind];
|
||||
if (output.get_target_inputs().empty() &&
|
||||
!std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
|
||||
auto model_output_name =
|
||||
output.get_node_shared_ptr()->get_friendly_name() + ":" + std::to_string(output_ind);
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(output);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: reorder results and params according to indices given in RT info (if any)
|
||||
|
||||
// create the OV Model
|
||||
ov_model = std::make_shared<ov::Model>(results, params, model_name);
|
||||
}
|
||||
|
||||
/// \brief Check if FrontEndTensorflow can recognize model from given parts
|
||||
bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
|
||||
// TODO: Support other TensorFlow formats: SavedModel, .meta, checkpoint, pbtxt
|
||||
@ -430,9 +127,25 @@ std::shared_ptr<ov::Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr
|
||||
return function;
|
||||
}
|
||||
|
||||
// create a shared pointer to the cloned dictionary of translators
|
||||
auto translator_map = std::make_shared<TranslatorDictionaryType>(m_op_translators);
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", true, false, f, cached_body_models);
|
||||
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", true, m_telemetry != nullptr);
|
||||
try {
|
||||
f = translate_session.get_converted_model();
|
||||
} catch (const std::exception&) {
|
||||
if (m_telemetry) {
|
||||
auto telemetry_data = translate_session.get_telemetry_data();
|
||||
if (telemetry_data) {
|
||||
// send event about which operation is not supported for conversion
|
||||
for (const auto& telemetry_item : *telemetry_data.get()) {
|
||||
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
throw;
|
||||
}
|
||||
normalize(f);
|
||||
|
||||
for (const auto& node : f->get_ordered_ops()) {
|
||||
@ -464,18 +177,55 @@ std::shared_ptr<ov::Model> FrontEnd::convert_partially(const ov::frontend::Input
|
||||
return function;
|
||||
}
|
||||
|
||||
// create a shared pointer to the cloned dictionary of translators
|
||||
auto translator_map = std::make_shared<TranslatorDictionaryType>(m_op_translators);
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, false, f, cached_body_models);
|
||||
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", false, m_telemetry != nullptr);
|
||||
try {
|
||||
f = translate_session.get_converted_model();
|
||||
} catch (const std::exception&) {
|
||||
if (m_telemetry) {
|
||||
auto telemetry_data = translate_session.get_telemetry_data();
|
||||
if (telemetry_data) {
|
||||
// send event about which operation is not supported for conversion
|
||||
for (const auto& telemetry_item : *telemetry_data.get()) {
|
||||
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
throw;
|
||||
}
|
||||
normalize(f);
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> FrontEnd::decode(const ov::frontend::InputModel::Ptr& model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModel>(model);
|
||||
auto translator_map = std::make_shared<TranslatorDictionaryType>();
|
||||
|
||||
const std::set<std::string> required_types{"Placeholder", "NoOp"};
|
||||
for (const auto& name : required_types) {
|
||||
translator_map->emplace(name, m_op_translators.at(name));
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
std::shared_ptr<CachedBodyModelsType> cached_body_models = std::make_shared<CachedBodyModelsType>();
|
||||
translate_graph(model_tf, "TensorFlow_Frontend_IR", false, true, f, cached_body_models);
|
||||
TranslateSession translate_session(model, translator_map, "TensorFlow_Frontend_IR", false, m_telemetry != nullptr);
|
||||
try {
|
||||
f = translate_session.get_converted_model();
|
||||
} catch (const std::exception&) {
|
||||
if (m_telemetry) {
|
||||
auto telemetry_data = translate_session.get_telemetry_data();
|
||||
if (telemetry_data) {
|
||||
// send event about which operation is not supported for conversion
|
||||
for (const auto& telemetry_item : *telemetry_data.get()) {
|
||||
m_telemetry->send_event(telemetry_item.first, telemetry_item.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "openvino/frontend/input_model.hpp"
|
||||
#include "openvino/frontend/tensorflow/graph_iterator.hpp"
|
||||
#include "place.hpp"
|
||||
#include "translate_session.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -17,7 +18,7 @@ class OpPlace;
|
||||
class TensorPlace;
|
||||
|
||||
class InputModel : public ov::frontend::InputModel {
|
||||
friend class FrontEnd;
|
||||
friend class TranslateSession;
|
||||
class InputModelTFImpl;
|
||||
std::shared_ptr<InputModelTFImpl> _impl;
|
||||
|
||||
|
@ -23,7 +23,7 @@ OutputVector translate_input_arg_op(const NodeContext& node) {
|
||||
|
||||
OutputVector translate_output_arg_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"output_arg"});
|
||||
auto result = std::make_shared<Result>();
|
||||
auto result = std::make_shared<Result>(node.get_input(0));
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
}
|
||||
|
87
src/frontends/tensorflow/src/op/if.cpp
Normal file
87
src/frontends/tensorflow/src/op/if.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_if_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"If", "StatelessIf"});
|
||||
auto node_name = node.get_name();
|
||||
auto translate_session = node.get_translate_session();
|
||||
FRONT_END_GENERAL_CHECK(translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
|
||||
// retrieve body ov::Model for then and else branches
|
||||
auto then_branch_type = node.get_attribute<std::string>("then_branch");
|
||||
auto else_branch_type = node.get_attribute<std::string>("else_branch");
|
||||
auto then_branch_body = translate_session->get_body_ov_model(then_branch_type);
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
then_branch_body,
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find branch body graph with name " +
|
||||
then_branch_type);
|
||||
auto else_branch_body = translate_session->get_body_ov_model(else_branch_type);
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
else_branch_body,
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find branch body graph with name " +
|
||||
else_branch_type);
|
||||
|
||||
// get condition output
|
||||
auto cond = node.get_input(0);
|
||||
size_t input_size_t = node.get_input_size() - 1;
|
||||
auto then_params = then_branch_body->get_parameters();
|
||||
auto else_params = else_branch_body->get_parameters();
|
||||
FRONT_END_GENERAL_CHECK(input_size_t == then_params.size(),
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model: number of inputs to If and "
|
||||
"number of inputs in then branch do not match.");
|
||||
FRONT_END_GENERAL_CHECK(input_size_t == else_params.size(),
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model: number of inputs to If and "
|
||||
"number of inputs in else branch do not match.");
|
||||
|
||||
// create If operation and set body graphs
|
||||
auto if_op = std::make_shared<If>(cond);
|
||||
if_op->set_then_body(then_branch_body);
|
||||
if_op->set_else_body(else_branch_body);
|
||||
|
||||
// set inputs
|
||||
int input_size = static_cast<int>(input_size_t);
|
||||
for (int ind = 0; ind < input_size; ++ind) {
|
||||
auto curr_input = node.get_input(ind + 1);
|
||||
auto then_param = then_params[ind];
|
||||
auto else_param = else_params[ind];
|
||||
if_op->set_input(curr_input, then_param, else_param);
|
||||
}
|
||||
|
||||
// set outputs
|
||||
auto then_results = then_branch_body->get_results();
|
||||
auto else_results = else_branch_body->get_results();
|
||||
FRONT_END_GENERAL_CHECK(then_results.size() == else_results.size(),
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model: number of result nodes in "
|
||||
"then and else branches do not match.");
|
||||
int output_size = static_cast<int>(then_results.size());
|
||||
for (int ind = 0; ind < output_size; ++ind) {
|
||||
if_op->set_output(then_results[ind], else_results[ind]);
|
||||
}
|
||||
|
||||
auto ov_outputs = if_op->outputs();
|
||||
|
||||
// set output tensor names
|
||||
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
|
||||
ov_outputs[idx].get_tensor().set_names({node_name + ":" + std::to_string(idx)});
|
||||
}
|
||||
|
||||
return ov_outputs;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
52
src/frontends/tensorflow/src/op/partitioned_call.cpp
Normal file
52
src/frontends/tensorflow/src/op/partitioned_call.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_partitioned_call_op(const NodeContext& node) {
|
||||
default_op_checks(node, 0, {"PartitionedCall", "StatefulPartitionedCall"});
|
||||
auto node_name = node.get_name();
|
||||
auto translate_session = node.get_translate_session();
|
||||
FRONT_END_GENERAL_CHECK(translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr.");
|
||||
auto operation_type = node.get_attribute<std::string>("f");
|
||||
|
||||
// prepare a vector of inputs
|
||||
OutputVector ov_inputs;
|
||||
int input_size = static_cast<int>(node.get_input_size());
|
||||
for (int ind = 0; ind < input_size; ++ind) {
|
||||
ov_inputs.push_back(node.get_input(ind));
|
||||
}
|
||||
|
||||
// try to retrieve ov::Model for body graph
|
||||
auto body_model = translate_session->get_body_ov_model(operation_type);
|
||||
FRONT_END_OP_CONVERSION_CHECK(
|
||||
body_model,
|
||||
"[TensorFlow Frontend] Internal error or incorrect input model: body graph is not found for " + operation_type +
|
||||
".");
|
||||
|
||||
// inject the body graph into the parent graph
|
||||
OutputVector ov_outputs;
|
||||
translate_session->inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
|
||||
|
||||
// set output tensor names
|
||||
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
|
||||
set_out_name({node_name + ":" + std::to_string(idx)}, ov_outputs[idx]);
|
||||
}
|
||||
|
||||
return ov_outputs;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -66,6 +66,7 @@ OP_CONVERTER(translate_identity_op);
|
||||
OP_CONVERTER(translate_identity_n_op);
|
||||
OP_CONVERTER(translate_input_arg_op);
|
||||
OP_CONVERTER(translate_output_arg_op);
|
||||
OP_CONVERTER(translate_if_op);
|
||||
OP_CONVERTER(translate_interpolate_op);
|
||||
OP_CONVERTER(translate_is_finite_op);
|
||||
OP_CONVERTER(translate_is_inf_op);
|
||||
@ -84,6 +85,7 @@ OP_CONVERTER(translate_mirror_pad_op);
|
||||
OP_CONVERTER(translate_non_max_suppression_op);
|
||||
OP_CONVERTER(translate_normalize_l2_op);
|
||||
OP_CONVERTER(translate_parallel_dynamic_stitch_op);
|
||||
OP_CONVERTER(translate_partitioned_call_op);
|
||||
OP_CONVERTER(translate_placeholder_op);
|
||||
OP_CONVERTER(translate_placeholder_with_default_op);
|
||||
OP_CONVERTER(translate_no_op);
|
||||
@ -248,6 +250,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"GatherNd", translate_gather_nd_op},
|
||||
{"Identity", translate_identity_op},
|
||||
{"IdentityN", translate_identity_n_op},
|
||||
{"If", translate_if_op},
|
||||
{"input_arg", translate_input_arg_op},
|
||||
{"output_arg", translate_output_arg_op},
|
||||
{"L2Loss", translate_l2_loss_op},
|
||||
@ -276,6 +279,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"PadV2", translate_padv2_op},
|
||||
{"DynamicStitch", translate_parallel_dynamic_stitch_op},
|
||||
{"ParallelDynamicStitch", translate_parallel_dynamic_stitch_op},
|
||||
{"PartitionedCall", translate_partitioned_call_op},
|
||||
{"Placeholder", translate_placeholder_op},
|
||||
{"PlaceholderWithDefault", translate_placeholder_with_default_op},
|
||||
{"PreventGradient", translate_identity_op},
|
||||
@ -314,6 +318,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Square", translate_square_op},
|
||||
{"Squeeze", translate_squeeze_op},
|
||||
{"SpaceToBatchND", translate_space_to_batch_nd_op},
|
||||
{"StatefulPartitionedCall", translate_partitioned_call_op},
|
||||
{"StatelessIf", translate_if_op},
|
||||
{"StridedSlice", translate_strided_slice_op},
|
||||
{"Tile", translate_tile_op},
|
||||
{"TopK", translate_top_k_op},
|
||||
|
343
src/frontends/tensorflow/src/translate_session.cpp
Normal file
343
src/frontends/tensorflow/src/translate_session.cpp
Normal file
@ -0,0 +1,343 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "translate_session.hpp"
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "tf_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::frontend::tensorflow;
|
||||
|
||||
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
|
||||
const std::shared_ptr<TranslatorDictionaryType>& translator_map,
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool telemetry)
|
||||
: m_fail_fast(fail_fast),
|
||||
m_telemetry(telemetry),
|
||||
m_input_model(input_model),
|
||||
m_translator_map(translator_map),
|
||||
m_model_name(model_name),
|
||||
m_ov_model(nullptr),
|
||||
m_cached_body_models(std::make_shared<CachedBodyModelsType>()),
|
||||
m_telemetry_data(std::make_shared<TelemetryDataType>()) {}
|
||||
|
||||
std::shared_ptr<ov::Model> TranslateSession::get_converted_model() {
|
||||
if (m_ov_model) {
|
||||
return m_ov_model;
|
||||
}
|
||||
translate_graph(m_input_model, m_ov_model);
|
||||
return m_ov_model;
|
||||
}
|
||||
|
||||
std::shared_ptr<TelemetryDataType> TranslateSession::get_telemetry_data() const {
|
||||
return m_telemetry_data;
|
||||
}
|
||||
|
||||
void TranslateSession::inject_body_model(std::shared_ptr<ov::Model> body_model,
|
||||
const std::string& operation_type,
|
||||
const ov::OutputVector& ov_inputs,
|
||||
ov::OutputVector& ov_outputs) {
|
||||
ov_outputs.clear();
|
||||
auto body_parameters = body_model->get_parameters();
|
||||
FRONT_END_GENERAL_CHECK(body_parameters.size() == ov_inputs.size(),
|
||||
"[TensorFlow Error] Internal error or incorrect input models: number of "
|
||||
"inputs and arguments to the function " +
|
||||
operation_type + " do not match.");
|
||||
for (size_t param_ind = 0; param_ind < body_parameters.size(); ++param_ind) {
|
||||
body_parameters[param_ind]->output(0).replace(ov_inputs[param_ind]);
|
||||
}
|
||||
for (const auto& result_node : body_model->get_results()) {
|
||||
ov_outputs.push_back(result_node->input_value(0));
|
||||
}
|
||||
}
|
||||
|
||||
void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model,
|
||||
std::shared_ptr<ov::Model>& ov_model) {
|
||||
OpMap ng_op_map;
|
||||
ov::ParameterVector params;
|
||||
ov::ResultVector results;
|
||||
const auto& model_tf = std::dynamic_pointer_cast<InputModel>(input_model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV Model");
|
||||
const auto& operation_places = model_tf->get_op_places();
|
||||
const auto& model_inputs = model_tf->get_inputs();
|
||||
const auto& model_outputs = model_tf->get_outputs();
|
||||
const auto& model_frozen_inputs = model_tf->get_tensor_values();
|
||||
|
||||
// fill ng_op_map with Constant outputs for frozen inputs
|
||||
for (const auto& frozen_input : model_frozen_inputs) {
|
||||
const auto& frozen_input_name = frozen_input.first;
|
||||
const auto& frozen_input_value = frozen_input.second;
|
||||
FRONT_END_GENERAL_CHECK(ng_op_map.count(frozen_input_name) == 0,
|
||||
"Input with frozen value has been already met: " + frozen_input_name);
|
||||
ng_op_map[frozen_input_name] = {frozen_input_value};
|
||||
}
|
||||
// create parameter nodes for all tensor places corresponding to inputs
|
||||
for (const auto& input_place : model_inputs) {
|
||||
FRONT_END_GENERAL_CHECK(input_place->get_names().size() == 1, "Input place must have one name.");
|
||||
auto input_name = input_place->get_names()[0];
|
||||
if (ng_op_map.count(input_name)) {
|
||||
// probably this input is frozen
|
||||
continue;
|
||||
}
|
||||
const auto& input_tensor_place = std::dynamic_pointer_cast<TensorPlace>(input_place);
|
||||
auto input_shape = input_tensor_place->get_partial_shape();
|
||||
auto input_type = input_tensor_place->get_element_type();
|
||||
|
||||
// in case of cutting graph, types of custom inputs can be undefined,
|
||||
// according to MO help, fp32 is used by default in such cases
|
||||
if (input_type == element::undefined) {
|
||||
input_type = element::f32;
|
||||
}
|
||||
|
||||
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
|
||||
set_node_name(input_name, param);
|
||||
params.push_back(param);
|
||||
ng_op_map[input_name] = {param};
|
||||
}
|
||||
|
||||
// create the OV ops from TensorFlow ops
|
||||
for (const auto& operation_place : operation_places) {
|
||||
auto operation_decoder = operation_place->get_decoder();
|
||||
auto operation_name = operation_place->get_names()[0];
|
||||
// output for parameter nodes has been already generated
|
||||
if (ng_op_map.count(operation_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// prepare a list of OV node inputs for each node
|
||||
ov::OutputVector ov_inputs;
|
||||
size_t operation_input_size = operation_decoder->get_input_size();
|
||||
|
||||
if (operation_decoder->get_op_type() == "NextIteration") {
|
||||
// we expect no inputs for NextIteration because we break-up the cycle in InputModel
|
||||
operation_input_size = 0;
|
||||
}
|
||||
for (size_t input_port_idx = 0; input_port_idx < operation_input_size; ++input_port_idx) {
|
||||
// TODO: Implement more general approach. Skipping Constants that have input edges
|
||||
if (operation_decoder->get_op_type() == "Const") {
|
||||
break;
|
||||
}
|
||||
std::string producer_name;
|
||||
size_t producer_port_idx;
|
||||
try {
|
||||
operation_decoder->get_input_node(input_port_idx, producer_name, producer_port_idx);
|
||||
} catch (const std::exception&) {
|
||||
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(input_port_idx) +
|
||||
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
|
||||
producer_name + "', expected input port index: " + std::to_string(producer_port_idx) +
|
||||
'\n');
|
||||
}
|
||||
|
||||
// skip conditional edges that must be resolved before operation translation
|
||||
// now we can meet them because we still work with TensorFlow protobuf
|
||||
if (is_conditional_edge(producer_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: re-implement the logic below once Place graph structure is implemented
|
||||
// Using Place graph structure (OpPlace, In/OutPortPlace places and their connections) can give
|
||||
// names of ports and operations that can be used for further check about existence in ng_op_map
|
||||
|
||||
// check if output vector for places have been already defined and the order of this check is important
|
||||
// it moves from places corresponding to input port of the current operation node to output port of original
|
||||
// producers
|
||||
if (ng_op_map.count(std::to_string(input_port_idx) + ":" + operation_name)) {
|
||||
const auto& input_outputs_vector = ng_op_map.at(std::to_string(input_port_idx) + ":" + operation_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name + ":" + std::to_string(producer_port_idx))) {
|
||||
const auto& input_outputs_vector =
|
||||
ng_op_map.at(producer_name + ":" + std::to_string(producer_port_idx));
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() == 1,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(0));
|
||||
} else if (ng_op_map.count(producer_name)) {
|
||||
const auto& input_outputs_vector = ng_op_map.at(producer_name);
|
||||
FRONT_END_GENERAL_CHECK(input_outputs_vector.size() > producer_port_idx,
|
||||
"Input created with pruning must have one output");
|
||||
ov_inputs.push_back(input_outputs_vector.at(producer_port_idx));
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"No input is found for node \"" + operation_name + "\" by port " +
|
||||
std::to_string(producer_port_idx));
|
||||
}
|
||||
}
|
||||
|
||||
// generate OV node output vector for the current operation node
|
||||
ov::OutputVector ov_outputs;
|
||||
bool is_converted = false;
|
||||
auto operation_type = operation_decoder->get_op_type();
|
||||
try {
|
||||
if (m_translator_map->count(operation_type)) {
|
||||
auto translator = m_translator_map->at(operation_decoder->get_op_type());
|
||||
NodeContext node_context(operation_decoder, ov_inputs, this);
|
||||
ov_outputs = translator(node_context);
|
||||
is_converted = true;
|
||||
} else if (auto body_ov_model = get_body_ov_model(operation_type)) {
|
||||
inject_body_model(body_ov_model, operation_type, ov_inputs, ov_outputs);
|
||||
|
||||
// set output tensor names
|
||||
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
|
||||
ov_outputs[idx].get_tensor().set_names({operation_name + ":" + std::to_string(idx)});
|
||||
}
|
||||
is_converted = true;
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(is_converted, "No translator found for " + operation_type + " node.");
|
||||
} catch (...) {
|
||||
if (m_fail_fast) {
|
||||
// in case of decode, unsupported operation will be converted to FrameworkNode
|
||||
if (m_telemetry && !is_converted) {
|
||||
// send event about which operation is not supported for conversion
|
||||
m_telemetry_data->push_back(
|
||||
std::make_pair<std::string, std::string>("error_cause", "tf_" + operation_type));
|
||||
}
|
||||
// re-throw any exception
|
||||
throw;
|
||||
} else {
|
||||
auto ng_node = std::make_shared<FrameworkNode>(operation_decoder,
|
||||
ov_inputs,
|
||||
operation_place->get_output_ports().size());
|
||||
set_node_name(operation_name, ng_node);
|
||||
ov_outputs = ng_node->outputs();
|
||||
}
|
||||
}
|
||||
|
||||
// register OV node outputs in the map for new operation node
|
||||
for (const auto& output : ov_outputs) {
|
||||
if (auto result = std::dynamic_pointer_cast<ov::opset10::Result>(output.get_node_shared_ptr())) {
|
||||
// do not add RetVal type operation to ng_op_map
|
||||
results.push_back(result);
|
||||
} else {
|
||||
auto param = std::dynamic_pointer_cast<ov::opset8::Parameter>(output.get_node_shared_ptr());
|
||||
// avoid duplicating Parameter nodes if they are already in the Parameters vector
|
||||
if (param && operation_decoder->get_op_type() != "Identity" &&
|
||||
std::find(params.begin(), params.end(), param) == params.end()) {
|
||||
params.push_back(param);
|
||||
}
|
||||
ng_op_map[operation_name].push_back(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create Result nodes for all model outputs
|
||||
if (results.empty()) {
|
||||
for (const auto& model_output : model_outputs) {
|
||||
auto model_output_tensor_place = std::dynamic_pointer_cast<TensorPlace>(model_output);
|
||||
auto model_output_name = model_output_tensor_place->get_names()[0];
|
||||
std::string operation_name;
|
||||
std::string port_type;
|
||||
size_t port_index;
|
||||
ov::frontend::tensorflow::extract_operation_name_and_port(model_output_name,
|
||||
operation_name,
|
||||
port_index,
|
||||
port_type);
|
||||
|
||||
if (port_type == "none") {
|
||||
for (const auto& node_output : ng_op_map[operation_name]) {
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_output);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
} else if (port_type == "out") {
|
||||
const auto& node_outputs = ng_op_map[operation_name];
|
||||
FRONT_END_GENERAL_CHECK(node_outputs.size() > port_index,
|
||||
"Output port with index " + std::to_string(port_index) + " of " +
|
||||
operation_name + "node specified as custom output does not exist");
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[port_index]);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
} else if (port_type == "in") {
|
||||
// TODO: avoid this traversing by having a map for OpPlace objects, for example
|
||||
std::shared_ptr<OpPlace> operation_place = nullptr;
|
||||
for (const auto& op_place : operation_places) {
|
||||
FRONT_END_GENERAL_CHECK(!op_place->get_names().empty(), "No names for OpPlace found.");
|
||||
if (op_place->get_names()[0] == operation_name) {
|
||||
operation_place = op_place;
|
||||
}
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(operation_place, "There is no operation place with a name: " + operation_name);
|
||||
auto operation_decoder = operation_place->get_decoder();
|
||||
|
||||
// get to know a producer node and by which its output port data is generated
|
||||
std::string producer_name;
|
||||
size_t producer_port_idx;
|
||||
try {
|
||||
operation_decoder->get_input_node(port_index, producer_name, producer_port_idx);
|
||||
} catch (const std::exception&) {
|
||||
FRONT_END_THROW("[ ERROR ] Exception happened when preparing input " + std::to_string(port_index) +
|
||||
" for op '" + operation_decoder->get_op_name() + "', expected input name: '" +
|
||||
producer_name +
|
||||
"', expected input port index: " + std::to_string(producer_port_idx) + '\n');
|
||||
}
|
||||
|
||||
// add Result node for this producer output port
|
||||
const auto& node_outputs = ng_op_map[producer_name];
|
||||
FRONT_END_GENERAL_CHECK(node_outputs.size() > producer_port_idx,
|
||||
"Output port with index " + std::to_string(producer_port_idx) + " of " +
|
||||
producer_name + "node specified as custom output does not exist");
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx]);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: it may be redundant step since models_output is filled in InputModel constructor
|
||||
// find all terminal nodes in OV graph to complete list of results
|
||||
if (results.empty()) {
|
||||
for (const auto& node_output_vector : ng_op_map) {
|
||||
for (size_t output_ind = 0; output_ind < node_output_vector.second.size(); ++output_ind) {
|
||||
auto output = node_output_vector.second[output_ind];
|
||||
if (output.get_target_inputs().empty() &&
|
||||
!std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
|
||||
auto model_output_name =
|
||||
output.get_node_shared_ptr()->get_friendly_name() + ":" + std::to_string(output_ind);
|
||||
auto result_node = std::make_shared<ov::opset8::Result>(output);
|
||||
result_node->set_friendly_name(model_output_name);
|
||||
results.push_back(result_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: reorder results and params according to indices given in RT info (if any)
|
||||
|
||||
// create the OV Model
|
||||
ov_model = std::make_shared<ov::Model>(results, params, m_model_name);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> TranslateSession::get_body_ov_model(const std::string& body_graph_name) {
|
||||
std::shared_ptr<ov::Model> body_model = nullptr;
|
||||
auto input_model = std::dynamic_pointer_cast<InputModel>(m_input_model);
|
||||
if (m_cached_body_models->count(body_graph_name)) {
|
||||
// check if such body graph has been converted before
|
||||
// re-use it from the cache for further injection
|
||||
|
||||
// create new instance of the required body model
|
||||
// since it will be modified by injection
|
||||
auto cached_body_model = m_cached_body_models->at(body_graph_name);
|
||||
body_model = cached_body_model->clone();
|
||||
} else if (auto body_input_model = input_model->get_body_input_model(body_graph_name)) {
|
||||
// try to find a function by name in the model library
|
||||
translate_graph(body_input_model, body_model);
|
||||
// save new instance of body_model in the cache of body models
|
||||
// before its injection into the parent graph
|
||||
|
||||
// before caching, erase tensor names from the body graph
|
||||
// otherwise, it can lead tensor names conflicts
|
||||
for (const auto& op : body_model->get_ordered_ops()) {
|
||||
for (size_t ind = 0; ind < op->get_output_size(); ++ind) {
|
||||
op->get_output_tensor(ind).set_names({});
|
||||
}
|
||||
}
|
||||
|
||||
auto cached_body_model = body_model->clone();
|
||||
update_cached_body_models(body_graph_name, cached_body_model);
|
||||
}
|
||||
return body_model;
|
||||
}
|
57
src/frontends/tensorflow/src/translate_session.hpp
Normal file
57
src/frontends/tensorflow/src/translate_session.hpp
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/frontend/input_model.hpp"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
using CachedBodyModelsType = std::unordered_map<std::string, std::shared_ptr<const ov::Model>>;
|
||||
using TelemetryDataType = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
/// For one call of convert and decode method of Frontend, it creates one TranslateSession object to save data for the
|
||||
/// translation session: telemetry statistics, cache of convrted body graph models, operation translators (including
|
||||
/// extensions) registered for this translation session.
|
||||
class TranslateSession {
|
||||
public:
|
||||
TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
|
||||
const std::shared_ptr<TranslatorDictionaryType>& translator_map,
|
||||
const std::string& model_name,
|
||||
bool fail_fast,
|
||||
bool telemetry);
|
||||
std::shared_ptr<ov::Model> get_converted_model();
|
||||
std::shared_ptr<TelemetryDataType> get_telemetry_data() const;
|
||||
|
||||
void translate_graph(const ov::frontend::InputModel::Ptr& input_model, std::shared_ptr<ov::Model>& ov_model);
|
||||
|
||||
void inject_body_model(std::shared_ptr<ov::Model> body_model,
|
||||
const std::string& operation_type,
|
||||
const ov::OutputVector& ov_inputs,
|
||||
ov::OutputVector& ov_outputs);
|
||||
|
||||
std::shared_ptr<ov::Model> get_body_ov_model(const std::string& body_graph_name);
|
||||
|
||||
private:
|
||||
const ov::frontend::InputModel::Ptr m_input_model;
|
||||
const bool m_fail_fast;
|
||||
const bool m_telemetry;
|
||||
const std::shared_ptr<TranslatorDictionaryType> m_translator_map;
|
||||
const std::string m_model_name;
|
||||
|
||||
std::shared_ptr<CachedBodyModelsType> m_cached_body_models;
|
||||
std::shared_ptr<TelemetryDataType> m_telemetry_data;
|
||||
std::shared_ptr<ov::Model> m_ov_model;
|
||||
|
||||
void update_cached_body_models(const std::string& operation_type,
|
||||
const std::shared_ptr<const ov::Model>& cached_body_model) {
|
||||
m_cached_body_models->insert(std::make_pair(operation_type, cached_body_model));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -114,3 +114,92 @@ TEST_F(TransformationTestsF, ModelWithSwishF32BodyGraph) {
|
||||
model_ref = make_shared<Model>(OutputVector{sigmoid2}, ParameterVector{x});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PartitionedCall) {
|
||||
{
|
||||
model = convert_model("partitioned_call/partitioned_call.pb");
|
||||
// need to call shape inference since body graphs can be injected with undefined shapes
|
||||
model->validate_nodes_and_infer_types();
|
||||
}
|
||||
{
|
||||
auto x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto sub = make_shared<Subtract>(x, y);
|
||||
auto const_pow = make_shared<Constant>(i32, Shape{}, 2);
|
||||
auto pow = make_shared<Power>(sub, const_pow);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{pow}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ModelWithIf) {
|
||||
{ model = convert_model("model_with_if/model_with_if.pb"); }
|
||||
{
|
||||
// create then branch body graph
|
||||
auto then_x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto then_y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto add = make_shared<Add>(then_x, then_y);
|
||||
auto then_result = make_shared<Result>(add);
|
||||
auto then_model = make_shared<Model>(OutputVector{then_result}, ParameterVector{then_x, then_y});
|
||||
|
||||
// create else branch body graph
|
||||
auto else_x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto else_y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto sub = make_shared<Subtract>(else_x, else_y);
|
||||
auto else_result = make_shared<Result>(sub);
|
||||
auto else_model = make_shared<Model>(OutputVector{else_result}, ParameterVector{else_x, else_y});
|
||||
|
||||
// create the main graph
|
||||
auto x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto cond_const = make_shared<Constant>(i32, Shape{}, 10);
|
||||
auto cond = make_shared<Greater>(x, cond_const);
|
||||
auto if_op = make_shared<If>(cond);
|
||||
if_op->set_then_body(then_model);
|
||||
if_op->set_else_body(else_model);
|
||||
if_op->set_input(x, then_x, else_x);
|
||||
if_op->set_input(y, then_y, else_y);
|
||||
if_op->set_output(then_result, else_result);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{if_op}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, InjectedBodyAndIf) {
|
||||
{
|
||||
model = convert_model("injected_body_and_if/injected_body_and_if.pb");
|
||||
// need to call shape inference since body graphs can be injected with undefined shapes
|
||||
model->validate_nodes_and_infer_types();
|
||||
}
|
||||
{
|
||||
// create then branch body graph
|
||||
auto then_x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto then_y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto add = make_shared<Add>(then_x, then_y);
|
||||
auto then_result = make_shared<Result>(add);
|
||||
auto then_model = make_shared<Model>(OutputVector{then_result}, ParameterVector{then_x, then_y});
|
||||
|
||||
// create else branch body graph
|
||||
auto else_x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto else_y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto sub = make_shared<Subtract>(else_x, else_y);
|
||||
auto pow_const = make_shared<Constant>(i32, Shape{}, 2);
|
||||
auto pow = make_shared<Power>(sub, pow_const);
|
||||
auto else_result = make_shared<Result>(pow);
|
||||
auto else_model = make_shared<Model>(OutputVector{else_result}, ParameterVector{else_x, else_y});
|
||||
|
||||
// create the main graph
|
||||
auto x = make_shared<Parameter>(i32, Shape{2});
|
||||
auto y = make_shared<Parameter>(i32, Shape{1});
|
||||
auto cond_const = make_shared<Constant>(i32, Shape{}, 10);
|
||||
auto cond = make_shared<Greater>(x, cond_const);
|
||||
auto if_op = make_shared<If>(cond);
|
||||
if_op->set_then_body(then_model);
|
||||
if_op->set_else_body(else_model);
|
||||
if_op->set_input(x, then_x, else_x);
|
||||
if_op->set_input(y, then_y, else_y);
|
||||
if_op->set_output(then_result, else_result);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{if_op}, ParameterVector{x, y});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,368 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Greater"
|
||||
op: "Greater"
|
||||
input: "x"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "If"
|
||||
op: "If"
|
||||
input: "Greater"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "else_branch_func_AthmcAbLnco"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "then_branch_func_mdn8Hcdd6RQ"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "then_branch_func_mdn8Hcdd6RQ"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "AddV2"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "else_branch_func_AthmcAbLnco"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "aux_func_gmpkxbsu4wi"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "sub"
|
||||
op: "Sub"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "sub"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "aux_func_GmpkxbsU4WI"
|
||||
op: "aux_func_GmpkxbsU4WI"
|
||||
input: "sub:z:0"
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "aux_func_GmpkxbsU4WI"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "aux_func_gmpkxbsu4wi"
|
||||
value: "aux_func_GmpkxbsU4WI:pow:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "aux_func_GmpkxbsU4WI"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "pow"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "pow/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "pow/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "pow_0"
|
||||
op: "Pow"
|
||||
input: "x"
|
||||
input: "pow/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "pow"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "pow"
|
||||
value: "pow_0:z:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 808
|
||||
min_consumer: 12
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.python.framework import function
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
@function.Defun(tf.int32)
|
||||
def aux_func(x):
|
||||
return x ** 2
|
||||
|
||||
|
||||
@function.Defun(tf.int32, tf.int32)
|
||||
def then_branch_func(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
@function.Defun(tf.int32, tf.int32)
|
||||
def else_branch_func(x, y):
|
||||
return aux_func(x - y)
|
||||
|
||||
|
||||
with tf.Session() as sess:
|
||||
x = tf.placeholder(tf.int32, [2], 'x')
|
||||
y = tf.placeholder(tf.int32, [1], 'y')
|
||||
const_cond = tf.constant(10, dtype=tf.int32)
|
||||
cond = tf.raw_ops.Greater(x=x, y=const_cond)
|
||||
if_op = tf.raw_ops.If(cond=cond, input=[x, y], Tout=[tf.int32], then_branch=then_branch_func,
|
||||
else_branch=else_branch_func)
|
||||
tf.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
tf.io.write_graph(tf_net, './', 'injected_body_and_if.pbtxt', as_text=True)
|
@ -0,0 +1,278 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Greater"
|
||||
op: "Greater"
|
||||
input: "x"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "If"
|
||||
op: "If"
|
||||
input: "Greater"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "else_branch_func_Fw4jHLGozIk"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "then_branch_func_mdn8Hcdd6RQ"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "then_branch_func_mdn8Hcdd6RQ"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "AddV2"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "else_branch_func_Fw4jHLGozIk"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "sub"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "sub_0"
|
||||
op: "Sub"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "sub"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "sub"
|
||||
value: "sub_0:z:0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 808
|
||||
min_consumer: 12
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.python.framework import function
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
@function.Defun(tf.int32, tf.int32)
|
||||
def then_branch_func(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
@function.Defun(tf.int32, tf.int32)
|
||||
def else_branch_func(x, y):
|
||||
return x - y
|
||||
|
||||
|
||||
with tf.Session() as sess:
|
||||
x = tf.placeholder(tf.int32, [2], 'x')
|
||||
y = tf.placeholder(tf.int32, [1], 'y')
|
||||
const_cond = tf.constant(10, dtype=tf.int32)
|
||||
cond = tf.raw_ops.Greater(x=x, y=const_cond)
|
||||
if_op = tf.raw_ops.If(cond=cond, input=[x, y], Tout=[tf.int32], then_branch=then_branch_func,
|
||||
else_branch=else_branch_func)
|
||||
tf.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
tf.io.write_graph(tf_net, './', 'model_with_if.pbtxt', as_text=True)
|
@ -0,0 +1,240 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "x"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "y"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "sub"
|
||||
op: "Sub"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "PartitionedCall"
|
||||
op: "PartitionedCall"
|
||||
input: "sub"
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_collective_manager_ids"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_read_only_resource_inputs"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "config"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "config_proto"
|
||||
value {
|
||||
s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "executor_type"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "f"
|
||||
value {
|
||||
func {
|
||||
name: "__inference_second_func_14"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "PartitionedCall"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "__inference_second_func_14"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "identity"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "pow/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "pow/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "pow"
|
||||
op: "Pow"
|
||||
input: "x"
|
||||
input: "pow/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "pow"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "pow:z:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Identity"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "identity"
|
||||
value: "Identity:output:0"
|
||||
}
|
||||
attr {
|
||||
key: "_construction_context"
|
||||
value {
|
||||
s: "kEagerRuntime"
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "x"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 808
|
||||
min_consumer: 12
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
@tf.function
|
||||
def second_func(x):
|
||||
return x ** 2
|
||||
|
||||
|
||||
@tf.function
|
||||
def first_func(x, y):
|
||||
return second_func(x - y)
|
||||
|
||||
|
||||
graph_def = first_func.get_concrete_function(tf.constant([6, 3]), tf.constant([7])).graph.as_graph_def()
|
||||
tf.io.write_graph(graph_def, '.', 'partitioned_call.pbtxt', as_text=True)
|
Loading…
Reference in New Issue
Block a user