[PT FE] Recognize empty non-frozen lists (#21224)
* [PT FE] Recognize empty non-frozen lists * Do not produce alias for aten::clone
This commit is contained in:
parent
97381e0b63
commit
1a288f0e9a
@ -372,7 +372,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return False
|
||||
|
||||
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
||||
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul"]:
|
||||
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]:
|
||||
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
|
||||
return False
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
@ -149,58 +150,81 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index) {
|
||||
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index, bool allow_empty = true) {
|
||||
FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none.");
|
||||
auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr();
|
||||
auto input = std::dynamic_pointer_cast<v0::Constant>(input_node);
|
||||
FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node);
|
||||
return input;
|
||||
auto input_val = ctx.get_input_from_visible_context(index);
|
||||
if (ctx.get_input_type(index).is<type::List>()) {
|
||||
if (allow_empty && is_empty_list(input_val))
|
||||
return {};
|
||||
input_val = concat_list_construct(input_val);
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
auto constant = get_constant_from_source(input_val);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
FRONT_END_GENERAL_CHECK(constant, "Input with index ", index, " cannot be interpreted as Constant: ", input_val);
|
||||
return constant;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
std::vector<int64_t> NodeContext::const_input<std::vector<int64_t>>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<int64_t>();
|
||||
auto c = get_constant_at_input(*this, index);
|
||||
if (c)
|
||||
return c->cast_vector<int64_t>();
|
||||
else
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
Strides NodeContext::const_input<Strides>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<Strides::value_type>();
|
||||
auto c = get_constant_at_input(*this, index);
|
||||
if (c)
|
||||
return c->cast_vector<Strides::value_type>();
|
||||
else
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
CoordinateDiff NodeContext::const_input<CoordinateDiff>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<CoordinateDiff::value_type>();
|
||||
auto c = get_constant_at_input(*this, index);
|
||||
if (c)
|
||||
return c->cast_vector<CoordinateDiff::value_type>();
|
||||
else
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
Shape NodeContext::const_input<Shape>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<Shape::value_type>();
|
||||
auto c = get_constant_at_input(*this, index);
|
||||
if (c)
|
||||
return c->cast_vector<Shape::value_type>();
|
||||
else
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
int32_t NodeContext::const_input<int32_t>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<int32_t>()[0];
|
||||
return get_constant_at_input(*this, index, false)->cast_vector<int32_t>()[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
int64_t NodeContext::const_input<int64_t>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<int64_t>()[0];
|
||||
return get_constant_at_input(*this, index, false)->cast_vector<int64_t>()[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
bool NodeContext::const_input<bool>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<bool>()[0];
|
||||
return get_constant_at_input(*this, index, false)->cast_vector<bool>()[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
double NodeContext::const_input<double>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<double>()[0];
|
||||
return get_constant_at_input(*this, index, false)->cast_vector<double>()[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
float NodeContext::const_input<float>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<float>()[0];
|
||||
return get_constant_at_input(*this, index, false)->cast_vector<float>()[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -233,13 +257,21 @@ Any NodeContext::get_values_from_const_input(int index) const {
|
||||
"Input with index: ",
|
||||
index,
|
||||
" does not exist.");
|
||||
|
||||
if (input_is_none(index)) {
|
||||
if (input_is_none(index))
|
||||
return {};
|
||||
auto input_val = get_input_from_visible_context(index);
|
||||
if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_val.get_node_shared_ptr())) {
|
||||
const auto& attrs = input->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
return {};
|
||||
}
|
||||
auto it = attrs.find("string_value");
|
||||
if (it != attrs.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto input_node = get_input_from_visible_context(index).get_node_shared_ptr();
|
||||
if (auto constant = as_type_ptr<v0::Constant>(input_node)) {
|
||||
auto constant = get_constant_at_input(*this, index);
|
||||
if (constant) {
|
||||
switch (constant->get_element_type()) {
|
||||
case element::f32:
|
||||
return get_constant_data<float>(constant);
|
||||
@ -266,18 +298,8 @@ Any NodeContext::get_values_from_const_input(int index) const {
|
||||
default:
|
||||
FRONT_END_GENERAL_CHECK(false, "Input with index: ", index, " has unsupported type.");
|
||||
}
|
||||
} else if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_node)) {
|
||||
const auto& attrs = input->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
return {};
|
||||
}
|
||||
auto it = attrs.find("string_value");
|
||||
if (it != attrs.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_node);
|
||||
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_val);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ class TestPooling(PytorchLayerTest):
|
||||
reason='Ticket - 122715')
|
||||
def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad),
|
||||
ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False)
|
||||
ie_device, precision, ir_version, trace_model=True, freeze_model=False, dynamic_shapes=False)
|
||||
|
||||
@pytest.mark.parametrize("params", d3_params)
|
||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||
|
Loading…
Reference in New Issue
Block a user