From bc663878eb5ecf2afb7a982ec85918285bc9c5b6 Mon Sep 17 00:00:00 2001 From: Leonard Sikorski Date: Thu, 23 Feb 2023 09:26:17 +0100 Subject: [PATCH] [PT FE] Add torchvision::roi_align operator with layer test (#15821) --- src/frontends/pytorch/src/node_context.cpp | 5 ++ src/frontends/pytorch/src/op/roi_align.cpp | 68 +++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + .../pytorch_tests/test_roi_align.py | 58 ++++++++++++++++ 4 files changed, 133 insertions(+) create mode 100644 src/frontends/pytorch/src/op/roi_align.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_roi_align.py diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index d8bb94305d8..a3e8c81633a 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -142,6 +142,11 @@ ngraph::Shape NodeContext::const_input(size_t index) const { return get_constant_at_input(*this, index)->cast_vector(); } +template <> +int32_t NodeContext::const_input(size_t index) const { + return get_constant_at_input(*this, index)->cast_vector()[0]; +} + template <> int64_t NodeContext::const_input(size_t index) const { return get_constant_at_input(*this, index)->cast_vector()[0]; diff --git a/src/frontends/pytorch/src/op/roi_align.cpp b/src/frontends/pytorch/src/op/roi_align.cpp new file mode 100644 index 00000000000..d3a389c5965 --- /dev/null +++ b/src/frontends/pytorch/src/op/roi_align.cpp @@ -0,0 +1,68 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/roi_align.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_roi_align(NodeContext& context) { + num_inputs_check(context, 7, 7); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto const_rois_indices = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {1, 2, 3, 4})); + + auto input = context.get_input(0); + auto boxes_input = context.get_input(1); + + auto input_real_type = context.mark_node(std::make_shared(input, element::f32)); + auto boxes = context.mark_node(std::make_shared(boxes_input, input_real_type)); + + auto spatial_scale = context.const_input(2); + int output_size_h = context.const_input(3); + int output_size_w = context.const_input(4); + int sampling_ratio = context.const_input(5); + + auto aligned = context.const_input(6); + + auto rois = context.mark_node(std::make_shared(boxes, const_rois_indices, const_1)); + + auto batch_indices_gather = context.mark_node(std::make_shared(boxes, const_0, const_1)); + auto batch_indices_reshape = + context.mark_node(std::make_shared(batch_indices_gather, const_neg_1, false)); + auto batch_indices = context.mark_node(std::make_shared(batch_indices_reshape, element::i32)); + + v9::ROIAlign::AlignedMode aligned_mode = + aligned ? v9::ROIAlign::AlignedMode::HALF_PIXEL_FOR_NN : v9::ROIAlign::AlignedMode::ASYMMETRIC; + + auto roi_align = context.mark_node(std::make_shared(input_real_type, + rois, + batch_indices, + output_size_h, + output_size_w, + sampling_ratio, + spatial_scale, + v9::ROIAlign::PoolingMode::AVG, + aligned_mode)); + + return {roi_align}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b32ce37c55f..bd2e9bf0564 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -89,6 +89,7 @@ OP_CONVERTER(translate_repeat); OP_CONVERTER(translate_repeat_interleave); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_reshape_as); +OP_CONVERTER(translate_roi_align); OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_rsub); @@ -327,6 +328,7 @@ const std::map get_supported_ops() { {"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape [] {"prim::requires_grad", op::return_false_scalar}, {"torchvision::nms", op::translate_nms}, + {"torchvision::roi_align", op::translate_roi_align}, }; }; diff --git a/tests/layer_tests/pytorch_tests/test_roi_align.py b/tests/layer_tests/pytorch_tests/test_roi_align.py new file mode 100644 index 00000000000..fb03c51b091 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_roi_align.py @@ -0,0 +1,58 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest +from torchvision.ops import roi_align + + +class TestROIAlign(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor, self.boxes) + + def create_model(self, output_size, spatial_scale, sampling_ratio, aligned): + + class torchvision_roi_align(torch.nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio, aligned): + super().__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + + def forward(self, input_tensor, rois): + return roi_align( + input_tensor, + rois.to(dtype=input_tensor.dtype), + self.output_size, + self.spatial_scale, + self.sampling_ratio, + self.aligned, + ) + + ref_net = None + + return (torchvision_roi_align(output_size, spatial_scale, sampling_ratio, aligned), + ref_net, "torchvision::roi_align") + + @pytest.mark.parametrize('input_tensor', (np.random.randn(4, 5, 6, 7).astype(np.float32),)) + @pytest.mark.parametrize('boxes', (np.array([[1, 2, 2, 3, 3]]).astype(np.float32), + np.array([[0, 1, 2, 5, 4], + [2, 1, 2, 5, 4], + [3, 1, 2, 5, 4]]).astype(np.float32))) + @pytest.mark.parametrize('output_size', ((4, 5), (3, 2), 3)) + @pytest.mark.parametrize('spatial_scale', (0.5, 1.0)) + @pytest.mark.parametrize('sampling_ratio', (0, 1)) + @pytest.mark.parametrize('aligned', (True, False)) + @pytest.mark.nightly + @pytest.mark.precommit + def test_roi_align(self, ie_device, precision, ir_version, input_tensor, boxes, output_size, + spatial_scale, sampling_ratio, aligned): + self.input_tensor = input_tensor + self.boxes = boxes + self._test(*self.create_model(output_size, spatial_scale, sampling_ratio, aligned), + ie_device, precision, ir_version, trace_model=True)