[PT FE]: Add aten::view transformations (#15339)
This commit is contained in:
parent
cab559b478
commit
a769cfe7e8
@ -14,6 +14,7 @@
|
||||
#include "transforms/append_list_unpack_replacer.hpp"
|
||||
#include "transforms/aten_cat_replacer.hpp"
|
||||
#include "transforms/aten_getitem_replacer.hpp"
|
||||
#include "transforms/listconstruct_reshape_replacer.hpp"
|
||||
#include "transforms/max_prim_list_construct_replacer.hpp"
|
||||
#include "transforms/prim_list_construct_pad.hpp"
|
||||
#include "transforms/prim_list_unpack_replacer.hpp"
|
||||
@ -88,6 +89,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::MaxPrimListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReshapeReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeTupleResults>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
|
@ -12,10 +12,10 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_int(NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), element::i64))};
|
||||
return {context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), element::i32))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -13,25 +13,12 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_reshape(NodeContext& context) {
|
||||
auto shape_node = context.get_input(1).get_node();
|
||||
auto shape_node_fw_node = dynamic_cast<PtFrameworkNode*>(shape_node);
|
||||
std::shared_ptr<ov::Node> reshape;
|
||||
// TODO: move this to transform stage
|
||||
if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") {
|
||||
OutputVector inputs;
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
for (auto& input : shape_node->inputs()) {
|
||||
auto rank = input.get_partial_shape().rank();
|
||||
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
|
||||
auto unsqueeze = context.mark_node(std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0));
|
||||
inputs.push_back(unsqueeze);
|
||||
}
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(inputs, 0));
|
||||
reshape = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), concat, false));
|
||||
} else {
|
||||
reshape =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
|
||||
}
|
||||
// Translation is used by both aten::view and aten::reshape.
|
||||
// Schema: aten::view(Tensor input, int[] shape) -> Tensor
|
||||
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
|
||||
// For shape parameter, int[] is converted into single dimensional Tensor.
|
||||
auto reshape =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
|
||||
return {reshape};
|
||||
};
|
||||
|
||||
|
@ -1,45 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_view(NodeContext& context) {
|
||||
auto shape_node = context.get_input(1).get_node();
|
||||
auto shape_node_fw_node = dynamic_cast<PtFrameworkNode*>(shape_node);
|
||||
std::shared_ptr<ov::Node> reshape;
|
||||
// TODO: move this to transform stage
|
||||
if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") {
|
||||
OutputVector inputs;
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
for (auto& input : shape_node->inputs()) {
|
||||
auto rank = input.get_partial_shape().rank();
|
||||
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
|
||||
auto unsqueeze = context.mark_node(std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0));
|
||||
inputs.push_back(unsqueeze);
|
||||
}
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(inputs, 0));
|
||||
reshape = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), concat, false));
|
||||
// TODO: fix rt_info
|
||||
// auto list_set = shape_node_fw_node->get_rt_info()["pt_node"].as<std::set<const Node*>>();
|
||||
// reshape->get_rt_info()["pt_node"].as<std::set<const Node*>>().insert(list_set.begin(),
|
||||
// list_set.end());
|
||||
} else {
|
||||
reshape =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
|
||||
}
|
||||
return {reshape};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -108,7 +108,6 @@ OP_CONVERTER(translate_upsample_bilinear2d);
|
||||
OP_CONVERTER(translate_upsample_nearest2d);
|
||||
OP_CONVERTER(translate_var);
|
||||
OP_CONVERTER(translate_var_mean);
|
||||
OP_CONVERTER(translate_view);
|
||||
OP_CONVERTER(translate_where);
|
||||
OP_CONVERTER(translate_zeros);
|
||||
OP_CONVERTER(translate_zeros_like);
|
||||
@ -296,7 +295,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::upsample_nearest2d", op::translate_upsample_nearest2d},
|
||||
{"aten::var", op::translate_var},
|
||||
{"aten::var_mean", op::translate_var_mean},
|
||||
{"aten::view", op::translate_view},
|
||||
{"aten::view", op::translate_reshape},
|
||||
{"aten::where", op::translate_where},
|
||||
{"aten::zeros", op::translate_zeros},
|
||||
{"aten::zeros_like", op::translate_zeros_like},
|
||||
|
@ -0,0 +1,55 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "listconstruct_reshape_replacer.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
ListConstructReshapeReplacer::ListConstructReshapeReplacer() {
|
||||
// Transformation for torch operators aten::view and aten::reshape for cases where second input is
|
||||
// prim::ListConstruct.
|
||||
auto reshape_op = ov::pass::pattern::wrap_type<opset10::Reshape>();
|
||||
// Both aten::view and aten::reshape are using same translation returning opset10::Reshape operator.
|
||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||
auto reshape_op = std::dynamic_pointer_cast<opset10::Reshape>(m.get_match_root());
|
||||
if (!reshape_op) {
|
||||
return false;
|
||||
}
|
||||
auto shape_node = reshape_op->input_value(1).get_node_shared_ptr();
|
||||
if (auto list_unpack_node = cast_fw_node(shape_node, "prim::ListConstruct")) {
|
||||
// Check if second input to operator is prim::ListConstruct, and if so, concatenate it inputs into single
|
||||
// Tensor. Concatenation is possible because all elements in list should be scalar intigers.
|
||||
OutputVector inputs;
|
||||
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
for (auto& input : shape_node->inputs()) {
|
||||
auto rank = input.get_partial_shape().rank();
|
||||
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
|
||||
auto unsqueeze = std::make_shared<opset10::Unsqueeze>(input.get_source_output(), axis_0);
|
||||
inputs.push_back(unsqueeze);
|
||||
}
|
||||
auto concat = std::make_shared<opset10::Concat>(inputs, 0);
|
||||
copy_runtime_info({shape_node}, concat);
|
||||
replace_node(shape_node, concat);
|
||||
return true;
|
||||
};
|
||||
return false;
|
||||
};
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(reshape_op,
|
||||
"ov::frontend::pytorch::pass::ListConstructReshapeReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class ListConstructReshapeReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::ListConstructReshapeReplacer");
|
||||
ListConstructReshapeReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -31,9 +31,54 @@ class TestViewListConstruct(PytorchLayerTest):
|
||||
self.input_data = input_data
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.array(2))])
|
||||
class TestViewDtype(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model(self):
|
||||
class aten_view_dtype(torch.nn.Module):
|
||||
|
||||
def forward(self, input_tensor, dtype):
|
||||
return input_tensor.view(torch.int64)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_view_dtype(), ref_net, "aten::view"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_view_dtype(self, ie_device, precision, ir_version, input_data):
|
||||
self.input_data = input_data
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('input_data', [(np.random.randn(4), np.random.randn(2, 2))])
|
||||
class TestViewSize(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model(self):
|
||||
class aten_view_size(torch.nn.Module):
|
||||
|
||||
def forward(self, input_tensor, input_size):
|
||||
return input_tensor.view(input_size.size()[:])
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_view_size(), ref_net, "aten::view"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_view_size(self, ie_device, precision, ir_version, input_data):
|
||||
self.input_data = input_data
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 2), 2, 6),
|
||||
(np.random.randn(4), 2, 2)])
|
||||
(np.random.randn(4), 2, 2),
|
||||
(np.random.randn(4), 2, 2.1)])
|
||||
class TestView(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
@ -48,7 +93,7 @@ class TestView(PytorchLayerTest):
|
||||
self.dim2 = input_data[2]
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return input_tensor.view(self.dim1, self.dim2)
|
||||
return input_tensor.view(self.dim1, int(self.dim2))
|
||||
|
||||
ref_net = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user