[TF FE] Refactor ExtractImagePatches and add tests (#15456)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
8e5ec19d05
commit
d86ba0742c
@ -16,41 +16,36 @@ namespace tensorflow {
|
|||||||
namespace op {
|
namespace op {
|
||||||
|
|
||||||
OutputVector translate_extract_image_patches_op(const NodeContext& node) {
|
OutputVector translate_extract_image_patches_op(const NodeContext& node) {
|
||||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 0, "ExtractImagePatches must have at least one input.");
|
default_op_checks(node, 1, {"ExtractImagePatches"});
|
||||||
auto images = node.get_input(0);
|
auto images = node.get_input(0);
|
||||||
|
|
||||||
// retrieve attributes for ExtractImagePatches
|
// retrieve attributes for ExtractImagePatches
|
||||||
auto tf_ksizes = node.get_attribute<std::vector<int64_t>>("ksizes");
|
auto tf_ksizes = node.get_attribute<std::vector<int64_t>>("ksizes");
|
||||||
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||||
auto tf_rates = node.get_attribute<std::vector<int64_t>>("rates");
|
auto tf_rates = node.get_attribute<std::vector<int64_t>>("rates");
|
||||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
auto padding = node.get_attribute<std::string>("padding");
|
||||||
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
ov::op::PadType auto_pad = convert_tf_padding(node, padding);
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
auto_pad == ov::op::PadType::SAME_UPPER || auto_pad == ov::op::PadType::VALID,
|
auto_pad == ov::op::PadType::SAME_UPPER || auto_pad == ov::op::PadType::VALID,
|
||||||
"Only SAME_UPPER and VALID padding modes are supported for ExtractImagePatches.");
|
"[TensorFlow Frontend] Inconsistent model: only SAME and VALID padding modes are "
|
||||||
|
"supported for ExtractImagePatches.");
|
||||||
|
|
||||||
// prepare attributes for OpenVINO ExtractImagePatches
|
// prepare attributes for opset ExtractImagePatches
|
||||||
Shape sizes(2);
|
Shape sizes(2);
|
||||||
Shape rates(2);
|
Shape rates(2);
|
||||||
Strides strides(2);
|
Strides strides(2);
|
||||||
|
convert_nhwc_to_hw(true, tf_ksizes, sizes);
|
||||||
// layout for this operation is always NHWC
|
convert_nhwc_to_hw(true, tf_strides, strides);
|
||||||
bool is_nhwc = true;
|
convert_nhwc_to_hw(true, tf_rates, rates);
|
||||||
convert_nhwc_to_hw(is_nhwc, tf_ksizes, sizes);
|
|
||||||
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
|
||||||
convert_nhwc_to_hw(is_nhwc, tf_rates, rates);
|
|
||||||
|
|
||||||
// prepare input to ExtractImagePatches
|
// prepare input to ExtractImagePatches
|
||||||
convert_nhwc_to_nchw(is_nhwc, images);
|
convert_nhwc_to_nchw(true, images);
|
||||||
|
|
||||||
auto extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
|
Output<Node> extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
|
||||||
|
convert_nchw_to_nhwc(true, extract_image_patches);
|
||||||
|
|
||||||
// prepare output to return the original layout NHWC
|
set_node_name(node.get_name(), extract_image_patches.get_node_shared_ptr());
|
||||||
auto extract_image_patches_output = extract_image_patches->output(0);
|
return {extract_image_patches};
|
||||||
convert_nchw_to_nhwc(is_nhwc, extract_image_patches_output);
|
|
||||||
|
|
||||||
set_node_name(node.get_name(), extract_image_patches_output.get_node_shared_ptr());
|
|
||||||
return {extract_image_patches_output};
|
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import tensorflow as tf
|
||||||
|
from common.tf_layer_test_class import CommonTFLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractImagePatches(CommonTFLayerTest):
|
||||||
|
def _prepare_input(self, inputs_info):
|
||||||
|
# generate elements so that the input tensor may contain repeating elements
|
||||||
|
assert 'images' in inputs_info, "Test error: inputs_info must contain `images`"
|
||||||
|
images_shape = inputs_info['images']
|
||||||
|
inputs_data = {}
|
||||||
|
inputs_data['images'] = np.random.randint(-10, 10, images_shape).astype(np.float32)
|
||||||
|
return inputs_data
|
||||||
|
|
||||||
|
def create_extract_image_patches_net(self, images_shape, ksizes, strides, rates, padding):
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
images = tf.compat.v1.placeholder(tf.float32, images_shape, 'images')
|
||||||
|
tf.raw_ops.ExtractImagePatches(images=images, ksizes=ksizes, strides=strides, rates=rates, padding=padding)
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
|
return tf_net, None
|
||||||
|
|
||||||
|
test_basic = [
|
||||||
|
# TensorFlow supports patching only across spatial dimensions
|
||||||
|
dict(images_shape=[2, 110, 50, 4], ksizes=[1, 20, 30, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1]),
|
||||||
|
dict(images_shape=[3, 30, 40, 3], ksizes=[1, 5, 10, 1], strides=[1, 3, 1, 1], rates=[1, 4, 3, 1]),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_basic)
|
||||||
|
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
|
def test_extract_image_patches_basic(self, params, padding, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend,
|
||||||
|
use_old_api):
|
||||||
|
self._test(*self.create_extract_image_patches_net(**params, padding=padding),
|
||||||
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user