Add support for concatenation in Loop (#15899)
* Add support for concatenation in Loop * Apply suggestions from code review * Fix win build * Fix issues with propagation shapes and types in Loop * Fix einsum * Set type and shape of count in frontend
This commit is contained in:
parent
62ff31df8a
commit
87e714eb5c
@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() {
|
||||
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) {
|
||||
const auto& input_type_i = get_input_element_type(input_idx);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_type_0 == input_type_i,
|
||||
input_type_0.compatible(input_type_i),
|
||||
"Inputs to Einsum operation must have the same type.");
|
||||
}
|
||||
|
||||
|
@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() {
|
||||
if (auto slice_input_description = ov::as_type_ptr<SliceInputDescription>(input_description)) {
|
||||
auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index);
|
||||
const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
|
||||
const auto& input_type = inputs().at(index).get_source_output().get_element_type();
|
||||
body_parameter->set_element_type(input_type);
|
||||
if (input_partial_shape.rank().is_dynamic()) {
|
||||
body_parameter->set_partial_shape(ov::PartialShape::dynamic());
|
||||
} else {
|
||||
@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() {
|
||||
|
||||
auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index);
|
||||
|
||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
||||
auto input_partial_shape = input(index).get_partial_shape();
|
||||
auto input_type = input(index).get_element_type();
|
||||
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
body_parameter->set_element_type(input_type);
|
||||
back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index;
|
||||
} else if (auto invariant_input_description =
|
||||
ov::as_type_ptr<v0::TensorIterator::InvariantInputDescription>(input_description)) {
|
||||
auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index);
|
||||
|
||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
||||
auto input_partial_shape = input(index).get_partial_shape();
|
||||
auto input_type = input(index).get_element_type();
|
||||
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
body_parameter->set_element_type(input_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
28
src/frontends/pytorch/src/op/cat.cpp
Normal file
28
src/frontends/pytorch/src/op/cat.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_cat(NodeContext& context) {
|
||||
// This translator is only needed to get axis as constant from external scope
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
|
||||
auto attrs = fw_node->get_attrs();
|
||||
// If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline
|
||||
attrs["axis"] = std::to_string(context.const_input<int64_t>(1));
|
||||
fw_node->set_attrs(attrs);
|
||||
return {context.mark_node(std::dynamic_pointer_cast<Node>(fw_node))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -31,7 +31,7 @@ OutputVector translate_list_construct(NodeContext& context) {
|
||||
consts.push_back(unsqueezed_c_node);
|
||||
}
|
||||
}
|
||||
auto list_construct = std::make_shared<v0::Concat>(consts, 0);
|
||||
auto list_construct = context.mark_node(std::make_shared<v0::Concat>(consts, 0));
|
||||
if (list_construct->has_evaluate()) {
|
||||
OutputVector replacements(list_construct->get_output_size());
|
||||
|
||||
@ -39,7 +39,7 @@ OutputVector translate_list_construct(NodeContext& context) {
|
||||
return replacements;
|
||||
}
|
||||
}
|
||||
return {context.mark_output(list_construct)};
|
||||
return {list_construct};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) {
|
||||
loop->set_special_body_ports(spec_ports);
|
||||
|
||||
auto body_parameters = body->get_parameters();
|
||||
// #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition
|
||||
// #0 body parameter is counter;
|
||||
FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required");
|
||||
// Set counter type and shape
|
||||
body_parameters[0]->set_element_type(element::i32);
|
||||
body_parameters[0]->set_partial_shape(PartialShape{});
|
||||
// #0 loop input is trip_count, #1 loop input is condition
|
||||
// Connect other inputs
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
|
||||
@ -39,7 +44,6 @@ OutputVector translate_loop(NodeContext& context) {
|
||||
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
|
||||
loop->set_invariant_inputs(external_output, {param});
|
||||
}
|
||||
// TODO: Connect back edges (merged inputs)
|
||||
auto body_results = body->get_results();
|
||||
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
|
||||
std::set<size_t> output_idxs;
|
||||
|
@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor);
|
||||
OP_CONVERTER(translate_avg_poolnd);
|
||||
OP_CONVERTER(translate_bool);
|
||||
OP_CONVERTER(translate_batch_norm);
|
||||
OP_CONVERTER(translate_cat);
|
||||
OP_CONVERTER(translate_clamp);
|
||||
OP_CONVERTER(translate_constant);
|
||||
OP_CONVERTER(translate_conv_transposend);
|
||||
@ -160,7 +161,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::Bool", op::translate_bool},
|
||||
// {"aten::cat", done as transformation},
|
||||
{"aten::cat", op::translate_cat},
|
||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||
{"aten::clamp", op::translate_clamp},
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/loop.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() {
|
||||
if (!cat)
|
||||
return false;
|
||||
|
||||
auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr();
|
||||
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
|
||||
if (!axis_const)
|
||||
return false;
|
||||
auto axis = axis_const->cast_vector<int64_t>();
|
||||
if (axis.size() != 1)
|
||||
return false;
|
||||
int64_t axis;
|
||||
if (cat->get_input_size() > 1) {
|
||||
auto axis_node = cat->get_input_node_shared_ptr(1);
|
||||
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
|
||||
if (!axis_const)
|
||||
return false;
|
||||
auto _axis = axis_const->cast_vector<int64_t>();
|
||||
if (_axis.size() != 1)
|
||||
return false;
|
||||
axis = _axis[0];
|
||||
} else {
|
||||
const auto& attrs = cat->get_attrs();
|
||||
if (attrs.find("axis") == attrs.end())
|
||||
return false;
|
||||
axis = std::stoll(attrs.at("axis"));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
|
||||
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) {
|
||||
// case when concatenation is done inside the Loop
|
||||
auto body = loop->get_function();
|
||||
auto output_index = cat->input(0).get_source_output().get_index();
|
||||
int64_t body_result_index = -1;
|
||||
for (auto out_desc : loop->get_output_descriptions()) {
|
||||
if (out_desc->m_output_index == output_index) {
|
||||
body_result_index = static_cast<int64_t>(out_desc->m_body_value_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output.");
|
||||
auto body_result = body->get_results()[body_result_index];
|
||||
auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append");
|
||||
if (!append)
|
||||
return false;
|
||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0));
|
||||
if (!param)
|
||||
return false;
|
||||
auto body_param_index = body->get_parameter_index(param);
|
||||
FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters.");
|
||||
int64_t input_index = -1;
|
||||
for (auto in_desc : loop->get_input_descriptions()) {
|
||||
if (in_desc->m_body_parameter_index == static_cast<size_t>(body_param_index)) {
|
||||
input_index = static_cast<int64_t>(in_desc->m_input_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input.");
|
||||
auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct");
|
||||
if (!list_construct || list_construct->get_input_size() > 0)
|
||||
return false;
|
||||
// TODO: Is unsqueeze needed?
|
||||
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1));
|
||||
body->add_results({new_result});
|
||||
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
|
||||
copy_runtime_info(cat, loop);
|
||||
cat->output(0).replace(new_output);
|
||||
return true;
|
||||
}
|
||||
|
||||
OutputVector tmp_inputs;
|
||||
NodeVector rt_copy_from{cat};
|
||||
std::shared_ptr<Node> input_node = cat->input(0).get_source_output().get_node_shared_ptr();
|
||||
while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) {
|
||||
rt_copy_from.push_back(input_fw_node);
|
||||
tmp_inputs.push_back(input_fw_node->input(1).get_source_output());
|
||||
@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() {
|
||||
inputs.push_back(input.get_source_output());
|
||||
}
|
||||
inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend());
|
||||
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis[0]);
|
||||
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis);
|
||||
copy_runtime_info(rt_copy_from, result);
|
||||
replace_node(cat, result);
|
||||
|
||||
|
@ -52,7 +52,8 @@ class PytorchLayerTest:
|
||||
else:
|
||||
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
|
||||
model = torch.jit.trace(model, torch_inputs)
|
||||
model = torch.jit.freeze(model)
|
||||
if kwargs.get('freeze_model', True):
|
||||
model = torch.jit.freeze(model)
|
||||
graph = model.inlined_graph
|
||||
print(graph)
|
||||
|
||||
|
@ -32,5 +32,5 @@ class TestBool(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_ceil(self, ie_device, precision, ir_version, input_type):
|
||||
def test_bool(self, ie_device, precision, ir_version, input_type):
|
||||
self._test(*self.create_model(input_type), ie_device, precision, ir_version)
|
||||
|
51
tests/layer_tests/pytorch_tests/test_cat.py
Normal file
51
tests/layer_tests/pytorch_tests/test_cat.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class aten_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cat([x, x], 1)
|
||||
|
||||
|
||||
class aten_append_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
list = []
|
||||
list.append(x)
|
||||
list.append(x)
|
||||
return torch.cat(list, 1)
|
||||
|
||||
class aten_loop_append_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
list = []
|
||||
for i in range(3):
|
||||
list.append(x)
|
||||
return torch.cat(list, 1)
|
||||
|
||||
|
||||
class TestCat(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(2, 1, 3),)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_append_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_loop_append_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
||||
ie_device, precision, ir_version, freeze_model=False)
|
Loading…
Reference in New Issue
Block a user