diff --git a/src/core/src/op/util/multi_subgraph_base.cpp b/src/core/src/op/util/multi_subgraph_base.cpp index 7d36b372bf4..fa7da5aadae 100644 --- a/src/core/src/op/util/multi_subgraph_base.cpp +++ b/src/core/src/op/util/multi_subgraph_base.cpp @@ -156,7 +156,9 @@ void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body( auto body_parameter = params.at(input_description->m_body_parameter_index); auto input_partial_shape = input_value(index).get_partial_shape(); + auto dtype = input_value(index).get_element_type(); body_parameter->set_partial_shape(input_partial_shape); + body_parameter->set_element_type(dtype); } body->validate_nodes_and_infer_types(); } diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index 1355fa43539..97a45050972 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -335,3 +335,40 @@ TEST(type_prop, if_scalar_and_1d_static_union) { auto sh = result0->get_output_partial_shape(0); EXPECT_EQ(sh, out_shape); } + +TEST(type_prop, if_element_type_dynamic) { + // That which we iterate over + auto X = make_shared(element::f16, Shape{32, 40, 10}); + auto Y = make_shared(element::f16, Shape{32, 40, 10}); + auto cond = std::make_shared(ngraph::element::boolean, ngraph::Shape{1}, true); + + // Set up the cell body, a function from (Xi, Yi) -> (Zo) + // Body parameters + auto Xt = make_shared(element::dynamic, PartialShape::dynamic()); + auto Yt = make_shared(element::dynamic, PartialShape::dynamic()); + auto Xe = make_shared(element::dynamic, PartialShape::dynamic()); + auto Ye = make_shared(element::dynamic, PartialShape::dynamic()); + // Body + auto then_op = std::make_shared(Xt, Yt); + auto then_op_res = std::make_shared(then_op); + + auto then_body = make_shared(OutputVector{then_op_res}, ParameterVector{Xt, Yt}); + + auto else_op = std::make_shared(Xe, Ye); + auto else_op_res = std::make_shared(else_op); + auto else_body = make_shared(OutputVector{else_op_res}, ParameterVector{Xe, Ye}); + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, Xe); + if_op->set_input(Y, Yt, Ye); + auto res = if_op->set_output(then_op_res, else_op_res); + auto result0 = make_shared(res); + Shape out0_shape{32, 40, 10}; + auto sh = result0->get_output_shape(0); + EXPECT_EQ(sh, out0_shape); + // Check that If validation validates both bodies + if_op->validate_and_infer_types(); + EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16); + EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16); +} diff --git a/src/frontends/pytorch/src/op/getitem.cpp b/src/frontends/pytorch/src/op/getitem.cpp new file mode 100644 index 00000000000..61c2b4e2540 --- /dev/null +++ b/src/frontends/pytorch/src/op/getitem.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_getitem(NodeContext& context) { + auto input = context.get_input(0); + FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct") == nullptr, + "unsupported case for aten::getitem"); + FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "aten::split") == nullptr, + "unsupported case for aten::getitem"); + auto getitem_idx = context.get_input(1); + auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0})); + return {context.mark_node(std::make_shared(input, getitem_idx, zero))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index cd20cfc494e..b06361986ff 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -44,6 +44,7 @@ OP_CONVERTER(translate_full); OP_CONVERTER(translate_full_like); OP_CONVERTER(translate_gelu); OP_CONVERTER(translate_get_attr); +OP_CONVERTER(translate_getitem); OP_CONVERTER(translate_glu); OP_CONVERTER(translate_grid_sampler); OP_CONVERTER(translate_group_norm); @@ -117,6 +118,7 @@ const std::map get_supported_ops() { return { {"aten::__and__", op::translate_1to1_match_2_inputs}, // TODO: cover numerical cases {"aten::__not__", op::translate_1to1_match_1_inputs}, + {"aten::__getitem__", op::translate_getitem}, {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, {"aten::abs", op::translate_1to1_match_1_inputs}, diff --git a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp index 2fdb0679f2e..84032d02fa3 100644 --- a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp @@ -95,6 +95,15 @@ AtenGetItemReplacer::AtenGetItemReplacer() { } return true; } + if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) { + auto input_concat = concat_list_construct(list_construct); + auto getitem_idx = getitem->input_value(1).get_node_shared_ptr(); + auto zero = opset10::Constant::create(element::i32, Shape{}, {0}); + auto gather = std::make_shared(input_concat, getitem_idx, zero); + copy_runtime_info({getitem, input_node}, gather); + replace_node(getitem, gather); + return true; + } return false; }; diff --git a/tests/layer_tests/pytorch_tests/test_getitem.py b/tests/layer_tests/pytorch_tests/test_getitem.py new file mode 100644 index 00000000000..4d1ba5bf80f --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_getitem.py @@ -0,0 +1,54 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestGetItem(PytorchLayerTest): + def _prepare_input(self, input_shape): + import numpy as np + return (np.random.randn(*input_shape).astype(np.float32),) + + def create_model(self, idx, case="size_with_getitem"): + import torch + + class aten_size_get_item(torch.nn.Module): + def __init__(self, idx): + super().__init__() + self.idx = idx + def forward(self, x): + return x.shape[self.idx] + + class aten_size_get_item_with_if(torch.nn.Module): + def __init__(self, idx): + super().__init__() + self.idx:int = idx + def forward(self, x): + if x.shape[self.idx] > self.idx: + res = x.shape[self.idx] + else: + res = x.shape[-self.idx] + return res + + ref_net = None + op_cls = { + "getitem": (aten_size_get_item, ["aten::size", "aten::__getitem__"]), + "getitem_with_if": (aten_size_get_item_with_if, ["aten::size", "aten::__getitem__", "prim::If"]) + } + op, op_in_graph = op_cls[case] + + return op(idx), ref_net, op_in_graph + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("input_shape", "idx"), [ + ([1,], 0), ([1,], -1), + ([1, 2], 0), ([1, 2], 1), ([1, 2], -1), ([1, 2], -2), + ([1, 2, 3], 0), ([1, 2, 3], 1), ([1, 2, 3], 2), ([1, 2, 3], -1), ([1, 2, 3], -2), ([1, 2, 3], -3), + ([1, 2, 3, 4], 0), ([1, 2, 3, 4], 1), ([1, 2, 3, 4], 2), ([1, 2, 3, 4], 3), ([1, 2, 3, 4], -1), ([1, 2, 3, 4], -2), ([1, 2, 3, 4], -3), ([1, 2, 3, 4], -4), + ([1, 2, 3, 4, 5], 0), ([1, 2, 3, 4, 5], 1), ([1, 2, 3, 4, 5], 2), ([1, 2, 3, 4, 5], 3), ([1, 2, 3, 4, 5], 4), ([1, 2, 3, 4, 5], -1), ([1, 2, 3, 4, 5], -2), ([1, 2, 3, 4, 5], -3), ([1, 2, 3, 4, 5], -4), ([1, 2, 3, 4, 5], -5)]) + @pytest.mark.parametrize("case", ["getitem", "getitem_with_if"]) + def test_getitem(self, input_shape, idx, case, ie_device, precision, ir_version): + self._test(*self.create_model(idx, case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape}) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_size.py b/tests/layer_tests/pytorch_tests/test_size.py new file mode 100644 index 00000000000..b861d0ee4d0 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_size.py @@ -0,0 +1,30 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestSize(PytorchLayerTest): + def _prepare_input(self, input_shape): + import numpy as np + return (np.random.randn(*input_shape).astype(np.float32),) + + def create_model(self): + import torch + + class aten_size(torch.nn.Module): + def forward(self, x): + return x.shape + + ref_net = None + + op = aten_size() + + return op, ref_net, "aten::size" + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_shape", [[1,], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]) + def test_size(self, input_shape, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape}) \ No newline at end of file