AvgPool3D translator, fix of MaxPool translator in TF FE (#10530)
* Fixed MaxPool translator, added AvgPool3D translator. * Update src/frontends/tensorflow/src/op/avg_pool.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Code style. Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
fb6359586d
commit
2e164b4ddc
@ -21,15 +21,21 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format");
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
tf_data_format == "NHWC" || tf_data_format == "NCHW",
|
||||
"AvgPool data format is neither NHWC nor NCHW");
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
tf_data_format == "NHWC" || tf_data_format == "NCHW" || tf_data_format == "NDHWC" || tf_data_format == "NCDHW",
|
||||
"AvgPool data format is neither NHWC (NDHWC) nor NCHW (NCDHW)");
|
||||
|
||||
bool is_nhwc = (tf_data_format == "NHWC");
|
||||
bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC");
|
||||
|
||||
Strides ng_strides(2);
|
||||
Shape ng_image_shape(2);
|
||||
Shape ng_kernel_shape(2);
|
||||
int N = 2;
|
||||
if (node.get_op_type() == "AvgPool3D") {
|
||||
N = 3;
|
||||
}
|
||||
|
||||
Strides ng_strides(N);
|
||||
Shape ng_image_shape(N);
|
||||
Shape ng_kernel_shape(N);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides);
|
||||
convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_ksize, ng_kernel_shape);
|
||||
@ -37,7 +43,7 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
|
||||
|
||||
CoordinateDiff padding_below;
|
||||
CoordinateDiff padding_above;
|
||||
Shape ng_dilations{1, 1};
|
||||
Shape ng_dilations(N, 1);
|
||||
make_padding(tf_padding_type,
|
||||
ng_image_shape,
|
||||
ng_kernel_shape,
|
||||
|
@ -25,7 +25,7 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
|
||||
bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC");
|
||||
|
||||
int N = 2;
|
||||
if (node.get_name() == "MaxPool3D") {
|
||||
if (node.get_op_type() == "MaxPool3D") {
|
||||
N = 3;
|
||||
}
|
||||
Strides ng_strides(N);
|
||||
|
@ -157,6 +157,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"ArgMax", translate_arg_max_op},
|
||||
{"ArgMin", translate_arg_min_op},
|
||||
{"AvgPool", translate_avg_pool_op},
|
||||
{"AvgPool3D", translate_avg_pool_op},
|
||||
{"BatchToSpaceND", translate_batch_nd_and_space_nd_op},
|
||||
{"BiasAdd", translate_bias_add_op},
|
||||
{"Cast", translate_cast_op},
|
||||
|
Loading…
Reference in New Issue
Block a user