[TF FE] Support dynamic shape Placeholder freezing and PlaceholderWithDefault (#14450)

* [TF FE] Support dynamic shape Placeholder freezing and PlaceholderWithDefault

Also, this PR contains reorganization of python unit tests for TF FE that
covers conversion and inference of different models in pbtxt.
This mini-infrastructure will be used in the future for TF FE support.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Revert debug info

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-12-07 12:13:10 +04:00 committed by GitHub
parent a47688e593
commit d3fa858fcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 665 additions and 83 deletions

View File

@ -116,7 +116,12 @@ void InputModel::InputModelTFImpl::loadPlaces() {
all_op_names.insert(op_name);
m_op_places.push_back(op_place);
m_op_places_map[op_name] = op_place;
if (op_type == "Placeholder") {
if (op_type == "Placeholder" || op_type == "PlaceholderWithDefault") {
// in case Placeholder we put created TensorPlace to both m_tensor_places container and m_inputs
// since they can be used if user does not override them
// in case PlaceholderWithDefault we put created TensorPlace only to m_tensor_places container
// so that we know its shape and type for a case of custom input
// by default, PlaceholderWithDefault is replaced by Constant with the default value
auto pshape = ov::PartialShape::dynamic();
auto shape_any = node_decoder->get_attribute("shape");
if (shape_any.is<ov::PartialShape>()) {
@ -148,8 +153,11 @@ void InputModel::InputModelTFImpl::loadPlaces() {
std::vector<std::string> names = {op_name};
auto tensor_place = std::make_shared<TensorPlace>(m_input_model, pshape, type, names);
m_tensor_places[op_name] = tensor_place;
if (op_type == "Placeholder") {
// by default, PlaceholderWithDefault is NOT used as input
m_inputs.push_back(tensor_place);
}
}
for (size_t input_port_idx = 0; input_port_idx < node_decoder->get_input_size(); ++input_port_idx) {
std::string producer_op_name;
size_t producer_output_port_idx;
@ -331,9 +339,10 @@ ov::frontend::Place::Ptr InputModel::InputModelTFImpl::getPlaceByTensorName(cons
std::string port_type;
tensorflow::extract_operation_name_and_port(tensorName, operation_name, port_idx, port_type);
if (m_op_places_map.find(operation_name) != m_op_places_map.end()) {
// new Tensor places must be constructed of dynamic rank and type
std::vector<std::string> names = {tensorName};
auto m_var_place =
std::make_shared<TensorPlace>(m_input_model, ov::PartialShape(), ov::element::undefined, names);
std::make_shared<TensorPlace>(m_input_model, ov::PartialShape::dynamic(), ov::element::undefined, names);
m_tensor_places[tensorName] = m_var_place;
return m_var_place;
}
@ -396,8 +405,14 @@ void InputModel::InputModelTFImpl::setTensorValue(ov::frontend::Place::Ptr place
auto tensor_place = castToTensorPlace(place);
auto p_shape = tensor_place->get_partial_shape();
auto type = tensor_place->get_element_type();
auto constant = opset7::Constant::create(type, p_shape.to_shape(), value);
FRONT_END_GENERAL_CHECK(tensor_place->get_names().size() > 0,
"TensorFlow Frontend: place to be frozen must have the name.");
auto name = tensor_place->get_names()[0];
FRONT_END_GENERAL_CHECK(p_shape.is_static(),
"TensorFlow Frontend: specify static shape for " + name + " to be frozen.");
FRONT_END_GENERAL_CHECK(type.is_static(),
"TensorFlow Frontend: define static size type for " + name + " to be frozen.");
auto constant = opset7::Constant::create(type, p_shape.to_shape(), value);
constant->set_friendly_name(name);
m_tensor_values[name] = constant;
}

View File

@ -147,9 +147,6 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
"Place (operation or tensor) with name {} is not found.".format(name))
place = node.get('node')
if node.get('shape'):
input_model.set_partial_shape(place, node['shape'])
if node.get('data_type'):
dtype = node['data_type']
ov_type = Type(dtype)
@ -177,8 +174,19 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
value = mo_array(casted_list, dtype=dtype)
else:
value = np_map_cast[dtype](value)
value = np.array(value, dtype=dtype)
ov_shape = input_model.get_partial_shape(place)
if node.get('shape'):
# set user defined shape
ov_shape = PartialShape(node['shape'])
input_model.set_partial_shape(place, ov_shape)
elif ov_shape.is_dynamic:
# in case of dynamic shape (dynamic rank or dynamic dimension)
# deduce it based on the value shape and set it
ov_shape = PartialShape(value.shape)
input_model.set_partial_shape(place, ov_shape)
input_model.set_tensor_value(place, value)
def shape_to_array(shape: PartialShape):

View File

@ -68,85 +68,20 @@ except ImportError:
@generator
class TestMoFreezePlaceholderTFFE(unittest.TestCase):
def setUp(self):
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
tm.Telemetry.__init__ = Mock(return_value=None)
tm.Telemetry.send_event = Mock()
FrontEnd.add_extension = Mock()
self.models = []
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 2], 'in1')
y = tf.placeholder(tf.float32, [2, 2], 'in2')
tf.add(x, y, name="add")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_fp32.pb', as_text=False)
self.models.append("model_fp32.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [2, 3], 'in1')
y = tf.placeholder(tf.int32, [2, 3], 'in2')
tf.multiply(x, y, name="add")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_int32.pb', as_text=False)
self.models.append("model_int32.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.bool, [2, 3], 'in1')
y = tf.placeholder(tf.bool, [2, 3], 'in2')
tf.math.logical_and(x, y)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_bool.pb', as_text=False)
self.models.append("model_bool.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'in1')
y = tf.placeholder(tf.float32, [3], 'in2')
cond = tf.placeholder(tf.bool, [], 'cond')
tf.where(cond, x, y)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_bool2.pb', as_text=False)
self.models.append("model_bool2.pb")
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'x')
y = tf.placeholder(tf.float32, [3], 'y')
z = tf.placeholder(tf.float32, [3], 'z')
add = tf.add(x, y, name="add")
tf.multiply(add, z, name="multiply")
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'model_three_inputs.pb', as_text=False)
self.models.append("model_three_inputs.pb")
def tearDown(self):
for name in self.models:
os.remove(name)
def basic(self, input_model, argv_input, inputs, dtype, expected, freeze_placeholder_with_value=None,
input_shape=None, only_conversion=False):
input_shape=None, only_conversion=False, input_model_is_text=True):
path = os.path.dirname(__file__)
input_model = os.path.join(path, "test_models", input_model)
args = base_args_config()
args.input_model = input_model
args.input = argv_input
args.freeze_placeholder_with_value = freeze_placeholder_with_value
args.input_shape = input_shape
args.input_model_is_text = input_model_is_text
try:
_, model = prepare_ir(args)
@ -195,7 +130,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
)
def test_fp32(self, input_freezing_value, inputs, expected,
dtype):
self.basic("model_fp32.pb", input_freezing_value, inputs, dtype, expected)
self.basic("model_fp32.pbtxt", input_freezing_value, inputs, dtype, expected)
@generate(
*[
@ -215,7 +150,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
)
def test_int32(self, input_freezing_value, inputs, expected,
dtype=None):
self.basic("model_int32.pb", input_freezing_value, inputs, dtype, expected)
self.basic("model_int32.pbtxt", input_freezing_value, inputs, dtype, expected)
@generate(
*[
@ -241,7 +176,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
)
def test_bool(self, input_freezing_value, inputs, expected,
dtype=None):
self.basic("model_bool.pb", input_freezing_value, inputs, dtype, expected)
self.basic("model_bool.pbtxt", input_freezing_value, inputs, dtype, expected)
@generate(
*[
@ -276,7 +211,7 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
)
def test_bool2(self, input_freezing_value, inputs, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False):
self.basic("model_bool2.pb", input_freezing_value, inputs, dtype, expected, freeze_placeholder_with_value,
self.basic("model_bool2.pbtxt", input_freezing_value, inputs, dtype, expected, freeze_placeholder_with_value,
input_shape, only_conversion)
@generate(
@ -299,6 +234,63 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
)
def test_cutting_fp32(self, input_freezing_value, inputs, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None, only_conversion=False):
self.basic("model_three_inputs.pb", input_freezing_value, inputs, dtype, expected,
self.basic("model_three_inputs.pbtxt", input_freezing_value, inputs, dtype, expected,
freeze_placeholder_with_value,
input_shape, only_conversion)
input_shape, only_conversion, True)
@generate(
*[
(
"x[1,4],y[4]",
{"x": np.array([[3, 2, 1, 5]], dtype=np.int32), "y": np.array([0, -1, -7, 8], dtype=np.int32)},
np.array([[3, 1, -6, 13]], dtype=np.int32),
np.int32,
None
),
(
"x,y",
{"x": np.array([[-3, 20, 1]], dtype=np.int32), "y": np.array([[10, -11, -17]], dtype=np.int32)},
np.array([[7, 9, -16]], dtype=np.int32),
np.int32,
None
),
(
"x",
{"x": np.array([[-3, 20, 1]], dtype=np.int32)},
np.array([[-2, 22, 4], [1, 25, 7]], dtype=np.int32),
np.int32,
None
),
],
)
def test_placeholder_with_default(self, inputs, inputs_data, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None,
only_conversion=False):
self.basic("placeholder_with_default.pbtxt", inputs, inputs_data, dtype, expected,
freeze_placeholder_with_value,
input_shape, only_conversion, True)
@generate(
*[
(
"x[4],y->2.0",
{"x": np.array([3, 2, 1, 5], dtype=np.float32)},
np.array([6, 4, 2, 10], dtype=np.float32),
np.float32,
None
),
(
"x[1],y->[2.0,3.0]",
{"x": np.array([3], dtype=np.float32)},
np.array([6, 9], dtype=np.float32),
np.float32,
None
),
],
)
def test_freeze_placeholder_with_unknown_rank(self, inputs, inputs_data, expected,
dtype=None, freeze_placeholder_with_value=None, input_shape=None,
only_conversion=False):
self.basic("mul_with_unknown_rank_y.pbtxt", inputs, inputs_data, dtype, expected,
freeze_placeholder_with_value,
input_shape, only_conversion, True)

View File

@ -0,0 +1,52 @@
node {
name: "in1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "in2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "LogicalAnd"
op: "LogicalAnd"
input: "in1"
input: "in2"
}

View File

@ -0,0 +1,15 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.bool, [2, 3], 'in1')
y = tf.placeholder(tf.bool, [2, 3], 'in2')
tf.math.logical_and(x, y)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_bool.pbtxt', True)

View File

@ -0,0 +1,70 @@
node {
name: "in1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "in2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "cond"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "Select"
op: "Select"
input: "cond"
input: "in1"
input: "in2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,16 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'in1')
y = tf.placeholder(tf.float32, [3], 'in2')
cond = tf.placeholder(tf.bool, [], 'cond')
tf.where(cond, x, y)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_bool2.pbtxt', True)

View File

@ -0,0 +1,58 @@
node {
name: "in1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 2
}
}
}
}
}
node {
name: "in2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 2
}
}
}
}
}
node {
name: "add"
op: "AddV2"
input: "in1"
input: "in2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,15 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 2], 'in1')
y = tf.placeholder(tf.float32, [2, 2], 'in2')
tf.add(x, y, name="add")
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_fp32.pbtxt', True)

View File

@ -0,0 +1,58 @@
node {
name: "in1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "in2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "add"
op: "Mul"
input: "in1"
input: "in2"
attr {
key: "T"
value {
type: DT_INT32
}
}
}

View File

@ -0,0 +1,15 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [2, 3], 'in1')
y = tf.placeholder(tf.int32, [2, 3], 'in2')
tf.multiply(x, y, name="add")
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_int32.pbtxt', True)

