[TF FE][TF Hub] Use ConcreteFunc input and output signatures (#19690)

* [TF Hub][TF FE] Preserve outputs of ConcreteFunction from signature and their names

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix naming and complete TODO

* Apply code-review: extra assert to check input_signature

* Fix inputs for fw

* Fix input data preparation and import convert_model

* Correct variable detection among all inputs

* Handle special input and output signature

* Fix adjust_saved_model_names

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-09-10 07:46:01 +04:00 committed by GitHub
parent 932ba63744
commit 37f61551a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 49 deletions

View File

@ -10,14 +10,16 @@ from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndPyGraph
class GraphIteratorTFGraph(GraphIterator):
def __init__(self, tf_graph: tf.Graph, share_weights: bool, inner_graph: bool = False):
def __init__(self, tf_graph: tf.Graph, share_weights: bool, inner_graph: bool = False,
input_names_map: dict = None, output_names_map: dict = None):
GraphIterator.__init__(self)
self.m_graph = tf_graph
self.m_node_index = 0
self.m_decoders = []
self.m_inner_graph = inner_graph
self.m_share_weights = share_weights
self.m_input_names_map = input_names_map or {}
self.m_output_names_map = output_names_map or {}
self.m_vars = None
if hasattr(tf_graph, "variables"):
# This field is needed to keep the link to graph variables,
@ -32,6 +34,10 @@ class GraphIteratorTFGraph(GraphIterator):
self.m_iterators[func_name] = None
def get_input_names(self) -> list:
# returns a vector of input names in the original order
# Note: used only for the library functions
if not self.m_inner_graph:
return []
inp_ops = filter(lambda op: op.type == "Placeholder", self.m_graph.get_operations())
inp_names = []
for inp in inp_ops:
@ -48,6 +54,10 @@ class GraphIteratorTFGraph(GraphIterator):
return inp_names
def get_output_names(self) -> list:
# returns a vector of output names in the original order
# Note: used only for the library functions
if not self.m_inner_graph:
return []
# tf.Graph has ordered outputs which are stored in 'outputs' field,
# but using this field results in mismatch of outputs in inner graph and outputs in outer graph
# during the injection of subgraph.
@ -67,6 +77,14 @@ class GraphIteratorTFGraph(GraphIterator):
outputs = [output.name] + outputs
return outputs
def get_input_names_map(self) -> dict:
# returns a map from (user-defined) external tensor name to internal name for inputs
return self.m_input_names_map
def get_output_names_map(self) -> dict:
# returns a map from (user-defined) external tensor name to internal name for outputs
return self.m_output_names_map
def is_end(self) -> bool:
return self.m_node_index >= len(self.m_decoders)

View File

@ -237,7 +237,32 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
if isinstance(input_model, tf.Graph):
return GraphIteratorTFGraph(input_model, share_weights)
elif isinstance(input_model, tf.types.experimental.ConcreteFunction):
return GraphIteratorTFGraph(input_model.graph, share_weights)
# create a map for inputs to map internal tensor name to external one
# collect all internal tensor names in a given order
input_names_map = None
if hasattr(input_model, 'inputs') and hasattr(input_model, 'structured_input_signature'):
internal_tensor_names = []
for func_input in input_model.inputs:
if func_input.dtype == tf.resource:
continue
internal_tensor_names.append(func_input.name)
if len(input_model.structured_input_signature) > 1 and \
len(internal_tensor_names) == len(input_model.structured_input_signature[1]):
external_tensor_names = sorted(input_model.structured_input_signature[1].keys())
for internal_name, external_name in zip(internal_tensor_names, external_tensor_names):
input_names_map = input_names_map or {}
input_names_map[internal_name] = external_name
output_names_map = None
if hasattr(input_model, 'outputs') and hasattr(input_model, 'structured_outputs') and \
isinstance(input_model.structured_outputs, dict):
external_names = sorted(list(input_model.structured_outputs.keys()))
internal_names = sorted([tensor.name for tensor in input_model.outputs])
if len(external_names) == len(internal_names):
for external_name, internal_name in zip(external_names, internal_names):
output_names_map = output_names_map or {}
output_names_map[internal_name] = external_name
return GraphIteratorTFGraph(input_model.graph, share_weights, False, input_names_map, output_names_map)
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))

View File

