[PT FE] Add torchvision::roi_align operator with layer test (#15821)

This commit is contained in:
Leonard Sikorski 2023-02-23 09:26:17 +01:00 committed by GitHub
parent 288a750bc6
commit bc663878eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 0 deletions

View File

@ -142,6 +142,11 @@ ngraph::Shape NodeContext::const_input<ngraph::Shape>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<ngraph::Shape::value_type>();
}
template <>
int32_t NodeContext::const_input<int32_t>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<int32_t>()[0];
}
template <>
int64_t NodeContext::const_input<int64_t>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<int64_t>()[0];

View File

@ -0,0 +1,68 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/roi_align.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_roi_align(NodeContext& context) {
num_inputs_check(context, 7, 7);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_rois_indices = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {1, 2, 3, 4}));
auto input = context.get_input(0);
auto boxes_input = context.get_input(1);
auto input_real_type = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
auto boxes = context.mark_node(std::make_shared<v1::ConvertLike>(boxes_input, input_real_type));
auto spatial_scale = context.const_input<float>(2);
int output_size_h = context.const_input<int32_t>(3);
int output_size_w = context.const_input<int32_t>(4);
int sampling_ratio = context.const_input<int32_t>(5);
auto aligned = context.const_input<bool>(6);
auto rois = context.mark_node(std::make_shared<v8::Gather>(boxes, const_rois_indices, const_1));
auto batch_indices_gather = context.mark_node(std::make_shared<v8::Gather>(boxes, const_0, const_1));
auto batch_indices_reshape =
context.mark_node(std::make_shared<v1::Reshape>(batch_indices_gather, const_neg_1, false));
auto batch_indices = context.mark_node(std::make_shared<v0::Convert>(batch_indices_reshape, element::i32));
v9::ROIAlign::AlignedMode aligned_mode =
aligned ? v9::ROIAlign::AlignedMode::HALF_PIXEL_FOR_NN : v9::ROIAlign::AlignedMode::ASYMMETRIC;
auto roi_align = context.mark_node(std::make_shared<v9::ROIAlign>(input_real_type,
rois,
batch_indices,
output_size_h,
output_size_w,
sampling_ratio,
spatial_scale,
v9::ROIAlign::PoolingMode::AVG,
aligned_mode));
return {roi_align};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -89,6 +89,7 @@ OP_CONVERTER(translate_repeat);
OP_CONVERTER(translate_repeat_interleave);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_reshape_as);
OP_CONVERTER(translate_roi_align);
OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_rsub);
@ -327,6 +328,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
{"prim::requires_grad", op::return_false_scalar},
{"torchvision::nms", op::translate_nms},
{"torchvision::roi_align", op::translate_roi_align},
};
};

View File

@ -0,0 +1,58 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
from torchvision.ops import roi_align
class TestROIAlign(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor, self.boxes)
def create_model(self, output_size, spatial_scale, sampling_ratio, aligned):
class torchvision_roi_align(torch.nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned):
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input_tensor, rois):
return roi_align(
input_tensor,
rois.to(dtype=input_tensor.dtype),
self.output_size,
self.spatial_scale,
self.sampling_ratio,
self.aligned,
)
ref_net = None
return (torchvision_roi_align(output_size, spatial_scale, sampling_ratio, aligned),
ref_net, "torchvision::roi_align")
@pytest.mark.parametrize('input_tensor', (np.random.randn(4, 5, 6, 7).astype(np.float32),))
@pytest.mark.parametrize('boxes', (np.array([[1, 2, 2, 3, 3]]).astype(np.float32),
np.array([[0, 1, 2, 5, 4],
[2, 1, 2, 5, 4],
[3, 1, 2, 5, 4]]).astype(np.float32)))
@pytest.mark.parametrize('output_size', ((4, 5), (3, 2), 3))
@pytest.mark.parametrize('spatial_scale', (0.5, 1.0))
@pytest.mark.parametrize('sampling_ratio', (0, 1))
@pytest.mark.parametrize('aligned', (True, False))
@pytest.mark.nightly
@pytest.mark.precommit
def test_roi_align(self, ie_device, precision, ir_version, input_tensor, boxes, output_size,
spatial_scale, sampling_ratio, aligned):
self.input_tensor = input_tensor
self.boxes = boxes
self._test(*self.create_model(output_size, spatial_scale, sampling_ratio, aligned),
ie_device, precision, ir_version, trace_model=True)