[PT FE] Recognize empty non-frozen lists (#21224)

* [PT FE] Recognize empty non-frozen lists

* Do not produce alias for aten::clone
This commit is contained in:
Maxim Vafin 2023-11-22 11:58:53 +01:00 committed by GitHub
parent 97381e0b63
commit 1a288f0e9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 32 deletions

View File

@ -372,7 +372,7 @@ class TorchScriptPythonDecoder (Decoder):
return False
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul"]:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
return False
try:

View File

@ -4,6 +4,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/op/constant.hpp"
@ -149,58 +150,81 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
}
namespace {
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index) {
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index, bool allow_empty = true) {
FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none.");
auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr();
auto input = std::dynamic_pointer_cast<v0::Constant>(input_node);
FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node);
return input;
auto input_val = ctx.get_input_from_visible_context(index);
if (ctx.get_input_type(index).is<type::List>()) {
if (allow_empty && is_empty_list(input_val))
return {};
input_val = concat_list_construct(input_val);
}
OPENVINO_SUPPRESS_DEPRECATED_START
auto constant = get_constant_from_source(input_val);
OPENVINO_SUPPRESS_DEPRECATED_END
FRONT_END_GENERAL_CHECK(constant, "Input with index ", index, " cannot be interpreted as Constant: ", input_val);
return constant;
}
} // namespace
template <>
std::vector<int64_t> NodeContext::const_input<std::vector<int64_t>>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<int64_t>();
auto c = get_constant_at_input(*this, index);
if (c)
return c->cast_vector<int64_t>();
else
return {};
}
template <>
Strides NodeContext::const_input<Strides>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<Strides::value_type>();
auto c = get_constant_at_input(*this, index);
if (c)
return c->cast_vector<Strides::value_type>();
else
return {};
}
template <>
CoordinateDiff NodeContext::const_input<CoordinateDiff>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<CoordinateDiff::value_type>();
auto c = get_constant_at_input(*this, index);
if (c)
return c->cast_vector<CoordinateDiff::value_type>();
else
return {};
}
template <>
Shape NodeContext::const_input<Shape>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<Shape::value_type>();
auto c = get_constant_at_input(*this, index);
if (c)
return c->cast_vector<Shape::value_type>();
else
return {};
}
template <>
int32_t NodeContext::const_input<int32_t>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<int32_t>()[0];
return get_constant_at_input(*this, index, false)->cast_vector<int32_t>()[0];
}
template <>
int64_t NodeContext::const_input<int64_t>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<int64_t>()[0];
return get_constant_at_input(*this, index, false)->cast_vector<int64_t>()[0];
}
template <>
bool NodeContext::const_input<bool>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<bool>()[0];
return get_constant_at_input(*this, index, false)->cast_vector<bool>()[0];
}
template <>
double NodeContext::const_input<double>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<double>()[0];
return get_constant_at_input(*this, index, false)->cast_vector<double>()[0];
}
template <>
float NodeContext::const_input<float>(size_t index) const {
return get_constant_at_input(*this, index)->cast_vector<float>()[0];
return get_constant_at_input(*this, index, false)->cast_vector<float>()[0];
}
template <>
@ -233,13 +257,21 @@ Any NodeContext::get_values_from_const_input(int index) const {
"Input with index: ",
index,
" does not exist.");
if (input_is_none(index)) {
if (input_is_none(index))
return {};
auto input_val = get_input_from_visible_context(index);
if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_val.get_node_shared_ptr())) {
const auto& attrs = input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
return {};
}
auto it = attrs.find("string_value");
if (it != attrs.end()) {
return it->second;
}
}
auto input_node = get_input_from_visible_context(index).get_node_shared_ptr();
if (auto constant = as_type_ptr<v0::Constant>(input_node)) {
auto constant = get_constant_at_input(*this, index);
if (constant) {
switch (constant->get_element_type()) {
case element::f32:
return get_constant_data<float>(constant);
@ -266,18 +298,8 @@ Any NodeContext::get_values_from_const_input(int index) const {
default:
FRONT_END_GENERAL_CHECK(false, "Input with index: ", index, " has unsupported type.");
}
} else if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_node)) {
const auto& attrs = input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
return {};
}
auto it = attrs.find("string_value");
if (it != attrs.end()) {
return it->second;
}
}
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_node);
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_val);
return 0;
}

View File

@ -159,7 +159,7 @@ class TestPooling(PytorchLayerTest):
reason='Ticket - 122715')
def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version):
self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad),
ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False)
ie_device, precision, ir_version, trace_model=True, freeze_model=False, dynamic_shapes=False)
@pytest.mark.parametrize("params", d3_params)
@pytest.mark.parametrize("ceil_mode", [True, False])