[PT FE] Fix issue when cat input is folded to tensor (#20090)
* [PT FE] Fix issue when cat input is folded to tensor * CHeck real first input * Update src/frontends/pytorch/src/op/cat.cpp
This commit is contained in:
@@ -141,7 +141,7 @@ void InputModel::override_all_inputs(const std::vector<Place::Ptr>& inputs) {
|
||||
"Number of inputs provided is incorrect. Graph modification is not supported for "
|
||||
"this model. Expected number of inputs: ",
|
||||
m_inputs.size() - 1,
|
||||
" recieved ",
|
||||
" received ",
|
||||
inputs.size());
|
||||
auto self_place = m_inputs[0];
|
||||
// Verify that no same place already in vector
|
||||
@@ -158,7 +158,7 @@ void InputModel::override_all_inputs(const std::vector<Place::Ptr>& inputs) {
|
||||
"Number of inputs provided is incorrect. Graph modification is not supported for "
|
||||
"this model. Expected number of inputs: ",
|
||||
m_inputs.size(),
|
||||
" recieved ",
|
||||
" received ",
|
||||
inputs.size());
|
||||
m_inputs = inputs;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_elements_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "utils_quantize.hpp"
|
||||
@@ -28,12 +32,27 @@ OutputVector translate_cat_common(const NodeContext& context,
|
||||
attrs["axis"] = std::to_string(axis);
|
||||
fw_node->set_attrs(attrs);
|
||||
return {context.mark_node(fw_node)};
|
||||
} else {
|
||||
auto first_elem = list_elems.front().get_node_shared_ptr();
|
||||
FRONT_END_OP_CONVERSION_CHECK(
|
||||
list_elems.size() > 1 || !ov::as_type_ptr<v0::Parameter>(first_elem),
|
||||
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
|
||||
"This case is not supported.");
|
||||
}
|
||||
auto first_node = list_elems.front().get_node_shared_ptr();
|
||||
FRONT_END_OP_CONVERSION_CHECK(
|
||||
list_elems.size() > 1 || !ov::as_type_ptr<v0::Parameter>(first_node),
|
||||
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
|
||||
"This case is not supported.");
|
||||
if (list_elems.size() == 1 &&
|
||||
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
|
||||
// Case when list was merged into tensor
|
||||
auto tensor = list_elems[0];
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto axis_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {axis}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto int_max =
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>().max()}));
|
||||
auto shape_sliced = context.mark_node(std::make_shared<v8::Slice>(shape, one, int_max, one));
|
||||
auto new_shape =
|
||||
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(shape_sliced, axis_const, neg_1, zero));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(tensor, new_shape, false))};
|
||||
}
|
||||
auto concat = std::make_shared<v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), axis);
|
||||
return {context.mark_node(concat)};
|
||||
|
||||
@@ -48,3 +48,31 @@ class TestMatMul(PytorchLayerTest):
|
||||
def test_matmul(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(len(kwargs_to_prepare_input) == 3), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input=kwargs_to_prepare_input)
|
||||
|
||||
|
||||
class TestLinearBiasList(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 15, 10).astype(np.float32), np.random.randn(66, 10).astype(np.float32))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_mm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_mm, self).__init__()
|
||||
self.bias = [torch.randn(22),
|
||||
torch.randn(22),
|
||||
torch.randn(22)]
|
||||
|
||||
def forward(self, m1, m2):
|
||||
m2 = m2.reshape([66, -1])
|
||||
return torch.nn.functional.linear(m1, m2, torch.cat(self.bias, 0))
|
||||
|
||||
return aten_mm(), None, "aten::linear"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_linear_bias_list(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||
trace_model=True, freeze_model=False)
|
||||
|
||||
Reference in New Issue
Block a user