[PT FE]: Add aten::view transformations (#15339)

This commit is contained in:
Mateusz Mikolajczyk 2023-02-01 12:14:17 +01:00 committed by GitHub
parent cab559b478
commit a769cfe7e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 70 deletions

View File

@ -14,6 +14,7 @@
#include "transforms/append_list_unpack_replacer.hpp"
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
#include "transforms/listconstruct_reshape_replacer.hpp"
#include "transforms/max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
#include "transforms/prim_list_unpack_replacer.hpp"
@ -88,6 +89,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReshapeReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeTupleResults>();
manager.register_pass<ov::pass::ConstantFolding>();

View File

@ -12,10 +12,10 @@ namespace pytorch {
namespace op {
OutputVector translate_int(NodeContext& context) {
return {context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), element::i64))};
return {context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), element::i32))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@ -13,25 +13,12 @@ namespace pytorch {
namespace op {
OutputVector translate_reshape(NodeContext& context) {
auto shape_node = context.get_input(1).get_node();
auto shape_node_fw_node = dynamic_cast<PtFrameworkNode*>(shape_node);
std::shared_ptr<ov::Node> reshape;
// TODO: move this to transform stage
if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") {
OutputVector inputs;
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
for (auto& input : shape_node->inputs()) {
auto rank = input.get_partial_shape().rank();
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
auto unsqueeze = context.mark_node(std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0));
inputs.push_back(unsqueeze);
}
auto concat = context.mark_node(std::make_shared<opset10::Concat>(inputs, 0));
reshape = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), concat, false));
} else {
reshape =
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
}
// Translation is used by both aten::view and aten::reshape.
// Schema: aten::view(Tensor input, int[] shape) -> Tensor
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
// For shape parameter, int[] is converted into single dimensional Tensor.
auto reshape =
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
return {reshape};
};

View File

@ -1,45 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_view(NodeContext& context) {
auto shape_node = context.get_input(1).get_node();
auto shape_node_fw_node = dynamic_cast<PtFrameworkNode*>(shape_node);
std::shared_ptr<ov::Node> reshape;
// TODO: move this to transform stage
if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") {
OutputVector inputs;
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
for (auto& input : shape_node->inputs()) {
auto rank = input.get_partial_shape().rank();
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
auto unsqueeze = context.mark_node(std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0));
inputs.push_back(unsqueeze);
}
auto concat = context.mark_node(std::make_shared<opset10::Concat>(inputs, 0));
reshape = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), concat, false));
// TODO: fix rt_info
// auto list_set = shape_node_fw_node->get_rt_info()["pt_node"].as<std::set<const Node*>>();
// reshape->get_rt_info()["pt_node"].as<std::set<const Node*>>().insert(list_set.begin(),
// list_set.end());
} else {
reshape =
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
}
return {reshape};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -108,7 +108,6 @@ OP_CONVERTER(translate_upsample_bilinear2d);
OP_CONVERTER(translate_upsample_nearest2d);
OP_CONVERTER(translate_var);
OP_CONVERTER(translate_var_mean);
OP_CONVERTER(translate_view);
OP_CONVERTER(translate_where);
OP_CONVERTER(translate_zeros);
OP_CONVERTER(translate_zeros_like);
@ -296,7 +295,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::upsample_nearest2d", op::translate_upsample_nearest2d},
{"aten::var", op::translate_var},
{"aten::var_mean", op::translate_var_mean},
{"aten::view", op::translate_view},
{"aten::view", op::translate_reshape},
{"aten::where", op::translate_where},
{"aten::zeros", op::translate_zeros},
{"aten::zeros_like", op::translate_zeros_like},

View File

@ -0,0 +1,55 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "listconstruct_reshape_replacer.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
ListConstructReshapeReplacer::ListConstructReshapeReplacer() {
// Transformation for torch operators aten::view and aten::reshape for cases where second input is
// prim::ListConstruct.
auto reshape_op = ov::pass::pattern::wrap_type<opset10::Reshape>();
// Both aten::view and aten::reshape are using same translation returning opset10::Reshape operator.
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto reshape_op = std::dynamic_pointer_cast<opset10::Reshape>(m.get_match_root());
if (!reshape_op) {
return false;
}
auto shape_node = reshape_op->input_value(1).get_node_shared_ptr();
if (auto list_unpack_node = cast_fw_node(shape_node, "prim::ListConstruct")) {
// Check if second input to operator is prim::ListConstruct, and if so, concatenate it inputs into single
// Tensor. Concatenation is possible because all elements in list should be scalar intigers.
OutputVector inputs;
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
for (auto& input : shape_node->inputs()) {
auto rank = input.get_partial_shape().rank();
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
auto unsqueeze = std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0);
inputs.push_back(unsqueeze);
}
auto concat = std::make_shared<opset10::Concat>(inputs, 0);
copy_runtime_info({shape_node}, concat);
replace_node(shape_node, concat);
return true;
};
return false;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(reshape_op,
"ov::frontend::pytorch::pass::ListConstructReshapeReplacer");
this->register_matcher(m, callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class ListConstructReshapeReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::ListConstructReshapeReplacer");
ListConstructReshapeReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -31,9 +31,54 @@ class TestViewListConstruct(PytorchLayerTest):
self.input_data = input_data
self._test(*self.create_model(), ie_device, precision, ir_version)
@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.array(2))])
class TestViewDtype(PytorchLayerTest):
def _prepare_input(self):
return self.input_data
def create_model(self):
class aten_view_dtype(torch.nn.Module):
def forward(self, input_tensor, dtype):
return input_tensor.view(torch.int64)
ref_net = None
return aten_view_dtype(), ref_net, "aten::view"
@pytest.mark.nightly
@pytest.mark.precommit
def test_view_dtype(self, ie_device, precision, ir_version, input_data):
self.input_data = input_data
self._test(*self.create_model(), ie_device, precision, ir_version)
@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.random.randn(2, 2))])
class TestViewSize(PytorchLayerTest):
def _prepare_input(self):
return self.input_data
def create_model(self):
class aten_view_size(torch.nn.Module):
def forward(self, input_tensor, input_size):
return input_tensor.view(input_size.size()[:])
ref_net = None
return aten_view_size(), ref_net, "aten::view"
@pytest.mark.nightly
@pytest.mark.precommit
def test_view_size(self, ie_device, precision, ir_version, input_data):
self.input_data = input_data
self._test(*self.create_model(), ie_device, precision, ir_version)
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 2), 2, 6),
(np.random.randn(4), 2, 2)])
(np.random.randn(4), 2, 2),
(np.random.randn(4), 2, 2.1)])
class TestView(PytorchLayerTest):
def _prepare_input(self):
@ -48,7 +93,7 @@ class TestView(PytorchLayerTest):
self.dim2 = input_data[2]
def forward(self, input_tensor):
return input_tensor.view(self.dim1, self.dim2)
return input_tensor.view(self.dim1, int(self.dim2))
ref_net = None