[TF FE] [TF Hub] Support MaxPoolWithArgmax operation (#19085)

* [TF FE] Support MaxPoolWithArgmax operation

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add ticket number for TS crash

* Correct error message

* Skip crashing tests

* Set additional tensor name for MaxPool

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-08-09 20:18:27 +04:00 committed by GitHub
parent c44df9907b
commit 1939dd1df0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 34 deletions

View File

@ -190,6 +190,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"MaxPool", CreatorFunction(translate_max_pool_op)},
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
{"MaxPool3D", CreatorFunction(translate_max_pool_op)},
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)},
{"Merge", CreatorFunction(translate_merge_op)},
{"MirrorPad", CreatorFunction(translate_mirror_pad_op)},
{"MutableHashTable", CreatorFunction(translate_hash_table_op)},

View File

@ -2,12 +2,21 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/max_pool.hpp"
#include "common_op_table.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::op;
using namespace ov::frontend::tensorflow;
namespace ov {
@ -17,21 +26,25 @@ namespace op {
OutputVector translate_max_pool_util(const NodeContext& node,
size_t spatial_dims_num,
const std::vector<int64_t>& tf_kernel_sizes,
const std::vector<int64_t>& tf_strides) {
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D"});
const vector<int64_t>& tf_kernel_sizes,
const vector<int64_t>& tf_strides,
element::Type indices_element_type = element::i64,
int64_t axis = 0,
bool set_friendly_name = true,
bool with_indices = false) {
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D", "MaxPoolWithArgmax"});
TENSORFLOW_OP_VALIDATION(node,
spatial_dims_num == 2 || spatial_dims_num == 3,
"Only MaxPool, MaxPoolV2 and MaxPool3D are supported.");
"Only MaxPool, MaxPoolV2, MaxPool3D and MaxPoolWithArgmax are supported.");
auto input = node.get_input(0);
auto tf_padding_type = node.get_attribute<std::string>("padding");
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
auto tf_data_format = node.get_attribute<std::string>("data_format", spatial_dims_num == 2 ? "NHWC" : "NDHWC");
auto tf_padding_type = node.get_attribute<string>("padding");
PadType auto_pad = convert_tf_padding(node, tf_padding_type);
auto tf_data_format = node.get_attribute<string>("data_format", spatial_dims_num == 2 ? "NHWC" : "NDHWC");
auto tf_explicit_paddings = std::vector<int64_t>{};
if (auto_pad == ov::op::PadType::EXPLICIT) {
tf_explicit_paddings = node.get_attribute<std::vector<int64_t>>("explicit_paddings", {});
auto tf_explicit_paddings = vector<int64_t>{};
if (auto_pad == PadType::EXPLICIT) {
tf_explicit_paddings = node.get_attribute<vector<int64_t>>("explicit_paddings", {});
}
bool is_nhwc = true;
@ -48,40 +61,51 @@ OutputVector translate_max_pool_util(const NodeContext& node,
}
// prepare attributes for OpenVINO MaxPool operation
ov::Strides strides(spatial_dims_num);
ov::Strides dilations = (spatial_dims_num == 2 ? ov::Strides({1, 1}) : ov::Strides({1, 1, 1}));
ov::Shape kernel_sizes(spatial_dims_num);
ov::frontend::tensorflow::convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
ov::frontend::tensorflow::convert_nhwc_to_hw(is_nhwc, tf_kernel_sizes, kernel_sizes);
Strides strides(spatial_dims_num);
Strides dilations = (spatial_dims_num == 2 ? Strides({1, 1}) : Strides({1, 1, 1}));
Shape kernel_sizes(spatial_dims_num);
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
convert_nhwc_to_hw(is_nhwc, tf_kernel_sizes, kernel_sizes);
ov::CoordinateDiff pads_begin;
ov::CoordinateDiff pads_end;
if (auto_pad == ov::op::PadType::EXPLICIT) {
CoordinateDiff pads_begin;
CoordinateDiff pads_end;
if (auto_pad == PadType::EXPLICIT) {
fill_explicit_pads_vectors(node, is_nhwc, spatial_dims_num, tf_explicit_paddings, pads_begin, pads_end);
}
// prepare input to MaxPool
convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2));
convert_nhwc_to_nchw(is_nhwc, input, Rank(spatial_dims_num + 2));
auto max_pool_node = std::make_shared<ov::opset8::MaxPool>(input,
strides,
dilations,
ov::Shape(pads_begin.begin(), pads_begin.end()),
ov::Shape(pads_end.begin(), pads_end.end()),
kernel_sizes,
ov::op::RoundingType::FLOOR,
auto_pad);
auto max_pool_node = make_shared<v8::MaxPool>(input,
strides,
dilations,
Shape(pads_begin.begin(), pads_begin.end()),
Shape(pads_end.begin(), pads_end.end()),
kernel_sizes,
RoundingType::FLOOR,
auto_pad,
indices_element_type,
axis);
auto max_pool = max_pool_node->output(0);
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool, ov::Rank(spatial_dims_num + 2));
ov::frontend::tensorflow::set_node_name(node.get_name(), max_pool.get_node_shared_ptr());
convert_nchw_to_nhwc(is_nhwc, max_pool, Rank(spatial_dims_num + 2));
if (set_friendly_name) {
set_node_name(node.get_name(), max_pool.get_node_shared_ptr());
} else {
set_out_name(node.get_name() + ":0", max_pool);
}
if (with_indices) {
auto output_indices = max_pool_node->output(1);
return OutputVector{max_pool, output_indices};
}
return {max_pool};
}
OutputVector translate_max_pool(const NodeContext& node, size_t spatial_dims_num) {
// MaxPool2D and MaxPool3D have ksize and strides as attributes
// retrieve attributes
auto strides = node.get_attribute<std::vector<int64_t>>("strides");
auto kernel_sizes = node.get_attribute<std::vector<int64_t>>("ksize");
auto strides = node.get_attribute<vector<int64_t>>("strides");
auto kernel_sizes = node.get_attribute<vector<int64_t>>("ksize");
return translate_max_pool_util(node, spatial_dims_num, kernel_sizes, strides);
}
@ -104,6 +128,81 @@ OutputVector translate_max_pool_v2(const NodeContext& node) {
return translate_max_pool_util(node, 2, ksize_vector, strides_vector);
}
OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
// MaxPoolWithArgmax has just one input. ksize and strides are attributes
TENSORFLOW_OP_VALIDATION(node,
node.get_input_size() > 0,
"MaxPoolWithArgmax operation must have at least one input.");
auto include_batch_in_index = node.get_attribute<bool>("include_batch_in_index", false);
auto targmax = node.get_attribute<element::Type>("Targmax", element::i64);
auto ksize = node.get_attribute<vector<int64_t>>("ksize");
auto strides = node.get_attribute<vector<int64_t>>("ksize");
auto images = node.get_input(0);
auto node_name = node.get_name();
// indices from which dimension to count output indices
int64_t axis = include_batch_in_index ? 0 : 1;
auto max_pool_with_indices = translate_max_pool_util(node, 2, ksize, strides, targmax, axis, false, true);
TENSORFLOW_OP_VALIDATION(node,
max_pool_with_indices.size() == 2,
"[TensorFlow Frontend] internal error: expect two outputs for MaxPoolWithArgmax.");
auto max_pool = max_pool_with_indices[0];
auto output_indices_nchw = max_pool_with_indices[1];
auto tf_data_format = node.get_attribute<string>("data_format", "NHWC");
Output<Node> output_indices;
if (tf_data_format != "NHWC") {
output_indices = output_indices_nchw;
} else {
output_indices = output_indices_nchw;
// adjust output indices to have them for NHWC layout
// now it is computed for NCHW layout
// 1. compute all dimensions N, H, W, C
auto images_shape = make_shared<v3::ShapeOf>(images, targmax);
auto const_zero = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto const_one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto const_two = make_shared<v0::Constant>(element::i32, Shape{1}, 2);
auto const_three = make_shared<v0::Constant>(element::i32, Shape{1}, 3);
auto N = make_shared<v8::Gather>(images_shape, const_zero, const_zero);
auto H = make_shared<v8::Gather>(images_shape, const_one, const_zero);
auto W = make_shared<v8::Gather>(images_shape, const_two, const_zero);
auto C = make_shared<v8::Gather>(images_shape, const_three, const_zero);
// 2. compute complex index for NCHW layout, i.e. n, h, w, c
auto HW = make_shared<v1::Multiply>(H, W);
Output<Node> n;
if (include_batch_in_index) {
auto CHW = make_shared<v1::Multiply>(C, HW);
n = make_shared<v1::Divide>(output_indices_nchw, CHW);
auto nCHW = make_shared<v1::Multiply>(n, CHW);
output_indices_nchw = make_shared<v1::Subtract>(output_indices_nchw, nCHW);
} else {
n = make_shared<v0::Constant>(targmax, Shape{1}, 0);
}
auto c = make_shared<v1::Divide>(output_indices_nchw, HW);
auto cHW = make_shared<v1::Multiply>(c, HW);
output_indices_nchw = make_shared<v1::Subtract>(output_indices_nchw, cHW);
auto h = make_shared<v1::Divide>(output_indices_nchw, W);
auto hW = make_shared<v1::Multiply>(h, W);
auto w = make_shared<v1::Subtract>(output_indices_nchw, hW);
// transform them into flatten form for NHWC layout
auto WC = make_shared<v1::Multiply>(W, C);
auto HWC = make_shared<v1::Multiply>(H, WC);
output_indices = make_shared<v1::Multiply>(n, HWC);
auto hWC = make_shared<v1::Multiply>(h, WC);
output_indices = make_shared<v1::Add>(output_indices, hWC);
auto wC = make_shared<v1::Multiply>(w, C);
output_indices = make_shared<v1::Add>(output_indices, wC);
output_indices = make_shared<v1::Add>(output_indices, c);
convert_nchw_to_nhwc(true, output_indices, 4);
}
set_out_name(node_name + ":1", output_indices);
return {max_pool, output_indices};
}
OutputVector translate_max_pool_op(const NodeContext& node) {
if (node.get_op_type() == "MaxPool") {
return translate_max_pool(node, 2);
@ -111,6 +210,8 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
return translate_max_pool_v2(node);
} else if (node.get_op_type() == "MaxPool3D") {
return translate_max_pool(node, 3);
} else if (node.get_op_type() == "MaxPoolWithArgmax") {
return translate_max_pool_with_argmax(node);
} else {
TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported.");
}

View File

@ -42,6 +42,7 @@ PadType convert_tf_padding(const frontend::NodeContext& node, const string& tf_p
"MaxPool",
"MaxPoolV2",
"MaxPool3D",
"MaxPoolWithArgmax",
"ExtractImagePatches",
"DepthwiseConv2dNative",
"AvgPool",
@ -68,8 +69,8 @@ PadType convert_tf_padding(const frontend::NodeContext& node, const string& tf_p
return PadType::SAME_LOWER;
}
} else if (op_type == "Conv2D" || op_type == "Conv3D" || op_type == "MaxPool" || op_type == "MaxPoolV2" ||
op_type == "MaxPool3D" || op_type == "ExtractImagePatches" || op_type == "DepthwiseConv2dNative" ||
op_type == "AvgPool" || op_type == "AvgPool3D") {
op_type == "MaxPool3D" || op_type == "MaxPoolWithArgmax" || op_type == "ExtractImagePatches" ||
op_type == "DepthwiseConv2dNative" || op_type == "AvgPool" || op_type == "AvgPool3D") {
if (tf_padding == "SAME") {
// According to the formulas for calculating auto_pad values of the
// Conv layer in the Operation specification,

View File

@ -0,0 +1,75 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestMaxPoolWithArgmax(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'input' in inputs_info
input_shape = inputs_info['input']
inputs_data = {}
inputs_data['input'] = np.random.randint(-5, 5, input_shape).astype(self.input_type)
return inputs_data
def create_max_pool_with_argmax_net(self, input_shape, ksize, strides, input_type, padding, targmax,
include_batch_in_index, with_second_output):
self.input_type = input_type
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
input = tf.compat.v1.placeholder(input_type, input_shape, 'input')
max_pool_with_argmax = tf.raw_ops.MaxPoolWithArgmax(input=input, ksize=ksize, strides=strides,
padding=padding, Targmax=targmax,
include_batch_in_index=include_batch_in_index
)
tf.identity(max_pool_with_argmax[0], name='max_pool')
if with_second_output:
tf.identity(max_pool_with_argmax[1], name='output_indices')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[1, 25, 24, 3],
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1]),
dict(input_shape=[1, 10, 20, 3],
ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.parametrize("input_type", [
np.float32, np.int32
])
@pytest.mark.parametrize("padding", [
'VALID', 'SAME'
])
@pytest.mark.parametrize("targmax", [
tf.int32, tf.int64
])
@pytest.mark.parametrize("include_batch_in_index", [
True, False
])
@pytest.mark.parametrize("with_second_output", [
pytest.param(
True,
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
),
False
])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_max_pool_with_argmax_basic(self, params, input_type, padding, targmax,
include_batch_in_index, with_second_output,
ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_max_pool_with_argmax_net(**params, input_type=input_type, padding=padding, targmax=targmax,
include_batch_in_index=include_batch_in_index,
with_second_output=with_second_output),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)