[TF FE] Refactor ExtractImagePatches and add tests (#15456)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
8e5ec19d05
commit
d86ba0742c
@ -16,41 +16,36 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_extract_image_patches_op(const NodeContext& node) {
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 0, "ExtractImagePatches must have at least one input.");
|
||||
default_op_checks(node, 1, {"ExtractImagePatches"});
|
||||
auto images = node.get_input(0);
|
||||
|
||||
// retrieve attributes for ExtractImagePatches
|
||||
auto tf_ksizes = node.get_attribute<std::vector<int64_t>>("ksizes");
|
||||
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||
auto tf_rates = node.get_attribute<std::vector<int64_t>>("rates");
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
||||
auto padding = node.get_attribute<std::string>("padding");
|
||||
ov::op::PadType auto_pad = convert_tf_padding(node, padding);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
auto_pad == ov::op::PadType::SAME_UPPER || auto_pad == ov::op::PadType::VALID,
|
||||
"Only SAME_UPPER and VALID padding modes are supported for ExtractImagePatches.");
|
||||
"[TensorFlow Frontend] Inconsistent model: only SAME and VALID padding modes are "
|
||||
"supported for ExtractImagePatches.");
|
||||
|
||||
// prepare attributes for OpenVINO ExtractImagePatches
|
||||
// prepare attributes for opset ExtractImagePatches
|
||||
Shape sizes(2);
|
||||
Shape rates(2);
|
||||
Strides strides(2);
|
||||
|
||||
// layout for this operation is always NHWC
|
||||
bool is_nhwc = true;
|
||||
convert_nhwc_to_hw(is_nhwc, tf_ksizes, sizes);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_rates, rates);
|
||||
convert_nhwc_to_hw(true, tf_ksizes, sizes);
|
||||
convert_nhwc_to_hw(true, tf_strides, strides);
|
||||
convert_nhwc_to_hw(true, tf_rates, rates);
|
||||
|
||||
// prepare input to ExtractImagePatches
|
||||
convert_nhwc_to_nchw(is_nhwc, images);
|
||||
convert_nhwc_to_nchw(true, images);
|
||||
|
||||
auto extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
|
||||
Output<Node> extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
|
||||
convert_nchw_to_nhwc(true, extract_image_patches);
|
||||
|
||||
// prepare output to return the original layout NHWC
|
||||
auto extract_image_patches_output = extract_image_patches->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, extract_image_patches_output);
|
||||
|
||||
set_node_name(node.get_name(), extract_image_patches_output.get_node_shared_ptr());
|
||||
return {extract_image_patches_output};
|
||||
set_node_name(node.get_name(), extract_image_patches.get_node_shared_ptr());
|
||||
return {extract_image_patches};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,44 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestExtractImagePatches(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
# generate elements so that the input tensor may contain repeating elements
|
||||
assert 'images' in inputs_info, "Test error: inputs_info must contain `images`"
|
||||
images_shape = inputs_info['images']
|
||||
inputs_data = {}
|
||||
inputs_data['images'] = np.random.randint(-10, 10, images_shape).astype(np.float32)
|
||||
return inputs_data
|
||||
|
||||
def create_extract_image_patches_net(self, images_shape, ksizes, strides, rates, padding):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
with tf.compat.v1.Session() as sess:
|
||||
images = tf.compat.v1.placeholder(tf.float32, images_shape, 'images')
|
||||
tf.raw_ops.ExtractImagePatches(images=images, ksizes=ksizes, strides=strides, rates=rates, padding=padding)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_basic = [
|
||||
# TensorFlow supports patching only across spatial dimensions
|
||||
dict(images_shape=[2, 110, 50, 4], ksizes=[1, 20, 30, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1]),
|
||||
dict(images_shape=[3, 30, 40, 3], ksizes=[1, 5, 10, 1], strides=[1, 3, 1, 1], rates=[1, 4, 3, 1]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_basic)
|
||||
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_extract_image_patches_basic(self, params, padding, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend,
|
||||
use_old_api):
|
||||
self._test(*self.create_extract_image_patches_net(**params, padding=padding),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user