tf.Graph decoder. (#16355)

* tf.Graph decoder.

* Fix conflicts.

* Fixed det_input_node()

* Added support of non-frozen models.

* Cleaned code.

* Small fix.

* Small corrections.

* Error fixes.

* Code style.

* Code style.

* Code style.

* Small correction.

* Fixed float32 attributes.

* Small correction.

* Fixed tests.

* Fixed errors.

* Added statefull partitioned call test.

* Import fix.

* Code corrections.

* BOM test fixed.

* Corrected check, added comment.

* Added checks.

* Supported TF Fraph Iterator in load_by_model().

* Clang format.

* Small correction.

* Fixed example_input logic, added tests.

* Added comment.

* Small correction.

* Corrected example_input description.

* Moved load_by_model test to MO Python API tests.

* Minor corrections.

* Code corrections.

* Small correction.

* Clang format.

* Fixed tests.

* Import change.

* Moved GraphIterator to common FE.

* Tests refactoring, minor fixes.

* Small test correction.

* Removed not needed change.

* Removed commented code.

* Removed not needed change.

* Unit tests fix.

* Temporarily added debug output.

* Test fix.

* Applied comments.

* Fixed test.
This commit is contained in:
Anastasiia Pnevskaia
2023-06-13 14:04:26 +02:00
committed by GitHub
parent af9204488d
commit 77711be786
32 changed files with 1101 additions and 129 deletions

View File

@@ -13,6 +13,7 @@ from openvino.utils import _add_openvino_libs_to_search_path
_add_openvino_libs_to_search_path()
try:
from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndPyGraphIterator as GraphIterator
from openvino.frontend.tensorflow.py_tensorflow_frontend import ConversionExtensionTensorflow as ConversionExtension
from openvino.frontend.tensorflow.py_tensorflow_frontend import OpExtensionTensorflow as OpExtension
except ImportError as err:

View File

@@ -0,0 +1,89 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
# mypy: ignore-errors
import tensorflow as tf
from openvino.frontend.tensorflow.node_decoder import TFGraphNodeDecoder
from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndPyGraphIterator as GraphIterator
class GraphIteratorTFGraph(GraphIterator):
def __init__(self, tf_graph: tf.Graph, inner_graph: bool = False):
GraphIterator.__init__(self)
self.m_graph = tf_graph
self.m_node_index = 0
self.m_decoders = []
self.m_inner_graph = inner_graph
self.m_vars = None
if hasattr(tf_graph, "variables"):
# This field is needed to keep the link to graph variables,
# otherwise Python releases memory kept by variables when it is accessed from c++ bindings
self.m_vars = tf_graph.variables
for op in tf_graph.get_operations():
self.m_decoders.append(TFGraphNodeDecoder(op, inner_graph))
self.m_iterators = {}
for func_name, _ in self.m_graph._functions.items():
self.m_iterators[func_name] = None
def get_input_names(self) -> list:
inp_ops = filter(lambda op: op.type == "Placeholder", self.m_graph.get_operations())
inp_names = []
for inp in inp_ops:
assert isinstance(inp, tf.Operation), "Unknown node type. Expected tf.Operation, got {}".format(type(inp))
assert hasattr(inp, "node_def") and isinstance(inp.node_def, tf.compat.v1.NodeDef), \
"Could not find node_def in node {}".format(inp.name)
type_attr = inp.node_def.attr["dtype"].type
# Placeholders with type "resource" have exact values in "variables" field,
# so they are passed to TF FE as constants.
# For this reason they are not listed as model inputs.
if tf.dtypes.DType(type_attr).name != "resource" or self.m_inner_graph:
inp_names.append(inp.name)
return inp_names
def get_output_names(self) -> list:
# tf.Graph has ordered outputs which are stored in 'outputs' field,
# but using this field results in mismatch of outputs in inner graph and outputs in outer graph
# during the injection of subgraph.
# For this reason only nodes without outputs are considered graph outputs here
# as this approach does not lead to conflicts.
# The order of outputs is important and wrong order may lead to conversion error.
non_outputs = set()
for op in self.m_graph.get_operations():
assert isinstance(op, tf.Operation), "Unknown node type. Expected tf.Operation, got {}".format(type(op))
for inp in op.inputs:
non_outputs.add(inp.op.name)
outputs = []
for op in self.m_graph.get_operations():
if op.name not in non_outputs:
for output in op.outputs:
outputs = [output.name] + outputs
return outputs
def is_end(self) -> bool:
return self.m_node_index >= len(self.m_decoders)
def reset(self):
self.m_node_index = 0
def size(self) -> int:
return len(self.m_decoders)
def next_impl(self):
self.m_node_index += 1
def get_decoder(self):
return self.m_decoders[self.m_node_index]
def get_body_graph_iterator(self, func_name):
if func_name not in self.m_iterators:
return None
if self.m_iterators[func_name] is None:
self.m_iterators[func_name] = GraphIteratorTFGraph(self.m_graph._functions[func_name].graph, True)
return self.m_iterators[func_name]

View File

@@ -0,0 +1,157 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
# mypy: ignore-errors
import numpy as np
import tensorflow as tf
from openvino.frontend.tensorflow.py_tensorflow_frontend import _FrontEndDecoderBase as DecoderBase
from openvino.runtime import PartialShape, Type, OVAny, Tensor
def tf_type_to_ov_type(tf_type_int):
tf_type = tf.dtypes.as_dtype(tf_type_int)
if tf_type.name == "variant":
return Type.dynamic
numpy_type = tf_type.as_numpy_dtype
try:
ret_type = Type(numpy_type)
except:
ret_type = Type.undefined
return ret_type
def tf_attr_to_numpy(attr):
attr_type = attr.WhichOneof("value")
if attr_type == "func":
return attr.func.name
if attr_type == "s":
return attr.s.decode("utf-8")
if attr_type == "f":
return np.float32(attr.f)
if attr_type == "type":
return tf_type_to_ov_type(attr.type)
if attr_type == "list":
list_value = attr.list
return list(list_value.ListFields()[0][1])
if attr_type is None:
return None
return getattr(attr, attr.WhichOneof("value"))
def tf_attr_to_ov(attr):
return OVAny(tf_attr_to_numpy(attr))
class TFGraphNodeDecoder(DecoderBase):
def __init__(self, operation: tf.Operation, inner_graph: bool):
DecoderBase.__init__(self)
assert isinstance(operation, tf.Operation), "Unknown operation type. " \
"Expected tf.Operation, got {}".format(type(operation))
self.m_operation = operation
self.m_inner_graph = inner_graph
if self.m_operation.type == "Const":
value = self.m_operation.node_def.attr["value"].tensor
# copies tensor value from node_def
self.m_parsed_content = tf.make_ndarray(value)
if self.m_operation.type == "Placeholder":
data_type = self.m_operation.node_def.attr["dtype"].type
if tf.dtypes.DType(data_type).name == "resource" and not self.m_inner_graph:
variable_value = TFGraphNodeDecoder.get_variable(self.m_operation)
if variable_value is not None:
# does not copy data
self.m_parsed_content = variable_value.value().__array__()
def get_op_name(self) -> str:
return self.m_operation.name
def get_op_type(self) -> str:
if self.m_operation.type == "Placeholder":
type_attr = tf.dtypes.DType(self.m_operation.node_def.attr["dtype"].type)
if type_attr.name == "resource" and not self.m_inner_graph:
if TFGraphNodeDecoder.get_variable(self.m_operation) is not None:
return "Const"
raise Exception("Could not get variable for resource Placeholder {0}".format(self.m_operation.name))
return self.m_operation.type
@staticmethod
def get_variable(operation):
tf_graph = operation.graph
if not hasattr(tf_graph, "captures"):
return None
for var_tensor, op_tensor in tf_graph.captures:
if operation.outputs[0].name == op_tensor.name:
resource_name = var_tensor._name
for variable_value in operation.graph.variables:
if variable_value.name == resource_name:
return variable_value
return None
return None
def get_attribute(self, name):
if name == "shape" or name == "_output_shapes":
if self.m_operation.node_def.attr["shape"].shape.unknown_rank:
return OVAny(PartialShape.dynamic())
shape_dims = self.m_operation.node_def.attr["shape"].shape.dim
shape = [dim.size for dim in shape_dims]
type_num = self.m_operation.node_def.attr["dtype"].type
if type_num is not None and tf.dtypes.DType(type_num).name == "resource":
if self.m_inner_graph:
return OVAny(PartialShape.dynamic())
variable_value = TFGraphNodeDecoder.get_variable(self.m_operation)
return OVAny(PartialShape(list(variable_value.shape)))
return OVAny(PartialShape(shape))
if name == "dtype":
type_num = self.m_operation.node_def.attr["dtype"].type
if tf.dtypes.DType(type_num).name == "resource":
if not self.m_inner_graph:
variable_value = TFGraphNodeDecoder.get_variable(self.m_operation)
return OVAny(tf_type_to_ov_type(variable_value.dtype))
else:
return OVAny(Type.undefined)
return OVAny(tf_type_to_ov_type(type_num))
if name == "value":
if self.m_parsed_content.size == 1:
if isinstance(self.m_parsed_content, np.ndarray):
return OVAny(Tensor(self.m_parsed_content))
return OVAny(Tensor(np.array([self.m_parsed_content]), shape=[1]))
ov_tensor = Tensor(self.m_parsed_content, shared_memory=True)
ov_tensor = OVAny(ov_tensor)
return ov_tensor
attr_value = self.m_operation.node_def.attr[name]
return tf_attr_to_ov(attr_value)
def get_input_size(self) -> int:
return len(self.m_operation.inputs)
def get_input_node_name(self, input_port_idx):
assert input_port_idx >= 0, "Got negative input node index."
assert input_port_idx < len(self.m_operation.inputs), "Input node index is out of range. Got {}, " \
"when number of input nodes {}.".format(input_port_idx,
len(self.m_operation.inputs))
return self.m_operation.inputs[input_port_idx].op.name
def get_input_node_name_output_port_index(self, input_port_idx):
tensor_name = self.m_operation.inputs[input_port_idx].name
if ":" in tensor_name:
port_idx_str = tensor_name[tensor_name.rfind(":") + 1:len(tensor_name)]
if port_idx_str.isdigit():
return int(port_idx_str)
else:
return 0
return 0
def get_input_node_name_output_port_name(self, input_port_idx):
tensor_name = self.m_operation.inputs[input_port_idx].name
if ":" not in tensor_name:
return ""
first_col_idx = tensor_name.find(":")
last_col_idx = tensor_name.rfind(":")
if first_col_idx == last_col_idx:
return ""
return tensor_name[first_col_idx + 1: last_col_idx]

View File

@@ -0,0 +1,221 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
# mypy: ignore-errors
from openvino.tools.mo.moc_frontend.shape_utils import get_static_shape
from openvino.tools.mo.utils.versions_checker import get_environment_setup # pylint: disable=no-name-in-module
from openvino.tools.mo.utils.error import Error
from distutils.version import LooseVersion
import logging as log
def trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_types, example_input):
import tensorflow as tf
if not isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
return input_model
return trace_tf_model(input_model, placeholder_shapes, placeholder_data_types, example_input)
def get_input_spec_from_model(model):
import tensorflow as tf
if hasattr(model, "_build_input_shape") and model._build_input_shape is not None:
if isinstance(model._build_input_shape, list):
input_spec = [[tf.TensorSpec(shape) for shape in model._build_input_shape]]
else:
input_spec = [tf.TensorSpec(model._build_input_shape)]
else:
input_spec = [tf.TensorSpec(None)]
return input_spec
def create_example_input_by_user_shapes(input_shapes, input_types):
import tensorflow as tf
if input_shapes is None:
return None
if isinstance(input_shapes, dict):
res = {}
for name, shape in input_shapes.items():
shape = get_static_shape(shape, 1)
args = {}
if name in input_types:
args["dtype"] = input_types[name]
tensor = tf.zeros(shape=shape, **args)
res[name] = tensor
return res
elif isinstance(input_shapes, list):
res = []
for idx, shape in enumerate(input_shapes):
shape = get_static_shape(shape, 1)
args = {}
if idx < len(input_types):
args["dtype"] = input_types[idx]
tensor = tf.zeros(shape=shape, **args)
res.append(tensor)
return res
raise Error("Could not create example input by provided shape {}".format(input_shapes))
def get_concrete_func(tf_function, example_input, input_needs_packing, error_message, use_example_input=True):
"""
Runs tracing of TF function and returns a concrete function.
:param tf_function: TF function that needs to be traced.
:param example_input: Example of function input.
:param input_needs_packing: determines if input needs to be packed in a list before passing to TF function.
It is used when original function was wrapped in outer TF function, which changes function signature.
In this case wrapper TF function always expects list of inputs which are unpacked inside subfunction.
So list/tuple are treated as multiple inputs of original model.
Non list/tuple are treated as single input, and it needs packing to a list,
as wrapper function always expect list of inputs.
:param error_message: Error message which should be shown in case of tracing error.
:param use_example_input: Determines if example_input should be used.
:returns: Object of type tf.types.experimental.ConcreteFunction.
"""
if input_needs_packing and not isinstance(example_input, (list, tuple)):
example_input = [example_input]
try:
if use_example_input:
if not input_needs_packing and isinstance(example_input, (list, tuple)):
concrete_func = tf_function.get_concrete_function(*example_input)
else:
concrete_func = tf_function.get_concrete_function(example_input)
else:
concrete_func = tf_function.get_concrete_function()
except Exception as e:
raise Exception(error_message.format(e))
return concrete_func
def trace_tf_model(model, input_shapes, input_types, example_input):
import tensorflow as tf
if isinstance(model.__call__, tf.types.experimental.GenericFunction):
tf_function = model.__call__
input_needs_packing = False
elif isinstance(model, tf.types.experimental.GenericFunction):
tf_function = model
input_needs_packing = False
else:
# Wrap model to tf.Function.
# In this case we loose input/output tensor names.
@tf.function
def tf_function(args):
return model(*args)
input_needs_packing = True
if example_input is not None:
concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing,
"Could not trace the TF model with the following error: {}")
elif input_shapes is not None:
inp = create_example_input_by_user_shapes(input_shapes, input_types)
concrete_func = get_concrete_func(tf_function, inp, input_needs_packing,
"Could not trace the TF model with the following error: {}")
else:
if isinstance(tf_function, tf.types.experimental.GenericFunction) and \
tf_function.input_signature is not None:
concrete_func = get_concrete_func(tf_function, None, input_needs_packing,
"Could not trace the TF model with the following error: {}",
use_example_input=False)
else:
input_spec = get_input_spec_from_model(model)
concrete_func = get_concrete_func(tf_function, input_spec, input_needs_packing,
"Could not trace the TF model with the following error: {}.\n"
"Please provide 'example_input'.")
return concrete_func
def type_supported_by_tf_fe(input_model):
import tensorflow as tf
# Types that require tracing
if isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
return True
# Types that do not require tracing
if isinstance(input_model, (tf.Graph, tf.types.experimental.ConcreteFunction)):
return True
# GraphIterator
elif model_is_graph_iterator(input_model):
return True
return False
def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_types, example_input):
input_model = trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_types, example_input)
import tensorflow as tf
from openvino.frontend.tensorflow.graph_iterator import GraphIteratorTFGraph
if model_is_graph_iterator(input_model):
return input_model
if isinstance(input_model, tf.Graph):
return GraphIteratorTFGraph(input_model)
elif isinstance(input_model, tf.types.experimental.ConcreteFunction):
return GraphIteratorTFGraph(input_model.graph)
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))
def extract_model_graph(argv):
model = argv["input_model"]
import tensorflow as tf
trackable_is_imported = False
try:
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
trackable_is_imported = True
except:
log.warning("Could not import tensorflow.python.training.tracking.base.Trackable type.")
env_setup = get_environment_setup("tf")
if isinstance(model, tf.Graph):
return True
if isinstance(model, tf.compat.v1.GraphDef):
graph = tf.Graph()
with graph.as_default():
tf.graph_util.import_graph_def(model)
argv["input_model"] = graph
return True
if isinstance(model, tf.compat.v1.Session):
argv["input_model"] = model.graph
return True
if env_setup["tensorflow"] >= LooseVersion("2.6.0") and isinstance(model, (tf.types.experimental.GenericFunction,
tf.types.experimental.ConcreteFunction)):
return True
if isinstance(model, tf.train.Checkpoint):
if isinstance(model.root, tf.keras.Model):
argv["input_model"] = model.root
return True
else:
raise Error("Unknown checkpoint format.")
if isinstance(model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model)):
return True
if trackable_is_imported and isinstance(model, Trackable):
if hasattr(model, "signatures") and len(model.signatures.items()):
if "serving_default" in model.signatures:
argv["input_model"] = model.signatures["serving_default"]
elif "default" in model.signatures:
argv["input_model"] = model.signatures["default"]
else:
for signature_name, signature in model.signatures.items():
argv["input_model"] = model.signatures[signature_name]
log.warning("Could not find the default signature. "
"The following signature was used for conversion: {}".format(signature_name))
break
elif hasattr(model, "graph"):
argv["input_model"] = model.graph
else:
raise Error("Could not find signature of graph in a Trackable object.")
return True
if model_is_graph_iterator(model):
return True
return False
def model_is_graph_iterator(model):
try:
from openvino.frontend.tensorflow.graph_iterator import GraphIteratorTFGraph
except:
return False
return isinstance(model, GraphIteratorTFGraph)

