[PT FE]: support transformation for case aten::size + aten::__getitem__ (#15368)
This commit is contained in:
parent
1dd84e2074
commit
d57862edee
@ -156,7 +156,9 @@ void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body(
|
|||||||
|
|
||||||
auto body_parameter = params.at(input_description->m_body_parameter_index);
|
auto body_parameter = params.at(input_description->m_body_parameter_index);
|
||||||
auto input_partial_shape = input_value(index).get_partial_shape();
|
auto input_partial_shape = input_value(index).get_partial_shape();
|
||||||
|
auto dtype = input_value(index).get_element_type();
|
||||||
body_parameter->set_partial_shape(input_partial_shape);
|
body_parameter->set_partial_shape(input_partial_shape);
|
||||||
|
body_parameter->set_element_type(dtype);
|
||||||
}
|
}
|
||||||
body->validate_nodes_and_infer_types();
|
body->validate_nodes_and_infer_types();
|
||||||
}
|
}
|
||||||
|
@ -335,3 +335,40 @@ TEST(type_prop, if_scalar_and_1d_static_union) {
|
|||||||
auto sh = result0->get_output_partial_shape(0);
|
auto sh = result0->get_output_partial_shape(0);
|
||||||
EXPECT_EQ(sh, out_shape);
|
EXPECT_EQ(sh, out_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, if_element_type_dynamic) {
|
||||||
|
// That which we iterate over
|
||||||
|
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||||
|
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||||
|
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||||
|
|
||||||
|
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||||
|
// Body parameters
|
||||||
|
auto Xt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||||
|
auto Yt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||||
|
auto Xe = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||||
|
auto Ye = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||||
|
// Body
|
||||||
|
auto then_op = std::make_shared<op::v1::Add>(Xt, Yt);
|
||||||
|
auto then_op_res = std::make_shared<op::Result>(then_op);
|
||||||
|
|
||||||
|
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
|
||||||
|
|
||||||
|
auto else_op = std::make_shared<op::v1::Maximum>(Xe, Ye);
|
||||||
|
auto else_op_res = std::make_shared<op::Result>(else_op);
|
||||||
|
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
|
||||||
|
auto if_op = make_shared<op::v8::If>(cond);
|
||||||
|
if_op->set_then_body(then_body);
|
||||||
|
if_op->set_else_body(else_body);
|
||||||
|
if_op->set_input(X, Xt, Xe);
|
||||||
|
if_op->set_input(Y, Yt, Ye);
|
||||||
|
auto res = if_op->set_output(then_op_res, else_op_res);
|
||||||
|
auto result0 = make_shared<op::Result>(res);
|
||||||
|
Shape out0_shape{32, 40, 10};
|
||||||
|
auto sh = result0->get_output_shape(0);
|
||||||
|
EXPECT_EQ(sh, out0_shape);
|
||||||
|
// Check that If validation validates both bodies
|
||||||
|
if_op->validate_and_infer_types();
|
||||||
|
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||||
|
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
|
||||||
|
}
|
||||||
|
29
src/frontends/pytorch/src/op/getitem.cpp
Normal file
29
src/frontends/pytorch/src/op/getitem.cpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/gather.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
OutputVector translate_getitem(NodeContext& context) {
|
||||||
|
auto input = context.get_input(0);
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct") == nullptr,
|
||||||
|
"unsupported case for aten::getitem");
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "aten::split") == nullptr,
|
||||||
|
"unsupported case for aten::getitem");
|
||||||
|
auto getitem_idx = context.get_input(1);
|
||||||
|
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
|
||||||
|
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(input, getitem_idx, zero))};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -44,6 +44,7 @@ OP_CONVERTER(translate_full);
|
|||||||
OP_CONVERTER(translate_full_like);
|
OP_CONVERTER(translate_full_like);
|
||||||
OP_CONVERTER(translate_gelu);
|
OP_CONVERTER(translate_gelu);
|
||||||
OP_CONVERTER(translate_get_attr);
|
OP_CONVERTER(translate_get_attr);
|
||||||
|
OP_CONVERTER(translate_getitem);
|
||||||
OP_CONVERTER(translate_glu);
|
OP_CONVERTER(translate_glu);
|
||||||
OP_CONVERTER(translate_grid_sampler);
|
OP_CONVERTER(translate_grid_sampler);
|
||||||
OP_CONVERTER(translate_group_norm);
|
OP_CONVERTER(translate_group_norm);
|
||||||
@ -117,6 +118,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
return {
|
return {
|
||||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||||
|
{"aten::__getitem__", op::translate_getitem},
|
||||||
{"aten::_convolution", op::translate_convolution},
|
{"aten::_convolution", op::translate_convolution},
|
||||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||||
|
@ -95,6 +95,15 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||||
|
auto input_concat = concat_list_construct(list_construct);
|
||||||
|
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
|
||||||
|
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||||
|
auto gather = std::make_shared<opset10::Gather>(input_concat, getitem_idx, zero);
|
||||||
|
copy_runtime_info({getitem, input_node}, gather);
|
||||||
|
replace_node(getitem, gather);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
54
tests/layer_tests/pytorch_tests/test_getitem.py
Normal file
54
tests/layer_tests/pytorch_tests/test_getitem.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetItem(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, input_shape):
|
||||||
|
import numpy as np
|
||||||
|
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self, idx, case="size_with_getitem"):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_size_get_item(torch.nn.Module):
|
||||||
|
def __init__(self, idx):
|
||||||
|
super().__init__()
|
||||||
|
self.idx = idx
|
||||||
|
def forward(self, x):
|
||||||
|
return x.shape[self.idx]
|
||||||
|
|
||||||
|
class aten_size_get_item_with_if(torch.nn.Module):
|
||||||
|
def __init__(self, idx):
|
||||||
|
super().__init__()
|
||||||
|
self.idx:int = idx
|
||||||
|
def forward(self, x):
|
||||||
|
if x.shape[self.idx] > self.idx:
|
||||||
|
res = x.shape[self.idx]
|
||||||
|
else:
|
||||||
|
res = x.shape[-self.idx]
|
||||||
|
return res
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
op_cls = {
|
||||||
|
"getitem": (aten_size_get_item, ["aten::size", "aten::__getitem__"]),
|
||||||
|
"getitem_with_if": (aten_size_get_item_with_if, ["aten::size", "aten::__getitem__", "prim::If"])
|
||||||
|
}
|
||||||
|
op, op_in_graph = op_cls[case]
|
||||||
|
|
||||||
|
return op(idx), ref_net, op_in_graph
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize(("input_shape", "idx"), [
|
||||||
|
([1,], 0), ([1,], -1),
|
||||||
|
([1, 2], 0), ([1, 2], 1), ([1, 2], -1), ([1, 2], -2),
|
||||||
|
([1, 2, 3], 0), ([1, 2, 3], 1), ([1, 2, 3], 2), ([1, 2, 3], -1), ([1, 2, 3], -2), ([1, 2, 3], -3),
|
||||||
|
([1, 2, 3, 4], 0), ([1, 2, 3, 4], 1), ([1, 2, 3, 4], 2), ([1, 2, 3, 4], 3), ([1, 2, 3, 4], -1), ([1, 2, 3, 4], -2), ([1, 2, 3, 4], -3), ([1, 2, 3, 4], -4),
|
||||||
|
([1, 2, 3, 4, 5], 0), ([1, 2, 3, 4, 5], 1), ([1, 2, 3, 4, 5], 2), ([1, 2, 3, 4, 5], 3), ([1, 2, 3, 4, 5], 4), ([1, 2, 3, 4, 5], -1), ([1, 2, 3, 4, 5], -2), ([1, 2, 3, 4, 5], -3), ([1, 2, 3, 4, 5], -4), ([1, 2, 3, 4, 5], -5)])
|
||||||
|
@pytest.mark.parametrize("case", ["getitem", "getitem_with_if"])
|
||||||
|
def test_getitem(self, input_shape, idx, case, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(idx, case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})
|
30
tests/layer_tests/pytorch_tests/test_size.py
Normal file
30
tests/layer_tests/pytorch_tests/test_size.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSize(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, input_shape):
|
||||||
|
import numpy as np
|
||||||
|
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_size(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.shape
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
op = aten_size()
|
||||||
|
|
||||||
|
return op, ref_net, "aten::size"
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize("input_shape", [[1,], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]])
|
||||||
|
def test_size(self, input_shape, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape})
|
Loading…
Reference in New Issue
Block a user