AvgPool3D translator, fix of MaxPool translator in TF FE (#10530)

* Fixed MaxPool translator, added AvgPool3D translator.

* Update src/frontends/tensorflow/src/op/avg_pool.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Code style.

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Anastasia Popova 2022-02-19 02:47:01 +03:00 committed by GitHub
parent fb6359586d
commit 2e164b4ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 9 deletions

View File

@ -21,15 +21,21 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
auto tf_padding_type = node.get_attribute<std::string>("padding");
auto tf_data_format = node.get_attribute<std::string>("data_format");
TENSORFLOW_OP_VALIDATION(node,
tf_data_format == "NHWC" || tf_data_format == "NCHW",
"AvgPool data format is neither NHWC nor NCHW");
TENSORFLOW_OP_VALIDATION(
node,
tf_data_format == "NHWC" || tf_data_format == "NCHW" || tf_data_format == "NDHWC" || tf_data_format == "NCDHW",
"AvgPool data format is neither NHWC (NDHWC) nor NCHW (NCDHW)");
bool is_nhwc = (tf_data_format == "NHWC");
bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC");
Strides ng_strides(2);
Shape ng_image_shape(2);
Shape ng_kernel_shape(2);
int N = 2;
if (node.get_op_type() == "AvgPool3D") {
N = 3;
}
Strides ng_strides(N);
Shape ng_image_shape(N);
Shape ng_kernel_shape(N);
convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides);
convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape);
convert_nhwc_to_hw(is_nhwc, tf_ksize, ng_kernel_shape);
@ -37,7 +43,7 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
CoordinateDiff padding_below;
CoordinateDiff padding_above;
Shape ng_dilations{1, 1};
Shape ng_dilations(N, 1);
make_padding(tf_padding_type,
ng_image_shape,
ng_kernel_shape,

View File

@ -25,7 +25,7 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC");
int N = 2;
if (node.get_name() == "MaxPool3D") {
if (node.get_op_type() == "MaxPool3D") {
N = 3;
}
Strides ng_strides(N);

View File

@ -157,6 +157,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"ArgMax", translate_arg_max_op},
{"ArgMin", translate_arg_min_op},
{"AvgPool", translate_avg_pool_op},
{"AvgPool3D", translate_avg_pool_op},
{"BatchToSpaceND", translate_batch_nd_and_space_nd_op},
{"BiasAdd", translate_bias_add_op},
{"Cast", translate_cast_op},