Add aten stack transformation (#15311)
* add support for aten::stack * add new lines * updated aten stack transformation * add comments to the code --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
This commit is contained in:
parent
d8dfcac729
commit
566fae2b01
@ -15,6 +15,7 @@
|
||||
#include "transforms/append_list_unpack_replacer.hpp"
|
||||
#include "transforms/aten_cat_replacer.hpp"
|
||||
#include "transforms/aten_getitem_replacer.hpp"
|
||||
#include "transforms/aten_stack_list_construct_replacer.hpp"
|
||||
#include "transforms/listconstruct_reshape_replacer.hpp"
|
||||
#include "transforms/max_prim_list_construct_replacer.hpp"
|
||||
#include "transforms/prim_list_construct_pad.hpp"
|
||||
@ -86,6 +87,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::pass::UnrollIf>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::MaxPrimListConstructReplacer>();
|
||||
|
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "aten_stack_list_construct_replacer.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
|
||||
auto list_construct = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||
auto axis = ov::pass::pattern::wrap_type<opset10::Constant>();
|
||||
|
||||
// We search for a pattern: ListConstruct -> aten::stack <- Constant
|
||||
auto stack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>({list_construct, axis});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
auto stack = cast_fw_node(m.get_match_root(), "aten::stack");
|
||||
if (!stack) {
|
||||
return false;
|
||||
}
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto input_node = pattern_map.at(list_construct).get_node_shared_ptr();
|
||||
auto axis_node = pattern_map.at(axis).get_node_shared_ptr();
|
||||
auto axis_const = std::dynamic_pointer_cast<opset10::Constant>(axis_node);
|
||||
auto axis = axis_const->cast_vector<int64_t>();
|
||||
// Check if ListConstruct is an input
|
||||
if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||
const auto& list_inputs = list_construct_node->input_values();
|
||||
OutputVector node_vector;
|
||||
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
// Iterate over values in ListConstruct
|
||||
for (const auto& list_input : list_inputs) {
|
||||
auto node = concat_list_construct(list_input.get_node_shared_ptr());
|
||||
auto unsqueezed_node = std::make_shared<opset10::Unsqueeze>(node, axis_const);
|
||||
node_vector.push_back(unsqueezed_node);
|
||||
}
|
||||
// Concat vectors on provided axis
|
||||
auto concat = std::make_shared<opset10::Concat>(node_vector, axis[0]);
|
||||
|
||||
copy_runtime_info({stack, input_node}, concat);
|
||||
replace_node(stack, concat);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(stack, "ov::frontend::pytorch::pass::AtenStackListConstructReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class AtenStackListConstructReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenStackListConstructReplacer");
|
||||
AtenStackListConstructReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
77
tests/layer_tests/pytorch_tests/test_stack.py
Normal file
77
tests/layer_tests/pytorch_tests/test_stack.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestStack2D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return self.input_tensors
|
||||
|
||||
def create_model(self, dim):
|
||||
import torch
|
||||
|
||||
class aten_stack(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(aten_stack, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, y):
|
||||
inputs = [x, y]
|
||||
return torch.stack(inputs, self.dim)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_stack(dim), ref_net, "aten::stack"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", ([
|
||||
[np.random.rand(1, 3, 3), np.random.rand(1, 3, 3)],
|
||||
[np.random.rand(4, 4, 2), np.random.rand(4, 4, 2)],
|
||||
[np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9)]
|
||||
]))
|
||||
@pytest.mark.parametrize("dim", ([
|
||||
0, 1, 2,
|
||||
]))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_stack2D(self, input_tensor, dim, ie_device, precision, ir_version):
|
||||
self.input_tensors = input_tensor
|
||||
self._test(*self.create_model(dim), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestStack3D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return self.input_tensors
|
||||
|
||||
def create_model(self, dim):
|
||||
import torch
|
||||
|
||||
class aten_stack(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(aten_stack, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, y, z):
|
||||
inputs = [x, y, z]
|
||||
return torch.stack(inputs, self.dim)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_stack(dim), ref_net, "aten::stack"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", ([
|
||||
[np.random.rand(1, 3, 3), np.random.rand(1, 3, 3), np.random.rand(1, 3, 3)],
|
||||
[np.random.rand(4, 4, 2), np.random.rand(4, 4, 2), np.random.rand(4, 4, 2)],
|
||||
[np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9)]
|
||||
]))
|
||||
@pytest.mark.parametrize("dim", ([
|
||||
0, 1, 2,
|
||||
]))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_stack3D(self, input_tensor, dim, ie_device, precision, ir_version):
|
||||
self.input_tensors = input_tensor
|
||||
self._test(*self.create_model(dim), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user