[TF Hub][TF FE] Support TensorListLength and TensorListResize operations (#19390)
* [TF Hub][TF FE] Support TensorListLength and TensorListResize operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add test with empty tensor list * remove assert --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
ba6cca8740
commit
8df85badf8
@ -263,10 +263,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Switch", CreatorFunction(translate_switch_op)},
|
||||
{"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)},
|
||||
{"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)},
|
||||
{"TensorListLength", CreatorFunction(translate_tensor_list_length_op)},
|
||||
{"TensorListPushBack", CreatorFunction(translate_tensor_list_push_back_op)},
|
||||
{"TensorListSetItem", CreatorFunction(translate_tensor_list_set_item_op)},
|
||||
{"TensorListStack", CreatorFunction(translate_tensor_list_stack_op)},
|
||||
{"TensorListReserve", CreatorFunction(translate_tensor_list_reserve_op)},
|
||||
{"TensorListResize", CreatorFunction(translate_tensor_list_resize_op)},
|
||||
{"Tile", CreatorFunction(translate_tile_op)},
|
||||
{"TopK", CreatorFunction(translate_top_k_op)},
|
||||
{"TopKV2", CreatorFunction(translate_top_k_v2_op)},
|
||||
|
@ -132,10 +132,12 @@ OP_CONVERTER(translate_strided_slice_op);
|
||||
OP_CONVERTER(translate_sqrt_op);
|
||||
OP_CONVERTER(translate_tensor_list_from_tensor_op);
|
||||
OP_CONVERTER(translate_tensor_list_get_item_op);
|
||||
OP_CONVERTER(translate_tensor_list_length_op);
|
||||
OP_CONVERTER(translate_tensor_list_push_back_op);
|
||||
OP_CONVERTER(translate_tensor_list_reserve_op);
|
||||
OP_CONVERTER(translate_tensor_list_set_item_op);
|
||||
OP_CONVERTER(translate_tensor_list_stack_op);
|
||||
OP_CONVERTER(translate_tensor_list_resize_op);
|
||||
OP_CONVERTER(translate_tile_op);
|
||||
OP_CONVERTER_NAMED(translate_top_k_op);
|
||||
OP_CONVERTER_NAMED(translate_top_k_v2_op);
|
||||
|
@ -3,12 +3,26 @@
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/maximum.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -22,7 +36,7 @@ OutputVector translate_tensor_list_reserve_op(const NodeContext& node) {
|
||||
// all tensor elements will be saved in the flatten form in the list
|
||||
// because we want to cover a case of dynamic rank tensor list
|
||||
// the real shape of the tensor elements will be restored by TensorListStack operations
|
||||
auto empty_constant = make_shared<Constant>(element_dtype, Shape{0, 0});
|
||||
auto empty_constant = make_shared<v0::Constant>(element_dtype, Shape{0, 0});
|
||||
set_node_name(node.get_name(), empty_constant);
|
||||
return {empty_constant};
|
||||
}
|
||||
@ -41,14 +55,14 @@ OutputVector translate_tensor_list_stack_op(const NodeContext& node) {
|
||||
auto element_shape = node.get_input(1);
|
||||
|
||||
// compute number of tensor elements in the list
|
||||
Output<Node> num_elements = make_shared<ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
num_elements = make_shared<Slice>(num_elements, zero_const, one_const, one_const);
|
||||
Output<Node> num_elements = make_shared<v3::ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
num_elements = make_shared<v8::Slice>(num_elements, zero_const, one_const, one_const);
|
||||
|
||||
// restore the real shape of tensor elements
|
||||
auto new_shape = make_shared<Concat>(OutputVector{num_elements, element_shape}, 0);
|
||||
auto reshape = make_shared<Reshape>(input_handle, new_shape, false);
|
||||
auto new_shape = make_shared<v0::Concat>(OutputVector{num_elements, element_shape}, 0);
|
||||
auto reshape = make_shared<v1::Reshape>(input_handle, new_shape, false);
|
||||
|
||||
set_node_name(node.get_name(), reshape);
|
||||
return {reshape};
|
||||
@ -62,12 +76,12 @@ OutputVector translate_tensor_list_get_item_op(const NodeContext& node) {
|
||||
auto element_dtype = node.get_attribute<element::Type>("element_dtype");
|
||||
|
||||
// squeeze index tensor to have a scalar
|
||||
index = make_shared<Squeeze>(index);
|
||||
index = make_shared<v0::Squeeze>(index);
|
||||
|
||||
// gather tensor element by the required position
|
||||
auto gather_axis = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
Output<Node> tensor_element = make_shared<Gather>(input_handle, index, gather_axis);
|
||||
tensor_element = make_shared<Convert>(tensor_element, element_dtype);
|
||||
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
Output<Node> tensor_element = make_shared<v8::Gather>(input_handle, index, gather_axis);
|
||||
tensor_element = make_shared<v0::Convert>(tensor_element, element_dtype);
|
||||
|
||||
set_node_name(node.get_name(), tensor_element.get_node_shared_ptr());
|
||||
return {tensor_element};
|
||||
@ -80,44 +94,44 @@ OutputVector translate_tensor_list_set_item_op(const NodeContext& node) {
|
||||
auto item = node.get_input(2);
|
||||
|
||||
// squeeze index tensor to have a scalar
|
||||
index = make_shared<Squeeze>(index);
|
||||
index = make_shared<v0::Squeeze>(index);
|
||||
|
||||
// flatten item to be inserted since
|
||||
// the tensor list saves elements in the flatten form
|
||||
auto new_item_shape = make_shared<Constant>(element::i32, Shape{1}, -1);
|
||||
item = make_shared<Reshape>(item, new_item_shape, false);
|
||||
auto item_shape = make_shared<ShapeOf>(item, element::i32);
|
||||
auto new_item_shape = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
|
||||
item = make_shared<v1::Reshape>(item, new_item_shape, false);
|
||||
auto item_shape = make_shared<v3::ShapeOf>(item, element::i32);
|
||||
|
||||
// reshape the tensor list to the shape [num_elements, -1]
|
||||
// that is because in the first iteration we have empty constant of a shape [0,0]
|
||||
auto minus_one = make_shared<Constant>(element::i32, Shape{1}, -1);
|
||||
auto new_input_handle_shape = make_shared<Concat>(OutputVector{minus_one, item_shape}, 0);
|
||||
input_handle = make_shared<Reshape>(input_handle, new_input_handle_shape, false);
|
||||
input_handle = make_shared<ConvertLike>(input_handle, item);
|
||||
auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
|
||||
auto new_input_handle_shape = make_shared<v0::Concat>(OutputVector{minus_one, item_shape}, 0);
|
||||
input_handle = make_shared<v1::Reshape>(input_handle, new_input_handle_shape, false);
|
||||
input_handle = make_shared<v1::ConvertLike>(input_handle, item);
|
||||
|
||||
// compute the current length of the list
|
||||
Output<Node> list_length = make_shared<ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
list_length = make_shared<Slice>(list_length, zero_const, one_const, one_const);
|
||||
Output<Node> list_length = make_shared<v3::ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
list_length = make_shared<v8::Slice>(list_length, zero_const, one_const, one_const);
|
||||
|
||||
// compute a size of the dummy tensor that serves to fill holes in the list
|
||||
// if no tensor is inserted at this position
|
||||
auto one_const_scalar = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto index_plus_one = make_shared<Add>(index, one_const_scalar);
|
||||
Output<Node> max_length = make_shared<Maximum>(list_length, index_plus_one);
|
||||
Output<Node> dummy_tensor_size = make_shared<Subtract>(max_length, list_length);
|
||||
auto one_const_scalar = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
auto index_plus_one = make_shared<v1::Add>(index, one_const_scalar);
|
||||
Output<Node> max_length = make_shared<v1::Maximum>(list_length, index_plus_one);
|
||||
Output<Node> dummy_tensor_size = make_shared<v1::Subtract>(max_length, list_length);
|
||||
|
||||
// create dummy tensor and concatenate it
|
||||
auto zero_element = create_same_type_const_scalar<int32_t>(item, 0);
|
||||
auto dummy_tensor_shape = make_shared<Concat>(OutputVector{dummy_tensor_size, item_shape}, 0);
|
||||
auto dummy_tensor = make_shared<Broadcast>(zero_element, dummy_tensor_shape);
|
||||
input_handle = make_shared<Concat>(OutputVector{input_handle, dummy_tensor}, 0);
|
||||
auto dummy_tensor_shape = make_shared<v0::Concat>(OutputVector{dummy_tensor_size, item_shape}, 0);
|
||||
auto dummy_tensor = make_shared<v3::Broadcast>(zero_element, dummy_tensor_shape);
|
||||
input_handle = make_shared<v0::Concat>(OutputVector{input_handle, dummy_tensor}, 0);
|
||||
|
||||
// update the resulted tensor using ScatterUpdate
|
||||
index = make_shared<Unsqueeze>(index, zero_const);
|
||||
item = make_shared<Unsqueeze>(item, zero_const);
|
||||
auto scatter_update = make_shared<ScatterUpdate>(input_handle, index, item, zero_const);
|
||||
index = make_shared<v0::Unsqueeze>(index, zero_const);
|
||||
item = make_shared<v0::Unsqueeze>(item, zero_const);
|
||||
auto scatter_update = make_shared<v3::ScatterUpdate>(input_handle, index, item, zero_const);
|
||||
|
||||
set_node_name(node.get_name(), scatter_update);
|
||||
return {scatter_update};
|
||||
@ -132,29 +146,82 @@ OutputVector translate_tensor_list_push_back_op(const NodeContext& node) {
|
||||
// the tensor list saves elements in the flatten form
|
||||
// because we want to cover a case of dynamic rank tensor list
|
||||
// the real shape of the tensor elements will be restored by TensorListStack operations
|
||||
auto new_tensor_shape = make_shared<Constant>(element::i32, Shape{1}, -1);
|
||||
tensor = make_shared<Reshape>(tensor, new_tensor_shape, false);
|
||||
auto tensor_shape = make_shared<ShapeOf>(tensor, element::i32);
|
||||
auto new_tensor_shape = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
|
||||
tensor = make_shared<v1::Reshape>(tensor, new_tensor_shape, false);
|
||||
auto tensor_shape = make_shared<v3::ShapeOf>(tensor, element::i32);
|
||||
|
||||
// reshape the tensor list to the shape [num_elements, -1]
|
||||
// that is because in the first iteration we have empty constant of a shape [0,0]
|
||||
Output<Node> num_elements = make_shared<ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
num_elements = make_shared<Slice>(num_elements, zero_const, one_const, one_const);
|
||||
auto new_input_handle_shape = make_shared<Concat>(OutputVector{num_elements, tensor_shape}, 0);
|
||||
input_handle = make_shared<Reshape>(input_handle, new_input_handle_shape, false);
|
||||
Output<Node> num_elements = make_shared<v3::ShapeOf>(input_handle, element::i32);
|
||||
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
num_elements = make_shared<v8::Slice>(num_elements, zero_const, one_const, one_const);
|
||||
auto new_input_handle_shape = make_shared<v0::Concat>(OutputVector{num_elements, tensor_shape}, 0);
|
||||
input_handle = make_shared<v1::Reshape>(input_handle, new_input_handle_shape, false);
|
||||
|
||||
// unsqueeze tensor to be inserted into the list
|
||||
tensor = make_shared<Unsqueeze>(tensor, zero_const);
|
||||
tensor = make_shared<v0::Unsqueeze>(tensor, zero_const);
|
||||
|
||||
// insert the tensor into the end
|
||||
auto updated_list = make_shared<Concat>(OutputVector{input_handle, tensor}, 0);
|
||||
auto updated_list = make_shared<v0::Concat>(OutputVector{input_handle, tensor}, 0);
|
||||
|
||||
set_node_name(node.get_name(), updated_list);
|
||||
return {updated_list};
|
||||
}
|
||||
|
||||
OutputVector translate_tensor_list_resize_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"TensorListResize"});
|
||||
auto input_handle = node.get_input(0);
|
||||
auto size = node.get_input(1);
|
||||
|
||||
// create auxiliary constants
|
||||
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
auto max_const = make_shared<v0::Constant>(element::i32, Shape{1}, numeric_limits<int32_t>::max());
|
||||
|
||||
// compute the current length of the list and item shape
|
||||
auto tensor_list_shape = make_shared<v3::ShapeOf>(input_handle, element::i32);
|
||||
auto list_length = make_shared<v8::Slice>(tensor_list_shape, zero_const, one_const, one_const);
|
||||
auto item_shape = make_shared<v8::Slice>(tensor_list_shape, one_const, max_const, one_const);
|
||||
|
||||
// compute a size of the dummy tensor to resize
|
||||
// and clip it by zero if it is negative
|
||||
Output<Node> dummy_tensor_size = make_shared<v1::Subtract>(size, list_length);
|
||||
dummy_tensor_size = make_shared<v1::Maximum>(dummy_tensor_size, zero_const);
|
||||
|
||||
// create dummy tensor and concatenate it
|
||||
auto zero_const_same_type = create_same_type_const<float>(input_handle, vector<float>{0.0f}, Shape{});
|
||||
auto dummy_tensor_shape = make_shared<v0::Concat>(OutputVector{dummy_tensor_size, item_shape}, 0);
|
||||
auto dummy_tensor = make_shared<v3::Broadcast>(zero_const_same_type, dummy_tensor_shape);
|
||||
input_handle = make_shared<v0::Concat>(OutputVector{input_handle, dummy_tensor}, 0);
|
||||
|
||||
// reshape size to have 1D tensor with one element
|
||||
auto new_size_shape = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
size = make_shared<v1::Reshape>(size, new_size_shape, false);
|
||||
|
||||
// resize can also shrink the input tensor list
|
||||
input_handle = make_shared<v8::Slice>(input_handle, zero_const, size, one_const);
|
||||
|
||||
set_node_name(node.get_name(), input_handle.get_node_shared_ptr());
|
||||
return {input_handle};
|
||||
}
|
||||
|
||||
OutputVector translate_tensor_list_length_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"TensorListLength"});
|
||||
auto input_handle = node.get_input(0);
|
||||
|
||||
// create auxiliary constants
|
||||
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
|
||||
// compute the current length of the list
|
||||
auto tensor_list_shape = make_shared<v3::ShapeOf>(input_handle, element::i32);
|
||||
auto list_length = make_shared<v8::Slice>(tensor_list_shape, zero_const, one_const, one_const);
|
||||
|
||||
set_node_name(node.get_name(), list_length);
|
||||
return {list_length};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
@ -0,0 +1,83 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestTensorListLength(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info
|
||||
x_shape = inputs_info['x']
|
||||
inputs_data = {}
|
||||
inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type)
|
||||
return inputs_data
|
||||
|
||||
def create_tensor_list_length(self, input_shape, input_type):
|
||||
self.input_type = input_type
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
|
||||
tensor_list = tf.raw_ops.TensorListFromTensor(tensor=x,
|
||||
element_shape=tf.constant(input_shape[1:], dtype=tf.int32))
|
||||
tf.raw_ops.TensorListLength(input_handle=tensor_list)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(input_shape=[7], input_type=np.float32),
|
||||
dict(input_shape=[10, 20], input_type=np.float32),
|
||||
dict(input_shape=[2, 3, 4], input_type=np.int32),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_tensor_list_length_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_tensor_list_length(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
|
||||
class TestTensorListLengthEmptyList(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
inputs_data = {}
|
||||
inputs_data['tensor_list_size'] = np.array([self.tensor_list_size], dtype=np.int32)
|
||||
return inputs_data
|
||||
|
||||
def create_tensor_list_length_empty_list(self, tensor_list_size, element_shape):
|
||||
self.tensor_list_size = tensor_list_size
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
tensor_list_size = tf.compat.v1.placeholder(tf.int32, [1], 'tensor_list_size')
|
||||
tf_element_shape = tf.constant(element_shape, dtype=tf.int32)
|
||||
tensor_shape = tf.concat([tensor_list_size, tf_element_shape], 0)
|
||||
tensor = tf.broadcast_to(tf.constant(0.0, dtype=tf.float32), tensor_shape)
|
||||
tensor_list = tf.raw_ops.TensorListFromTensor(tensor=tensor,
|
||||
element_shape=tf_element_shape)
|
||||
tf.raw_ops.TensorListLength(input_handle=tensor_list)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_tensor_list_length_empty_list = [
|
||||
dict(tensor_list_size=0, element_shape=[]),
|
||||
dict(tensor_list_size=0, element_shape=[2, 3]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_tensor_list_length_empty_list)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_tensor_list_length_empty_list(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_tensor_list_length_empty_list(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
@ -0,0 +1,49 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestTensorListResize(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info
|
||||
x_shape = inputs_info['x']
|
||||
inputs_data = {}
|
||||
inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type)
|
||||
return inputs_data
|
||||
|
||||
def create_tensor_list_resize(self, input_shape, input_type, new_size):
|
||||
self.input_type = input_type
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
|
||||
tensor_list = tf.raw_ops.TensorListFromTensor(tensor=x,
|
||||
element_shape=tf.constant(input_shape[1:], dtype=tf.int32))
|
||||
tf_new_size = tf.constant(new_size, dtype=tf.int32)
|
||||
tensor_list_resize = tf.raw_ops.TensorListResize(input_handle=tensor_list, size=tf_new_size)
|
||||
element_shape = tf.constant(input_shape[1:], dtype=tf.int32)
|
||||
tf.raw_ops.TensorListStack(input_handle=tensor_list_resize, element_shape=element_shape,
|
||||
element_dtype=input_type)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(input_shape=[7], input_type=np.float32, new_size=3),
|
||||
dict(input_shape=[10, 20], input_type=np.float32, new_size=20),
|
||||
dict(input_shape=[2, 3, 4], input_type=np.int32, new_size=5),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_tensor_list_resize_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_tensor_list_resize(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user