[TF FE] [TF Hub] Support MaxPoolWithArgmax operation (#19085)
* [TF FE] Support MaxPoolWithArgmax operation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add ticket number for TS crash * Correct error message * Skip crashing tests * Set additional tensor name for MaxPool --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
c44df9907b
commit
1939dd1df0
@ -190,6 +190,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"MaxPool", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPool3D", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)},
|
||||
{"Merge", CreatorFunction(translate_merge_op)},
|
||||
{"MirrorPad", CreatorFunction(translate_mirror_pad_op)},
|
||||
{"MutableHashTable", CreatorFunction(translate_hash_table_op)},
|
||||
|
@ -2,12 +2,21 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/max_pool.hpp"
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::op;
|
||||
using namespace ov::frontend::tensorflow;
|
||||
|
||||
namespace ov {
|
||||
@ -17,21 +26,25 @@ namespace op {
|
||||
|
||||
OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
size_t spatial_dims_num,
|
||||
const std::vector<int64_t>& tf_kernel_sizes,
|
||||
const std::vector<int64_t>& tf_strides) {
|
||||
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D"});
|
||||
const vector<int64_t>& tf_kernel_sizes,
|
||||
const vector<int64_t>& tf_strides,
|
||||
element::Type indices_element_type = element::i64,
|
||||
int64_t axis = 0,
|
||||
bool set_friendly_name = true,
|
||||
bool with_indices = false) {
|
||||
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D", "MaxPoolWithArgmax"});
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
spatial_dims_num == 2 || spatial_dims_num == 3,
|
||||
"Only MaxPool, MaxPoolV2 and MaxPool3D are supported.");
|
||||
"Only MaxPool, MaxPoolV2, MaxPool3D and MaxPoolWithArgmax are supported.");
|
||||
auto input = node.get_input(0);
|
||||
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format", spatial_dims_num == 2 ? "NHWC" : "NDHWC");
|
||||
auto tf_padding_type = node.get_attribute<string>("padding");
|
||||
PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
||||
auto tf_data_format = node.get_attribute<string>("data_format", spatial_dims_num == 2 ? "NHWC" : "NDHWC");
|
||||
|
||||
auto tf_explicit_paddings = std::vector<int64_t>{};
|
||||
if (auto_pad == ov::op::PadType::EXPLICIT) {
|
||||
tf_explicit_paddings = node.get_attribute<std::vector<int64_t>>("explicit_paddings", {});
|
||||
auto tf_explicit_paddings = vector<int64_t>{};
|
||||
if (auto_pad == PadType::EXPLICIT) {
|
||||
tf_explicit_paddings = node.get_attribute<vector<int64_t>>("explicit_paddings", {});
|
||||
}
|
||||
|
||||
bool is_nhwc = true;
|
||||
@ -48,40 +61,51 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
}
|
||||
|
||||
// prepare attributes for OpenVINO MaxPool operation
|
||||
ov::Strides strides(spatial_dims_num);
|
||||
ov::Strides dilations = (spatial_dims_num == 2 ? ov::Strides({1, 1}) : ov::Strides({1, 1, 1}));
|
||||
ov::Shape kernel_sizes(spatial_dims_num);
|
||||
ov::frontend::tensorflow::convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
||||
ov::frontend::tensorflow::convert_nhwc_to_hw(is_nhwc, tf_kernel_sizes, kernel_sizes);
|
||||
Strides strides(spatial_dims_num);
|
||||
Strides dilations = (spatial_dims_num == 2 ? Strides({1, 1}) : Strides({1, 1, 1}));
|
||||
Shape kernel_sizes(spatial_dims_num);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_kernel_sizes, kernel_sizes);
|
||||
|
||||
ov::CoordinateDiff pads_begin;
|
||||
ov::CoordinateDiff pads_end;
|
||||
if (auto_pad == ov::op::PadType::EXPLICIT) {
|
||||
CoordinateDiff pads_begin;
|
||||
CoordinateDiff pads_end;
|
||||
if (auto_pad == PadType::EXPLICIT) {
|
||||
fill_explicit_pads_vectors(node, is_nhwc, spatial_dims_num, tf_explicit_paddings, pads_begin, pads_end);
|
||||
}
|
||||
|
||||
// prepare input to MaxPool
|
||||
convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2));
|
||||
convert_nhwc_to_nchw(is_nhwc, input, Rank(spatial_dims_num + 2));
|
||||
|
||||
auto max_pool_node = std::make_shared<ov::opset8::MaxPool>(input,
|
||||
strides,
|
||||
dilations,
|
||||
ov::Shape(pads_begin.begin(), pads_begin.end()),
|
||||
ov::Shape(pads_end.begin(), pads_end.end()),
|
||||
kernel_sizes,
|
||||
ov::op::RoundingType::FLOOR,
|
||||
auto_pad);
|
||||
auto max_pool_node = make_shared<v8::MaxPool>(input,
|
||||
strides,
|
||||
dilations,
|
||||
Shape(pads_begin.begin(), pads_begin.end()),
|
||||
Shape(pads_end.begin(), pads_end.end()),
|
||||
kernel_sizes,
|
||||
RoundingType::FLOOR,
|
||||
auto_pad,
|
||||
indices_element_type,
|
||||
axis);
|
||||
auto max_pool = max_pool_node->output(0);
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool, ov::Rank(spatial_dims_num + 2));
|
||||
ov::frontend::tensorflow::set_node_name(node.get_name(), max_pool.get_node_shared_ptr());
|
||||
convert_nchw_to_nhwc(is_nhwc, max_pool, Rank(spatial_dims_num + 2));
|
||||
if (set_friendly_name) {
|
||||
set_node_name(node.get_name(), max_pool.get_node_shared_ptr());
|
||||
} else {
|
||||
set_out_name(node.get_name() + ":0", max_pool);
|
||||
}
|
||||
|
||||
if (with_indices) {
|
||||
auto output_indices = max_pool_node->output(1);
|
||||
return OutputVector{max_pool, output_indices};
|
||||
}
|
||||
return {max_pool};
|
||||
}
|
||||
|
||||
OutputVector translate_max_pool(const NodeContext& node, size_t spatial_dims_num) {
|
||||
// MaxPool2D and MaxPool3D have ksize and strides as attributes
|
||||
// retrieve attributes
|
||||
auto strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||
auto kernel_sizes = node.get_attribute<std::vector<int64_t>>("ksize");
|
||||
auto strides = node.get_attribute<vector<int64_t>>("strides");
|
||||
auto kernel_sizes = node.get_attribute<vector<int64_t>>("ksize");
|
||||
return translate_max_pool_util(node, spatial_dims_num, kernel_sizes, strides);
|
||||
}
|
||||
|
||||
@ -104,6 +128,81 @@ OutputVector translate_max_pool_v2(const NodeContext& node) {
|
||||
return translate_max_pool_util(node, 2, ksize_vector, strides_vector);
|
||||
}
|
||||
|
||||
OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
|
||||
// MaxPoolWithArgmax has just one input. ksize and strides are attributes
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
node.get_input_size() > 0,
|
||||
"MaxPoolWithArgmax operation must have at least one input.");
|
||||
auto include_batch_in_index = node.get_attribute<bool>("include_batch_in_index", false);
|
||||
auto targmax = node.get_attribute<element::Type>("Targmax", element::i64);
|
||||
auto ksize = node.get_attribute<vector<int64_t>>("ksize");
|
||||
auto strides = node.get_attribute<vector<int64_t>>("ksize");
|
||||
auto images = node.get_input(0);
|
||||
auto node_name = node.get_name();
|
||||
|
||||
// indices from which dimension to count output indices
|
||||
int64_t axis = include_batch_in_index ? 0 : 1;
|
||||
|
||||
auto max_pool_with_indices = translate_max_pool_util(node, 2, ksize, strides, targmax, axis, false, true);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
max_pool_with_indices.size() == 2,
|
||||
"[TensorFlow Frontend] internal error: expect two outputs for MaxPoolWithArgmax.");
|
||||
auto max_pool = max_pool_with_indices[0];
|
||||
auto output_indices_nchw = max_pool_with_indices[1];
|
||||
|
||||
auto tf_data_format = node.get_attribute<string>("data_format", "NHWC");
|
||||
Output<Node> output_indices;
|
||||
if (tf_data_format != "NHWC") {
|
||||
output_indices = output_indices_nchw;
|
||||
} else {
|
||||
output_indices = output_indices_nchw;
|
||||
// adjust output indices to have them for NHWC layout
|
||||
// now it is computed for NCHW layout
|
||||
// 1. compute all dimensions N, H, W, C
|
||||
auto images_shape = make_shared<v3::ShapeOf>(images, targmax);
|
||||
auto const_zero = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||
auto const_one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||
auto const_two = make_shared<v0::Constant>(element::i32, Shape{1}, 2);
|
||||
auto const_three = make_shared<v0::Constant>(element::i32, Shape{1}, 3);
|
||||
auto N = make_shared<v8::Gather>(images_shape, const_zero, const_zero);
|
||||
auto H = make_shared<v8::Gather>(images_shape, const_one, const_zero);
|
||||
auto W = make_shared<v8::Gather>(images_shape, const_two, const_zero);
|
||||
auto C = make_shared<v8::Gather>(images_shape, const_three, const_zero);
|
||||
|
||||
// 2. compute complex index for NCHW layout, i.e. n, h, w, c
|
||||
auto HW = make_shared<v1::Multiply>(H, W);
|
||||
Output<Node> n;
|
||||
if (include_batch_in_index) {
|
||||
auto CHW = make_shared<v1::Multiply>(C, HW);
|
||||
n = make_shared<v1::Divide>(output_indices_nchw, CHW);
|
||||
auto nCHW = make_shared<v1::Multiply>(n, CHW);
|
||||
output_indices_nchw = make_shared<v1::Subtract>(output_indices_nchw, nCHW);
|
||||
} else {
|
||||
n = make_shared<v0::Constant>(targmax, Shape{1}, 0);
|
||||
}
|
||||
auto c = make_shared<v1::Divide>(output_indices_nchw, HW);
|
||||
auto cHW = make_shared<v1::Multiply>(c, HW);
|
||||
output_indices_nchw = make_shared<v1::Subtract>(output_indices_nchw, cHW);
|
||||
auto h = make_shared<v1::Divide>(output_indices_nchw, W);
|
||||
auto hW = make_shared<v1::Multiply>(h, W);
|
||||
auto w = make_shared<v1::Subtract>(output_indices_nchw, hW);
|
||||
|
||||
// transform them into flatten form for NHWC layout
|
||||
auto WC = make_shared<v1::Multiply>(W, C);
|
||||
auto HWC = make_shared<v1::Multiply>(H, WC);
|
||||
output_indices = make_shared<v1::Multiply>(n, HWC);
|
||||
auto hWC = make_shared<v1::Multiply>(h, WC);
|
||||
output_indices = make_shared<v1::Add>(output_indices, hWC);
|
||||
auto wC = make_shared<v1::Multiply>(w, C);
|
||||
output_indices = make_shared<v1::Add>(output_indices, wC);
|
||||
output_indices = make_shared<v1::Add>(output_indices, c);
|
||||
convert_nchw_to_nhwc(true, output_indices, 4);
|
||||
}
|
||||
|
||||
set_out_name(node_name + ":1", output_indices);
|
||||
return {max_pool, output_indices};
|
||||
}
|
||||
|
||||
OutputVector translate_max_pool_op(const NodeContext& node) {
|
||||
if (node.get_op_type() == "MaxPool") {
|
||||
return translate_max_pool(node, 2);
|
||||
@ -111,6 +210,8 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
|
||||
return translate_max_pool_v2(node);
|
||||
} else if (node.get_op_type() == "MaxPool3D") {
|
||||
return translate_max_pool(node, 3);
|
||||
} else if (node.get_op_type() == "MaxPoolWithArgmax") {
|
||||
return translate_max_pool_with_argmax(node);
|
||||
} else {
|
||||
TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported.");
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ PadType convert_tf_padding(const frontend::NodeContext& node, const string& tf_p
|
||||
"MaxPool",
|
||||
"MaxPoolV2",
|
||||
"MaxPool3D",
|
||||
"MaxPoolWithArgmax",
|
||||
"ExtractImagePatches",
|
||||
"DepthwiseConv2dNative",
|
||||
"AvgPool",
|
||||
@ -68,8 +69,8 @@ PadType convert_tf_padding(const frontend::NodeContext& node, const string& tf_p
|
||||
return PadType::SAME_LOWER;
|
||||
}
|
||||
} else if (op_type == "Conv2D" || op_type == "Conv3D" || op_type == "MaxPool" || op_type == "MaxPoolV2" ||
|
||||
op_type == "MaxPool3D" || op_type == "ExtractImagePatches" || op_type == "DepthwiseConv2dNative" ||
|
||||
op_type == "AvgPool" || op_type == "AvgPool3D") {
|
||||
op_type == "MaxPool3D" || op_type == "MaxPoolWithArgmax" || op_type == "ExtractImagePatches" ||
|
||||
op_type == "DepthwiseConv2dNative" || op_type == "AvgPool" || op_type == "AvgPool3D") {
|
||||
if (tf_padding == "SAME") {
|
||||
// According to the formulas for calculating auto_pad values of the
|
||||
// Conv layer in the Operation specification,
|
||||
|
@ -0,0 +1,75 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestMaxPoolWithArgmax(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'input' in inputs_info
|
||||
input_shape = inputs_info['input']
|
||||
inputs_data = {}
|
||||
inputs_data['input'] = np.random.randint(-5, 5, input_shape).astype(self.input_type)
|
||||
return inputs_data
|
||||
|
||||
def create_max_pool_with_argmax_net(self, input_shape, ksize, strides, input_type, padding, targmax,
|
||||
include_batch_in_index, with_second_output):
|
||||
self.input_type = input_type
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input = tf.compat.v1.placeholder(input_type, input_shape, 'input')
|
||||
max_pool_with_argmax = tf.raw_ops.MaxPoolWithArgmax(input=input, ksize=ksize, strides=strides,
|
||||
padding=padding, Targmax=targmax,
|
||||
include_batch_in_index=include_batch_in_index
|
||||
)
|
||||
tf.identity(max_pool_with_argmax[0], name='max_pool')
|
||||
if with_second_output:
|
||||
tf.identity(max_pool_with_argmax[1], name='output_indices')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(input_shape=[1, 25, 24, 3],
|
||||
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1]),
|
||||
dict(input_shape=[1, 10, 20, 3],
|
||||
ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.parametrize("input_type", [
|
||||
np.float32, np.int32
|
||||
])
|
||||
@pytest.mark.parametrize("padding", [
|
||||
'VALID', 'SAME'
|
||||
])
|
||||
@pytest.mark.parametrize("targmax", [
|
||||
tf.int32, tf.int64
|
||||
])
|
||||
@pytest.mark.parametrize("include_batch_in_index", [
|
||||
True, False
|
||||
])
|
||||
@pytest.mark.parametrize("with_second_output", [
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
|
||||
),
|
||||
False
|
||||
])
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_max_pool_with_argmax_basic(self, params, input_type, padding, targmax,
|
||||
include_batch_in_index, with_second_output,
|
||||
ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(
|
||||
*self.create_max_pool_with_argmax_net(**params, input_type=input_type, padding=padding, targmax=targmax,
|
||||
include_batch_in_index=include_batch_in_index,
|
||||
with_second_output=with_second_output),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user