Improve detectron2 support (#16011)

* Improve op support for detectron mask rcnn

* Initial commit

* Fix for reading processed list

* Format code

* Cleanup

* cleanup

* Cleanup

* cleanup test

* Add comment

* Add rt_info

* fix type

* More fixes for detectron

* Fix build

* Add tests for if

* Revert changes in index

* Add comment

* Fix test

* Fix get_axes_range

* Add tests and fix if type alignment

* Fix code style

---------

Co-authored-by: Mateusz <mateusz.mikolajczyk@intel.com>
This commit is contained in:
Maxim Vafin 2023-03-23 23:30:03 +01:00 committed by GitHub
parent 52b27d82c5
commit abaf61d059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 41 deletions

View File

@ -116,7 +116,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
// inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any
// inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any
// consumers in the finally converted model and remove this parameter. This parameter should have index 0.
if (model->get_parameters().size() > 0) {
auto self = model->get_parameters()[0];

View File

@ -176,7 +176,7 @@ OutputVector translate_empty(const NodeContext& context) {
// side, so just skip these parameters
num_inputs_check(context, 1, 6);
auto sizes = context.get_input(0);
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
// In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
int dtype_id = 1;
Output<Node> empty;

View File

@ -13,6 +13,31 @@ namespace frontend {
namespace pytorch {
namespace op {
namespace {
// TODO: Ticket 106627. This is a WA and will work only if both branches of if will eventually go to the operation that
// will have same output type for both types
void align_result_types(const NodeContext& context,
std::shared_ptr<opset10::Result> r1,
std::shared_ptr<opset10::Result> r2) {
auto r1_tensor = r1->input_value(0);
auto r2_tensor = r2->input_value(0);
auto r1_type = r1_tensor.get_element_type();
auto r2_type = r2_tensor.get_element_type();
if (r1_type.is_dynamic() || r2_type.is_dynamic())
return;
element::Type merged_type;
if (!element::Type::merge(merged_type, r1_type, r2_type)) {
if (r1_type.bitwidth() >= r2_type.bitwidth()) {
auto convert = std::make_shared<opset10::Convert>(r2_tensor, r1_type);
r2->set_argument(0, convert);
} else {
auto convert = std::make_shared<opset10::Convert>(r1_tensor, r2_type);
r1->set_argument(0, convert);
}
}
}
} // namespace
OutputVector translate_if(const NodeContext& context) {
auto if_node = std::make_shared<opset10::If>(context.get_input(0));
context.mark_node(if_node);
@ -62,6 +87,7 @@ OutputVector translate_if(const NodeContext& context) {
FRONT_END_OP_CONVERSION_CHECK(then_results.size() >= num_outs && else_results.size() >= num_outs,
"Else or then body have less outputs than prim::If requires.");
for (size_t i = 0; i < num_outs; i++) {
align_result_types(context, then_results[i], else_results[i]);
res.push_back(if_node->set_output(then_results[i], else_results[i]));
}
// Each body can have mutated outputs that are not included into pytorch node outputs.
@ -136,6 +162,7 @@ OutputVector translate_if(const NodeContext& context) {
}
}
for (const auto& output_idx : extra_output_idxs) {
align_result_types(context, extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx));
context.add_tensor_to_context(
output_idx,
if_node->set_output(extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx)));

View File

@ -5,11 +5,7 @@
#include "openvino/op/select.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/squeeze.hpp"
#include "utils.hpp"
@ -21,22 +17,12 @@ namespace op {
using namespace ov::op;
OutputVector translate_select(const NodeContext& context) {
// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
num_inputs_check(context, 3, 3);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto input_tensor = context.get_input(0);
auto dim = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1, false));
auto start = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(2), const_1, false));
auto less = context.mark_node(std::make_shared<v1::Less>(start, const_0));
auto const_1_signed = context.mark_node(std::make_shared<v1::Select>(less, const_minus_1, const_1));
auto stop = context.mark_node(std::make_shared<v1::Add>(start, const_1_signed));
auto slice_node = context.mark_node(std::make_shared<v8::Slice>(input_tensor, start, stop, const_1_signed, dim));
return {context.mark_node(std::make_shared<v0::Squeeze>(slice_node, dim))};
auto data = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.get_input(2);
return {context.mark_node(std::make_shared<v8::Gather>(data, index, dim))};
};
} // namespace op

View File

@ -17,6 +17,7 @@
#include "openvino/op/tile.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
@ -49,6 +50,8 @@ ListConstructReplacer::ListConstructReplacer() {
auto tile_op = pattern::wrap_type<v0::Tile>({pattern::any_input(), list});
// replace aten::permute(tensor, prim::ListConstruct)
auto transpose_op = pattern::wrap_type<v1::Transpose>({pattern::any_input(), list});
// aten::split_with_sizes case
auto vsplit_op = pattern::wrap_type<v1::VariadicSplit>({pattern::any_input(), pattern::any_input(), list});
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
roll_op,
broadcast_op,
@ -57,7 +60,8 @@ ListConstructReplacer::ListConstructReplacer() {
equal_op,
select_op,
tile_op,
transpose_op});
transpose_op,
vsplit_op});
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();

View File

