[TF FE] Refactor ExtractImagePatches and add tests (#15456)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-02 18:28:31 +04:00 committed by GitHub
parent 8e5ec19d05
commit d86ba0742c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 19 deletions

View File

@ -16,41 +16,36 @@ namespace tensorflow {
namespace op {
OutputVector translate_extract_image_patches_op(const NodeContext& node) {
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 0, "ExtractImagePatches must have at least one input.");
default_op_checks(node, 1, {"ExtractImagePatches"});
auto images = node.get_input(0);
// retrieve attributes for ExtractImagePatches
auto tf_ksizes = node.get_attribute<std::vector<int64_t>>("ksizes");
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
auto tf_rates = node.get_attribute<std::vector<int64_t>>("rates");
auto tf_padding_type = node.get_attribute<std::string>("padding");
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
auto padding = node.get_attribute<std::string>("padding");
ov::op::PadType auto_pad = convert_tf_padding(node, padding);
TENSORFLOW_OP_VALIDATION(node,
auto_pad == ov::op::PadType::SAME_UPPER || auto_pad == ov::op::PadType::VALID,
"Only SAME_UPPER and VALID padding modes are supported for ExtractImagePatches.");
"[TensorFlow Frontend] Inconsistent model: only SAME and VALID padding modes are "
"supported for ExtractImagePatches.");
// prepare attributes for OpenVINO ExtractImagePatches
// prepare attributes for opset ExtractImagePatches
Shape sizes(2);
Shape rates(2);
Strides strides(2);
// layout for this operation is always NHWC
bool is_nhwc = true;
convert_nhwc_to_hw(is_nhwc, tf_ksizes, sizes);
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
convert_nhwc_to_hw(is_nhwc, tf_rates, rates);
convert_nhwc_to_hw(true, tf_ksizes, sizes);
convert_nhwc_to_hw(true, tf_strides, strides);
convert_nhwc_to_hw(true, tf_rates, rates);
// prepare input to ExtractImagePatches
convert_nhwc_to_nchw(is_nhwc, images);
convert_nhwc_to_nchw(true, images);
auto extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
Output<Node> extract_image_patches = make_shared<ExtractImagePatches>(images, sizes, strides, rates, auto_pad);
convert_nchw_to_nhwc(true, extract_image_patches);
// prepare output to return the original layout NHWC
auto extract_image_patches_output = extract_image_patches->output(0);
convert_nchw_to_nhwc(is_nhwc, extract_image_patches_output);
set_node_name(node.get_name(), extract_image_patches_output.get_node_shared_ptr());
return {extract_image_patches_output};
set_node_name(node.get_name(), extract_image_patches.get_node_shared_ptr());
return {extract_image_patches};
}
} // namespace op
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestExtractImagePatches(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
# generate elements so that the input tensor may contain repeating elements
assert 'images' in inputs_info, "Test error: inputs_info must contain `images`"
images_shape = inputs_info['images']
inputs_data = {}
inputs_data['images'] = np.random.randint(-10, 10, images_shape).astype(np.float32)
return inputs_data
def create_extract_image_patches_net(self, images_shape, ksizes, strides, rates, padding):
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
images = tf.compat.v1.placeholder(tf.float32, images_shape, 'images')
tf.raw_ops.ExtractImagePatches(images=images, ksizes=ksizes, strides=strides, rates=rates, padding=padding)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_basic = [
# TensorFlow supports patching only across spatial dimensions
dict(images_shape=[2, 110, 50, 4], ksizes=[1, 20, 30, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1]),
dict(images_shape=[3, 30, 40, 3], ksizes=[1, 5, 10, 1], strides=[1, 3, 1, 1], rates=[1, 4, 3, 1]),
]
@pytest.mark.parametrize("params", test_basic)
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.nightly
@pytest.mark.precommit_tf_fe
def test_extract_image_patches_basic(self, params, padding, ie_device, precision, ir_version, temp_dir,
use_new_frontend,
use_old_api):
self._test(*self.create_extract_image_patches_net(**params, padding=padding),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)