View File

@ -0,0 +1,84 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "z"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "add"
op: "AddV2"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "multiply"
op: "Mul"
input: "add"
input: "z"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,17 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'x')
y = tf.placeholder(tf.float32, [3], 'y')
z = tf.placeholder(tf.float32, [3], 'z')
add = tf.add(x, y, name="add")
tf.multiply(add, z, name="multiply")
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'model_three_inputs.pbtxt', True)

View File

@ -0,0 +1,50 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "Mul"
op: "Mul"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}

View File

@ -0,0 +1,15 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [3], 'x')
keep_prob = tf.placeholder(tf.float32, None, 'y')
tf.multiply(x, keep_prob)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'mul_with_unknown_rank_y.pbtxt', True)

View File

@ -0,0 +1,86 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 3
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 3
}
}
tensor_content: "\001\000\000\000\002\000\000\000\003\000\000\000\004\000\000\000\005\000\000\000\006\000\000\000"
}
}
}
}
node {
name: "y"
op: "PlaceholderWithDefault"
input: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 3
}
}
}
}
}
node {
name: "Add"
op: "AddV2"
input: "x"
input: "y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}

View File

@ -0,0 +1,16 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(tf.int32, [None, 3], 'x')
y = tf.placeholder_with_default(tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int32),
[None, 3], 'y')
tf.add(x, y)
tf.global_variables_initializer()
tf_net = sess.graph_def
tf.io.write_graph(tf_net, './', 'placeholder_with_default.pbtxt', True)