[TF FE] Refactor translators for Resize operations and correct Pooling (#12721)

* [TF FE] Refactor translators for Resize operations and correct Pooling

It allows to convert magenta_arbitrary-image-stylization model

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Align TF FE tranlator for Resize with legacy frontend

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Do minor fix for MaxPool

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-08-24 14:55:10 +03:00 committed by GitHub
parent 39b743f7a0
commit 3595f195f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 23 deletions

View File

@ -57,7 +57,8 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
auto avg_pool_output = avg_pool->output(0);
convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2));
set_node_name(node.get_name(), avg_pool);
return {avg_pool};
return {avg_pool_output};
}
} // namespace op
} // namespace tensorflow

View File

@ -13,34 +13,65 @@ namespace frontend {
namespace tensorflow {
namespace op {
ov::OutputVector translate_interpolate_op(const NodeContext& node) {
auto input = node.get_input(0);
auto input_sizes = node.get_input(1);
default_op_checks(node, 2, {"ResizeBilinear", "ResizeNearestNeighbor"});
auto images = node.get_input(0);
auto size = node.get_input(1);
auto op_name = node.get_name();
auto op_type = node.get_op_type();
// retrieve optional attribute
auto tf_align_corners = node.get_attribute<bool>("align_corners", false);
auto tf_half_pixel_centers = node.get_attribute<bool>("half_pixel_centers", false);
TENSORFLOW_OP_VALIDATION(node,
!tf_half_pixel_centers || (tf_half_pixel_centers && !tf_align_corners),
"If half_pixel_centers attribute of the node" + op_name + " with op " + op_type +
" is True, the attribute align_corners must be False.");
// prepare attributes for OpenVINO Interpolate operation
Interpolate::InterpolateAttrs interpolate_attrs;
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
interpolate_attrs.shape_calculation_mode = Interpolate::ShapeCalcMode::SIZES;
if (node.get_attribute<bool>("align_corners", false))
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
if (node.get_op_type() == "ResizeNearestNeighbor") {
if (op_type == "ResizeNearestNeighbor") {
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
interpolate_attrs.nearest_mode = Interpolate::NearestMode::FLOOR;
} else if (op_type == "ResizeBilinear") {
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_FLOOR;
}
// TODO: do we need this .get_shape() actually?
auto input_shape = input.get_shape();
std::vector<float> spatial_shape = {static_cast<float>(input_shape[1]), static_cast<float>(input_shape[2])};
auto ng_spatial_shape = make_shared<Constant>(element::f32, Shape{2}, spatial_shape);
if (tf_align_corners) {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_CEIL;
}
} else if (tf_half_pixel_centers) {
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
interpolate_attrs.coordinate_transformation_mode =
Interpolate::CoordinateTransformMode::TF_HALF_PIXEL_FOR_NN;
} else {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::HALF_PIXEL;
}
} else {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ASYMMETRIC;
}
auto ng_sizes = make_shared<Convert>(input_sizes, element::f32);
auto ng_scales = make_shared<Divide>(ng_sizes, ng_spatial_shape);
auto ng_axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
// prepare scales input
auto images_shape = make_shared<ShapeOf>(images, ov::element::i32);
auto spatial_shape = make_shared<Slice>(images_shape,
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{3}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{0}));
auto scales = make_shared<Divide>(make_shared<Convert>(size, element::f32),
make_shared<Convert>(spatial_shape, element::f32));
input = make_transpose(input, {0, 3, 1, 2});
auto res = make_shared<Interpolate>(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0);
res = make_transpose(res, {0, 2, 3, 1});
set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res};
// since Interpolate is layout agnostic
// we can avoid Transpose operation by specifying axes = {1, 2} for original NHWC layout
auto axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({1, 2}));
auto interpolate = make_shared<Interpolate>(images, size, scales, axes, interpolate_attrs);
set_node_name(node.get_name(), interpolate);
return {interpolate};
}
} // namespace op
} // namespace tensorflow

View File

@ -19,10 +19,10 @@ OutputVector translate_max_pool_util(const NodeContext& node,
size_t spatial_dims_num,
const std::vector<int64_t>& tf_kernel_sizes,
const std::vector<int64_t>& tf_strides) {
default_op_checks(node, 1, {"MaxPool2D", "MaxPool3D"});
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D"});
TENSORFLOW_OP_VALIDATION(node,
spatial_dims_num == 2 || spatial_dims_num == 3,
"Only MaxPool2D and MaxPool3D are supported.");
"Only MaxPool, MaxPoolV2 and MaxPool3D are supported.");
auto input = node.get_input(0);
auto tf_padding_type = node.get_attribute<std::string>("padding");
@ -38,7 +38,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
if (spatial_dims_num == 2) {
TENSORFLOW_OP_VALIDATION(node,
tf_data_format == "NHWC" || tf_data_format == "NCHW",
"MaxPool2D or MaxPoolV2 data format is neither NHWC nor NCHW");
"MaxPool or MaxPoolV2 data format is neither NHWC nor NCHW");
is_nhwc = (tf_data_format == "NHWC");
} else {
TENSORFLOW_OP_VALIDATION(node,