From 1a288f0e9acfcd7e84eb989b9240acfaa4ef1a6f Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 22 Nov 2023 11:58:53 +0100 Subject: [PATCH] [PT FE] Recognize empty non-frozen lists (#21224) * [PT FE] Recognize empty non-frozen lists * Do not produce alias for aten::clone --- .../openvino/frontend/pytorch/ts_decoder.py | 2 +- src/frontends/pytorch/src/node_context.cpp | 82 ++++++++++++------- .../layer_tests/pytorch_tests/test_pooling.py | 2 +- 3 files changed, 54 insertions(+), 32 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 04259234298..5723d06d395 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -372,7 +372,7 @@ class TorchScriptPythonDecoder (Decoder): return False def may_produce_alias(self, in_index: int, out_index: int) -> bool: - if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul"]: + if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]: # AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that return False try: diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index 4c6bf2e9e50..db2d520c98c 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -4,6 +4,7 @@ #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/exception.hpp" #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/op/constant.hpp" @@ -149,58 +150,81 @@ std::shared_ptr NodeContext::convert_subgraph(size_t index) const { } namespace { -std::shared_ptr get_constant_at_input(const NodeContext& ctx, size_t index) { +std::shared_ptr get_constant_at_input(const NodeContext& ctx, size_t index, bool allow_empty = true) { FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none."); - auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr(); - auto input = std::dynamic_pointer_cast(input_node); - FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node); - return input; + auto input_val = ctx.get_input_from_visible_context(index); + if (ctx.get_input_type(index).is()) { + if (allow_empty && is_empty_list(input_val)) + return {}; + input_val = concat_list_construct(input_val); + } + OPENVINO_SUPPRESS_DEPRECATED_START + auto constant = get_constant_from_source(input_val); + OPENVINO_SUPPRESS_DEPRECATED_END + FRONT_END_GENERAL_CHECK(constant, "Input with index ", index, " cannot be interpreted as Constant: ", input_val); + return constant; } } // namespace template <> std::vector NodeContext::const_input>(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector(); + auto c = get_constant_at_input(*this, index); + if (c) + return c->cast_vector(); + else + return {}; } template <> Strides NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector(); + auto c = get_constant_at_input(*this, index); + if (c) + return c->cast_vector(); + else + return {}; } template <> CoordinateDiff NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector(); + auto c = get_constant_at_input(*this, index); + if (c) + return c->cast_vector(); + else + return {}; } template <> Shape NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector(); + auto c = get_constant_at_input(*this, index); + if (c) + return c->cast_vector(); + else + return {}; } template <> int32_t NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector()[0]; + return get_constant_at_input(*this, index, false)->cast_vector()[0]; } template <> int64_t NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector()[0]; + return get_constant_at_input(*this, index, false)->cast_vector()[0]; } template <> bool NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector()[0]; + return get_constant_at_input(*this, index, false)->cast_vector()[0]; } template <> double NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector()[0]; + return get_constant_at_input(*this, index, false)->cast_vector()[0]; } template <> float NodeContext::const_input(size_t index) const { - return get_constant_at_input(*this, index)->cast_vector()[0]; + return get_constant_at_input(*this, index, false)->cast_vector()[0]; } template <> @@ -233,13 +257,21 @@ Any NodeContext::get_values_from_const_input(int index) const { "Input with index: ", index, " does not exist."); - - if (input_is_none(index)) { + if (input_is_none(index)) return {}; + auto input_val = get_input_from_visible_context(index); + if (auto input = std::dynamic_pointer_cast(input_val.get_node_shared_ptr())) { + const auto& attrs = input->get_attrs(); + if (attrs.find("none_value") != attrs.end()) { + return {}; + } + auto it = attrs.find("string_value"); + if (it != attrs.end()) { + return it->second; + } } - - auto input_node = get_input_from_visible_context(index).get_node_shared_ptr(); - if (auto constant = as_type_ptr(input_node)) { + auto constant = get_constant_at_input(*this, index); + if (constant) { switch (constant->get_element_type()) { case element::f32: return get_constant_data(constant); @@ -266,18 +298,8 @@ Any NodeContext::get_values_from_const_input(int index) const { default: FRONT_END_GENERAL_CHECK(false, "Input with index: ", index, " has unsupported type."); } - } else if (auto input = std::dynamic_pointer_cast(input_node)) { - const auto& attrs = input->get_attrs(); - if (attrs.find("none_value") != attrs.end()) { - return {}; - } - auto it = attrs.find("string_value"); - if (it != attrs.end()) { - return it->second; - } } - - FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_node); + FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_val); return 0; } diff --git a/tests/layer_tests/pytorch_tests/test_pooling.py b/tests/layer_tests/pytorch_tests/test_pooling.py index 14e606e5771..ac069567d69 100644 --- a/tests/layer_tests/pytorch_tests/test_pooling.py +++ b/tests/layer_tests/pytorch_tests/test_pooling.py @@ -159,7 +159,7 @@ class TestPooling(PytorchLayerTest): reason='Ticket - 122715') def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad), - ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False) + ie_device, precision, ir_version, trace_model=True, freeze_model=False, dynamic_shapes=False) @pytest.mark.parametrize("params", d3_params) @pytest.mark.parametrize("ceil_mode", [True, False])