Add support for concatenation in Loop (#15899)

* Add support for concatenation in Loop

* Apply suggestions from code review

* Fix win build

* Fix issues with propagation shapes and types in Loop

* Fix einsum

* Set type and shape of count in frontend
This commit is contained in:
Maxim Vafin 2023-02-28 21:31:33 +01:00 committed by GitHub
parent 62ff31df8a
commit 87e714eb5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 159 additions and 19 deletions

View File

@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() {
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) {
const auto& input_type_i = get_input_element_type(input_idx);
NODE_VALIDATION_CHECK(this,
input_type_0 == input_type_i,
input_type_0.compatible(input_type_i),
"Inputs to Einsum operation must have the same type.");
}

View File

@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() {
if (auto slice_input_description = ov::as_type_ptr<SliceInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index);
const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
const auto& input_type = inputs().at(index).get_source_output().get_element_type();
body_parameter->set_element_type(input_type);
if (input_partial_shape.rank().is_dynamic()) {
body_parameter->set_partial_shape(ov::PartialShape::dynamic());
} else {
@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() {
auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index);
auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();
body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index;
} else if (auto invariant_input_description =
ov::as_type_ptr<v0::TensorIterator::InvariantInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index);
auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();
body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
}
}

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_cat(NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 2, 2);
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
auto attrs = fw_node->get_attrs();
// If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline
attrs["axis"] = std::to_string(context.const_input<int64_t>(1));
fw_node->set_attrs(attrs);
return {context.mark_node(std::dynamic_pointer_cast<Node>(fw_node))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -31,7 +31,7 @@ OutputVector translate_list_construct(NodeContext& context) {
consts.push_back(unsqueezed_c_node);
}
}
auto list_construct = std::make_shared<v0::Concat>(consts, 0);
auto list_construct = context.mark_node(std::make_shared<v0::Concat>(consts, 0));
if (list_construct->has_evaluate()) {
OutputVector replacements(list_construct->get_output_size());
@ -39,7 +39,7 @@ OutputVector translate_list_construct(NodeContext& context) {
return replacements;
}
}
return {context.mark_output(list_construct)};
return {list_construct};
};
} // namespace op

View File

@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) {
loop->set_special_body_ports(spec_ports);
auto body_parameters = body->get_parameters();
// #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition
// #0 body parameter is counter;
FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required");
// Set counter type and shape
body_parameters[0]->set_element_type(element::i32);
body_parameters[0]->set_partial_shape(PartialShape{});
// #0 loop input is trip_count, #1 loop input is condition
// Connect other inputs
for (size_t i = 2; i < inputs.size(); i++) {
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
@ -39,7 +44,6 @@ OutputVector translate_loop(NodeContext& context) {
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
loop->set_invariant_inputs(external_output, {param});
}
// TODO: Connect back edges (merged inputs)
auto body_results = body->get_results();
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
std::set<size_t> output_idxs;

View File

@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
@ -160,7 +161,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::Bool", op::translate_bool},
// {"aten::cat", done as transformation},
{"aten::cat", op::translate_cat},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::clamp", op::translate_clamp},

View File

@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() {
if (!cat)
return false;
auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr();
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto axis = axis_const->cast_vector<int64_t>();
if (axis.size() != 1)
return false;
int64_t axis;
if (cat->get_input_size() > 1) {
auto axis_node = cat->get_input_node_shared_ptr(1);
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto _axis = axis_const->cast_vector<int64_t>();
if (_axis.size() != 1)
return false;
axis = _axis[0];
} else {
const auto& attrs = cat->get_attrs();
if (attrs.find("axis") == attrs.end())
return false;
axis = std::stoll(attrs.at("axis"));
}
std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) {
// case when concatenation is done inside the Loop
auto body = loop->get_function();
auto output_index = cat->input(0).get_source_output().get_index();
int64_t body_result_index = -1;
for (auto out_desc : loop->get_output_descriptions()) {
if (out_desc->m_output_index == output_index) {
body_result_index = static_cast<int64_t>(out_desc->m_body_value_index);
break;
}
}
FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output.");
auto body_result = body->get_results()[body_result_index];
auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append");
if (!append)
return false;
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0));
if (!param)
return false;
auto body_param_index = body->get_parameter_index(param);
FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters.");
int64_t input_index = -1;
for (auto in_desc : loop->get_input_descriptions()) {
if (in_desc->m_body_parameter_index == static_cast<size_t>(body_param_index)) {
input_index = static_cast<int64_t>(in_desc->m_input_index);
break;
}
}
FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input.");
auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct");
if (!list_construct || list_construct->get_input_size() > 0)
return false;
// TODO: Is unsqueeze needed?
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1));
body->add_results({new_result});
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
copy_runtime_info(cat, loop);
cat->output(0).replace(new_output);
return true;
}
OutputVector tmp_inputs;
NodeVector rt_copy_from{cat};
std::shared_ptr<Node> input_node = cat->input(0).get_source_output().get_node_shared_ptr();
while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) {
rt_copy_from.push_back(input_fw_node);
tmp_inputs.push_back(input_fw_node->input(1).get_source_output());
@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() {
inputs.push_back(input.get_source_output());
}
inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend());
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis[0]);
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis);
copy_runtime_info(rt_copy_from, result);
replace_node(cat, result);

View File

@ -52,7 +52,8 @@ class PytorchLayerTest:
else:
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
model = torch.jit.trace(model, torch_inputs)
model = torch.jit.freeze(model)
if kwargs.get('freeze_model', True):
model = torch.jit.freeze(model)
graph = model.inlined_graph
print(graph)

View File

@ -32,5 +32,5 @@ class TestBool(PytorchLayerTest):
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_ceil(self, ie_device, precision, ir_version, input_type):
def test_bool(self, ie_device, precision, ir_version, input_type):
self._test(*self.create_model(input_type), ie_device, precision, ir_version)

View File

@ -0,0 +1,51 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class aten_cat(torch.nn.Module):
def forward(self, x):
return torch.cat([x, x], 1)
class aten_append_cat(torch.nn.Module):
def forward(self, x):
list = []
list.append(x)
list.append(x)
return torch.cat(list, 1)
class aten_loop_append_cat(torch.nn.Module):
def forward(self, x):
list = []
for i in range(3):
list.append(x)
return torch.cat(list, 1)
class TestCat(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 1, 3),)
@pytest.mark.nightly
@pytest.mark.precommit
def test_cat(self, ie_device, precision, ir_version):
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
ie_device, precision, ir_version)
@pytest.mark.nightly
@pytest.mark.precommit
def test_append_cat(self, ie_device, precision, ir_version):
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
ie_device, precision, ir_version)
@pytest.mark.nightly
@pytest.mark.precommit
def test_loop_append_cat(self, ie_device, precision, ir_version):
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)