@ -5,8 +5,9 @@
#pragma once
#include <pybind11/pybind11.h>
#include "openvino/frontend/graph_iterator.hpp"
#include "openvino/frontend/decoder.hpp"
#include "openvino/frontend/graph_iterator.hpp"
namespace py = pybind11;
@ -14,9 +15,10 @@ namespace py = pybind11;
class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
/* Inherit the constructors */
using ov::frontend::tensorflow::GraphIterator::GraphIterator;
using map_str_to_str = std::map<std::string, std::string>;
/// \brief Get a number of operation nodes in the graph
size_t size() const override{
size_t size() const override {
PYBIND11_OVERRIDE_PURE(size_t, GraphIterator, size);
}
@ -30,7 +32,8 @@ class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
next_impl();
}
/// Implementation of next method, it is needed to be in separate method to avoid shadowing of Python "next" operator.
/// Implementation of next method, it is needed to be in separate method to avoid shadowing of Python "next"
/// operator.
void next_impl() {
PYBIND11_OVERRIDE_PURE(void, GraphIterator, next_impl);
}
@ -41,26 +44,35 @@ class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
}
/// \brief Return a pointer to a decoder of the current node
std::shared_ptr<ov::frontend::DecoderBase> get_decoder() const override{
std::shared_ptr<ov::frontend::DecoderBase> get_decoder() const override {
PYBIND11_OVERRIDE_PURE(std::shared_ptr<ov::frontend::DecoderBase>, GraphIterator, get_decoder);
}
/// \brief Checks if the main model graph contains a function of the requested name in the library
/// Returns GraphIterator to this function and nullptr, if it does not exist
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override{
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
PYBIND11_OVERRIDE_PURE(std::shared_ptr<GraphIterator>, GraphIterator, get_body_graph_iterator, func_name);
}
/// \brief Returns a vector of input names in the original order
std::vector<std::string> get_input_names() const override{
std::vector<std::string> get_input_names() const override {
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_input_names);
}
/// \brief Returns a vector of output names in the original order
std::vector<std::string> get_output_names() const override{
std::vector<std::string> get_output_names() const override {
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_output_names);
}
/// \brief Returns a map from internal tensor name to (user-defined) external name for inputs
map_str_to_str get_input_names_map() const override {
PYBIND11_OVERRIDE_PURE(map_str_to_str, GraphIterator, get_input_names_map);
}
/// \brief Returns a map from internal tensor name to (user-defined) external name for outputs
map_str_to_str get_output_names_map() const override {
PYBIND11_OVERRIDE_PURE(map_str_to_str, GraphIterator, get_output_names_map);
}
};
void regclass_frontend_tensorflow_graph_iterator(py::module m);

View File

@ -46,6 +46,16 @@ public:
/// \brief Returns a vector of output names in the original order
virtual std::vector<std::string> get_output_names() const = 0;
/// \brief Returns a map from internal tensor name to (user-defined) external name for inputs
virtual std::map<std::string, std::string> get_input_names_map() const {
return {};
}
/// \brief Returns a map from internal tensor name to (user-defined) external name for outputs
virtual std::map<std::string, std::string> get_output_names_map() const {
return {};
}
};
} // namespace tensorflow

View File

