[TF FE][TF Hub] Use ConcreteFunc input and output signatures (#19690)
* [TF Hub][TF FE] Preserve outputs of ConcreteFunction from signature and their names Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix naming and complete TODO * Apply code-review: extra assert to check input_signature * Fix inputs for fw * Fix input data preparation and import convert_model * Correct variable detection among all inputs * Handle special input and output signature * Fix adjust_saved_model_names --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
932ba63744
commit
37f61551a3
@ -10,14 +10,16 @@ from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndPyGraph
|
||||
|
||||
|
||||
class GraphIteratorTFGraph(GraphIterator):
|
||||
def __init__(self, tf_graph: tf.Graph, share_weights: bool, inner_graph: bool = False):
|
||||
def __init__(self, tf_graph: tf.Graph, share_weights: bool, inner_graph: bool = False,
|
||||
input_names_map: dict = None, output_names_map: dict = None):
|
||||
GraphIterator.__init__(self)
|
||||
self.m_graph = tf_graph
|
||||
self.m_node_index = 0
|
||||
self.m_decoders = []
|
||||
self.m_inner_graph = inner_graph
|
||||
self.m_share_weights = share_weights
|
||||
|
||||
self.m_input_names_map = input_names_map or {}
|
||||
self.m_output_names_map = output_names_map or {}
|
||||
self.m_vars = None
|
||||
if hasattr(tf_graph, "variables"):
|
||||
# This field is needed to keep the link to graph variables,
|
||||
@ -32,6 +34,10 @@ class GraphIteratorTFGraph(GraphIterator):
|
||||
self.m_iterators[func_name] = None
|
||||
|
||||
def get_input_names(self) -> list:
|
||||
# returns a vector of input names in the original order
|
||||
# Note: used only for the library functions
|
||||
if not self.m_inner_graph:
|
||||
return []
|
||||
inp_ops = filter(lambda op: op.type == "Placeholder", self.m_graph.get_operations())
|
||||
inp_names = []
|
||||
for inp in inp_ops:
|
||||
@ -48,6 +54,10 @@ class GraphIteratorTFGraph(GraphIterator):
|
||||
return inp_names
|
||||
|
||||
def get_output_names(self) -> list:
|
||||
# returns a vector of output names in the original order
|
||||
# Note: used only for the library functions
|
||||
if not self.m_inner_graph:
|
||||
return []
|
||||
# tf.Graph has ordered outputs which are stored in 'outputs' field,
|
||||
# but using this field results in mismatch of outputs in inner graph and outputs in outer graph
|
||||
# during the injection of subgraph.
|
||||
@ -67,6 +77,14 @@ class GraphIteratorTFGraph(GraphIterator):
|
||||
outputs = [output.name] + outputs
|
||||
return outputs
|
||||
|
||||
def get_input_names_map(self) -> dict:
|
||||
# returns a map from (user-defined) external tensor name to internal name for inputs
|
||||
return self.m_input_names_map
|
||||
|
||||
def get_output_names_map(self) -> dict:
|
||||
# returns a map from (user-defined) external tensor name to internal name for outputs
|
||||
return self.m_output_names_map
|
||||
|
||||
def is_end(self) -> bool:
|
||||
return self.m_node_index >= len(self.m_decoders)
|
||||
|
||||
|
@ -237,7 +237,32 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
|
||||
if isinstance(input_model, tf.Graph):
|
||||
return GraphIteratorTFGraph(input_model, share_weights)
|
||||
elif isinstance(input_model, tf.types.experimental.ConcreteFunction):
|
||||
return GraphIteratorTFGraph(input_model.graph, share_weights)
|
||||
# create a map for inputs to map internal tensor name to external one
|
||||
# collect all internal tensor names in a given order
|
||||
input_names_map = None
|
||||
if hasattr(input_model, 'inputs') and hasattr(input_model, 'structured_input_signature'):
|
||||
internal_tensor_names = []
|
||||
for func_input in input_model.inputs:
|
||||
if func_input.dtype == tf.resource:
|
||||
continue
|
||||
internal_tensor_names.append(func_input.name)
|
||||
if len(input_model.structured_input_signature) > 1 and \
|
||||
len(internal_tensor_names) == len(input_model.structured_input_signature[1]):
|
||||
external_tensor_names = sorted(input_model.structured_input_signature[1].keys())
|
||||
for internal_name, external_name in zip(internal_tensor_names, external_tensor_names):
|
||||
input_names_map = input_names_map or {}
|
||||
input_names_map[internal_name] = external_name
|
||||
|
||||
output_names_map = None
|
||||
if hasattr(input_model, 'outputs') and hasattr(input_model, 'structured_outputs') and \
|
||||
isinstance(input_model.structured_outputs, dict):
|
||||
external_names = sorted(list(input_model.structured_outputs.keys()))
|
||||
internal_names = sorted([tensor.name for tensor in input_model.outputs])
|
||||
if len(external_names) == len(internal_names):
|
||||
for external_name, internal_name in zip(external_names, internal_names):
|
||||
output_names_map = output_names_map or {}
|
||||
output_names_map[internal_name] = external_name
|
||||
return GraphIteratorTFGraph(input_model.graph, share_weights, False, input_names_map, output_names_map)
|
||||
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))
|
||||
|
||||
|
||||
|
@ -5,8 +5,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "openvino/frontend/graph_iterator.hpp"
|
||||
|
||||
#include "openvino/frontend/decoder.hpp"
|
||||
#include "openvino/frontend/graph_iterator.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -14,9 +15,10 @@ namespace py = pybind11;
|
||||
class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
|
||||
/* Inherit the constructors */
|
||||
using ov::frontend::tensorflow::GraphIterator::GraphIterator;
|
||||
using map_str_to_str = std::map<std::string, std::string>;
|
||||
|
||||
/// \brief Get a number of operation nodes in the graph
|
||||
size_t size() const override{
|
||||
size_t size() const override {
|
||||
PYBIND11_OVERRIDE_PURE(size_t, GraphIterator, size);
|
||||
}
|
||||
|
||||
@ -30,7 +32,8 @@ class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
|
||||
next_impl();
|
||||
}
|
||||
|
||||
/// Implementation of next method, it is needed to be in separate method to avoid shadowing of Python "next" operator.
|
||||
/// Implementation of next method, it is needed to be in separate method to avoid shadowing of Python "next"
|
||||
/// operator.
|
||||
void next_impl() {
|
||||
PYBIND11_OVERRIDE_PURE(void, GraphIterator, next_impl);
|
||||
}
|
||||
@ -41,26 +44,35 @@ class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
|
||||
}
|
||||
|
||||
/// \brief Return a pointer to a decoder of the current node
|
||||
std::shared_ptr<ov::frontend::DecoderBase> get_decoder() const override{
|
||||
std::shared_ptr<ov::frontend::DecoderBase> get_decoder() const override {
|
||||
PYBIND11_OVERRIDE_PURE(std::shared_ptr<ov::frontend::DecoderBase>, GraphIterator, get_decoder);
|
||||
}
|
||||
|
||||
/// \brief Checks if the main model graph contains a function of the requested name in the library
|
||||
/// Returns GraphIterator to this function and nullptr, if it does not exist
|
||||
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override{
|
||||
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override {
|
||||
PYBIND11_OVERRIDE_PURE(std::shared_ptr<GraphIterator>, GraphIterator, get_body_graph_iterator, func_name);
|
||||
}
|
||||
|
||||
/// \brief Returns a vector of input names in the original order
|
||||
std::vector<std::string> get_input_names() const override{
|
||||
std::vector<std::string> get_input_names() const override {
|
||||
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_input_names);
|
||||
|
||||
}
|
||||
|
||||
/// \brief Returns a vector of output names in the original order
|
||||
std::vector<std::string> get_output_names() const override{
|
||||
std::vector<std::string> get_output_names() const override {
|
||||
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_output_names);
|
||||
}
|
||||
|
||||
/// \brief Returns a map from internal tensor name to (user-defined) external name for inputs
|
||||
map_str_to_str get_input_names_map() const override {
|
||||
PYBIND11_OVERRIDE_PURE(map_str_to_str, GraphIterator, get_input_names_map);
|
||||
}
|
||||
|
||||
/// \brief Returns a map from internal tensor name to (user-defined) external name for outputs
|
||||
map_str_to_str get_output_names_map() const override {
|
||||
PYBIND11_OVERRIDE_PURE(map_str_to_str, GraphIterator, get_output_names_map);
|
||||
}
|
||||
};
|
||||
|
||||
void regclass_frontend_tensorflow_graph_iterator(py::module m);
|
@ -46,6 +46,16 @@ public:
|
||||
|
||||
/// \brief Returns a vector of output names in the original order
|
||||
virtual std::vector<std::string> get_output_names() const = 0;
|
||||
|
||||
/// \brief Returns a map from internal tensor name to (user-defined) external name for inputs
|
||||
virtual std::map<std::string, std::string> get_input_names_map() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
/// \brief Returns a map from internal tensor name to (user-defined) external name for outputs
|
||||
virtual std::map<std::string, std::string> get_output_names_map() const {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -347,7 +347,23 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
|
||||
else if (variants[0].is<GraphIterator::Ptr>()) {
|
||||
// this is used for OpenVINO with TensorFlow Integration
|
||||
auto graph_iterator = variants[0].as<GraphIterator::Ptr>();
|
||||
return std::make_shared<InputModel>(graph_iterator, m_telemetry);
|
||||
std::shared_ptr<std::map<std::string, std::string>> input_names_map = nullptr;
|
||||
std::shared_ptr<std::map<std::string, std::string>> output_names_map = nullptr;
|
||||
if (graph_iterator->get_input_names_map().size() > 0) {
|
||||
input_names_map =
|
||||
std::make_shared<std::map<std::string, std::string>>(graph_iterator->get_input_names_map());
|
||||
}
|
||||
if (graph_iterator->get_output_names_map().size() > 0) {
|
||||
output_names_map =
|
||||
std::make_shared<std::map<std::string, std::string>>(graph_iterator->get_output_names_map());
|
||||
}
|
||||
return std::make_shared<InputModel>(graph_iterator,
|
||||
m_telemetry,
|
||||
nullptr,
|
||||
input_names_map,
|
||||
output_names_map,
|
||||
nullptr,
|
||||
false);
|
||||
}
|
||||
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
|
@ -66,31 +66,40 @@ void adjust_saved_model_names(ov::Output<ov::Node>& ov_output,
|
||||
// 2. find a set of clean-up names and aligned with the model signature
|
||||
const auto& tensor_names = ov_output.get_names();
|
||||
std::unordered_set<std::string> cleanup_names;
|
||||
bool signature_passed = true;
|
||||
if (is_input_tensor) {
|
||||
for (const auto& tensor_name : tensor_names) {
|
||||
if (saved_model_input_names->count(tensor_name) > 0) {
|
||||
cleanup_names.insert(saved_model_input_names->at(tensor_name));
|
||||
param_node->set_friendly_name(saved_model_input_names->at(tensor_name));
|
||||
if (saved_model_input_names) {
|
||||
for (const auto& tensor_name : tensor_names) {
|
||||
if (saved_model_input_names->count(tensor_name) > 0) {
|
||||
cleanup_names.insert(saved_model_input_names->at(tensor_name));
|
||||
param_node->set_friendly_name(saved_model_input_names->at(tensor_name));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
signature_passed = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_output_tensor) {
|
||||
std::vector<std::string> result_names;
|
||||
for (const auto& tensor_name : tensor_names) {
|
||||
if (saved_model_output_names->count(tensor_name) > 0) {
|
||||
cleanup_names.insert(saved_model_output_names->at(tensor_name));
|
||||
result_names.push_back(saved_model_output_names->at(tensor_name));
|
||||
if (saved_model_output_names) {
|
||||
std::vector<std::string> result_names;
|
||||
for (const auto& tensor_name : tensor_names) {
|
||||
if (saved_model_output_names->count(tensor_name) > 0) {
|
||||
cleanup_names.insert(saved_model_output_names->at(tensor_name));
|
||||
result_names.push_back(saved_model_output_names->at(tensor_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
// align the Result node names as many as possible
|
||||
// it is not bad if we remain it as is because OV API 2.0 relies only on tensor names
|
||||
size_t result_names_size = result_names.size();
|
||||
if (result_names_size > 0) {
|
||||
for (size_t ind = 0; ind < results.size(); ++ind) {
|
||||
auto new_result_name = result_names[ind % result_names_size];
|
||||
results[ind]->set_friendly_name(new_result_name);
|
||||
// align the Result node names as many as possible
|
||||
// it is not bad if we remain it as is because OV API 2.0 relies only on tensor names
|
||||
size_t result_names_size = result_names.size();
|
||||
if (result_names_size > 0) {
|
||||
for (size_t ind = 0; ind < results.size(); ++ind) {
|
||||
auto new_result_name = result_names[ind % result_names_size];
|
||||
results[ind]->set_friendly_name(new_result_name);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
signature_passed = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,7 +107,7 @@ void adjust_saved_model_names(ov::Output<ov::Node>& ov_output,
|
||||
// otherwise, the tensor corresponds to unused Parameter or Result nodes
|
||||
if (cleanup_names.size() > 0) {
|
||||
ov_output.set_names(cleanup_names);
|
||||
} else {
|
||||
} else if (signature_passed) {
|
||||
// this is unused tensor that should be removed
|
||||
// because it not present in the signature
|
||||
ov_output.add_names({"saved_model_unused"});
|
||||
|
@ -4,8 +4,8 @@ import gc
|
||||
|
||||
import numpy as np
|
||||
from models_hub_common.multiprocessing_utils import multiprocessing_run
|
||||
from openvino import convert_model
|
||||
from openvino.runtime import Core
|
||||
from openvino.tools.mo import convert_model
|
||||
|
||||
|
||||
class TestConvertModel:
|
||||
@ -32,9 +32,9 @@ class TestConvertModel:
|
||||
assert False, "Unsupported type {}".format(input_type)
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
inputs = []
|
||||
for input_shape, input_type in inputs_info:
|
||||
inputs.append(self.prepare_input(input_shape, input_type))
|
||||
inputs = {}
|
||||
for input_name, input_shape, input_type in inputs_info:
|
||||
inputs[input_name] = self.prepare_input(input_shape, input_type)
|
||||
return inputs
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
|
@ -32,7 +32,8 @@ class TestTFHubConvertModel(TestConvertModel):
|
||||
|
||||
def get_inputs_info(self, model_obj):
|
||||
inputs_info = []
|
||||
for input_info in model_obj.inputs:
|
||||
assert len(model_obj.structured_input_signature) > 1, "incorrect model or test issue"
|
||||
for input_name, input_info in model_obj.structured_input_signature[1].items():
|
||||
input_shape = []
|
||||
try:
|
||||
for dim in input_info.shape.as_list():
|
||||
@ -55,32 +56,25 @@ class TestTFHubConvertModel(TestConvertModel):
|
||||
tf.string: str,
|
||||
tf.bool: bool,
|
||||
}
|
||||
if input_info.dtype not in type_map:
|
||||
if input_info.dtype == tf.resource:
|
||||
# skip inputs corresponding to variables
|
||||
continue
|
||||
assert input_info.dtype in type_map, "Unsupported input type: {}".format(input_info.dtype)
|
||||
inputs_info.append((input_shape, type_map[input_info.dtype]))
|
||||
inputs_info.append((input_name, input_shape, type_map[input_info.dtype]))
|
||||
|
||||
return inputs_info
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
# TODO 119141 - use the same dictionary for OV inference
|
||||
# repack input dictionary to tensorflow constants
|
||||
tf_inputs = {}
|
||||
for input_ind, input_name in enumerate(sorted(model_obj.structured_input_signature[1].keys())):
|
||||
tf_inputs[input_name] = tf.constant(inputs[input_ind])
|
||||
for input_name, input_value in inputs.items():
|
||||
tf_inputs[input_name] = tf.constant(input_value)
|
||||
|
||||
output_dict = {}
|
||||
for out_name, out_value in model_obj(**tf_inputs).items():
|
||||
output_dict[out_name] = out_value.numpy()
|
||||
|
||||
# TODO: 119141 - remove this workaround
|
||||
# map external tensor names to internal names
|
||||
assert len(model_obj.outputs) == len(model_obj.structured_outputs)
|
||||
fw_outputs = {}
|
||||
for output_ind, external_name in enumerate(sorted(model_obj.structured_outputs.keys())):
|
||||
internal_name = model_obj.outputs[output_ind].name
|
||||
out_value = output_dict[external_name]
|
||||
fw_outputs[internal_name] = out_value
|
||||
return fw_outputs
|
||||
return output_dict
|
||||
|
||||
def teardown_method(self):
|
||||
# remove all downloaded files for TF Hub models
|
||||
|
Loading…
Reference in New Issue
Block a user