Improve detectron2 support (#16011)
* Improve op support for detectron mask rcnn * Initial commit * Fix for reading processed list * Format code * Cleanup * cleanup * Cleanup * cleanup test * Add comment * Add rt_info * fix type * More fixes for detectron * Fix build * Add tests for if * Revert changes in index * Add comment * Fix test * Fix get_axes_range * Add tests and fix if type alignment * Fix code style --------- Co-authored-by: Mateusz <mateusz.mikolajczyk@intel.com>
This commit is contained in:
parent
52b27d82c5
commit
abaf61d059
@ -116,7 +116,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
|
||||
// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
|
||||
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
|
||||
// inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any
|
||||
// inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any
|
||||
// consumers in the finally converted model and remove this parameter. This parameter should have index 0.
|
||||
if (model->get_parameters().size() > 0) {
|
||||
auto self = model->get_parameters()[0];
|
||||
|
@ -176,7 +176,7 @@ OutputVector translate_empty(const NodeContext& context) {
|
||||
// side, so just skip these parameters
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto sizes = context.get_input(0);
|
||||
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||
// In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
int dtype_id = 1;
|
||||
Output<Node> empty;
|
||||
|
@ -13,6 +13,31 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
namespace {
|
||||
// TODO: Ticket 106627. This is a WA and will work only if both branches of if will eventually go to the operation that
|
||||
// will have same output type for both types
|
||||
void align_result_types(const NodeContext& context,
|
||||
std::shared_ptr<opset10::Result> r1,
|
||||
std::shared_ptr<opset10::Result> r2) {
|
||||
auto r1_tensor = r1->input_value(0);
|
||||
auto r2_tensor = r2->input_value(0);
|
||||
auto r1_type = r1_tensor.get_element_type();
|
||||
auto r2_type = r2_tensor.get_element_type();
|
||||
if (r1_type.is_dynamic() || r2_type.is_dynamic())
|
||||
return;
|
||||
element::Type merged_type;
|
||||
if (!element::Type::merge(merged_type, r1_type, r2_type)) {
|
||||
if (r1_type.bitwidth() >= r2_type.bitwidth()) {
|
||||
auto convert = std::make_shared<opset10::Convert>(r2_tensor, r1_type);
|
||||
r2->set_argument(0, convert);
|
||||
} else {
|
||||
auto convert = std::make_shared<opset10::Convert>(r1_tensor, r2_type);
|
||||
r1->set_argument(0, convert);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_if(const NodeContext& context) {
|
||||
auto if_node = std::make_shared<opset10::If>(context.get_input(0));
|
||||
context.mark_node(if_node);
|
||||
@ -62,6 +87,7 @@ OutputVector translate_if(const NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(then_results.size() >= num_outs && else_results.size() >= num_outs,
|
||||
"Else or then body have less outputs than prim::If requires.");
|
||||
for (size_t i = 0; i < num_outs; i++) {
|
||||
align_result_types(context, then_results[i], else_results[i]);
|
||||
res.push_back(if_node->set_output(then_results[i], else_results[i]));
|
||||
}
|
||||
// Each body can have mutated outputs that are not included into pytorch node outputs.
|
||||
@ -136,6 +162,7 @@ OutputVector translate_if(const NodeContext& context) {
|
||||
}
|
||||
}
|
||||
for (const auto& output_idx : extra_output_idxs) {
|
||||
align_result_types(context, extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx));
|
||||
context.add_tensor_to_context(
|
||||
output_idx,
|
||||
if_node->set_output(extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx)));
|
||||
|
@ -5,11 +5,7 @@
|
||||
#include "openvino/op/select.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/less.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -21,22 +17,12 @@ namespace op {
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_select(const NodeContext& context) {
|
||||
// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto dim = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1, false));
|
||||
auto start = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(2), const_1, false));
|
||||
|
||||
auto less = context.mark_node(std::make_shared<v1::Less>(start, const_0));
|
||||
auto const_1_signed = context.mark_node(std::make_shared<v1::Select>(less, const_minus_1, const_1));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(start, const_1_signed));
|
||||
|
||||
auto slice_node = context.mark_node(std::make_shared<v8::Slice>(input_tensor, start, stop, const_1_signed, dim));
|
||||
|
||||
return {context.mark_node(std::make_shared<v0::Squeeze>(slice_node, dim))};
|
||||
auto data = context.get_input(0);
|
||||
auto dim = context.get_input(1);
|
||||
auto index = context.get_input(2);
|
||||
return {context.mark_node(std::make_shared<v8::Gather>(data, index, dim))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/op/variadic_split.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
@ -49,6 +50,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
auto tile_op = pattern::wrap_type<v0::Tile>({pattern::any_input(), list});
|
||||
// replace aten::permute(tensor, prim::ListConstruct)
|
||||
auto transpose_op = pattern::wrap_type<v1::Transpose>({pattern::any_input(), list});
|
||||
// aten::split_with_sizes case
|
||||
auto vsplit_op = pattern::wrap_type<v1::VariadicSplit>({pattern::any_input(), pattern::any_input(), list});
|
||||
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
|
||||
roll_op,
|
||||
broadcast_op,
|
||||
@ -57,7 +60,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
equal_op,
|
||||
select_op,
|
||||
tile_op,
|
||||
transpose_op});
|
||||
transpose_op,
|
||||
vsplit_op});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
|
@ -49,7 +49,8 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
|
||||
auto step = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 1);
|
||||
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32);
|
||||
auto rank = std::make_shared<ov::op::v3::ShapeOf>(shape, element::i32);
|
||||
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank);
|
||||
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank, axis_0);
|
||||
auto axes = std::make_shared<ov::op::v4::Range>(start, reduced_rank, step, element::i32);
|
||||
std::shared_ptr<Node> reduce_op;
|
||||
if (!is_min) {
|
||||
|
@ -33,6 +33,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
if (rank.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<Node> split;
|
||||
if (rank.get_length() == 0) {
|
||||
// Create split_lenghts tensor from split_size int,
|
||||
// allow for last chunk to be smaller if data is not equally divisible.
|
||||
@ -45,18 +46,17 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
auto split_lenghts_m_1 = std::make_shared<opset10::Tile>(split_size, num_out_m_1);
|
||||
NodeVector concat_inputs{split_lenghts_m_1, const_neg_1};
|
||||
auto split_lenghts = std::make_shared<opset10::Concat>(concat_inputs, 0);
|
||||
auto split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||
torch_split->get_input_source_output(2),
|
||||
split_lenghts);
|
||||
copy_runtime_info({list_unpack, input_node}, split);
|
||||
replace_node(list_unpack, split);
|
||||
} else {
|
||||
auto split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||
torch_split->get_input_source_output(2),
|
||||
torch_split->get_input_source_output(1));
|
||||
copy_runtime_info({list_unpack, input_node}, split);
|
||||
replace_node(list_unpack, split);
|
||||
}
|
||||
copy_runtime_info({list_unpack, input_node}, split);
|
||||
split->set_friendly_name(input_node->get_friendly_name());
|
||||
replace_node(list_unpack, split);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -67,6 +67,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
split_with_sizes->get_input_source_output(1));
|
||||
|
||||
copy_runtime_info({list_unpack, input_node}, split);
|
||||
split->set_friendly_name(input_node->get_friendly_name());
|
||||
replace_node(list_unpack, split);
|
||||
|
||||
return true;
|
||||
|
@ -66,7 +66,8 @@ std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, output_type));
|
||||
Output<Node> rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, output_type));
|
||||
if (as_scalar) {
|
||||
rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(output_type, Shape{}, {0}));
|
||||
rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank, axis_0));
|
||||
}
|
||||
return std::make_tuple(shape, rank);
|
||||
}
|
||||
@ -110,9 +111,8 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
|
||||
auto x = context.get_input(input_id);
|
||||
auto start = std::make_shared<opset10::Constant>(element::i32, Shape{}, 0);
|
||||
auto step = std::make_shared<opset10::Constant>(element::i32, Shape{}, 1);
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, element::i32));
|
||||
auto rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, element::i32));
|
||||
auto reduced_rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
|
||||
Output<Node> reduced_rank;
|
||||
std::tie(std::ignore, reduced_rank) = get_shape_rank(context, x, true);
|
||||
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
||||
};
|
||||
|
||||
|
@ -11,7 +11,7 @@ def not_yet_supported(value):
|
||||
return pytest.param(
|
||||
value,
|
||||
marks = pytest.mark.xfail(
|
||||
reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242"
|
||||
reason="Failed due to aten::argsort not yet supporting stable sorting. Ticket 105242"
|
||||
),
|
||||
)
|
||||
|
||||
|
40
tests/layer_tests/pytorch_tests/test_if.py
Normal file
40
tests/layer_tests/pytorch_tests/test_if.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestIf(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(1, 3, 224, 224).astype(np.float32), self.y)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class prim_if(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(prim_if, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
if y > 0:
|
||||
res = x.new_empty((0, 10), dtype=torch.uint8)
|
||||
else:
|
||||
res = torch.zeros(x.shape[:2], dtype=torch.bool)
|
||||
return res.to(torch.bool)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return prim_if(), ref_net, "prim::If"
|
||||
|
||||
@pytest.mark.parametrize("y", [np.array(1),
|
||||
np.array(-1)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_if(self, y, ie_device, precision, ir_version):
|
||||
self.y = y
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
@ -64,7 +64,8 @@ class TestSplit(PytorchLayerTest):
|
||||
def test_split_getitem(self, params, getitem, ie_device, precision, ir_version):
|
||||
(self.split_param, self.axis) = params
|
||||
self.getitem = getitem
|
||||
self._test(*self.create_model_split_getitem(), ie_device, precision, ir_version)
|
||||
self._test(*self.create_model_split_getitem(),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.parametrize("params", test_cases)
|
||||
@pytest.mark.nightly
|
||||
@ -74,3 +75,30 @@ class TestSplit(PytorchLayerTest):
|
||||
self._test(
|
||||
*self.create_model_split_listunpack(), ie_device, precision, ir_version
|
||||
)
|
||||
|
||||
|
||||
class TestSplitWithSizes(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(20).astype(np.float32),np.random.randn(20).astype(np.float32))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_split_with_sizes(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_split_with_sizes, self).__init__()
|
||||
#self.sizes = 20
|
||||
|
||||
def forward(self, x, y):
|
||||
return x.split([y.shape[0]], dim=0)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_split_with_sizes(), ref_net, ["aten::split_with_sizes", "prim::ListConstruct"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_relu(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(),
|
||||
ie_device, precision, ir_version, trace_model=True)
|
||||
|
Loading…
Reference in New Issue
Block a user