diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index a50d41e333c..0b56350fc51 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -14,6 +14,7 @@ #include "transforms/append_list_unpack_replacer.hpp" #include "transforms/aten_cat_replacer.hpp" #include "transforms/aten_getitem_replacer.hpp" +#include "transforms/listconstruct_reshape_replacer.hpp" #include "transforms/max_prim_list_construct_replacer.hpp" #include "transforms/prim_list_construct_pad.hpp" #include "transforms/prim_list_unpack_replacer.hpp" @@ -88,6 +89,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/op/int.cpp b/src/frontends/pytorch/src/op/int.cpp index f49bc30bf76..6c875c7afa3 100644 --- a/src/frontends/pytorch/src/op/int.cpp +++ b/src/frontends/pytorch/src/op/int.cpp @@ -12,10 +12,10 @@ namespace pytorch { namespace op { OutputVector translate_int(NodeContext& context) { - return {context.mark_node(std::make_shared(context.get_input(0), element::i64))}; + return {context.mark_node(std::make_shared(context.get_input(0), element::i32))}; }; } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp index 2da4cba5f19..01d2b1d348a 100644 --- a/src/frontends/pytorch/src/op/reshape.cpp +++ b/src/frontends/pytorch/src/op/reshape.cpp @@ -13,25 +13,12 @@ namespace pytorch { namespace op { OutputVector translate_reshape(NodeContext& context) { - auto shape_node = context.get_input(1).get_node(); - auto shape_node_fw_node = dynamic_cast(shape_node); - std::shared_ptr reshape; - // TODO: move this to transform stage - if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") { - OutputVector inputs; - auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0})); - for (auto& input : shape_node->inputs()) { - auto rank = input.get_partial_shape().rank(); - FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0"); - auto unsqueeze = context.mark_node(std::make_shared(input.get_source_output(), axis_0)); - inputs.push_back(unsqueeze); - } - auto concat = context.mark_node(std::make_shared(inputs, 0)); - reshape = context.mark_node(std::make_shared(context.get_input(0), concat, false)); - } else { - reshape = - context.mark_node(std::make_shared(context.get_input(0), context.get_input(1), false)); - } + // Translation is used by both aten::view and aten::reshape. + // Schema: aten::view(Tensor input, int[] shape) -> Tensor + // Schema: aten::reshape(Tensor input, int[] shape) -> Tensor + // For shape parameter, int[] is converted into single dimensional Tensor. + auto reshape = + context.mark_node(std::make_shared(context.get_input(0), context.get_input(1), false)); return {reshape}; }; diff --git a/src/frontends/pytorch/src/op/view.cpp b/src/frontends/pytorch/src/op/view.cpp deleted file mode 100644 index 46a01d229a6..00000000000 --- a/src/frontends/pytorch/src/op/view.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" -#include "pt_framework_node.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -OutputVector translate_view(NodeContext& context) { - auto shape_node = context.get_input(1).get_node(); - auto shape_node_fw_node = dynamic_cast(shape_node); - std::shared_ptr reshape; - // TODO: move this to transform stage - if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") { - OutputVector inputs; - auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0})); - for (auto& input : shape_node->inputs()) { - auto rank = input.get_partial_shape().rank(); - FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0"); - auto unsqueeze = context.mark_node(std::make_shared(input.get_source_output(), axis_0)); - inputs.push_back(unsqueeze); - } - auto concat = context.mark_node(std::make_shared(inputs, 0)); - reshape = context.mark_node(std::make_shared(context.get_input(0), concat, false)); - // TODO: fix rt_info - // auto list_set = shape_node_fw_node->get_rt_info()["pt_node"].as>(); - // reshape->get_rt_info()["pt_node"].as>().insert(list_set.begin(), - // list_set.end()); - } else { - reshape = - context.mark_node(std::make_shared(context.get_input(0), context.get_input(1), false)); - } - return {reshape}; -}; - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 4448041e1a0..b4b48668bc1 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -108,7 +108,6 @@ OP_CONVERTER(translate_upsample_bilinear2d); OP_CONVERTER(translate_upsample_nearest2d); OP_CONVERTER(translate_var); OP_CONVERTER(translate_var_mean); -OP_CONVERTER(translate_view); OP_CONVERTER(translate_where); OP_CONVERTER(translate_zeros); OP_CONVERTER(translate_zeros_like); @@ -296,7 +295,7 @@ const std::map get_supported_ops() { {"aten::upsample_nearest2d", op::translate_upsample_nearest2d}, {"aten::var", op::translate_var}, {"aten::var_mean", op::translate_var_mean}, - {"aten::view", op::translate_view}, + {"aten::view", op::translate_reshape}, {"aten::where", op::translate_where}, {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, diff --git a/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.cpp new file mode 100644 index 00000000000..5d5cff9800b --- /dev/null +++ b/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "listconstruct_reshape_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +ListConstructReshapeReplacer::ListConstructReshapeReplacer() { + // Transformation for torch operators aten::view and aten::reshape for cases where second input is + // prim::ListConstruct. + auto reshape_op = ov::pass::pattern::wrap_type(); + // Both aten::view and aten::reshape are using same translation returning opset10::Reshape operator. + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto reshape_op = std::dynamic_pointer_cast(m.get_match_root()); + if (!reshape_op) { + return false; + } + auto shape_node = reshape_op->input_value(1).get_node_shared_ptr(); + if (auto list_unpack_node = cast_fw_node(shape_node, "prim::ListConstruct")) { + // Check if second input to operator is prim::ListConstruct, and if so, concatenate it inputs into single + // Tensor. Concatenation is possible because all elements in list should be scalar intigers. + OutputVector inputs; + auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0}); + for (auto& input : shape_node->inputs()) { + auto rank = input.get_partial_shape().rank(); + FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0"); + auto unsqueeze = std::make_shared(input.get_source_output(), axis_0); + inputs.push_back(unsqueeze); + } + auto concat = std::make_shared(inputs, 0); + copy_runtime_info({shape_node}, concat); + replace_node(shape_node, concat); + return true; + }; + return false; + }; + auto m = std::make_shared(reshape_op, + "ov::frontend::pytorch::pass::ListConstructReshapeReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.hpp b/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.hpp new file mode 100644 index 00000000000..86219287c83 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/listconstruct_reshape_replacer.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +class ListConstructReshapeReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::ListConstructReshapeReplacer"); + ListConstructReshapeReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_view.py b/tests/layer_tests/pytorch_tests/test_view.py index 492c0d1d25b..2d2c80633f7 100644 --- a/tests/layer_tests/pytorch_tests/test_view.py +++ b/tests/layer_tests/pytorch_tests/test_view.py @@ -31,9 +31,54 @@ class TestViewListConstruct(PytorchLayerTest): self.input_data = input_data self._test(*self.create_model(), ie_device, precision, ir_version) +@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.array(2))]) +class TestViewDtype(PytorchLayerTest): + + def _prepare_input(self): + return self.input_data + + def create_model(self): + class aten_view_dtype(torch.nn.Module): + + def forward(self, input_tensor, dtype): + return input_tensor.view(torch.int64) + + ref_net = None + + return aten_view_dtype(), ref_net, "aten::view" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_view_dtype(self, ie_device, precision, ir_version, input_data): + self.input_data = input_data + self._test(*self.create_model(), ie_device, precision, ir_version) + + +@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.random.randn(2, 2))]) +class TestViewSize(PytorchLayerTest): + + def _prepare_input(self): + return self.input_data + + def create_model(self): + class aten_view_size(torch.nn.Module): + + def forward(self, input_tensor, input_size): + return input_tensor.view(input_size.size()[:]) + + ref_net = None + + return aten_view_size(), ref_net, "aten::view" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_view_size(self, ie_device, precision, ir_version, input_data): + self.input_data = input_data + self._test(*self.create_model(), ie_device, precision, ir_version) @pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 2), 2, 6), - (np.random.randn(4), 2, 2)]) + (np.random.randn(4), 2, 2), + (np.random.randn(4), 2, 2.1)]) class TestView(PytorchLayerTest): def _prepare_input(self): @@ -48,7 +93,7 @@ class TestView(PytorchLayerTest): self.dim2 = input_data[2] def forward(self, input_tensor): - return input_tensor.view(self.dim1, self.dim2) + return input_tensor.view(self.dim1, int(self.dim2)) ref_net = None