[PT FE]: support transformation for case aten::size + aten::__getitem__ (#15368)

This commit is contained in:
Ekaterina Aidova 2023-01-31 22:08:13 +04:00 committed by GitHub
parent 1dd84e2074
commit d57862edee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 163 additions and 0 deletions

View File

@ -156,7 +156,9 @@ void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body(
auto body_parameter = params.at(input_description->m_body_parameter_index);
auto input_partial_shape = input_value(index).get_partial_shape();
auto dtype = input_value(index).get_element_type();
body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(dtype);
}
body->validate_nodes_and_infer_types();
}

View File

@ -335,3 +335,40 @@ TEST(type_prop, if_scalar_and_1d_static_union) {
auto sh = result0->get_output_partial_shape(0);
EXPECT_EQ(sh, out_shape);
}
TEST(type_prop, if_element_type_dynamic) {
// That which we iterate over
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, true);
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
// Body parameters
auto Xt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Yt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Xe = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Ye = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
// Body
auto then_op = std::make_shared<op::v1::Add>(Xt, Yt);
auto then_op_res = std::make_shared<op::Result>(then_op);
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
auto else_op = std::make_shared<op::v1::Maximum>(Xe, Ye);
auto else_op_res = std::make_shared<op::Result>(else_op);
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, Xe);
if_op->set_input(Y, Yt, Ye);
auto res = if_op->set_output(then_op_res, else_op_res);
auto result0 = make_shared<op::Result>(res);
Shape out0_shape{32, 40, 10};
auto sh = result0->get_output_shape(0);
EXPECT_EQ(sh, out0_shape);
// Check that If validation validates both bodies
if_op->validate_and_infer_types();
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
}

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_getitem(NodeContext& context) {
auto input = context.get_input(0);
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct") == nullptr,
"unsupported case for aten::getitem");
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "aten::split") == nullptr,
"unsupported case for aten::getitem");
auto getitem_idx = context.get_input(1);
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(input, getitem_idx, zero))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -44,6 +44,7 @@ OP_CONVERTER(translate_full);
OP_CONVERTER(translate_full_like);
OP_CONVERTER(translate_gelu);
OP_CONVERTER(translate_get_attr);
OP_CONVERTER(translate_getitem);
OP_CONVERTER(translate_glu);
OP_CONVERTER(translate_grid_sampler);
OP_CONVERTER(translate_group_norm);
@ -117,6 +118,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
return {
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
{"aten::__getitem__", op::translate_getitem},
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},

View File

@ -95,6 +95,15 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
}
return true;
}
if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
auto input_concat = concat_list_construct(list_construct);
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<opset10::Gather>(input_concat, getitem_idx, zero);
copy_runtime_info({getitem, input_node}, gather);
replace_node(getitem, gather);
return true;
}
return false;
};

View File

@ -0,0 +1,54 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestGetItem(PytorchLayerTest):
def _prepare_input(self, input_shape):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32),)
def create_model(self, idx, case="size_with_getitem"):
import torch
class aten_size_get_item(torch.nn.Module):
def __init__(self, idx):
super().__init__()
self.idx = idx
def forward(self, x):
return x.shape[self.idx]
class aten_size_get_item_with_if(torch.nn.Module):
def __init__(self, idx):
super().__init__()
self.idx:int = idx
def forward(self, x):
if x.shape[self.idx] > self.idx:
res = x.shape[self.idx]
else:
res = x.shape[-self.idx]
return res
ref_net = None
op_cls = {
"getitem": (aten_size_get_item, ["aten::size", "aten::__getitem__"]),
"getitem_with_if": (aten_size_get_item_with_if, ["aten::size", "aten::__getitem__", "prim::If"])
}
op, op_in_graph = op_cls[case]
return op(idx), ref_net, op_in_graph
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape", "idx"), [
([1,], 0), ([1,], -1),
([1, 2], 0), ([1, 2], 1), ([1, 2], -1), ([1, 2], -2),
([1, 2, 3], 0), ([1, 2, 3], 1), ([1, 2, 3], 2), ([1, 2, 3], -1), ([1, 2, 3], -2), ([1, 2, 3], -3),
([1, 2, 3, 4], 0), ([1, 2, 3, 4], 1), ([1, 2, 3, 4], 2), ([1, 2, 3, 4], 3), ([1, 2, 3, 4], -1), ([1, 2, 3, 4], -2), ([1, 2, 3, 4], -3), ([1, 2, 3, 4], -4),
([1, 2, 3, 4, 5], 0), ([1, 2, 3, 4, 5], 1), ([1, 2, 3, 4, 5], 2), ([1, 2, 3, 4, 5], 3), ([1, 2, 3, 4, 5], 4), ([1, 2, 3, 4, 5], -1), ([1, 2, 3, 4, 5], -2), ([1, 2, 3, 4, 5], -3), ([1, 2, 3, 4, 5], -4), ([1, 2, 3, 4, 5], -5)])
@pytest.mark.parametrize("case", ["getitem", "getitem_with_if"])
def test_getitem(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(idx, case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})

View File

@ -0,0 +1,30 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestSize(PytorchLayerTest):
def _prepare_input(self, input_shape):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32),)
def create_model(self):
import torch
class aten_size(torch.nn.Module):
def forward(self, x):
return x.shape
ref_net = None
op = aten_size()
return op, ref_net, "aten::size"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("input_shape", [[1,], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]])
def test_size(self, input_shape, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})