@ -347,7 +347,23 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
else if (variants[0].is<GraphIterator::Ptr>()) {
// this is used for OpenVINO with TensorFlow Integration
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
std::shared_ptr<std::map<std::string, std::string>> input_names_map = nullptr;
std::shared_ptr<std::map<std::string, std::string>> output_names_map = nullptr;
if (graph_iterator->get_input_names_map().size() > 0) {
input_names_map =
std::make_shared<std::map<std::string, std::string>>(graph_iterator->get_input_names_map());
}
if (graph_iterator->get_output_names_map().size() > 0) {
output_names_map =
std::make_shared<std::map<std::string, std::string>>(graph_iterator->get_output_names_map());
}
return std::make_shared<InputModel>(graph_iterator,
m_telemetry,
nullptr,
input_names_map,
output_names_map,
nullptr,
false);
}
FRONT_END_GENERAL_CHECK(false,

View File

@ -66,31 +66,40 @@ void adjust_saved_model_names(ov::Output<ov::Node>& ov_output,
// 2. find a set of clean-up names and aligned with the model signature
const auto& tensor_names = ov_output.get_names();
std::unordered_set<std::string> cleanup_names;
bool signature_passed = true;
if (is_input_tensor) {
for (const auto& tensor_name : tensor_names) {
if (saved_model_input_names->count(tensor_name) > 0) {
cleanup_names.insert(saved_model_input_names->at(tensor_name));
param_node->set_friendly_name(saved_model_input_names->at(tensor_name));
if (saved_model_input_names) {
for (const auto& tensor_name : tensor_names) {
if (saved_model_input_names->count(tensor_name) > 0) {
cleanup_names.insert(saved_model_input_names->at(tensor_name));
param_node->set_friendly_name(saved_model_input_names->at(tensor_name));
}
}
} else {
signature_passed = false;
}
}
if (is_output_tensor) {
std::vector<std::string> result_names;
for (const auto& tensor_name : tensor_names) {
if (saved_model_output_names->count(tensor_name) > 0) {
cleanup_names.insert(saved_model_output_names->at(tensor_name));
result_names.push_back(saved_model_output_names->at(tensor_name));
if (saved_model_output_names) {
std::vector<std::string> result_names;
for (const auto& tensor_name : tensor_names) {
if (saved_model_output_names->count(tensor_name) > 0) {
cleanup_names.insert(saved_model_output_names->at(tensor_name));
result_names.push_back(saved_model_output_names->at(tensor_name));
}
}
}
// align the Result node names as many as possible
// it is not bad if we remain it as is because OV API 2.0 relies only on tensor names
size_t result_names_size = result_names.size();
if (result_names_size > 0) {
for (size_t ind = 0; ind < results.size(); ++ind) {
auto new_result_name = result_names[ind % result_names_size];
results[ind]->set_friendly_name(new_result_name);
// align the Result node names as many as possible
// it is not bad if we remain it as is because OV API 2.0 relies only on tensor names
size_t result_names_size = result_names.size();
if (result_names_size > 0) {
for (size_t ind = 0; ind < results.size(); ++ind) {
auto new_result_name = result_names[ind % result_names_size];
results[ind]->set_friendly_name(new_result_name);
}
}
} else {
signature_passed = false;
}
}
@ -98,7 +107,7 @@ void adjust_saved_model_names(ov::Output<ov::Node>& ov_output,
// otherwise, the tensor corresponds to unused Parameter or Result nodes
if (cleanup_names.size() > 0) {
ov_output.set_names(cleanup_names);
} else {
} else if (signature_passed) {
// this is unused tensor that should be removed
// because it not present in the signature
ov_output.add_names({"saved_model_unused"});

View File

@ -4,8 +4,8 @@ import gc
import numpy as np
from models_hub_common.multiprocessing_utils import multiprocessing_run
from openvino import convert_model
from openvino.runtime import Core
from openvino.tools.mo import convert_model
class TestConvertModel:
@ -32,9 +32,9 @@ class TestConvertModel:
assert False, "Unsupported type {}".format(input_type)
def prepare_inputs(self, inputs_info):
inputs = []
for input_shape, input_type in inputs_info:
inputs.append(self.prepare_input(input_shape, input_type))
inputs = {}
for input_name, input_shape, input_type in inputs_info:
inputs[input_name] = self.prepare_input(input_shape, input_type)
return inputs
def convert_model(self, model_obj):

View File

@ -32,7 +32,8 @@ class TestTFHubConvertModel(TestConvertModel):
def get_inputs_info(self, model_obj):
inputs_info = []
for input_info in model_obj.inputs:
assert len(model_obj.structured_input_signature) > 1, "incorrect model or test issue"
for input_name, input_info in model_obj.structured_input_signature[1].items():
input_shape = []
try:
for dim in input_info.shape.as_list():
@ -55,32 +56,25 @@ class TestTFHubConvertModel(TestConvertModel):
tf.string: str,
tf.bool: bool,
}
if input_info.dtype not in type_map:
if input_info.dtype == tf.resource:
# skip inputs corresponding to variables
continue
assert input_info.dtype in type_map, "Unsupported input type: {}".format(input_info.dtype)
inputs_info.append((input_shape, type_map[input_info.dtype]))
inputs_info.append((input_name, input_shape, type_map[input_info.dtype]))
return inputs_info
def infer_fw_model(self, model_obj, inputs):
# TODO 119141 - use the same dictionary for OV inference
# repack input dictionary to tensorflow constants
tf_inputs = {}
for input_ind, input_name in enumerate(sorted(model_obj.structured_input_signature[1].keys())):
tf_inputs[input_name] = tf.constant(inputs[input_ind])
for input_name, input_value in inputs.items():
tf_inputs[input_name] = tf.constant(input_value)
output_dict = {}
for out_name, out_value in model_obj(**tf_inputs).items():
output_dict[out_name] = out_value.numpy()
# TODO: 119141 - remove this workaround
# map external tensor names to internal names
assert len(model_obj.outputs) == len(model_obj.structured_outputs)
fw_outputs = {}
for output_ind, external_name in enumerate(sorted(model_obj.structured_outputs.keys())):
internal_name = model_obj.outputs[output_ind].name
out_value = output_dict[external_name]
fw_outputs[internal_name] = out_value
return fw_outputs
return output_dict
def teardown_method(self):
# remove all downloaded files for TF Hub models