[PT FE]: support transformation for case aten::size + aten::__getitem__ (#15368)
This commit is contained in:
parent
1dd84e2074
commit
d57862edee
@ -156,7 +156,9 @@ void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body(
|
||||
|
||||
auto body_parameter = params.at(input_description->m_body_parameter_index);
|
||||
auto input_partial_shape = input_value(index).get_partial_shape();
|
||||
auto dtype = input_value(index).get_element_type();
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
body_parameter->set_element_type(dtype);
|
||||
}
|
||||
body->validate_nodes_and_infer_types();
|
||||
}
|
||||
|
@ -335,3 +335,40 @@ TEST(type_prop, if_scalar_and_1d_static_union) {
|
||||
auto sh = result0->get_output_partial_shape(0);
|
||||
EXPECT_EQ(sh, out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, if_element_type_dynamic) {
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto Xt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Yt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Xe = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Ye = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
// Body
|
||||
auto then_op = std::make_shared<op::v1::Add>(Xt, Yt);
|
||||
auto then_op_res = std::make_shared<op::Result>(then_op);
|
||||
|
||||
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
|
||||
|
||||
auto else_op = std::make_shared<op::v1::Maximum>(Xe, Ye);
|
||||
auto else_op_res = std::make_shared<op::Result>(else_op);
|
||||
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
|
||||
auto if_op = make_shared<op::v8::If>(cond);
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, Xe);
|
||||
if_op->set_input(Y, Yt, Ye);
|
||||
auto res = if_op->set_output(then_op_res, else_op_res);
|
||||
auto result0 = make_shared<op::Result>(res);
|
||||
Shape out0_shape{32, 40, 10};
|
||||
auto sh = result0->get_output_shape(0);
|
||||
EXPECT_EQ(sh, out0_shape);
|
||||
// Check that If validation validates both bodies
|
||||
if_op->validate_and_infer_types();
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
|
||||
}
|
||||
|
29
src/frontends/pytorch/src/op/getitem.cpp
Normal file
29
src/frontends/pytorch/src/op/getitem.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_getitem(NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct") == nullptr,
|
||||
"unsupported case for aten::getitem");
|
||||
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "aten::split") == nullptr,
|
||||
"unsupported case for aten::getitem");
|
||||
auto getitem_idx = context.get_input(1);
|
||||
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(input, getitem_idx, zero))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -44,6 +44,7 @@ OP_CONVERTER(translate_full);
|
||||
OP_CONVERTER(translate_full_like);
|
||||
OP_CONVERTER(translate_gelu);
|
||||
OP_CONVERTER(translate_get_attr);
|
||||
OP_CONVERTER(translate_getitem);
|
||||
OP_CONVERTER(translate_glu);
|
||||
OP_CONVERTER(translate_grid_sampler);
|
||||
OP_CONVERTER(translate_group_norm);
|
||||
@ -117,6 +118,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||
|
@ -95,6 +95,15 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||
auto input_concat = concat_list_construct(list_construct);
|
||||
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
|
||||
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto gather = std::make_shared<opset10::Gather>(input_concat, getitem_idx, zero);
|
||||
copy_runtime_info({getitem, input_node}, gather);
|
||||
replace_node(getitem, gather);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
54
tests/layer_tests/pytorch_tests/test_getitem.py
Normal file
54
tests/layer_tests/pytorch_tests/test_getitem.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestGetItem(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape):
|
||||
import numpy as np
|
||||
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, idx, case="size_with_getitem"):
|
||||
import torch
|
||||
|
||||
class aten_size_get_item(torch.nn.Module):
|
||||
def __init__(self, idx):
|
||||
super().__init__()
|
||||
self.idx = idx
|
||||
def forward(self, x):
|
||||
return x.shape[self.idx]
|
||||
|
||||
class aten_size_get_item_with_if(torch.nn.Module):
|
||||
def __init__(self, idx):
|
||||
super().__init__()
|
||||
self.idx:int = idx
|
||||
def forward(self, x):
|
||||
if x.shape[self.idx] > self.idx:
|
||||
res = x.shape[self.idx]
|
||||
else:
|
||||
res = x.shape[-self.idx]
|
||||
return res
|
||||
|
||||
ref_net = None
|
||||
op_cls = {
|
||||
"getitem": (aten_size_get_item, ["aten::size", "aten::__getitem__"]),
|
||||
"getitem_with_if": (aten_size_get_item_with_if, ["aten::size", "aten::__getitem__", "prim::If"])
|
||||
}
|
||||
op, op_in_graph = op_cls[case]
|
||||
|
||||
return op(idx), ref_net, op_in_graph
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(("input_shape", "idx"), [
|
||||
([1,], 0), ([1,], -1),
|
||||
([1, 2], 0), ([1, 2], 1), ([1, 2], -1), ([1, 2], -2),
|
||||
([1, 2, 3], 0), ([1, 2, 3], 1), ([1, 2, 3], 2), ([1, 2, 3], -1), ([1, 2, 3], -2), ([1, 2, 3], -3),
|
||||
([1, 2, 3, 4], 0), ([1, 2, 3, 4], 1), ([1, 2, 3, 4], 2), ([1, 2, 3, 4], 3), ([1, 2, 3, 4], -1), ([1, 2, 3, 4], -2), ([1, 2, 3, 4], -3), ([1, 2, 3, 4], -4),
|
||||
([1, 2, 3, 4, 5], 0), ([1, 2, 3, 4, 5], 1), ([1, 2, 3, 4, 5], 2), ([1, 2, 3, 4, 5], 3), ([1, 2, 3, 4, 5], 4), ([1, 2, 3, 4, 5], -1), ([1, 2, 3, 4, 5], -2), ([1, 2, 3, 4, 5], -3), ([1, 2, 3, 4, 5], -4), ([1, 2, 3, 4, 5], -5)])
|
||||
@pytest.mark.parametrize("case", ["getitem", "getitem_with_if"])
|
||||
def test_getitem(self, input_shape, idx, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(idx, case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})
|
30
tests/layer_tests/pytorch_tests/test_size.py
Normal file
30
tests/layer_tests/pytorch_tests/test_size.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestSize(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape):
|
||||
import numpy as np
|
||||
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_size(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.shape
|
||||
|
||||
ref_net = None
|
||||
|
||||
op = aten_size()
|
||||
|
||||
return op, ref_net, "aten::size"
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("input_shape", [[1,], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]])
|
||||
def test_size(self, input_shape, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})
|
Loading…
Reference in New Issue
Block a user