@ -49,7 +49,8 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
auto step = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 1);
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32);
auto rank = std::make_shared<ov::op::v3::ShapeOf>(shape, element::i32);
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank);
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank, axis_0);
auto axes = std::make_shared<ov::op::v4::Range>(start, reduced_rank, step, element::i32);
std::shared_ptr<Node> reduce_op;
if (!is_min) {

View File

@ -33,6 +33,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
if (rank.is_dynamic()) {
return false;
}
std::shared_ptr<Node> split;
if (rank.get_length() == 0) {
// Create split_lenghts tensor from split_size int,
// allow for last chunk to be smaller if data is not equally divisible.
@ -45,18 +46,17 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
auto split_lenghts_m_1 = std::make_shared<opset10::Tile>(split_size, num_out_m_1);
NodeVector concat_inputs{split_lenghts_m_1, const_neg_1};
auto split_lenghts = std::make_shared<opset10::Concat>(concat_inputs, 0);
auto split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2),
split_lenghts);
copy_runtime_info({list_unpack, input_node}, split);
replace_node(list_unpack, split);
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2),
split_lenghts);
} else {
auto split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2),
torch_split->get_input_source_output(1));
copy_runtime_info({list_unpack, input_node}, split);
replace_node(list_unpack, split);
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2),
torch_split->get_input_source_output(1));
}
copy_runtime_info({list_unpack, input_node}, split);
split->set_friendly_name(input_node->get_friendly_name());
replace_node(list_unpack, split);
return true;
}
@ -67,6 +67,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
split_with_sizes->get_input_source_output(1));
copy_runtime_info({list_unpack, input_node}, split);
split->set_friendly_name(input_node->get_friendly_name());
replace_node(list_unpack, split);
return true;

View File

@ -66,7 +66,8 @@ std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, output_type));
Output<Node> rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, output_type));
if (as_scalar) {
rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
auto axis_0 = context.mark_node(opset10::Constant::create(output_type, Shape{}, {0}));
rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank, axis_0));
}
return std::make_tuple(shape, rank);
}
@ -110,9 +111,8 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
auto x = context.get_input(input_id);
auto start = std::make_shared<opset10::Constant>(element::i32, Shape{}, 0);
auto step = std::make_shared<opset10::Constant>(element::i32, Shape{}, 1);
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, element::i32));
auto rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, element::i32));
auto reduced_rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
Output<Node> reduced_rank;
std::tie(std::ignore, reduced_rank) = get_shape_rank(context, x, true);
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
};

View File

@ -11,7 +11,7 @@ def not_yet_supported(value):
return pytest.param(
value,
marks = pytest.mark.xfail(
reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242"
reason="Failed due to aten::argsort not yet supporting stable sorting. Ticket 105242"
),
)

View File

@ -0,0 +1,40 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest
class TestIf(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 3, 224, 224).astype(np.float32), self.y)
def create_model(self):
import torch
import torch.nn.functional as F
class prim_if(torch.nn.Module):
def __init__(self):
super(prim_if, self).__init__()
def forward(self, x, y):
if y > 0:
res = x.new_empty((0, 10), dtype=torch.uint8)
else:
res = torch.zeros(x.shape[:2], dtype=torch.bool)
return res.to(torch.bool)
ref_net = None
return prim_if(), ref_net, "prim::If"
@pytest.mark.parametrize("y", [np.array(1),
np.array(-1)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_if(self, y, ie_device, precision, ir_version):
self.y = y
self._test(*self.create_model(), ie_device, precision, ir_version)

View File

@ -47,7 +47,7 @@ class TestSplit(PytorchLayerTest):
return aten_split(self.split_param, self.axis), ref_net, "aten::split"
# Test case - (split_param, axis), always split into 5 due to hardcoded number of outputs in ListUnpack test.
# Test case - (split_param, axis), always split into 5 due to hardcoded number of outputs in ListUnpack test.
test_cases = [
(2, 1),
(45, 2),
@ -64,7 +64,8 @@ class TestSplit(PytorchLayerTest):
def test_split_getitem(self, params, getitem, ie_device, precision, ir_version):
(self.split_param, self.axis) = params
self.getitem = getitem
self._test(*self.create_model_split_getitem(), ie_device, precision, ir_version)
self._test(*self.create_model_split_getitem(),
ie_device, precision, ir_version)
@pytest.mark.parametrize("params", test_cases)
@pytest.mark.nightly
@ -74,3 +75,30 @@ class TestSplit(PytorchLayerTest):
self._test(
*self.create_model_split_listunpack(), ie_device, precision, ir_version
)
class TestSplitWithSizes(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(20).astype(np.float32),np.random.randn(20).astype(np.float32))
def create_model(self):
import torch
class aten_split_with_sizes(torch.nn.Module):
def __init__(self):
super(aten_split_with_sizes, self).__init__()
#self.sizes = 20
def forward(self, x, y):
return x.split([y.shape[0]], dim=0)
ref_net = None
return aten_split_with_sizes(), ref_net, ["aten::split_with_sizes", "prim::ListConstruct"]
@pytest.mark.nightly
@pytest.mark.precommit
def test_relu(self, ie_device, precision, ir_version):
self._test(*self.create_model(),
ie_device, precision, ir_version, trace_model=True)