[TF FE] Refactor translators for Resize operations and correct Pooling (#12721)
* [TF FE] Refactor translators for Resize operations and correct Pooling It allows to convert magenta_arbitrary-image-stylization model Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Align TF FE tranlator for Resize with legacy frontend Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Do minor fix for MaxPool Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
39b743f7a0
commit
3595f195f1
@ -57,7 +57,8 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
|
||||
auto avg_pool_output = avg_pool->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2));
|
||||
set_node_name(node.get_name(), avg_pool);
|
||||
return {avg_pool};
|
||||
|
||||
return {avg_pool_output};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -13,34 +13,65 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
ov::OutputVector translate_interpolate_op(const NodeContext& node) {
|
||||
auto input = node.get_input(0);
|
||||
auto input_sizes = node.get_input(1);
|
||||
default_op_checks(node, 2, {"ResizeBilinear", "ResizeNearestNeighbor"});
|
||||
auto images = node.get_input(0);
|
||||
auto size = node.get_input(1);
|
||||
auto op_name = node.get_name();
|
||||
auto op_type = node.get_op_type();
|
||||
|
||||
// retrieve optional attribute
|
||||
auto tf_align_corners = node.get_attribute<bool>("align_corners", false);
|
||||
auto tf_half_pixel_centers = node.get_attribute<bool>("half_pixel_centers", false);
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
!tf_half_pixel_centers || (tf_half_pixel_centers && !tf_align_corners),
|
||||
"If half_pixel_centers attribute of the node" + op_name + " with op " + op_type +
|
||||
" is True, the attribute align_corners must be False.");
|
||||
|
||||
// prepare attributes for OpenVINO Interpolate operation
|
||||
Interpolate::InterpolateAttrs interpolate_attrs;
|
||||
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
|
||||
interpolate_attrs.shape_calculation_mode = Interpolate::ShapeCalcMode::SIZES;
|
||||
if (node.get_attribute<bool>("align_corners", false))
|
||||
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||
|
||||
if (node.get_op_type() == "ResizeNearestNeighbor") {
|
||||
if (op_type == "ResizeNearestNeighbor") {
|
||||
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
|
||||
interpolate_attrs.nearest_mode = Interpolate::NearestMode::FLOOR;
|
||||
} else if (op_type == "ResizeBilinear") {
|
||||
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
|
||||
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_FLOOR;
|
||||
}
|
||||
|
||||
// TODO: do we need this .get_shape() actually?
|
||||
auto input_shape = input.get_shape();
|
||||
std::vector<float> spatial_shape = {static_cast<float>(input_shape[1]), static_cast<float>(input_shape[2])};
|
||||
auto ng_spatial_shape = make_shared<Constant>(element::f32, Shape{2}, spatial_shape);
|
||||
if (tf_align_corners) {
|
||||
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
|
||||
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_CEIL;
|
||||
}
|
||||
} else if (tf_half_pixel_centers) {
|
||||
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
|
||||
interpolate_attrs.coordinate_transformation_mode =
|
||||
Interpolate::CoordinateTransformMode::TF_HALF_PIXEL_FOR_NN;
|
||||
} else {
|
||||
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::HALF_PIXEL;
|
||||
}
|
||||
} else {
|
||||
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ASYMMETRIC;
|
||||
}
|
||||
|
||||
auto ng_sizes = make_shared<Convert>(input_sizes, element::f32);
|
||||
auto ng_scales = make_shared<Divide>(ng_sizes, ng_spatial_shape);
|
||||
auto ng_axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
|
||||
// prepare scales input
|
||||
auto images_shape = make_shared<ShapeOf>(images, ov::element::i32);
|
||||
auto spatial_shape = make_shared<Slice>(images_shape,
|
||||
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
|
||||
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{3}),
|
||||
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
|
||||
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{0}));
|
||||
auto scales = make_shared<Divide>(make_shared<Convert>(size, element::f32),
|
||||
make_shared<Convert>(spatial_shape, element::f32));
|
||||
|
||||
input = make_transpose(input, {0, 3, 1, 2});
|
||||
auto res = make_shared<Interpolate>(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0);
|
||||
res = make_transpose(res, {0, 2, 3, 1});
|
||||
set_node_name(node.get_name(), res.get_node_shared_ptr());
|
||||
return {res};
|
||||
// since Interpolate is layout agnostic
|
||||
// we can avoid Transpose operation by specifying axes = {1, 2} for original NHWC layout
|
||||
auto axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({1, 2}));
|
||||
|
||||
auto interpolate = make_shared<Interpolate>(images, size, scales, axes, interpolate_attrs);
|
||||
set_node_name(node.get_name(), interpolate);
|
||||
return {interpolate};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,10 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
size_t spatial_dims_num,
|
||||
const std::vector<int64_t>& tf_kernel_sizes,
|
||||
const std::vector<int64_t>& tf_strides) {
|
||||
default_op_checks(node, 1, {"MaxPool2D", "MaxPool3D"});
|
||||
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D"});
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
spatial_dims_num == 2 || spatial_dims_num == 3,
|
||||
"Only MaxPool2D and MaxPool3D are supported.");
|
||||
"Only MaxPool, MaxPoolV2 and MaxPool3D are supported.");
|
||||
auto input = node.get_input(0);
|
||||
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
@ -38,7 +38,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
if (spatial_dims_num == 2) {
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
tf_data_format == "NHWC" || tf_data_format == "NCHW",
|
||||
"MaxPool2D or MaxPoolV2 data format is neither NHWC nor NCHW");
|
||||
"MaxPool or MaxPoolV2 data format is neither NHWC nor NCHW");
|
||||
is_nhwc = (tf_data_format == "NHWC");
|
||||
} else {
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
|
Loading…
Reference in New Issue
Block a user