[FE] Support freezing Placeholder without specifying type (#13984)

* [FE] Fix freezing placeholder via freeze_placeholder_with_value option

Currently, freezing placeholder via freeze_placeholder_with_value option does not work for any frontends. It happens due to absence of a node in _input_shapes dictionary.

* Add get_element_type method for InputModel

* Revert not needed changes

* Revert not needed changes

* Update freeze_placeholder_test

* Add tests for TF FE with freezing values of different types

* Fix Python API return value

* Correct returned type for get_numpy_ctype

* Apply code-review feedback

* Update src/frontends/tensorflow/src/input_model.hpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Apply code-review feedback: no tf legacy specific routine and parameter names

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Roman Kazantsev 2022-12-01 14:37:25 +04:00 committed by GitHub
parent 8ad74c17a4
commit 38a1783527
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 593 additions and 111 deletions

View File

@ -106,6 +106,19 @@ def get_dtype(openvino_type: Type) -> np.dtype:
raise OVTypeError("Unidentified data type %s", openvino_type)
def get_numpy_ctype(openvino_type: Type) -> type:
"""Return numpy ctype for an openvino element type."""
np_type = next(
(np_type for (ov_type, np_type) in openvino_to_numpy_types_map if ov_type == openvino_type),
None,
)
if np_type:
return np_type
raise OVTypeError("Unidentified data type %s", openvino_type)
def get_ndarray(data: NumericData) -> np.ndarray:
"""Wrap data into a numpy ndarray."""
if type(data) == np.ndarray:

View File

@ -294,6 +294,18 @@ void regclass_frontend_InputModel(py::module m) {
:type type: openvino.runtime.Type
)");
im.def("get_element_type",
&ov::frontend::InputModel::get_element_type,
py::arg("place"),
R"(
Returns current element type used for this place.
:param place: Model place.
:type place: openvino.frontend.Place
:return: Element type for this place.
:rtype: openvino.runtime.Type
)");
im.def(
"set_tensor_value",
[](ov::frontend::InputModel& self, const ov::frontend::Place::Ptr& place, py::array& value) {

View File

@ -195,6 +195,11 @@ public:
/// \param type New element type
virtual void set_element_type(const Place::Ptr& place, const ov::element::Type& type);
/// \brief Returns current element type used for this place
/// \param place Model place
/// \return Element type for this place
virtual ov::element::Type get_element_type(const Place::Ptr& place) const;
/// \brief Freezes a tensor with statically defined value or replace existing value for
/// already constant node or tensor
/// \param place Tensor place

View File

@ -140,6 +140,11 @@ void InputModel::set_element_type(const Place::Ptr& place, const element::Type&
FRONTEND_CALL_STATEMENT("set_element_type", m_actual->set_element_type(place, type))
}
element::Type InputModel::get_element_type(const Place::Ptr& place) const {
FRONT_END_CHECK_IMPLEMENTED(m_actual, get_element_type);
FRONTEND_RETURN_STATEMENT("get_element_type", m_actual->get_element_type(place))
}
void InputModel::set_tensor_value(const Place::Ptr& place, const void* value) {
FRONT_END_CHECK_IMPLEMENTED(m_actual, set_tensor_value);
FRONTEND_CALL_STATEMENT("set_tensor_value", m_actual->set_tensor_value(place, value))

View File

@ -20,6 +20,7 @@
#include "tf_framework_node.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::frontend::tensorflow;
namespace {
@ -95,6 +96,12 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
auto input_shape = input_tensor_place->get_partial_shape();
auto input_type = input_tensor_place->get_element_type();
// in case of cutting graph, types of custom inputs can be undefined,
// according to MO help, fp32 is used by default in such cases
if (input_type == element::undefined) {
input_type = element::f32;
}
auto param = std::make_shared<ov::opset8::Parameter>(input_type, input_shape);
set_node_name(input_name, param);
params.push_back(param);

View File

@ -66,6 +66,7 @@ public:
void setPartialShape(ov::frontend::Place::Ptr place, const ov::PartialShape&);
ov::PartialShape getPartialShape(ov::frontend::Place::Ptr place) const;
void setElementType(ov::frontend::Place::Ptr place, const ov::element::Type&);
ov::element::Type getElementType(ov::frontend::Place::Ptr place) const;
void setTensorValue(ov::frontend::Place::Ptr place, const void* value);
std::vector<std::shared_ptr<OpPlace>> get_op_places() const;
@ -373,6 +374,10 @@ void InputModel::InputModelTFImpl::setElementType(ov::frontend::Place::Ptr place
castToTensorPlace(place)->set_element_type(type);
}
ov::element::Type InputModel::InputModelTFImpl::getElementType(ov::frontend::Place::Ptr place) const {
return castToTensorPlace(place)->get_element_type();
}
void InputModel::InputModelTFImpl::setTensorValue(ov::frontend::Place::Ptr place, const void* value) {
m_graph_changed = true;
auto tensor_place = castToTensorPlace(place);
@ -436,6 +441,10 @@ void InputModel::set_element_type(const ov::frontend::Place::Ptr& place, const o
_impl->setElementType(place, type);
}
ov::element::Type InputModel::get_element_type(const ov::frontend::Place::Ptr& place) const {
return _impl->getElementType(place);
}
void InputModel::set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) {
_impl->setTensorValue(place, value);
}

View File

@ -39,6 +39,7 @@ public:
void set_partial_shape(const ov::frontend::Place::Ptr& place, const ov::PartialShape&) override;
ov::PartialShape get_partial_shape(const ov::frontend::Place::Ptr& place) const override;
void set_element_type(const ov::frontend::Place::Ptr& place, const ov::element::Type&) override;
ov::element::Type get_element_type(const ov::frontend::Place::Ptr& place) const override;
void set_tensor_value(const ov::frontend::Place::Ptr& place, const void* value) override;
};

View File

@ -2,16 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
import re
from openvino.tools.mo.front.extractor import raise_no_node, raise_node_name_collision
from openvino.tools.mo.utils.error import Error
from openvino._pyopenvino import Place, Type, PartialShape
from openvino.frontend import InputModel # pylint: disable=no-name-in-module,import-error
from enum import Enum
import numpy as np
from openvino._pyopenvino import Place, PartialShape
from enum import Enum
from openvino.frontend import InputModel # pylint: disable=no-name-in-module,import-error
from openvino.tools.mo.front.extractor import raise_no_node, raise_node_name_collision
from openvino.tools.mo.utils.error import Error
class IOType(Enum):
@ -20,7 +18,7 @@ class IOType(Enum):
def decode_name_with_port(
input_model: InputModel, node_name: str, framework="", io_type=IOType.Input
input_model: InputModel, node_name: str, framework="", io_type=IOType.Input
) -> Place or None:
"""
Decode name with optional port specification w/o traversing all the nodes in the graph
@ -144,11 +142,11 @@ def decode_name_with_port(
def fe_input_user_data_repack(
input_model: InputModel,
input_user_shapes: [None, list, dict, np.ndarray],
freeze_placeholder: dict,
framework: str,
input_user_data_types=None,
input_model: InputModel,
input_user_shapes: [None, list, dict, np.ndarray],
freeze_placeholder: dict,
framework: str,
input_user_data_types=None,
):
"""
Restructures user input cutting request. Splits ports out of node names.
@ -188,7 +186,9 @@ def fe_input_user_data_repack(
}
"""
_input_shapes = []
if isinstance(input_user_shapes, list) and len(input_user_shapes) > 1 and isinstance(input_user_shapes[0], PartialShape):
_input_names = []
if isinstance(input_user_shapes, list) and len(input_user_shapes) > 1 and isinstance(input_user_shapes[0],
PartialShape):
for shape in input_user_shapes:
assert isinstance(shape, PartialShape), "Got incorrect format of input shapes."
model_inputs = input_model.get_inputs()
@ -227,14 +227,63 @@ def fe_input_user_data_repack(
"input_name": input_name
}
)
_input_names.append(input_name)
elif isinstance(input_user_shapes, PartialShape):
# this branch covers the single use of `input_shape` without `input` option
# but it can be used along with `freeze_placeholder_with_value` option
# for example, --input_shape [3] --freeze_placeholder_with_value "is_training->False"
# means the model has two inputs: one is is_training to be frozen, the other to re-write the shape
# NOTE: the logic relies on parameters with the single name
model_inputs = input_model.get_inputs()
assert len(model_inputs) == 1
_input_shapes.append({"node": model_inputs[0], "shape": input_user_shapes})
frozen_names = freeze_placeholder.keys()
assert len(model_inputs) == len(frozen_names) + 1, "Please check the conversion command-line. " \
"Total number of model inputs must match to a number " \
"of input shapes along with frozen inputs."
for node in model_inputs:
assert len(node.get_names()) > 0, "Original input models must have names."
input_name = node.get_names()[0]
if input_name not in frozen_names:
_input_shapes.append(
{
"node": node,
"shape": input_user_shapes,
"input_name": input_name
}
)
_input_names.append(input_name)
break
else:
# this case means that we use original inputs of the model
# and they should not be changed and their properties (shape and type) should not be over-written
# NOTE: the logic relies on parameters with the single name
assert input_user_shapes is None
for node in input_model.get_inputs():
assert len(node.get_names()) > 0, "Original input models must have names."
input_name = node.get_names()[0]
_input_shapes.append(
{
"node": node,
"input_name": input_name
}
)
# mark-up Place names we already put into the _input_names
# to avoid duplicates in updates by freeze_placeholder below
_input_names.append(input_name)
if freeze_placeholder:
# in case freezing via freeze_placeholder_with_value option, _input_shapes can miss some frozen places
for input_name in freeze_placeholder:
if input_name in _input_names:
continue
node = decode_name_with_port(
input_model, input_name, framework, IOType.Input
)
_input_shapes.append(
{
"node": node,
"input_name": input_name
}
)
return _input_shapes, freeze_placeholder
return _input_shapes, dict()
@ -274,12 +323,12 @@ def fe_output_user_data_repack(input_model: InputModel, outputs: list, framework
def fe_user_data_repack(
input_model: InputModel,
input_user_shapes: [None, list, dict, np.array],
input_user_data_types: dict,
outputs: list,
freeze_placeholder: dict,
framework: str,
input_model: InputModel,
input_user_shapes: [None, list, dict, np.array],
input_user_data_types: dict,
outputs: list,
freeze_placeholder: dict,
framework: str,
):
"""
:param input_model: Input Model to operate on

View File

@ -4,22 +4,22 @@
import argparse
import io
import logging as log
from typing import List
import sys
from os import environ
from openvino.tools.mo.moc_frontend.analysis import json_model_analysis_dump
from openvino.tools.mo.moc_frontend.extractor import fe_user_data_repack
from openvino.tools.mo.middle.passes.infer import validate_batch_in_shape
from openvino.tools.mo.utils.class_registration import get_enabled_and_disabled_transforms
from openvino.tools.mo.utils.error import Error
from openvino.runtime import Dimension, PartialShape, Type # pylint: disable=no-name-in-module,import-error
from openvino.frontend import FrontEnd, InputModel, NotImplementedFailure, Place # pylint: disable=no-name-in-module,import-error
from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
from typing import List
import numpy as np
from openvino.frontend import FrontEnd, InputModel, NotImplementedFailure, \
Place # pylint: disable=no-name-in-module,import-error
from openvino.runtime import Dimension, PartialShape, Type # pylint: disable=no-name-in-module,import-error
from openvino.runtime.utils.types import get_element_type, \
get_numpy_ctype # pylint: disable=no-name-in-module,import-error
from openvino.tools.mo.middle.passes.infer import validate_batch_in_shape
from openvino.tools.mo.moc_frontend.analysis import json_model_analysis_dump
from openvino.tools.mo.moc_frontend.extractor import fe_user_data_repack
from openvino.tools.mo.utils.class_registration import get_enabled_and_disabled_transforms
from openvino.tools.mo.utils.error import Error
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
"""
@ -136,17 +136,50 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
if freeze_placeholder:
for name, value in freeze_placeholder.items():
for node in user_shapes:
if node.get('input_name') == name:
place = node['node']
if node.get('shape'):
input_model.set_partial_shape(place, node['shape'])
if node.get('data_type'):
value = np.array(value, dtype=node['data_type'])
input_model.set_element_type(place, Type(node['data_type']))
else:
value = np.array(value, dtype=np.float32)
input_model.set_tensor_value(place, value)
node = None
# look for the certain place in user_shapes
for node_cur in user_shapes:
if node_cur.get('input_name') == name:
node = node_cur
break
if node is None:
raise Error("Please check correctness of the command-line. "
"Place (operation or tensor) with name {} is not found.".format(name))
place = node.get('node')
if node.get('shape'):
input_model.set_partial_shape(place, node['shape'])
if node.get('data_type'):
dtype = node['data_type']
ov_type = Type(dtype)
else:
# we need to detect type of Placeholder
try:
ov_type = input_model.get_element_type(place)
except NotImplementedFailure:
raise Error("Please specify type for value freezing {} node explicitly "
"because the frontend does not support automatic type detection.".format(name))
# in case of cutting graph (or using custom inputs) and unspecified type,
# the default type is fp32
if ov_type == Type.undefined:
ov_type = Type.f32
dtype = get_numpy_ctype(ov_type)
input_model.set_element_type(place, ov_type)
# prepare and cast value to dtype
from openvino.tools.mo.utils.type_utils import np_map_cast
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
if isinstance(value, list):
casted_list = list()
for v in mo_array(value):
casted_list.append(np_map_cast[dtype](v))
value = mo_array(casted_list, dtype=dtype)
else:
value = np_map_cast[dtype](value)
value = np.array(value, dtype=dtype)
input_model.set_tensor_value(place, value)
def shape_to_array(shape: PartialShape):
return [shape.get_dimension(i) for i in range(shape.rank.get_length())]

View File

@ -1457,7 +1457,8 @@ def get_shape_from_input_value(input_value: str):
if len(shape) == 0:
shape = None
elif len(shape) == 1 and shape[0] in ['', ' ']:
shape = ()
# this shape corresponds to scalar
shape = PartialShape([])
elif len(shape) == 1:
dims = re.split(r', *| +', shape[0])
dims = list(filter(None, dims))

View File

@ -9,6 +9,27 @@ from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.pipeline.common import convert_const_node_value_type
from openvino.tools.mo.utils.error import Error
np_map_cast = {np.bool: lambda x: bool_cast(x),
np.int8: lambda x: np.int8(x),
np.int16: lambda x: np.int16(x),
np.int32: lambda x: np.int32(x),
np.int64: lambda x: np.int64(x),
np.uint8: lambda x: np.uint8(x),
np.uint16: lambda x: np.uint16(x),
np.uint32: lambda x: np.uint32(x),
np.uint64: lambda x: np.uint64(x),
np.float16: lambda x: np.float16(x),
np.float32: lambda x: np.float32(x),
np.double: lambda x: np.double(x),
np.str: lambda x: np.str(x)}
def bool_cast(x):
if isinstance(x, str):
return False if x.lower() in ['false', '0'] else True if x.lower() in ['true', '1'] else 'unknown_boolean_cast'
else:
return np.bool(x)
def override_data_type_of_constant(node: Node, lhs_idx: int = 0, rhs_idx: int = 1):
in_type_0 = node.in_port(lhs_idx).get_data_type()

View File

@ -1,24 +1,23 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import unittest
from unittest.mock import patch, Mock
import pytest
from openvino.runtime import Core
from openvino.tools.mo.convert_impl import prepare_ir
import numpy as np
import onnx
from generator import generator, generate
from onnx.helper import make_graph, make_model, make_tensor_value_info
from openvino.frontend import (
FrontEndManager,
FrontEnd,
) # pylint: disable=no-name-in-module,import-error
from openvino.runtime import Core
from openvino.tools.mo.convert_impl import prepare_ir
from openvino.tools.mo.utils.error import Error
from onnx.helper import make_graph, make_model, make_tensor_value_info
import argparse
import os
from os import environ
import onnx
import numpy as np
from generator import generator, generate
def base_args_config(use_legacy_fe: bool = None, use_new_fe: bool = None):
@ -116,36 +115,37 @@ class TestMoFreezePlaceholder(unittest.TestCase):
@generate(
*[
(
"in1[1 4]{f32}->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]",
True,
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.float32,
"in1[1 4]{f32}->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]",
True,
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.float32,
),
(
"in2->[0.0 0.0 0.0 0.0]",
True,
{"in1": np.array([[1.0, 2.0], [3.0, 4.0]])},
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.float32,
"in2{f32}->[0.0 0.0 0.0 0.0]",
True,
{"in1": np.array([[1.0, 2.0], [3.0, 4.0]])},
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.float32,
),
(
"in2->[1.0 15.0 15.5 1.0]",
True,
{"in1": np.array([[2.0, 4.0], [12.0, 8.0]])},
np.array([[3.0, 19.0], [27.5, 9.0]]),
np.float32,
"in2{f32}->[1.0 15.0 15.5 1.0]",
True,
{"in1": np.array([[2.0, 4.0], [12.0, 8.0]])},
np.array([[3.0, 19.0], [27.5, 9.0]]),
np.float32,
),
(
"in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]",
True,
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.int32,
"in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]",
True,
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.int32,
),
],
)
def test_freeze_placeholder_with_value_onnx_fe(self, input_freezing_value, use_new_fe, inputs, expected, dtype=None):
def test_freeze_placeholder_with_value_onnx_fe(self, input_freezing_value, use_new_fe, inputs, expected,
dtype=None):
with patch("openvino.tools.mo.convert_impl.get_default_frontends") as default_fe:
default_fe.return_value = get_test_default_frontends()
args = base_args_config(use_new_fe=use_new_fe)
@ -166,49 +166,49 @@ class TestMoFreezePlaceholder(unittest.TestCase):
@generate(
*[
(
"in1->[1.0 15.0 1.0]",
True,
{"in2": np.array([2])},
np.array([2.0, 30.0, 2.0]),
np.float32,
"in1{f32}->[1.0 15.0 1.0]",
True,
{"in2": np.array([2])},
np.array([2.0, 30.0, 2.0]),
np.float32,
),
(
"in1->[7.0 11.0 -1.0],in2->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]),
np.float32,
"in1{f32}->[7.0 11.0 -1.0],in2{f32}->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]),
np.float32,
),
(
None,
True,
{
"in1": np.array([2.0, 2.0, 2.0]).reshape(1, 1, 3),
"in2": np.array([-1.0]),
},
np.array([-2.0, -2.0, -2.0]),
np.float32,
None,
True,
{
"in1": np.array([2.0, 2.0, 2.0]).reshape(1, 1, 3),
"in2": np.array([-1.0]),
},
np.array([-2.0, -2.0, -2.0]),
np.float32,
),
(
"in1[3 1]{f32}->[7.0 11.0 -1.0],in2{f32}->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]).reshape(3, 1),
np.float32,
"in1[3 1]{f32}->[7.0 11.0 -1.0],in2{f32}->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]).reshape(3, 1),
np.float32,
),
(
"in1[3 1]{f16}->[7.0 11.0 -1.0],in2{f16}->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]).reshape(3, 1),
np.float16,
"in1[3 1]{f16}->[7.0 11.0 -1.0],in2{f16}->3.0",
True,
{},
np.array([21.0, 33.0, -3.0]).reshape(3, 1),
np.float16,
),
(
"in1[3 1]{i32}->[7 11 -1],in2{i32}->3.0",
True,
{},
np.array([21, 33, -3]).reshape(3, 1),
np.int32,
"in1[3 1]{i32}->[7 11 -1],in2{i32}->3.0",
True,
{},
np.array([21, 33, -3]).reshape(3, 1),
np.int32,
),
],
)
@ -229,3 +229,25 @@ class TestMoFreezePlaceholder(unittest.TestCase):
if dtype is not None:
assert values.dtype == dtype
assert np.allclose(values, expected)
@generate(
*[
(
"in1->[1.0 15.0 1.0]",
True,
{"in2": np.array([2])},
np.array([2.0, 30.0, 2.0]),
np.float32,
),
],
)
def test_value_without_type(self, input_freezing_value, use_new_fe, inputs, expected,
dtype=None):
with patch("openvino.tools.mo.convert_impl.get_default_frontends") as default_fe:
default_fe.return_value = get_test_default_frontends()
args = base_args_config(use_new_fe=use_new_fe)
args.input_model = "test_model_2.onnx"
args.input = input_freezing_value
self.assertRaisesRegex(Error, "Please specify type for value freezing in1 node explicitly "
"because the frontend does not support automatic type detection.",
prepare_ir, args)

View File

@ -0,0 +1,304 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import unittest
from unittest.mock import Mock
import numpy as np
from generator import generator, generate
from openvino.frontend import (
FrontEndManager,
FrontEnd,
) # pylint: disable=no-name-in-module,import-error
from openvino.runtime import Core
from openvino.tools.mo.convert_impl import prepare_ir
def base_args_config():
args = argparse.Namespace()
args.feManager = FrontEndManager()
args.extensions = None
# use new TF FE
args.use_legacy_frontend = False
args.use_new_frontend = True
args.framework = "tf"
args.model_name = None
args.input_model = None
args.input_model_is_text = False
args.input_checkpoint = None
args.saved_model_dir = None
args.input_meta_graph = None
args.saved_model_tags = None
args.silent = True
args.transform = []
args.scale = None
args.output = None
args.input = None
args.input_shape = None
args.batch = None
args.mean_values = None
args.scale_values = None
args.output_dir = os.getcwd()
args.freeze_placeholder_with_value = None
args.tensorflow_use_custom_operations_config = None
args.transformations_config = None
args.disable_fusing = None
args.finegrain_fusing = None
args.disable_resnet_optimization = None
args.enable_concat_optimization = None
args.static_shape = None
args.disable_weights_compression = None
args.reverse_input_channels = None
args.data_type = None
args.layout = None
args.source_layout = None
args.target_layout = None
return args
try:
import openvino_telemetry as tm
except ImportError:
import openvino.tools.mo.utils.telemetry_stub as tm
@generator
class TestMoFreezePlaceholderTFFE(unittest.TestCase):
def setUp(self):
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
tm.Telemetry.__init__ = Mock(return_value=None)
tm.Telemetry.send_event = Mock()
FrontEnd.add_extension = Mock()
self.models = []
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 2], 'in1')
y = tf.placeholder(tf.float32, [2, 2], 'in2')
tf.add(x, y, name="add")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_fp32.pb', as_text=False)
self.models.append("model_fp32.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [2, 3], 'in1')
y = tf.placeholder(tf.int32, [2, 3], 'in2')
tf.multiply(x, y, name="add")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_int32.pb', as_text=False)
self.models.append("model_int32.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.bool, [2, 3], 'in1')
y = tf.placeholder(tf.bool, [2, 3], 'in2')
tf.math.logical_and(x, y)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_bool.pb', as_text=False)
self.models.append("model_bool.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'in1')
y = tf.placeholder(tf.float32, [3], 'in2')
cond = tf.placeholder(tf.bool, [], 'cond')
tf.where(cond, x, y)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_bool2.pb', as_text=False)
self.models.append("model_bool2.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'x')
y = tf.placeholder(tf.float32, [3], 'y')
z = tf.placeholder(tf.float32, [3], 'z')
add = tf.add(x, y, name="add")
tf.multiply(add, z, name="multiply")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_three_inputs.pb', as_text=False)
self.models.append("model_three_inputs.pb")
def tearDown(self):
for name in self.models:
os.remove(name)
def basic(self, input_model, argv_input, inputs, dtype, expected, freeze_placeholder_with_value=None,
input_shape=None, only_conversion=False):
args = base_args_config()
args.input_model = input_model
args.input = argv_input
args.freeze_placeholder_with_value = freeze_placeholder_with_value
args.input_shape = input_shape
try:
_, model = prepare_ir(args)
except Exception as ex:
self.fail("Model conversion failed due to error: {}".format(ex))
if only_conversion:
return
ie = Core()
exec_net = ie.compile_model(model, "CPU")
req = exec_net.create_infer_request()
results = req.infer(inputs)
values = list(results.values())[0]
if dtype is not None:
assert values.dtype == dtype
assert np.allclose(values, expected)
@generate(
*[
(
"in1[1 4]->[1.0 2.0 3.0 4.0],in2[1 4]{f32}->[1.0 2.0 3.0 4.0]",
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.float32,
),
(
"in2{f32}->[0.0 0.0 0.0 0.0]",
{"in1": np.array([[1.0, 2.0], [3.0, 4.0]])},
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.float32,
),
(
"in2->[1.0 15.0 15.5 1.0]",
{"in1": np.array([[2.0, 4.0], [12.0, 8.0]])},
np.array([[3.0, 19.0], [27.5, 9.0]]),
np.float32,
),
(
"in1[1 4]{i32}->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]",
{},
np.array([2.0, 4.0, 6.0, 8.0]),
np.int32,
),
],
)
def test_fp32(self, input_freezing_value, inputs, expected,
dtype):
self.basic("model_fp32.pb", input_freezing_value, inputs, dtype, expected)
@generate(
*[
(
"in1[1 4]->[1 2 3 4],in2[1 4]{i32}->[1 2 3 4]",
{},
np.array([1, 4, 9, 16]),
np.int32,
),
(
"in2->[2 5 6 7 3 2]",
{"in1": np.array([[2, 4, 1], [1, 2, 8]])},
np.array([[4, 20, 6], [7, 6, 16]]),
np.int32,
),
],
)
def test_int32(self, input_freezing_value, inputs, expected,
dtype=None):
self.basic("model_int32.pb", input_freezing_value, inputs, dtype, expected)
@generate(
*[
(
"in1[2]->[True False],in2[2]->[True True]",
{},
np.array([True, False], dtype=bool),
bool,
),
(
"in2[2,3]->[True,True,False,True,True,False]",
{"in1": np.array([[False, True, True], [False, True, True]], dtype=bool)},
np.array([[False, True, False], [False, True, False]], dtype=bool),
bool,
),
(
"in2[]->True",
{"in1": np.array([[False, True, True], [False, True, True]], dtype=bool)},
np.array([[False, True, True], [False, True, True]], dtype=bool),
bool,
),
],
)
def test_bool(self, input_freezing_value, inputs, expected,
dtype=None):
self.basic("model_bool.pb", input_freezing_value, inputs, dtype, expected)
@generate(
*[
(
"in1[3]->[1 2 3],in2[3]->[4 5 6],cond->False",
{},
np.array([4, 5, 6], dtype=np.float32),
np.float32,
None
),
(
None,
{"in1": np.array([2.0, 4.0, 6.0], dtype=np.float32),
"in2": np.array([1.0, 3.0, 5.0], dtype=np.float32)},
np.array([2, 4, 6], dtype=np.float32),
np.float32,
"cond->False",
None,
True # fill a bug to investigate why compilation of this model is hang on
),
# case: input_shape + freeze_placeholder_with_value
(
None,
{"in2": np.array([1.0, 3.0, 5.0], dtype=np.float32)},
np.array([2, 4, 6], dtype=np.float32),
np.float32,
"in1->[2.0 4.0 6.0],cond->True",
"[3]",
False
),
],
)
def test_bool2(self, input_freezing_value, inputs, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False):
self.basic("model_bool2.pb", input_freezing_value, inputs, dtype, expected, freeze_placeholder_with_value,
input_shape, only_conversion)
@generate(
*[
(
"add:0[3],z",
{"add:0": np.array([4, 5, 6], dtype=np.float32), "z": np.array([1, 2, 3], dtype=np.float32)},
np.array([4, 10, 18], dtype=np.float32),
np.float32,
None
),
(
"add:0{i32}[3],z{i32}",
{"add:0": np.array([4, 5, 6], dtype=np.int32), "z": np.array([1, 2, 3], dtype=np.int32)},
np.array([4, 10, 18], dtype=np.int32),
np.int32,
None
),
],
)
def test_cutting_fp32(self, input_freezing_value, inputs, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False):
self.basic("model_three_inputs.pb", input_freezing_value, inputs, dtype, expected,
freeze_placeholder_with_value,
input_shape, only_conversion)