From 38a1783527172d4ea11d6f87293b47c51a1192ae Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 1 Dec 2022 14:37:25 +0400 Subject: [PATCH] [FE] Support freezing Placeholder without specifying type (#13984) * [FE] Fix freezing placeholder via freeze_placeholder_with_value option Currently, freezing placeholder via freeze_placeholder_with_value option does not work for any frontends. It happens due to absence of a node in _input_shapes dictionary. * Add get_element_type method for InputModel * Revert not needed changes * Revert not needed changes * Update freeze_placeholder_test * Add tests for TF FE with freezing values of different types * Fix Python API return value * Correct returned type for get_numpy_ctype * Apply code-review feedback * Update src/frontends/tensorflow/src/input_model.hpp Co-authored-by: Maxim Vafin * Apply code-review feedback: no tf legacy specific routine and parameter names Co-authored-by: Maxim Vafin --- .../src/openvino/runtime/utils/types.py | 13 + .../src/pyopenvino/frontend/input_model.cpp | 12 + .../include/openvino/frontend/input_model.hpp | 5 + src/frontends/common/src/input_model.cpp | 5 + src/frontends/tensorflow/src/frontend.cpp | 7 + src/frontends/tensorflow/src/input_model.cpp | 9 + src/frontends/tensorflow/src/input_model.hpp | 1 + .../tools/mo/moc_frontend/extractor.py | 95 ++++-- .../tools/mo/moc_frontend/pipeline.py | 79 +++-- .../mo/openvino/tools/mo/utils/cli_parser.py | 3 +- .../mo/openvino/tools/mo/utils/type_utils.py | 21 ++ ...der_test.py => freeze_placeholder_test.py} | 150 +++++---- .../mo/utils/freeze_placeholder_test_tf_fe.py | 304 ++++++++++++++++++ 13 files changed, 593 insertions(+), 111 deletions(-) rename tools/mo/unit_tests/mo/utils/{freeze_placholder_test.py => freeze_placeholder_test.py} (62%) create mode 100644 tools/mo/unit_tests/mo/utils/freeze_placeholder_test_tf_fe.py diff --git a/src/bindings/python/src/openvino/runtime/utils/types.py b/src/bindings/python/src/openvino/runtime/utils/types.py index 71282c6ceba..794ad1220f9 100644 --- a/src/bindings/python/src/openvino/runtime/utils/types.py +++ b/src/bindings/python/src/openvino/runtime/utils/types.py @@ -106,6 +106,19 @@ def get_dtype(openvino_type: Type) -> np.dtype: raise OVTypeError("Unidentified data type %s", openvino_type) +def get_numpy_ctype(openvino_type: Type) -> type: + """Return numpy ctype for an openvino element type.""" + np_type = next( + (np_type for (ov_type, np_type) in openvino_to_numpy_types_map if ov_type == openvino_type), + None, + ) + + if np_type: + return np_type + + raise OVTypeError("Unidentified data type %s", openvino_type) + + def get_ndarray(data: NumericData) -> np.ndarray: """Wrap data into a numpy ndarray.""" if type(data) == np.ndarray: diff --git a/src/bindings/python/src/pyopenvino/frontend/input_model.cpp b/src/bindings/python/src/pyopenvino/frontend/input_model.cpp index a144b9d178e..0cfbc4a5efa 100644 --- a/src/bindings/python/src/pyopenvino/frontend/input_model.cpp +++ b/src/bindings/python/src/pyopenvino/frontend/input_model.cpp @@ -294,6 +294,18 @@ void regclass_frontend_InputModel(py::module m) { :type type: openvino.runtime.Type )"); + im.def("get_element_type", + &ov::frontend::InputModel::get_element_type, + py::arg("place"), + R"( + Returns current element type used for this place. + + :param place: Model place. + :type place: openvino.frontend.Place + :return: Element type for this place. + :rtype: openvino.runtime.Type + )"); + im.def( "set_tensor_value", [](ov::frontend::InputModel& self, const ov::frontend::Place::Ptr& place, py::array& value) { diff --git a/src/frontends/common/include/openvino/frontend/input_model.hpp b/src/frontends/common/include/openvino/frontend/input_model.hpp index adba7b76cb9..cf77eaf6747 100644 --- a/src/frontends/common/include/openvino/frontend/input_model.hpp +++ b/src/frontends/common/include/openvino/frontend/input_model.hpp @@ -195,6 +195,11 @@ public: /// \param type New element type virtual void set_element_type(const Place::Ptr& place, const ov::element::Type& type); + /// \brief Returns current element type used for this place + /// \param place Model place + /// \return Element type for this place + virtual ov::element::Type get_element_type(const Place::Ptr& place) const; + /// \brief Freezes a tensor with statically defined value or replace existing value for /// already constant node or tensor /// \param place Tensor place diff --git a/src/frontends/common/src/input_model.cpp b/src/frontends/common/src/input_model.cpp index 7a71ceaf6c4..e4c152078a9 100644 --- a/src/frontends/common/src/input_model.cpp +++ b/src/frontends/common/src/input_model.cpp @@ -140,6 +140,11 @@ void InputModel::set_element_type(const Place::Ptr& place, const element::Type& FRONTEND_CALL_STATEMENT("set_element_type", m_actual->set_element_type(place, type)) } +element::Type InputModel::get_element_type(const Place::Ptr& place) const { + FRONT_END_CHECK_IMPLEMENTED(m_actual, get_element_type); + FRONTEND_RETURN_STATEMENT("get_element_type", m_actual->get_element_type(place)) +} + void InputModel::set_tensor_value(const Place::Ptr& place, const void* value) { FRONT_END_CHECK_IMPLEMENTED(m_actual, set_tensor_value); FRONTEND_CALL_STATEMENT("set_tensor_value", m_actual->set_tensor_value(place, value)) diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 031fcfce76c..02f1e6b40b2 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -20,6 +20,7 @@ #include "tf_framework_node.hpp" #include "utils.hpp" +using namespace ov; using namespace ov::frontend::tensorflow; namespace { @@ -95,6 +96,12 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model, auto input_shape = input_tensor_place->get_partial_shape(); auto input_type = input_tensor_place->get_element_type(); + // in case of cutting graph, types of custom inputs can be undefined, + // according to MO help, fp32 is used by default in such cases + if (input_type == element::undefined) { + input_type = element::f32; + } + auto param = std::make_shared(input_type, input_shape); set_node_name(input_name, param); params.push_back(param); diff --git a/src/frontends/tensorflow/src/input_model.cpp b/src/frontends/tensorflow/src/input_model.cpp index c3dd5d7bb1e..d6f6ccebbb1 100644 --- a/src/frontends/tensorflow/src/input_model.cpp +++ b/src/frontends/tensorflow/src/input_model.cpp @@ -66,6 +66,7 @@ public: void setPartialShape(ov::frontend::Place::Ptr place, const ov::PartialShape&); ov::PartialShape getPartialShape(ov::frontend::Place::Ptr place) const; void setElementType(ov::frontend::Place::Ptr place, const ov::element::Type&); + ov::element::Type getElementType(ov::frontend::Place::Ptr place) const; void setTensorValue(ov::frontend::Place::Ptr place, const void* value); std::vector> get_op_places() const; @@ -373,6 +374,10 @@ void InputModel::InputModelTFImpl::setElementType(ov::frontend::Place::Ptr place castToTensorPlace(place)->set_element_type(type); } +ov::element::Type InputModel::InputModelTFImpl::getElementType(ov::frontend::Place::Ptr place) const { + return castToTensorPlace(place)->get_element_type(); +} + void InputModel::InputModelTFImpl::setTensorValue(ov::frontend::Place::Ptr place, const void* value) { m_graph_changed = true; auto tensor_place = castToTensorPlace(place); @@ -436,6 +441,10 @@ void InputModel::set_element_type(const ov::frontend::Place::Ptr& place, const o _impl->setElementType(place, type); } +ov::element::Type InputModel::get_element_type(const ov::frontend::Place::Ptr& place) const { + return _impl->getElementType(place); +} + void InputModel::set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) { _impl->setTensorValue(place, value); } diff --git a/src/frontends/tensorflow/src/input_model.hpp b/src/frontends/tensorflow/src/input_model.hpp index e8f96365c30..d561d5b1adf 100644 --- a/src/frontends/tensorflow/src/input_model.hpp +++ b/src/frontends/tensorflow/src/input_model.hpp @@ -39,6 +39,7 @@ public: void set_partial_shape(const ov::frontend::Place::Ptr& place, const ov::PartialShape&) override; ov::PartialShape get_partial_shape(const ov::frontend::Place::Ptr& place) const override; void set_element_type(const ov::frontend::Place::Ptr& place, const ov::element::Type&) override; + ov::element::Type get_element_type(const ov::frontend::Place::Ptr& place) const override; void set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) override; }; diff --git a/tools/mo/openvino/tools/mo/moc_frontend/extractor.py b/tools/mo/openvino/tools/mo/moc_frontend/extractor.py index 641bbe530e9..681face03df 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/extractor.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/extractor.py @@ -2,16 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import re - -from openvino.tools.mo.front.extractor import raise_no_node, raise_node_name_collision -from openvino.tools.mo.utils.error import Error -from openvino._pyopenvino import Place, Type, PartialShape - -from openvino.frontend import InputModel # pylint: disable=no-name-in-module,import-error +from enum import Enum import numpy as np +from openvino._pyopenvino import Place, PartialShape -from enum import Enum +from openvino.frontend import InputModel # pylint: disable=no-name-in-module,import-error +from openvino.tools.mo.front.extractor import raise_no_node, raise_node_name_collision +from openvino.tools.mo.utils.error import Error class IOType(Enum): @@ -20,7 +18,7 @@ class IOType(Enum): def decode_name_with_port( - input_model: InputModel, node_name: str, framework="", io_type=IOType.Input + input_model: InputModel, node_name: str, framework="", io_type=IOType.Input ) -> Place or None: """ Decode name with optional port specification w/o traversing all the nodes in the graph @@ -144,11 +142,11 @@ def decode_name_with_port( def fe_input_user_data_repack( - input_model: InputModel, - input_user_shapes: [None, list, dict, np.ndarray], - freeze_placeholder: dict, - framework: str, - input_user_data_types=None, + input_model: InputModel, + input_user_shapes: [None, list, dict, np.ndarray], + freeze_placeholder: dict, + framework: str, + input_user_data_types=None, ): """ Restructures user input cutting request. Splits ports out of node names. @@ -188,7 +186,9 @@ def fe_input_user_data_repack( } """ _input_shapes = [] - if isinstance(input_user_shapes, list) and len(input_user_shapes) > 1 and isinstance(input_user_shapes[0], PartialShape): + _input_names = [] + if isinstance(input_user_shapes, list) and len(input_user_shapes) > 1 and isinstance(input_user_shapes[0], + PartialShape): for shape in input_user_shapes: assert isinstance(shape, PartialShape), "Got incorrect format of input shapes." model_inputs = input_model.get_inputs() @@ -227,14 +227,63 @@ def fe_input_user_data_repack( "input_name": input_name } ) + _input_names.append(input_name) elif isinstance(input_user_shapes, PartialShape): + # this branch covers the single use of `input_shape` without `input` option + # but it can be used along with `freeze_placeholder_with_value` option + # for example, --input_shape [3] --freeze_placeholder_with_value "is_training->False" + # means the model has two inputs: one is is_training to be frozen, the other to re-write the shape + # NOTE: the logic relies on parameters with the single name model_inputs = input_model.get_inputs() - assert len(model_inputs) == 1 - _input_shapes.append({"node": model_inputs[0], "shape": input_user_shapes}) + frozen_names = freeze_placeholder.keys() + assert len(model_inputs) == len(frozen_names) + 1, "Please check the conversion command-line. " \ + "Total number of model inputs must match to a number " \ + "of input shapes along with frozen inputs." + for node in model_inputs: + assert len(node.get_names()) > 0, "Original input models must have names." + input_name = node.get_names()[0] + if input_name not in frozen_names: + _input_shapes.append( + { + "node": node, + "shape": input_user_shapes, + "input_name": input_name + } + ) + _input_names.append(input_name) + break else: + # this case means that we use original inputs of the model + # and they should not be changed and their properties (shape and type) should not be over-written + # NOTE: the logic relies on parameters with the single name assert input_user_shapes is None - + for node in input_model.get_inputs(): + assert len(node.get_names()) > 0, "Original input models must have names." + input_name = node.get_names()[0] + _input_shapes.append( + { + "node": node, + "input_name": input_name + } + ) + # mark-up Place names we already put into the _input_names + # to avoid duplicates in updates by freeze_placeholder below + _input_names.append(input_name) + if freeze_placeholder: + # in case freezing via freeze_placeholder_with_value option, _input_shapes can miss some frozen places + for input_name in freeze_placeholder: + if input_name in _input_names: + continue + node = decode_name_with_port( + input_model, input_name, framework, IOType.Input + ) + _input_shapes.append( + { + "node": node, + "input_name": input_name + } + ) return _input_shapes, freeze_placeholder return _input_shapes, dict() @@ -274,12 +323,12 @@ def fe_output_user_data_repack(input_model: InputModel, outputs: list, framework def fe_user_data_repack( - input_model: InputModel, - input_user_shapes: [None, list, dict, np.array], - input_user_data_types: dict, - outputs: list, - freeze_placeholder: dict, - framework: str, + input_model: InputModel, + input_user_shapes: [None, list, dict, np.array], + input_user_data_types: dict, + outputs: list, + freeze_placeholder: dict, + framework: str, ): """ :param input_model: Input Model to operate on diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py index 796f48d3700..99f63bfd061 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py @@ -4,22 +4,22 @@ import argparse import io import logging as log -from typing import List import sys -from os import environ - -from openvino.tools.mo.moc_frontend.analysis import json_model_analysis_dump -from openvino.tools.mo.moc_frontend.extractor import fe_user_data_repack -from openvino.tools.mo.middle.passes.infer import validate_batch_in_shape -from openvino.tools.mo.utils.class_registration import get_enabled_and_disabled_transforms -from openvino.tools.mo.utils.error import Error - -from openvino.runtime import Dimension, PartialShape, Type # pylint: disable=no-name-in-module,import-error -from openvino.frontend import FrontEnd, InputModel, NotImplementedFailure, Place # pylint: disable=no-name-in-module,import-error -from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error +from typing import List import numpy as np +from openvino.frontend import FrontEnd, InputModel, NotImplementedFailure, \ + Place # pylint: disable=no-name-in-module,import-error +from openvino.runtime import Dimension, PartialShape, Type # pylint: disable=no-name-in-module,import-error +from openvino.runtime.utils.types import get_element_type, \ + get_numpy_ctype # pylint: disable=no-name-in-module,import-error +from openvino.tools.mo.middle.passes.infer import validate_batch_in_shape +from openvino.tools.mo.moc_frontend.analysis import json_model_analysis_dump +from openvino.tools.mo.moc_frontend.extractor import fe_user_data_repack +from openvino.tools.mo.utils.class_registration import get_enabled_and_disabled_transforms +from openvino.tools.mo.utils.error import Error + def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): """ @@ -136,17 +136,50 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): if freeze_placeholder: for name, value in freeze_placeholder.items(): - for node in user_shapes: - if node.get('input_name') == name: - place = node['node'] - if node.get('shape'): - input_model.set_partial_shape(place, node['shape']) - if node.get('data_type'): - value = np.array(value, dtype=node['data_type']) - input_model.set_element_type(place, Type(node['data_type'])) - else: - value = np.array(value, dtype=np.float32) - input_model.set_tensor_value(place, value) + node = None + # look for the certain place in user_shapes + for node_cur in user_shapes: + if node_cur.get('input_name') == name: + node = node_cur + break + if node is None: + raise Error("Please check correctness of the command-line. " + "Place (operation or tensor) with name {} is not found.".format(name)) + place = node.get('node') + + if node.get('shape'): + input_model.set_partial_shape(place, node['shape']) + + if node.get('data_type'): + dtype = node['data_type'] + ov_type = Type(dtype) + else: + # we need to detect type of Placeholder + try: + ov_type = input_model.get_element_type(place) + except NotImplementedFailure: + raise Error("Please specify type for value freezing {} node explicitly " + "because the frontend does not support automatic type detection.".format(name)) + # in case of cutting graph (or using custom inputs) and unspecified type, + # the default type is fp32 + if ov_type == Type.undefined: + ov_type = Type.f32 + dtype = get_numpy_ctype(ov_type) + + input_model.set_element_type(place, ov_type) + # prepare and cast value to dtype + from openvino.tools.mo.utils.type_utils import np_map_cast + from openvino.tools.mo.front.common.partial_infer.utils import mo_array + if isinstance(value, list): + casted_list = list() + for v in mo_array(value): + casted_list.append(np_map_cast[dtype](v)) + value = mo_array(casted_list, dtype=dtype) + else: + value = np_map_cast[dtype](value) + + value = np.array(value, dtype=dtype) + input_model.set_tensor_value(place, value) def shape_to_array(shape: PartialShape): return [shape.get_dimension(i) for i in range(shape.rank.get_length())] diff --git a/tools/mo/openvino/tools/mo/utils/cli_parser.py b/tools/mo/openvino/tools/mo/utils/cli_parser.py index c7f9a10eaa3..f1fe8e97230 100644 --- a/tools/mo/openvino/tools/mo/utils/cli_parser.py +++ b/tools/mo/openvino/tools/mo/utils/cli_parser.py @@ -1457,7 +1457,8 @@ def get_shape_from_input_value(input_value: str): if len(shape) == 0: shape = None elif len(shape) == 1 and shape[0] in ['', ' ']: - shape = () + # this shape corresponds to scalar + shape = PartialShape([]) elif len(shape) == 1: dims = re.split(r', *| +', shape[0]) dims = list(filter(None, dims)) diff --git a/tools/mo/openvino/tools/mo/utils/type_utils.py b/tools/mo/openvino/tools/mo/utils/type_utils.py index f6fe3dc356e..ade914c0075 100644 --- a/tools/mo/openvino/tools/mo/utils/type_utils.py +++ b/tools/mo/openvino/tools/mo/utils/type_utils.py @@ -9,6 +9,27 @@ from openvino.tools.mo.graph.graph import Node from openvino.tools.mo.pipeline.common import convert_const_node_value_type from openvino.tools.mo.utils.error import Error +np_map_cast = {np.bool: lambda x: bool_cast(x), + np.int8: lambda x: np.int8(x), + np.int16: lambda x: np.int16(x), + np.int32: lambda x: np.int32(x), + np.int64: lambda x: np.int64(x), + np.uint8: lambda x: np.uint8(x), + np.uint16: lambda x: np.uint16(x), + np.uint32: lambda x: np.uint32(x), + np.uint64: lambda x: np.uint64(x), + np.float16: lambda x: np.float16(x), + np.float32: lambda x: np.float32(x), + np.double: lambda x: np.double(x), + np.str: lambda x: np.str(x)} + + +def bool_cast(x): + if isinstance(x, str): + return False if x.lower() in ['false', '0'] else True if x.lower() in ['true', '1'] else 'unknown_boolean_cast' + else: + return np.bool(x) + def override_data_type_of_constant(node: Node, lhs_idx: int = 0, rhs_idx: int = 1): in_type_0 = node.in_port(lhs_idx).get_data_type() diff --git a/tools/mo/unit_tests/mo/utils/freeze_placholder_test.py b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py similarity index 62% rename from tools/mo/unit_tests/mo/utils/freeze_placholder_test.py rename to tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py index 726777f1eda..eafa0bdc9e7 100644 --- a/tools/mo/unit_tests/mo/utils/freeze_placholder_test.py +++ b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test.py @@ -1,24 +1,23 @@ # Copyright (C) 2018-2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import argparse +import os import unittest from unittest.mock import patch, Mock -import pytest -from openvino.runtime import Core -from openvino.tools.mo.convert_impl import prepare_ir +import numpy as np +import onnx +from generator import generator, generate +from onnx.helper import make_graph, make_model, make_tensor_value_info + from openvino.frontend import ( FrontEndManager, FrontEnd, ) # pylint: disable=no-name-in-module,import-error +from openvino.runtime import Core +from openvino.tools.mo.convert_impl import prepare_ir from openvino.tools.mo.utils.error import Error -from onnx.helper import make_graph, make_model, make_tensor_value_info -import argparse -import os -from os import environ -import onnx -import numpy as np -from generator import generator, generate def base_args_config(use_legacy_fe: bool = None, use_new_fe: bool = None): @@ -116,36 +115,37 @@ class TestMoFreezePlaceholder(unittest.TestCase): @generate( *[ ( - "in1[1 4]{f32}->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]", - True, - {}, - np.array([2.0, 4.0, 6.0, 8.0]), - np.float32, + "in1[1 4]{f32}->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]", + True, + {}, + np.array([2.0, 4.0, 6.0, 8.0]), + np.float32, ), ( - "in2->[0.0 0.0 0.0 0.0]", - True, - {"in1": np.array([[1.0, 2.0], [3.0, 4.0]])}, - np.array([[1.0, 2.0], [3.0, 4.0]]), - np.float32, + "in2{f32}->[0.0 0.0 0.0 0.0]", + True, + {"in1": np.array([[1.0, 2.0], [3.0, 4.0]])}, + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.float32, ), ( - "in2->[1.0 15.0 15.5 1.0]", - True, - {"in1": np.array([[2.0, 4.0], [12.0, 8.0]])}, - np.array([[3.0, 19.0], [27.5, 9.0]]), - np.float32, + "in2{f32}->[1.0 15.0 15.5 1.0]", + True, + {"in1": np.array([[2.0, 4.0], [12.0, 8.0]])}, + np.array([[3.0, 19.0], [27.5, 9.0]]), + np.float32, ), ( - "in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]", - True, - {}, - np.array([2.0, 4.0, 6.0, 8.0]), - np.int32, + "in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]", + True, + {}, + np.array([2.0, 4.0, 6.0, 8.0]), + np.int32, ), ], ) - def test_freeze_placeholder_with_value_onnx_fe(self, input_freezing_value, use_new_fe, inputs, expected, dtype=None): + def test_freeze_placeholder_with_value_onnx_fe(self, input_freezing_value, use_new_fe, inputs, expected, + dtype=None): with patch("openvino.tools.mo.convert_impl.get_default_frontends") as default_fe: default_fe.return_value = get_test_default_frontends() args = base_args_config(use_new_fe=use_new_fe) @@ -166,49 +166,49 @@ class TestMoFreezePlaceholder(unittest.TestCase): @generate( *[ ( - "in1->[1.0 15.0 1.0]", - True, - {"in2": np.array([2])}, - np.array([2.0, 30.0, 2.0]), - np.float32, + "in1{f32}->[1.0 15.0 1.0]", + True, + {"in2": np.array([2])}, + np.array([2.0, 30.0, 2.0]), + np.float32, ), ( - "in1->[7.0 11.0 -1.0],in2->3.0", - True, - {}, - np.array([21.0, 33.0, -3.0]), - np.float32, + "in1{f32}->[7.0 11.0 -1.0],in2{f32}->3.0", + True, + {}, + np.array([21.0, 33.0, -3.0]), + np.float32, ), ( - None, - True, - { - "in1": np.array([2.0, 2.0, 2.0]).reshape(1, 1, 3), - "in2": np.array([-1.0]), - }, - np.array([-2.0, -2.0, -2.0]), - np.float32, + None, + True, + { + "in1": np.array([2.0, 2.0, 2.0]).reshape(1, 1, 3), + "in2": np.array([-1.0]), + }, + np.array([-2.0, -2.0, -2.0]), + np.float32, ), ( - "in1[3 1]{f32}->[7.0 11.0 -1.0],in2{f32}->3.0", - True, - {}, - np.array([21.0, 33.0, -3.0]).reshape(3, 1), - np.float32, + "in1[3 1]{f32}->[7.0 11.0 -1.0],in2{f32}->3.0", + True, + {}, + np.array([21.0, 33.0, -3.0]).reshape(3, 1), + np.float32, ), ( - "in1[3 1]{f16}->[7.0 11.0 -1.0],in2{f16}->3.0", - True, - {}, - np.array([21.0, 33.0, -3.0]).reshape(3, 1), - np.float16, + "in1[3 1]{f16}->[7.0 11.0 -1.0],in2{f16}->3.0", + True, + {}, + np.array([21.0, 33.0, -3.0]).reshape(3, 1), + np.float16, ), ( - "in1[3 1]{i32}->[7 11 -1],in2{i32}->3.0", - True, - {}, - np.array([21, 33, -3]).reshape(3, 1), - np.int32, + "in1[3 1]{i32}->[7 11 -1],in2{i32}->3.0", + True, + {}, + np.array([21, 33, -3]).reshape(3, 1), + np.int32, ), ], ) @@ -229,3 +229,25 @@ class TestMoFreezePlaceholder(unittest.TestCase): if dtype is not None: assert values.dtype == dtype assert np.allclose(values, expected) + + @generate( + *[ + ( + "in1->[1.0 15.0 1.0]", + True, + {"in2": np.array([2])}, + np.array([2.0, 30.0, 2.0]), + np.float32, + ), + ], + ) + def test_value_without_type(self, input_freezing_value, use_new_fe, inputs, expected, + dtype=None): + with patch("openvino.tools.mo.convert_impl.get_default_frontends") as default_fe: + default_fe.return_value = get_test_default_frontends() + args = base_args_config(use_new_fe=use_new_fe) + args.input_model = "test_model_2.onnx" + args.input = input_freezing_value + self.assertRaisesRegex(Error, "Please specify type for value freezing in1 node explicitly " + "because the frontend does not support automatic type detection.", + prepare_ir, args) diff --git a/tools/mo/unit_tests/mo/utils/freeze_placeholder_test_tf_fe.py b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test_tf_fe.py new file mode 100644 index 00000000000..33d07911d44 --- /dev/null +++ b/tools/mo/unit_tests/mo/utils/freeze_placeholder_test_tf_fe.py @@ -0,0 +1,304 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import unittest +from unittest.mock import Mock + +import numpy as np +from generator import generator, generate + +from openvino.frontend import ( + FrontEndManager, + FrontEnd, +) # pylint: disable=no-name-in-module,import-error +from openvino.runtime import Core +from openvino.tools.mo.convert_impl import prepare_ir + + +def base_args_config(): + args = argparse.Namespace() + args.feManager = FrontEndManager() + args.extensions = None + # use new TF FE + args.use_legacy_frontend = False + args.use_new_frontend = True + args.framework = "tf" + args.model_name = None + args.input_model = None + args.input_model_is_text = False + args.input_checkpoint = None + args.saved_model_dir = None + args.input_meta_graph = None + args.saved_model_tags = None + args.silent = True + args.transform = [] + args.scale = None + args.output = None + args.input = None + args.input_shape = None + args.batch = None + args.mean_values = None + args.scale_values = None + args.output_dir = os.getcwd() + args.freeze_placeholder_with_value = None + args.tensorflow_use_custom_operations_config = None + args.transformations_config = None + args.disable_fusing = None + args.finegrain_fusing = None + args.disable_resnet_optimization = None + args.enable_concat_optimization = None + args.static_shape = None + args.disable_weights_compression = None + args.reverse_input_channels = None + args.data_type = None + args.layout = None + args.source_layout = None + args.target_layout = None + return args + + +try: + import openvino_telemetry as tm +except ImportError: + import openvino.tools.mo.utils.telemetry_stub as tm + + +@generator +class TestMoFreezePlaceholderTFFE(unittest.TestCase): + def setUp(self): + try: + import tensorflow.compat.v1 as tf + except ImportError: + import tensorflow as tf + + tm.Telemetry.__init__ = Mock(return_value=None) + tm.Telemetry.send_event = Mock() + FrontEnd.add_extension = Mock() + + self.models = [] + tf.reset_default_graph() + with tf.Session() as sess: + x = tf.placeholder(tf.float32, [2, 2], 'in1') + y = tf.placeholder(tf.float32, [2, 2], 'in2') + tf.add(x, y, name="add") + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'model_fp32.pb', as_text=False) + + self.models.append("model_fp32.pb") + + tf.reset_default_graph() + with tf.Session() as sess: + x = tf.placeholder(tf.int32, [2, 3], 'in1') + y = tf.placeholder(tf.int32, [2, 3], 'in2') + tf.multiply(x, y, name="add") + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'model_int32.pb', as_text=False) + + self.models.append("model_int32.pb") + + tf.reset_default_graph() + with tf.Session() as sess: + x = tf.placeholder(tf.bool, [2, 3], 'in1') + y = tf.placeholder(tf.bool, [2, 3], 'in2') + tf.math.logical_and(x, y) + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'model_bool.pb', as_text=False) + + self.models.append("model_bool.pb") + + tf.reset_default_graph() + with tf.Session() as sess: + x = tf.placeholder(tf.float32, [3], 'in1') + y = tf.placeholder(tf.float32, [3], 'in2') + cond = tf.placeholder(tf.bool, [], 'cond') + tf.where(cond, x, y) + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'model_bool2.pb', as_text=False) + + self.models.append("model_bool2.pb") + + tf.reset_default_graph() + with tf.Session() as sess: + x = tf.placeholder(tf.float32, [3], 'x') + y = tf.placeholder(tf.float32, [3], 'y') + z = tf.placeholder(tf.float32, [3], 'z') + add = tf.add(x, y, name="add") + tf.multiply(add, z, name="multiply") + + tf.global_variables_initializer() + tf.io.write_graph(sess.graph, '.', 'model_three_inputs.pb', as_text=False) + + self.models.append("model_three_inputs.pb") + + def tearDown(self): + for name in self.models: + os.remove(name) + + def basic(self, input_model, argv_input, inputs, dtype, expected, freeze_placeholder_with_value=None, + input_shape=None, only_conversion=False): + args = base_args_config() + args.input_model = input_model + args.input = argv_input + args.freeze_placeholder_with_value = freeze_placeholder_with_value + args.input_shape = input_shape + + try: + _, model = prepare_ir(args) + except Exception as ex: + self.fail("Model conversion failed due to error: {}".format(ex)) + + if only_conversion: + return + + ie = Core() + exec_net = ie.compile_model(model, "CPU") + req = exec_net.create_infer_request() + results = req.infer(inputs) + values = list(results.values())[0] + if dtype is not None: + assert values.dtype == dtype + assert np.allclose(values, expected) + + @generate( + *[ + ( + "in1[1 4]->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]", + {}, + np.array([2.0, 4.0, 6.0, 8.0]), + np.float32, + ), + ( + "in2{f32}->[0.0 0.0 0.0 0.0]", + {"in1": np.array([[1.0, 2.0], [3.0, 4.0]])}, + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.float32, + ), + ( + "in2->[1.0 15.0 15.5 1.0]", + {"in1": np.array([[2.0, 4.0], [12.0, 8.0]])}, + np.array([[3.0, 19.0], [27.5, 9.0]]), + np.float32, + ), + ( + "in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]", + {}, + np.array([2.0, 4.0, 6.0, 8.0]), + np.int32, + ), + ], + ) + def test_fp32(self, input_freezing_value, inputs, expected, + dtype): + self.basic("model_fp32.pb", input_freezing_value, inputs, dtype, expected) + + @generate( + *[ + ( + "in1[1 4]->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]", + {}, + np.array([1, 4, 9, 16]), + np.int32, + ), + ( + "in2->[2 5 6 7 3 2]", + {"in1": np.array([[2, 4, 1], [1, 2, 8]])}, + np.array([[4, 20, 6], [7, 6, 16]]), + np.int32, + ), + ], + ) + def test_int32(self, input_freezing_value, inputs, expected, + dtype=None): + self.basic("model_int32.pb", input_freezing_value, inputs, dtype, expected) + + @generate( + *[ + ( + "in1[2]->[True False],in2[2]->[True True]", + {}, + np.array([True, False], dtype=bool), + bool, + ), + ( + "in2[2,3]->[True,True,False,True,True,False]", + {"in1": np.array([[False, True, True], [False, True, True]], dtype=bool)}, + np.array([[False, True, False], [False, True, False]], dtype=bool), + bool, + ), + ( + "in2[]->True", + {"in1": np.array([[False, True, True], [False, True, True]], dtype=bool)}, + np.array([[False, True, True], [False, True, True]], dtype=bool), + bool, + ), + ], + ) + def test_bool(self, input_freezing_value, inputs, expected, + dtype=None): + self.basic("model_bool.pb", input_freezing_value, inputs, dtype, expected) + + @generate( + *[ + ( + "in1[3]->[1 2 3],in2[3]->[4 5 6],cond->False", + {}, + np.array([4, 5, 6], dtype=np.float32), + np.float32, + None + ), + ( + None, + {"in1": np.array([2.0, 4.0, 6.0], dtype=np.float32), + "in2": np.array([1.0, 3.0, 5.0], dtype=np.float32)}, + np.array([2, 4, 6], dtype=np.float32), + np.float32, + "cond->False", + None, + True # fill a bug to investigate why compilation of this model is hang on + ), + # case: input_shape + freeze_placeholder_with_value + ( + None, + {"in2": np.array([1.0, 3.0, 5.0], dtype=np.float32)}, + np.array([2, 4, 6], dtype=np.float32), + np.float32, + "in1->[2.0 4.0 6.0],cond->True", + "[3]", + False + ), + ], + ) + def test_bool2(self, input_freezing_value, inputs, expected, + dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False): + self.basic("model_bool2.pb", input_freezing_value, inputs, dtype, expected, freeze_placeholder_with_value, + input_shape, only_conversion) + + @generate( + *[ + ( + "add:0[3],z", + {"add:0": np.array([4, 5, 6], dtype=np.float32), "z": np.array([1, 2, 3], dtype=np.float32)}, + np.array([4, 10, 18], dtype=np.float32), + np.float32, + None + ), + ( + "add:0{i32}[3],z{i32}", + {"add:0": np.array([4, 5, 6], dtype=np.int32), "z": np.array([1, 2, 3], dtype=np.int32)}, + np.array([4, 10, 18], dtype=np.int32), + np.int32, + None + ), + ], + ) + def test_cutting_fp32(self, input_freezing_value, inputs, expected, + dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False): + self.basic("model_three_inputs.pb", input_freezing_value, inputs, dtype, expected, + freeze_placeholder_with_value, + input_shape, only_conversion)