[PT FE] Recognize empty non-frozen lists (#21224)
* [PT FE] Recognize empty non-frozen lists * Do not produce alias for aten::clone
This commit is contained in:
parent
97381e0b63
commit
1a288f0e9a
@ -372,7 +372,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
||||||
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul"]:
|
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]:
|
||||||
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
|
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
|
||||||
|
#include "openvino/core/validation_util.hpp"
|
||||||
#include "openvino/frontend/exception.hpp"
|
#include "openvino/frontend/exception.hpp"
|
||||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||||
#include "openvino/op/constant.hpp"
|
#include "openvino/op/constant.hpp"
|
||||||
@ -149,58 +150,81 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index) {
|
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index, bool allow_empty = true) {
|
||||||
FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none.");
|
FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none.");
|
||||||
auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr();
|
auto input_val = ctx.get_input_from_visible_context(index);
|
||||||
auto input = std::dynamic_pointer_cast<v0::Constant>(input_node);
|
if (ctx.get_input_type(index).is<type::List>()) {
|
||||||
FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node);
|
if (allow_empty && is_empty_list(input_val))
|
||||||
return input;
|
return {};
|
||||||
|
input_val = concat_list_construct(input_val);
|
||||||
|
}
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
auto constant = get_constant_from_source(input_val);
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
FRONT_END_GENERAL_CHECK(constant, "Input with index ", index, " cannot be interpreted as Constant: ", input_val);
|
||||||
|
return constant;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
std::vector<int64_t> NodeContext::const_input<std::vector<int64_t>>(size_t index) const {
|
std::vector<int64_t> NodeContext::const_input<std::vector<int64_t>>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<int64_t>();
|
auto c = get_constant_at_input(*this, index);
|
||||||
|
if (c)
|
||||||
|
return c->cast_vector<int64_t>();
|
||||||
|
else
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Strides NodeContext::const_input<Strides>(size_t index) const {
|
Strides NodeContext::const_input<Strides>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<Strides::value_type>();
|
auto c = get_constant_at_input(*this, index);
|
||||||
|
if (c)
|
||||||
|
return c->cast_vector<Strides::value_type>();
|
||||||
|
else
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
CoordinateDiff NodeContext::const_input<CoordinateDiff>(size_t index) const {
|
CoordinateDiff NodeContext::const_input<CoordinateDiff>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<CoordinateDiff::value_type>();
|
auto c = get_constant_at_input(*this, index);
|
||||||
|
if (c)
|
||||||
|
return c->cast_vector<CoordinateDiff::value_type>();
|
||||||
|
else
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Shape NodeContext::const_input<Shape>(size_t index) const {
|
Shape NodeContext::const_input<Shape>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<Shape::value_type>();
|
auto c = get_constant_at_input(*this, index);
|
||||||
|
if (c)
|
||||||
|
return c->cast_vector<Shape::value_type>();
|
||||||
|
else
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
int32_t NodeContext::const_input<int32_t>(size_t index) const {
|
int32_t NodeContext::const_input<int32_t>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<int32_t>()[0];
|
return get_constant_at_input(*this, index, false)->cast_vector<int32_t>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
int64_t NodeContext::const_input<int64_t>(size_t index) const {
|
int64_t NodeContext::const_input<int64_t>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<int64_t>()[0];
|
return get_constant_at_input(*this, index, false)->cast_vector<int64_t>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool NodeContext::const_input<bool>(size_t index) const {
|
bool NodeContext::const_input<bool>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<bool>()[0];
|
return get_constant_at_input(*this, index, false)->cast_vector<bool>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
double NodeContext::const_input<double>(size_t index) const {
|
double NodeContext::const_input<double>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<double>()[0];
|
return get_constant_at_input(*this, index, false)->cast_vector<double>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
float NodeContext::const_input<float>(size_t index) const {
|
float NodeContext::const_input<float>(size_t index) const {
|
||||||
return get_constant_at_input(*this, index)->cast_vector<float>()[0];
|
return get_constant_at_input(*this, index, false)->cast_vector<float>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -233,13 +257,21 @@ Any NodeContext::get_values_from_const_input(int index) const {
|
|||||||
"Input with index: ",
|
"Input with index: ",
|
||||||
index,
|
index,
|
||||||
" does not exist.");
|
" does not exist.");
|
||||||
|
if (input_is_none(index))
|
||||||
if (input_is_none(index)) {
|
|
||||||
return {};
|
return {};
|
||||||
|
auto input_val = get_input_from_visible_context(index);
|
||||||
|
if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_val.get_node_shared_ptr())) {
|
||||||
|
const auto& attrs = input->get_attrs();
|
||||||
|
if (attrs.find("none_value") != attrs.end()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto it = attrs.find("string_value");
|
||||||
|
if (it != attrs.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
auto constant = get_constant_at_input(*this, index);
|
||||||
auto input_node = get_input_from_visible_context(index).get_node_shared_ptr();
|
if (constant) {
|
||||||
if (auto constant = as_type_ptr<v0::Constant>(input_node)) {
|
|
||||||
switch (constant->get_element_type()) {
|
switch (constant->get_element_type()) {
|
||||||
case element::f32:
|
case element::f32:
|
||||||
return get_constant_data<float>(constant);
|
return get_constant_data<float>(constant);
|
||||||
@ -266,18 +298,8 @@ Any NodeContext::get_values_from_const_input(int index) const {
|
|||||||
default:
|
default:
|
||||||
FRONT_END_GENERAL_CHECK(false, "Input with index: ", index, " has unsupported type.");
|
FRONT_END_GENERAL_CHECK(false, "Input with index: ", index, " has unsupported type.");
|
||||||
}
|
}
|
||||||
} else if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_node)) {
|
|
||||||
const auto& attrs = input->get_attrs();
|
|
||||||
if (attrs.find("none_value") != attrs.end()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
auto it = attrs.find("string_value");
|
|
||||||
if (it != attrs.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_val);
|
||||||
FRONT_END_GENERAL_CHECK(false, "Input node with index ", index, " cannot be interpreted as constant", input_node);
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -159,7 +159,7 @@ class TestPooling(PytorchLayerTest):
|
|||||||
reason='Ticket - 122715')
|
reason='Ticket - 122715')
|
||||||
def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version):
|
def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad),
|
self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad),
|
||||||
ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False)
|
ie_device, precision, ir_version, trace_model=True, freeze_model=False, dynamic_shapes=False)
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", d3_params)
|
@pytest.mark.parametrize("params", d3_params)
|
||||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||||
|
Loading…
Reference in New Issue
Block a user