[PT FE] Fix issue when cat input is folded to tensor (#20090)

* [PT FE] Fix issue when cat input is folded to tensor

* CHeck real first input

* Update src/frontends/pytorch/src/op/cat.cpp
This commit is contained in:
Maxim Vafin
2023-09-28 10:52:09 +02:00
committed by GitHub
parent 1be993dd39
commit fea6db1a5f
3 changed files with 55 additions and 8 deletions

View File

@@ -141,7 +141,7 @@ void InputModel::override_all_inputs(const std::vector<Place::Ptr>& inputs) {
"Number of inputs provided is incorrect. Graph modification is not supported for "
"this model. Expected number of inputs: ",
m_inputs.size() - 1,
" recieved ",
" received ",
inputs.size());
auto self_place = m_inputs[0];
// Verify that no same place already in vector
@@ -158,7 +158,7 @@ void InputModel::override_all_inputs(const std::vector<Place::Ptr>& inputs) {
"Number of inputs provided is incorrect. Graph modification is not supported for "
"this model. Expected number of inputs: ",
m_inputs.size(),
" recieved ",
" received ",
inputs.size());
m_inputs = inputs;
}

View File

@@ -5,6 +5,10 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"
@@ -28,12 +32,27 @@ OutputVector translate_cat_common(const NodeContext& context,
attrs["axis"] = std::to_string(axis);
fw_node->set_attrs(attrs);
return {context.mark_node(fw_node)};
} else {
auto first_elem = list_elems.front().get_node_shared_ptr();
FRONT_END_OP_CONVERSION_CHECK(
list_elems.size() > 1 || !ov::as_type_ptr<v0::Parameter>(first_elem),
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
"This case is not supported.");
}
auto first_node = list_elems.front().get_node_shared_ptr();
FRONT_END_OP_CONVERSION_CHECK(
list_elems.size() > 1 || !ov::as_type_ptr<v0::Parameter>(first_node),
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
"This case is not supported.");
if (list_elems.size() == 1 &&
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
// Case when list was merged into tensor
auto tensor = list_elems[0];
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto axis_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {axis}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto int_max =
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>().max()}));
auto shape_sliced = context.mark_node(std::make_shared<v8::Slice>(shape, one, int_max, one));
auto new_shape =
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(shape_sliced, axis_const, neg_1, zero));
return {context.mark_node(std::make_shared<v1::Reshape>(tensor, new_shape, false))};
}
auto concat = std::make_shared<v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), axis);
return {context.mark_node(concat)};

View File

@@ -48,3 +48,31 @@ class TestMatMul(PytorchLayerTest):
def test_matmul(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model(len(kwargs_to_prepare_input) == 3), ie_device, precision, ir_version,
kwargs_to_prepare_input=kwargs_to_prepare_input)
class TestLinearBiasList(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 15, 10).astype(np.float32), np.random.randn(66, 10).astype(np.float32))
def create_model(self):
import torch
class aten_mm(torch.nn.Module):
def __init__(self):
super(aten_mm, self).__init__()
self.bias = [torch.randn(22),
torch.randn(22),
torch.randn(22)]
def forward(self, m1, m2):
m2 = m2.reshape([66, -1])
return torch.nn.functional.linear(m1, m2, torch.cat(self.bias, 0))
return aten_mm(), None, "aten::linear"
@pytest.mark.nightly
@pytest.mark.precommit
def test_linear_bias_list(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version,
trace_model=True, freeze_model=False)