View File

@@ -59,6 +59,21 @@ void regclass_frontend_FrontEnd(py::module m) {
:rtype: openvino.frontend.InputModel
)");
fem.def(
"supported",
[](FrontEnd& self, const py::object& model) {
return self.supported({Common::utils::py_object_to_any(model)});
},
py::arg("model"),
R"(
Checks if model type is supported.
:param model: Object describing the model. It can be path to model file.
:type model: Any
:return: True if model type is supported, otherwise False.
:rtype: bool
)");
fem.def("convert",
static_cast<std::shared_ptr<ov::Model> (FrontEnd::*)(const InputModel::Ptr&) const>(&FrontEnd::convert),
py::arg("model"),

View File

@@ -10,6 +10,7 @@
#include "openvino/frontend/exception.hpp"
#include "pyopenvino/frontend/manager.hpp"
#include "pyopenvino/utils/utils.hpp"
namespace py = pybind11;
@@ -76,15 +77,19 @@ void regclass_frontend_FrontEndManager(py::module m) {
fem.def(
"load_by_model",
[](const std::shared_ptr<ov::frontend::FrontEndManager>& fem, const std::string& model_path) {
return fem->load_by_model(model_path);
[](const std::shared_ptr<ov::frontend::FrontEndManager>& fem, const py::object& model) {
if (py::isinstance(model, py::module_::import("pathlib").attr("Path"))) {
std::string model_path = Common::utils::convert_path_to_string(model);
return fem->load_by_model(model_path);
}
return fem->load_by_model({Common::utils::py_object_to_any(model)});
},
py::arg("model_path"),
py::arg("model"),
R"(
Selects and loads appropriate frontend depending on model file extension and other file info (header).
Selects and loads appropriate frontend depending on model type or model file extension and other file info (header).
:param model_path: A path to a model file/directory.
:type model_path: str
:param model_path: A model object or path to a model file/directory.
:type model_path: Any
:return: Frontend interface for further loading of models. 'None' if no suitable frontend is found.
:rtype: openvino.frontend.FrontEnd
)");

View File

@@ -0,0 +1,21 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include "decoder_base.hpp"
namespace py = pybind11;
using namespace ov::frontend;
using ov::Any;
void regclass_frontend_tensorflow_decoder_base(py::module m) {
py::class_<ov::frontend::tensorflow::DecoderBase, IDecoder, PyDecoderBase, std::shared_ptr<ov::frontend::tensorflow::DecoderBase>> cls(m, "_FrontEndDecoderBase");
cls.def(py::init<>());
}

View File

@@ -0,0 +1,54 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
#include "openvino/frontend/tensorflow/decoder.hpp"
namespace py = pybind11;
/// Trampoline class to support inheritance from GraphIterator in Python
class PyDecoderBase : public ov::frontend::tensorflow::DecoderBase {
ov::Any get_attribute(const std::string &name) const override{
PYBIND11_OVERRIDE_PURE(ov::Any, DecoderBase, get_attribute, name);
}
size_t get_input_size() const override{
PYBIND11_OVERRIDE_PURE(size_t, DecoderBase, get_input_size);
}
std::string get_input_node_name(size_t input_port_idx) const {
PYBIND11_OVERRIDE_PURE(std::string, DecoderBase, get_input_node_name, input_port_idx);
}
size_t get_input_node_name_output_port_index(size_t input_port_idx) const {
PYBIND11_OVERRIDE_PURE(size_t, DecoderBase, get_input_node_name_output_port_index, input_port_idx);
}
std::string get_input_node_name_output_port_name(size_t input_port_idx) const {
PYBIND11_OVERRIDE_PURE(std::string, DecoderBase, get_input_node_name_output_port_name, input_port_idx);
}
void get_input_node(size_t input_port_idx,
std::string &producer_name,
std::string &producer_output_port_name,
size_t &producer_output_port_index) const override{
producer_name = get_input_node_name(input_port_idx);
producer_output_port_index = get_input_node_name_output_port_index(input_port_idx);
producer_output_port_name = get_input_node_name_output_port_name(input_port_idx);
}
const std::string &get_op_type() const override{
PYBIND11_OVERRIDE_PURE(std::string&, DecoderBase, get_op_type);
}
const std::string &get_op_name() const override{
PYBIND11_OVERRIDE_PURE(std::string&, DecoderBase, get_op_name);
}
};
void regclass_frontend_tensorflow_decoder_base(py::module m);

View File

@@ -0,0 +1,22 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include "graph_iterator.hpp"
#include "openvino/frontend/graph_iterator.hpp"
namespace py = pybind11;
using namespace ov::frontend;
using ov::Any;
void regclass_frontend_tensorflow_graph_iterator(py::module m) {
py::class_<ov::frontend::tensorflow::GraphIterator, PyGraphIterator, std::shared_ptr<ov::frontend::tensorflow::GraphIterator>>(m, "_FrontEndPyGraphIterator")
.def(py::init<>());
}

View File

@@ -0,0 +1,66 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
#include "openvino/frontend/graph_iterator.hpp"
#include "openvino/frontend/decoder.hpp"
namespace py = pybind11;
/// Trampoline class to support inheritance from GraphIterator in Python
class PyGraphIterator : public ov::frontend::tensorflow::GraphIterator {
/* Inherit the constructors */
using ov::frontend::tensorflow::GraphIterator::GraphIterator;
/// \brief Get a number of operation nodes in the graph
size_t size() const override{
PYBIND11_OVERRIDE_PURE(size_t, GraphIterator, size);
}
/// \brief Set iterator to the start position
void reset() override {
PYBIND11_OVERRIDE_PURE(void, GraphIterator, reset);
}
/// \brief Move to the next node in the graph
void next() override {
next_impl();
}
/// Implementation of next method, it is needed to be in separate method to avoid shadowing of Python "next" operator.
void next_impl() {
PYBIND11_OVERRIDE_PURE(void, GraphIterator, next_impl);
}
/// \brief Returns true if iterator goes out of the range of available nodes
bool is_end() const override {
PYBIND11_OVERRIDE_PURE(bool, GraphIterator, is_end);
}
/// \brief Return a pointer to a decoder of the current node
std::shared_ptr<ov::frontend::DecoderBase> get_decoder() const override{
PYBIND11_OVERRIDE_PURE(std::shared_ptr<ov::frontend::DecoderBase>, GraphIterator, get_decoder);
}
/// \brief Checks if the main model graph contains a function of the requested name in the library
/// Returns GraphIterator to this function and nullptr, if it does not exist
std::shared_ptr<GraphIterator> get_body_graph_iterator(const std::string& func_name) const override{
PYBIND11_OVERRIDE_PURE(std::shared_ptr<GraphIterator>, GraphIterator, get_body_graph_iterator, func_name);
}
/// \brief Returns a vector of input names in the original order
std::vector<std::string> get_input_names() const override{
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_input_names);
}
/// \brief Returns a vector of output names in the original order
std::vector<std::string> get_output_names() const override{
PYBIND11_OVERRIDE_PURE(std::vector<std::string>, GraphIterator, get_output_names);
}
};
void regclass_frontend_tensorflow_graph_iterator(py::module m);

View File

@@ -7,10 +7,14 @@
#include <string>
#include "extension.hpp"
#include "graph_iterator.hpp"
#include "decoder_base.hpp"
namespace py = pybind11;
PYBIND11_MODULE(py_tensorflow_frontend, m) {
regclass_frontend_tensorflow_ConversionExtension(m);
regclass_frontend_tensorflow_OpExtension(m);
regclass_frontend_tensorflow_graph_iterator(m);
regclass_frontend_tensorflow_decoder_base(m);
}

View File

@@ -15,6 +15,7 @@
#include "meta_data.hpp"
#include "openvino/core/except.hpp"
#include "openvino/frontend/decoder.hpp"
#include "openvino/frontend/graph_iterator.hpp"
using Version = ov::pass::Serialize::Version;
@@ -284,17 +285,22 @@ ov::AnyMap py_object_to_any_map(const py::object& py_obj) {
ov::Any py_object_to_any(const py::object& py_obj) {
// Python types
py::object float_32_type = py::module_::import("numpy").attr("float32");
if (py::isinstance<py::str>(py_obj)) {
return py_obj.cast<std::string>();
} else if (py::isinstance<py::bool_>(py_obj)) {
return py_obj.cast<bool>();
} else if (py::isinstance<py::float_>(py_obj)) {
return py_obj.cast<double>();
} else if (py::isinstance(py_obj, float_32_type)) {
return py_obj.cast<float>();
} else if (py::isinstance<py::int_>(py_obj)) {
return py_obj.cast<int64_t>();
} else if (py::isinstance<py::none>(py_obj)) {
return {};
} else if (py::isinstance<py::list>(py_obj)) {
auto _list = py_obj.cast<py::list>();
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL };
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL, PARTIAL_SHAPE };
PY_TYPE detected_type = PY_TYPE::UNKNOWN;
for (const auto& it : _list) {
auto check_type = [&](PY_TYPE type) {
@@ -312,6 +318,8 @@ ov::Any py_object_to_any(const py::object& py_obj) {
check_type(PY_TYPE::FLOAT);
} else if (py::isinstance<py::bool_>(it)) {
check_type(PY_TYPE::BOOL);
} else if (py::isinstance<ov::PartialShape>(it)) {
check_type(PY_TYPE::PARTIAL_SHAPE);
}
}
@@ -327,6 +335,8 @@ ov::Any py_object_to_any(const py::object& py_obj) {
return _list.cast<std::vector<int64_t>>();
case PY_TYPE::BOOL:
return _list.cast<std::vector<bool>>();
case PY_TYPE::PARTIAL_SHAPE:
return _list.cast<std::vector<ov::PartialShape>>();
default:
OPENVINO_ASSERT(false, "Unsupported attribute type.");
}
@@ -337,6 +347,8 @@ ov::Any py_object_to_any(const py::object& py_obj) {
return py::cast<ov::Any>(py_obj);
} else if (py::isinstance<ov::element::Type>(py_obj)) {
return py::cast<ov::element::Type>(py_obj);
} else if (py::isinstance<ov::PartialShape>(py_obj)) {
return py::cast<ov::PartialShape>(py_obj);
} else if (py::isinstance<ov::hint::Priority>(py_obj)) {
return py::cast<ov::hint::Priority>(py_obj);
} else if (py::isinstance<ov::hint::PerformanceMode>(py_obj)) {
@@ -351,9 +363,14 @@ ov::Any py_object_to_any(const py::object& py_obj) {
return py::cast<ov::streams::Num>(py_obj);
} else if (py::isinstance<ov::Affinity>(py_obj)) {
return py::cast<ov::Affinity>(py_obj);
} else if (py::isinstance<ov::Tensor>(py_obj)) {
return py::cast<ov::Tensor>(py_obj);
// FrontEnd Decoder
} else if (py::isinstance<ov::frontend::IDecoder>(py_obj)) {
return py::cast<std::shared_ptr<ov::frontend::IDecoder>>(py_obj);
// TF FrontEnd GraphIterator
} else if (py::isinstance<ov::frontend::tensorflow::GraphIterator>(py_obj)) {
return py::cast<std::shared_ptr<ov::frontend::tensorflow::GraphIterator>>(py_obj);
// Custom FrontEnd Types
} else if (py::isinstance<ov::frontend::type::Tensor>(py_obj)) {
return py::cast<ov::frontend::type::Tensor>(py_obj);

View File

@@ -123,6 +123,18 @@ def test_load_by_model():
assert stat.supported == 1
@mock_needed
def test_load_by_model_path():
clear_all_stat()
import pathlib
fe = fem.load_by_model(pathlib.Path("abc.test_mock_py_mdl"))
assert fe is not None
assert fe.get_name() == MOCK_PY_FRONTEND_NAME
stat = get_fe_stat()
assert stat.get_name == 1
assert stat.supported == 1
@mock_needed
def test_convert_model():
clear_all_stat()