[TF FE] Support TF2 Object Detection models (#14979)

* [TF FE] Support TF2 Object detection models

For support of OOB conversion of OD models (Faster RCNN, SSD models) several fixes were done
for Select, BroadcastArgs, Slice, and Concat operations.
Implement tests for each case

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Switch off Transpose Sinking that breaks some model conversion

* Apply code-review feedback: copyright and extra commented out code

* Mention that for concat this is workaround

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-01-09 18:36:42 +04:00 committed by GitHub
parent 028cf7a34d
commit af6ed211d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 327 additions and 32 deletions

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -252,7 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<pass::GRUBlockCellReplacer>();
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
// manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.run_passes(function);
}

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -103,10 +103,16 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
auto cond_shape = make_shared<ShapeOf>(tile, element::i32);
auto new_cond_shape = make_shared<Concat>(OutputVector{cond_shape, new_subshape}, 0);
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
// remove this workaround once 100671 is resolved
auto const_1 = make_shared<Constant>(element::i32, Shape{1}, 1);
auto new_cond_shape = make_shared<Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0);
// prepare the condition to have the same rank as operands `x` and `y`
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false);
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false)->output(0);
// squeeze prep_cond by one extra dimension specially added
auto const_0 = make_shared<Constant>(element::i32, Shape{1}, 0);
prep_cond = make_shared<Squeeze>(prep_cond, const_0);
auto select_pattern = make_shared<Select>(prep_cond, zeros_like, sparse_segment_op);

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -18,8 +18,8 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
auto s1 = node.get_input(1);
// compute a number of shape elements to append for broadcasting
auto size0 = make_shared<Squeeze>(make_shared<ShapeOf>(s0));
auto size1 = make_shared<Squeeze>(make_shared<ShapeOf>(s1));
auto size0 = make_shared<ShapeOf>(s0);
auto size1 = make_shared<ShapeOf>(s1);
auto max_size = make_shared<Maximum>(size0, size1);
auto diff1 = make_shared<Subtract>(max_size, size0);
auto diff2 = make_shared<Subtract>(max_size, size1);
@ -28,14 +28,14 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
// to take dynamic shapes into account
auto padded_s0 =
make_shared<Pad>(s0,
make_shared<Constant>(diff1->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
diff1,
make_shared<Constant>(diff1->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
make_shared<Constant>(s0.get_element_type(), Shape{}, std::vector<int64_t>{-1}),
ov::op::PadMode::CONSTANT);
auto padded_s1 =
make_shared<Pad>(s1,
make_shared<Constant>(diff2->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
diff2,
make_shared<Constant>(diff2->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
make_shared<Constant>(s1.get_element_type(), Shape{}, std::vector<int64_t>{-1}),
ov::op::PadMode::CONSTANT);
@ -46,4 +46,4 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -20,7 +20,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
// axis is the first input for Concat
// and is the last input to ConcatV2
default_op_checks(node, 2, {"Concat", "ConcatV2"});
auto input_size = node.get_input_size();
auto input_size = static_cast<int>(node.get_input_size());
int64_t axis;
OutputVector inputs;
@ -33,7 +33,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
for (int input_idx = 1; input_idx < input_size; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else if (node.get_op_type() == "ConcatV2") {
@ -44,7 +44,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
for (int input_idx = 0; input_idx < input_size - 1; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else {

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -57,13 +57,18 @@ OutputVector translate_select_op(const NodeContext& node) {
auto num_new_axes = make_shared<Subtract>(x_rank, cond_rank);
// generate a new shape for the condition
auto const_one = make_shared<opset10::Constant>(element::i32, Shape{1}, 1);
auto new_subshape = make_shared<opset10::Broadcast>(const_one, num_new_axes);
auto cond_shape = make_shared<opset10::ShapeOf>(condition, element::i32);
auto new_cond_shape = make_shared<opset10::Concat>(OutputVector{cond_shape, new_subshape}, 0);
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
auto cond_shape = make_shared<ShapeOf>(condition, element::i32);
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
auto const_1 = make_shared<Constant>(element::i32, Shape{1}, 1);
auto new_cond_shape = make_shared<Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0);
// prepare the condition to have the same rank as operands `x` and `y`
auto prep_cond = make_shared<opset10::Reshape>(condition, new_cond_shape, false);
auto prep_cond = make_shared<Reshape>(condition, new_cond_shape, false)->output(0);
// squeeze prep_cond by one extra dimension specially added
auto const_0 = make_shared<Constant>(element::i32, Shape{1}, 0);
prep_cond = make_shared<Squeeze>(prep_cond, const_0);
return translate_select_base_op(node, prep_cond, x, y);
}

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -14,6 +14,7 @@ namespace tensorflow {
namespace op {
OutputVector translate_slice_op(const NodeContext& node) {
default_op_checks(node, 3, {"Slice"});
auto input = node.get_input(0);
auto start = node.get_input(1);
auto size = node.get_input(2);
@ -24,8 +25,7 @@ OutputVector translate_slice_op(const NodeContext& node) {
// compute stop values in case negative sizes
// since TensorFlow supports only -1 among negative sizes
// assign stop values to the data shape
auto input_shape = make_shared<ShapeOf>(input);
auto stop_neg = make_shared<ConvertLike>(input_shape, size);
auto stop_neg = make_shared<ShapeOf>(input, size.get_element_type());
// select the correct stop value based on a sign of size value
auto zeros = make_shared<Constant>(size.get_element_type(), Shape{}, 0);

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -26,13 +26,21 @@ void set_node_name(const std::string& node_name, const std::shared_ptr<Node>& no
bool is_conditional_edge(const std::string& input_tensor_name);
template <typename T>
void get_const_input(const NodeContext& node, int64_t input_index, std::vector<T>* vector) {
auto ng_input = node.get_input(static_cast<int>(input_index));
if (auto constant = std::dynamic_pointer_cast<opset8::Constant>(ng_input.get_node_shared_ptr())) {
void get_const_input(const NodeContext& node, int input_index, std::vector<T>* vector) {
auto input_size = static_cast<int>(node.get_input_size());
auto node_name = node.get_name();
auto node_type = node.get_op_type();
FRONT_END_GENERAL_CHECK(0 <= input_index && input_index < input_size,
"[TensorFlow Frontend] Internal error: Node " + node_name + " has " +
std::to_string(input_size) + " inputs, but requested input port index to be " +
std::to_string(input_size));
auto ov_input = node.get_input(input_index);
if (auto constant = get_constant_from_source(ov_input)) {
*vector = constant->cast_vector<T>();
return;
}
FRONT_END_THROW("Node must be converted to Constant.");
FRONT_END_THROW("[TensorFlow Frontend] Internal error: Input " + std::to_string(input_index) +
" cannot be folded to Constant for node " + node_name + " of type " + node_type);
}
ov::op::PadType convert_tf_padding(const NodeContext& node, const std::string& tf_padding);

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -15,7 +15,7 @@ static const std::vector<std::string> models{
std::string("2in_2out/2in_2out.pb"),
std::string("forward_edge_model/forward_edge_model.pb"),
std::string("forward_edge_model2/forward_edge_model2.pb"),
};
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pb")};
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
FrontEndConvertModelTest,

View File

@ -0,0 +1,156 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 3
}
}
}
}
}
node {
name: "z"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: -1
}
}
}
}
node {
name: "Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
}
node {
name: "axis"
op: "AddV2"
input: "Const"
input: "Const_1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "concat"
op: "ConcatV2"
input: "x"
input: "y"
input: "z"
input: "axis"
attr {
key: "N"
value {
i: 3
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
node {
name: "init"
op: "NoOp"
}
versions {
producer: 808
}

View File

@ -0,0 +1,19 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import tensorflow.compat.v1 as tf
tf.reset_default_graph()
with tf.Session() as sess:
x = tf.placeholder(dtype=tf.float32, shape=[4, 3], name='x')
y = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='y')
z = tf.placeholder(dtype=tf.float32, shape=[1, 3], name='z')
const1 = tf.constant(-1, dtype=tf.int32)
const2 = tf.constant(1, dtype=tf.int32)
axis = tf.add(const1, const2, name="axis")
tf.concat([x, y, z], axis)
tf.global_variables_initializer()
tf.io.write_graph(sess.graph, '.', 'concat_with_non_constant_axis.pbtxt', as_text=True)

View File

@ -0,0 +1,63 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestBroadcastArgs(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 's0' in inputs_info, "Test error: inputs_info must contain `s0`"
assert 's1' in inputs_info, "Test error: inputs_info must contain `s1`"
s0_shape = inputs_info['s0']
s1_shape = inputs_info['s1']
inputs_data = {}
inputs_data['s0'] = np.random.randint(1, 6, s0_shape)
inputs_data['s1'] = np.random.randint(1, 6, s1_shape)
# compute mask where we need to change dimension size in s1
# so that s1 will be broadcastable to s0
non_one_mask = inputs_data['s0'] != 1
diff_size = len(inputs_data['s1']) - len(inputs_data['s0'])
if diff_size > 0:
# pad False elements to non_one_mask to the begin
pad = np.full([diff_size], False, dtype=bool)
non_one_mask = np.concatenate((pad, non_one_mask), axis=0)
else:
# cut extra mask elements
diff_size = abs(diff_size)
non_one_mask = non_one_mask[diff_size:]
update_inds = np.argwhere(non_one_mask)
inputs_data['s1'][update_inds] = 1
print("inputs_data = ", inputs_data)
return inputs_data
def create_broadcast_args_net(self, s0_shape, s1_shape, input_type):
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
s0 = tf.compat.v1.placeholder(input_type, s0_shape, 's0')
s1 = tf.compat.v1.placeholder(input_type, s1_shape, 's1')
tf.raw_ops.BroadcastArgs(s0=s0, s1=s1)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
test_data_basic = [
dict(s0_shape=[6], s1_shape=[6], input_type=tf.int32),
dict(s0_shape=[2], s1_shape=[5], input_type=tf.int64),
dict(s0_shape=[7], s1_shape=[1], input_type=tf.int32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_broadcast_args_basic(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_broadcast_args_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -1,4 +1,4 @@
# Copyright (C) 2018-2022 Intel Corporation
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
@ -36,6 +36,7 @@ class TestSelect(CommonTFLayerTest):
return tf_net, None
test_data_basic = [
dict(cond_shape=[], x_shape=[], y_shape=[]),
dict(cond_shape=[], x_shape=[3, 2, 4], y_shape=[3, 2, 4]),
dict(cond_shape=[2], x_shape=[2, 4, 5], y_shape=[2, 4, 5]),
dict(cond_shape=[2, 3, 4], x_shape=[2, 3, 4], y_shape=[2, 3, 4]),

View File

@ -0,0 +1,37 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestSlice(CommonTFLayerTest):
def create_slice_net(self, input_shape, input_type, begin_value, size_value):
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
input_x = tf.compat.v1.placeholder(input_type, input_shape, 'input_x')
begin = tf.constant(begin_value, tf.int32)
size = tf.constant(size_value, tf.int32)
tf.raw_ops.Slice(input=input_x, begin=begin, size=size)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
test_data_basic = [
dict(input_shape=[6], input_type=tf.float32, begin_value=[2], size_value=[2]),
dict(input_shape=[2, 5, 3], input_type=tf.int32, begin_value=[0, 1, 0], size_value=[-1, 1, -1]),
dict(input_shape=[10, 5, 1, 5], input_type=tf.float32, begin_value=[5, 1, 0, 3], size_value=[2, 4, -1, -1]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_slice_basic(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_slice_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)