[TF FE] Refactor translators for Resize operations and correct Pooling (#12721)
* [TF FE] Refactor translators for Resize operations and correct Pooling It allows to convert magenta_arbitrary-image-stylization model Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Align TF FE tranlator for Resize with legacy frontend Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Do minor fix for MaxPool Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
39b743f7a0
commit
3595f195f1
@ -57,7 +57,8 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
|
|||||||
auto avg_pool_output = avg_pool->output(0);
|
auto avg_pool_output = avg_pool->output(0);
|
||||||
convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2));
|
convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2));
|
||||||
set_node_name(node.get_name(), avg_pool);
|
set_node_name(node.get_name(), avg_pool);
|
||||||
return {avg_pool};
|
|
||||||
|
return {avg_pool_output};
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,34 +13,65 @@ namespace frontend {
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace op {
|
namespace op {
|
||||||
ov::OutputVector translate_interpolate_op(const NodeContext& node) {
|
ov::OutputVector translate_interpolate_op(const NodeContext& node) {
|
||||||
auto input = node.get_input(0);
|
default_op_checks(node, 2, {"ResizeBilinear", "ResizeNearestNeighbor"});
|
||||||
auto input_sizes = node.get_input(1);
|
auto images = node.get_input(0);
|
||||||
|
auto size = node.get_input(1);
|
||||||
|
auto op_name = node.get_name();
|
||||||
|
auto op_type = node.get_op_type();
|
||||||
|
|
||||||
|
// retrieve optional attribute
|
||||||
|
auto tf_align_corners = node.get_attribute<bool>("align_corners", false);
|
||||||
|
auto tf_half_pixel_centers = node.get_attribute<bool>("half_pixel_centers", false);
|
||||||
|
|
||||||
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
!tf_half_pixel_centers || (tf_half_pixel_centers && !tf_align_corners),
|
||||||
|
"If half_pixel_centers attribute of the node" + op_name + " with op " + op_type +
|
||||||
|
" is True, the attribute align_corners must be False.");
|
||||||
|
|
||||||
|
// prepare attributes for OpenVINO Interpolate operation
|
||||||
Interpolate::InterpolateAttrs interpolate_attrs;
|
Interpolate::InterpolateAttrs interpolate_attrs;
|
||||||
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
|
|
||||||
interpolate_attrs.shape_calculation_mode = Interpolate::ShapeCalcMode::SIZES;
|
interpolate_attrs.shape_calculation_mode = Interpolate::ShapeCalcMode::SIZES;
|
||||||
if (node.get_attribute<bool>("align_corners", false))
|
if (op_type == "ResizeNearestNeighbor") {
|
||||||
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
|
||||||
|
|
||||||
if (node.get_op_type() == "ResizeNearestNeighbor") {
|
|
||||||
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
|
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
|
||||||
|
interpolate_attrs.nearest_mode = Interpolate::NearestMode::FLOOR;
|
||||||
|
} else if (op_type == "ResizeBilinear") {
|
||||||
|
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
|
||||||
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_FLOOR;
|
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_FLOOR;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: do we need this .get_shape() actually?
|
if (tf_align_corners) {
|
||||||
auto input_shape = input.get_shape();
|
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||||
std::vector<float> spatial_shape = {static_cast<float>(input_shape[1]), static_cast<float>(input_shape[2])};
|
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
|
||||||
auto ng_spatial_shape = make_shared<Constant>(element::f32, Shape{2}, spatial_shape);
|
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_CEIL;
|
||||||
|
}
|
||||||
|
} else if (tf_half_pixel_centers) {
|
||||||
|
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
|
||||||
|
interpolate_attrs.coordinate_transformation_mode =
|
||||||
|
Interpolate::CoordinateTransformMode::TF_HALF_PIXEL_FOR_NN;
|
||||||
|
} else {
|
||||||
|
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::HALF_PIXEL;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ASYMMETRIC;
|
||||||
|
}
|
||||||
|
|
||||||
auto ng_sizes = make_shared<Convert>(input_sizes, element::f32);
|
// prepare scales input
|
||||||
auto ng_scales = make_shared<Divide>(ng_sizes, ng_spatial_shape);
|
auto images_shape = make_shared<ShapeOf>(images, ov::element::i32);
|
||||||
auto ng_axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
|
auto spatial_shape = make_shared<Slice>(images_shape,
|
||||||
|
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
|
||||||
|
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{3}),
|
||||||
|
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
|
||||||
|
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{0}));
|
||||||
|
auto scales = make_shared<Divide>(make_shared<Convert>(size, element::f32),
|
||||||
|
make_shared<Convert>(spatial_shape, element::f32));
|
||||||
|
|
||||||
input = make_transpose(input, {0, 3, 1, 2});
|
// since Interpolate is layout agnostic
|
||||||
auto res = make_shared<Interpolate>(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0);
|
// we can avoid Transpose operation by specifying axes = {1, 2} for original NHWC layout
|
||||||
res = make_transpose(res, {0, 2, 3, 1});
|
auto axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({1, 2}));
|
||||||
set_node_name(node.get_name(), res.get_node_shared_ptr());
|
|
||||||
return {res};
|
auto interpolate = make_shared<Interpolate>(images, size, scales, axes, interpolate_attrs);
|
||||||
|
set_node_name(node.get_name(), interpolate);
|
||||||
|
return {interpolate};
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,10 +19,10 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
|||||||
size_t spatial_dims_num,
|
size_t spatial_dims_num,
|
||||||
const std::vector<int64_t>& tf_kernel_sizes,
|
const std::vector<int64_t>& tf_kernel_sizes,
|
||||||
const std::vector<int64_t>& tf_strides) {
|
const std::vector<int64_t>& tf_strides) {
|
||||||
default_op_checks(node, 1, {"MaxPool2D", "MaxPool3D"});
|
default_op_checks(node, 1, {"MaxPool", "MaxPoolV2", "MaxPool3D"});
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
spatial_dims_num == 2 || spatial_dims_num == 3,
|
spatial_dims_num == 2 || spatial_dims_num == 3,
|
||||||
"Only MaxPool2D and MaxPool3D are supported.");
|
"Only MaxPool, MaxPoolV2 and MaxPool3D are supported.");
|
||||||
auto input = node.get_input(0);
|
auto input = node.get_input(0);
|
||||||
|
|
||||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||||
@ -38,7 +38,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
|||||||
if (spatial_dims_num == 2) {
|
if (spatial_dims_num == 2) {
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
tf_data_format == "NHWC" || tf_data_format == "NCHW",
|
tf_data_format == "NHWC" || tf_data_format == "NCHW",
|
||||||
"MaxPool2D or MaxPoolV2 data format is neither NHWC nor NCHW");
|
"MaxPool or MaxPoolV2 data format is neither NHWC nor NCHW");
|
||||||
is_nhwc = (tf_data_format == "NHWC");
|
is_nhwc = (tf_data_format == "NHWC");
|
||||||
} else {
|
} else {
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
Loading…
Reference in New Issue
Block a user