[TF FE] Support TF2 Object Detection models (#14979)
* [TF FE] Support TF2 Object detection models For support of OOB conversion of OD models (Faster RCNN, SSD models) several fixes were done for Select, BroadcastArgs, Slice, and Concat operations. Implement tests for each case Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Switch off Transpose Sinking that breaks some model conversion * Apply code-review feedback: copyright and extra commented out code * Mention that for concat this is workaround Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
028cf7a34d
commit
af6ed211d6
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -252,7 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||
|
||||
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
|
||||
// manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
manager.run_passes(function);
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -103,10 +103,16 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe
|
||||
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
|
||||
auto cond_shape = make_shared<ShapeOf>(tile, element::i32);
|
||||
auto new_cond_shape = make_shared<Concat>(OutputVector{cond_shape, new_subshape}, 0);
|
||||
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
|
||||
// remove this workaround once 100671 is resolved
|
||||
auto const_1 = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto new_cond_shape = make_shared<Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0);
|
||||
|
||||
// prepare the condition to have the same rank as operands `x` and `y`
|
||||
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false);
|
||||
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false)->output(0);
|
||||
// squeeze prep_cond by one extra dimension specially added
|
||||
auto const_0 = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
prep_cond = make_shared<Squeeze>(prep_cond, const_0);
|
||||
|
||||
auto select_pattern = make_shared<Select>(prep_cond, zeros_like, sparse_segment_op);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -18,8 +18,8 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
|
||||
auto s1 = node.get_input(1);
|
||||
|
||||
// compute a number of shape elements to append for broadcasting
|
||||
auto size0 = make_shared<Squeeze>(make_shared<ShapeOf>(s0));
|
||||
auto size1 = make_shared<Squeeze>(make_shared<ShapeOf>(s1));
|
||||
auto size0 = make_shared<ShapeOf>(s0);
|
||||
auto size1 = make_shared<ShapeOf>(s1);
|
||||
auto max_size = make_shared<Maximum>(size0, size1);
|
||||
auto diff1 = make_shared<Subtract>(max_size, size0);
|
||||
auto diff2 = make_shared<Subtract>(max_size, size1);
|
||||
@ -28,14 +28,14 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
|
||||
// to take dynamic shapes into account
|
||||
auto padded_s0 =
|
||||
make_shared<Pad>(s0,
|
||||
make_shared<Constant>(diff1->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
|
||||
diff1,
|
||||
make_shared<Constant>(diff1->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
|
||||
make_shared<Constant>(s0.get_element_type(), Shape{}, std::vector<int64_t>{-1}),
|
||||
ov::op::PadMode::CONSTANT);
|
||||
auto padded_s1 =
|
||||
make_shared<Pad>(s1,
|
||||
make_shared<Constant>(diff2->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
|
||||
diff2,
|
||||
make_shared<Constant>(diff2->get_element_type(), Shape{1}, std::vector<int64_t>{0}),
|
||||
make_shared<Constant>(s1.get_element_type(), Shape{}, std::vector<int64_t>{-1}),
|
||||
ov::op::PadMode::CONSTANT);
|
||||
|
||||
@ -46,4 +46,4 @@ OutputVector translate_broadcast_args_op(const NodeContext& node) {
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -20,7 +20,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
|
||||
// axis is the first input for Concat
|
||||
// and is the last input to ConcatV2
|
||||
default_op_checks(node, 2, {"Concat", "ConcatV2"});
|
||||
auto input_size = node.get_input_size();
|
||||
auto input_size = static_cast<int>(node.get_input_size());
|
||||
|
||||
int64_t axis;
|
||||
OutputVector inputs;
|
||||
@ -33,7 +33,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
|
||||
axis_vector.size() == 1,
|
||||
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||
axis = axis_vector[0];
|
||||
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
|
||||
for (int input_idx = 1; input_idx < input_size; ++input_idx) {
|
||||
inputs.push_back(node.get_input(input_idx));
|
||||
}
|
||||
} else if (node.get_op_type() == "ConcatV2") {
|
||||
@ -44,7 +44,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
|
||||
axis_vector.size() == 1,
|
||||
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
|
||||
axis = axis_vector[0];
|
||||
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
|
||||
for (int input_idx = 0; input_idx < input_size - 1; ++input_idx) {
|
||||
inputs.push_back(node.get_input(input_idx));
|
||||
}
|
||||
} else {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -57,13 +57,18 @@ OutputVector translate_select_op(const NodeContext& node) {
|
||||
auto num_new_axes = make_shared<Subtract>(x_rank, cond_rank);
|
||||
|
||||
// generate a new shape for the condition
|
||||
auto const_one = make_shared<opset10::Constant>(element::i32, Shape{1}, 1);
|
||||
auto new_subshape = make_shared<opset10::Broadcast>(const_one, num_new_axes);
|
||||
auto cond_shape = make_shared<opset10::ShapeOf>(condition, element::i32);
|
||||
auto new_cond_shape = make_shared<opset10::Concat>(OutputVector{cond_shape, new_subshape}, 0);
|
||||
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
|
||||
auto cond_shape = make_shared<ShapeOf>(condition, element::i32);
|
||||
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
|
||||
auto const_1 = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto new_cond_shape = make_shared<Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0);
|
||||
|
||||
// prepare the condition to have the same rank as operands `x` and `y`
|
||||
auto prep_cond = make_shared<opset10::Reshape>(condition, new_cond_shape, false);
|
||||
auto prep_cond = make_shared<Reshape>(condition, new_cond_shape, false)->output(0);
|
||||
// squeeze prep_cond by one extra dimension specially added
|
||||
auto const_0 = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
prep_cond = make_shared<Squeeze>(prep_cond, const_0);
|
||||
|
||||
return translate_select_base_op(node, prep_cond, x, y);
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -14,6 +14,7 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_slice_op(const NodeContext& node) {
|
||||
default_op_checks(node, 3, {"Slice"});
|
||||
auto input = node.get_input(0);
|
||||
auto start = node.get_input(1);
|
||||
auto size = node.get_input(2);
|
||||
@ -24,8 +25,7 @@ OutputVector translate_slice_op(const NodeContext& node) {
|
||||
// compute stop values in case negative sizes
|
||||
// since TensorFlow supports only -1 among negative sizes
|
||||
// assign stop values to the data shape
|
||||
auto input_shape = make_shared<ShapeOf>(input);
|
||||
auto stop_neg = make_shared<ConvertLike>(input_shape, size);
|
||||
auto stop_neg = make_shared<ShapeOf>(input, size.get_element_type());
|
||||
|
||||
// select the correct stop value based on a sign of size value
|
||||
auto zeros = make_shared<Constant>(size.get_element_type(), Shape{}, 0);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -26,13 +26,21 @@ void set_node_name(const std::string& node_name, const std::shared_ptr<Node>& no
|
||||
bool is_conditional_edge(const std::string& input_tensor_name);
|
||||
|
||||
template <typename T>
|
||||
void get_const_input(const NodeContext& node, int64_t input_index, std::vector<T>* vector) {
|
||||
auto ng_input = node.get_input(static_cast<int>(input_index));
|
||||
if (auto constant = std::dynamic_pointer_cast<opset8::Constant>(ng_input.get_node_shared_ptr())) {
|
||||
void get_const_input(const NodeContext& node, int input_index, std::vector<T>* vector) {
|
||||
auto input_size = static_cast<int>(node.get_input_size());
|
||||
auto node_name = node.get_name();
|
||||
auto node_type = node.get_op_type();
|
||||
FRONT_END_GENERAL_CHECK(0 <= input_index && input_index < input_size,
|
||||
"[TensorFlow Frontend] Internal error: Node " + node_name + " has " +
|
||||
std::to_string(input_size) + " inputs, but requested input port index to be " +
|
||||
std::to_string(input_size));
|
||||
auto ov_input = node.get_input(input_index);
|
||||
if (auto constant = get_constant_from_source(ov_input)) {
|
||||
*vector = constant->cast_vector<T>();
|
||||
return;
|
||||
}
|
||||
FRONT_END_THROW("Node must be converted to Constant.");
|
||||
FRONT_END_THROW("[TensorFlow Frontend] Internal error: Input " + std::to_string(input_index) +
|
||||
" cannot be folded to Constant for node " + node_name + " of type " + node_type);
|
||||
}
|
||||
|
||||
ov::op::PadType convert_tf_padding(const NodeContext& node, const std::string& tf_padding);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -15,7 +15,7 @@ static const std::vector<std::string> models{
|
||||
std::string("2in_2out/2in_2out.pb"),
|
||||
std::string("forward_edge_model/forward_edge_model.pb"),
|
||||
std::string("forward_edge_model2/forward_edge_model2.pb"),
|
||||
};
|
||||
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pb")};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
||||
FrontEndConvertModelTest,
|
||||
|
@ -0,0 +1,156 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "z"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "axis"
|
||||
op: "AddV2"
|
||||
input: "Const"
|
||||
input: "Const_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "concat"
|
||||
op: "ConcatV2"
|
||||
input: "x"
|
||||
input: "y"
|
||||
input: "z"
|
||||
input: "axis"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 3
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tidx"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
versions {
|
||||
producer: 808
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
x = tf.placeholder(dtype=tf.float32, shape=[4, 3], name='x')
|
||||
y = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='y')
|
||||
z = tf.placeholder(dtype=tf.float32, shape=[1, 3], name='z')
|
||||
const1 = tf.constant(-1, dtype=tf.int32)
|
||||
const2 = tf.constant(1, dtype=tf.int32)
|
||||
axis = tf.add(const1, const2, name="axis")
|
||||
tf.concat([x, y, z], axis)
|
||||
|
||||
tf.global_variables_initializer()
|
||||
tf.io.write_graph(sess.graph, '.', 'concat_with_non_constant_axis.pbtxt', as_text=True)
|
63
tests/layer_tests/tensorflow_tests/test_tf_BroadcastArgs.py
Normal file
63
tests/layer_tests/tensorflow_tests/test_tf_BroadcastArgs.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestBroadcastArgs(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 's0' in inputs_info, "Test error: inputs_info must contain `s0`"
|
||||
assert 's1' in inputs_info, "Test error: inputs_info must contain `s1`"
|
||||
s0_shape = inputs_info['s0']
|
||||
s1_shape = inputs_info['s1']
|
||||
inputs_data = {}
|
||||
inputs_data['s0'] = np.random.randint(1, 6, s0_shape)
|
||||
inputs_data['s1'] = np.random.randint(1, 6, s1_shape)
|
||||
# compute mask where we need to change dimension size in s1
|
||||
# so that s1 will be broadcastable to s0
|
||||
non_one_mask = inputs_data['s0'] != 1
|
||||
diff_size = len(inputs_data['s1']) - len(inputs_data['s0'])
|
||||
if diff_size > 0:
|
||||
# pad False elements to non_one_mask to the begin
|
||||
pad = np.full([diff_size], False, dtype=bool)
|
||||
non_one_mask = np.concatenate((pad, non_one_mask), axis=0)
|
||||
else:
|
||||
# cut extra mask elements
|
||||
diff_size = abs(diff_size)
|
||||
non_one_mask = non_one_mask[diff_size:]
|
||||
update_inds = np.argwhere(non_one_mask)
|
||||
inputs_data['s1'][update_inds] = 1
|
||||
|
||||
print("inputs_data = ", inputs_data)
|
||||
return inputs_data
|
||||
|
||||
def create_broadcast_args_net(self, s0_shape, s1_shape, input_type):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
s0 = tf.compat.v1.placeholder(input_type, s0_shape, 's0')
|
||||
s1 = tf.compat.v1.placeholder(input_type, s1_shape, 's1')
|
||||
tf.raw_ops.BroadcastArgs(s0=s0, s1=s1)
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data_basic = [
|
||||
dict(s0_shape=[6], s1_shape=[6], input_type=tf.int32),
|
||||
dict(s0_shape=[2], s1_shape=[5], input_type=tf.int64),
|
||||
dict(s0_shape=[7], s1_shape=[1], input_type=tf.int32),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_broadcast_args_basic(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_broadcast_args_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
@ -1,4 +1,4 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
@ -36,6 +36,7 @@ class TestSelect(CommonTFLayerTest):
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(cond_shape=[], x_shape=[], y_shape=[]),
|
||||
dict(cond_shape=[], x_shape=[3, 2, 4], y_shape=[3, 2, 4]),
|
||||
dict(cond_shape=[2], x_shape=[2, 4, 5], y_shape=[2, 4, 5]),
|
||||
dict(cond_shape=[2, 3, 4], x_shape=[2, 3, 4], y_shape=[2, 3, 4]),
|
||||
|
37
tests/layer_tests/tensorflow_tests/test_tf_Slice.py
Normal file
37
tests/layer_tests/tensorflow_tests/test_tf_Slice.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestSlice(CommonTFLayerTest):
|
||||
def create_slice_net(self, input_shape, input_type, begin_value, size_value):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input_x = tf.compat.v1.placeholder(input_type, input_shape, 'input_x')
|
||||
begin = tf.constant(begin_value, tf.int32)
|
||||
size = tf.constant(size_value, tf.int32)
|
||||
tf.raw_ops.Slice(input=input_x, begin=begin, size=size)
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data_basic = [
|
||||
dict(input_shape=[6], input_type=tf.float32, begin_value=[2], size_value=[2]),
|
||||
dict(input_shape=[2, 5, 3], input_type=tf.int32, begin_value=[0, 1, 0], size_value=[-1, 1, -1]),
|
||||
dict(input_shape=[10, 5, 1, 5], input_type=tf.float32, begin_value=[5, 1, 0, 3], size_value=[2, 4, -1, -1]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_slice_basic(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_slice_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user