[PT FE]: fix constant folding for dequantization (#18190)

* [PT FE]: fix constant folding for dequantization

* add test
This commit is contained in:
Ekaterina Aidova 2023-06-23 08:41:32 +04:00 committed by GitHub
parent d13adf7ae8
commit df0bd18ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 1 deletions

View File

@ -16,6 +16,7 @@
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/control_flow/unroll_if.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transforms.hpp"
#include "transforms/append_list_unpack_replacer.hpp"
@ -35,6 +36,7 @@
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "transforms/tuple_unpack_replacer.hpp"
#include "translate_session.hpp"
namespace ov {
@ -155,6 +157,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
element::TypeVector{element::u8, element::i8, element::u4, element::i4});
manager.register_pass<ov::pass::ConstantFolding>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::pass::UnrollIf>();
@ -172,6 +176,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();

View File

@ -0,0 +1,47 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "tuple_unpack_replacer.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
auto tuple_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto tuple_unpack = cast_fw_node(m.get_match_root(), "prim::TupleUnpack");
if (!tuple_unpack)
return false;
OutputVector outputs;
auto input_node = tuple_unpack->get_input_node_shared_ptr(0);
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
if (!tuple_construct) {
return false;
}
for (const auto& input : input_node->inputs()) {
const auto& out = input.get_source_output();
outputs.push_back(out);
}
replace_node(tuple_unpack, outputs);
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(tuple_unpack,
"ov::frontend::pytorch::pass::PrimTupleUnpackReplacer");
this->register_matcher(m, callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class PrimTupleUnpackReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::PrimTupleUnpackReplacer");
PrimTupleUnpackReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -54,4 +54,30 @@ class TestTupleConstruct(PytorchLayerTest):
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
@pytest.mark.nightly
def test_tuple_construct(self, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version)
self._test(*self.create_model(case), ie_device, precision, ir_version)
class TestTupleConstructTupleUnpack(PytorchLayerTest):
def _prepare_input(self):
return (np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32),)
def create_model(self):
import torch
class prim_tuple_construct_tuple_unpack(torch.nn.Module):
def forward(self, x):
x1, x2, x3, x4, x5 = self.prepare_input(x)
return x1, x2, x3, x4, x5
def prepare_input(self, x):
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
ref_net = None
return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]
@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)