[PT FE]: fix constant folding for dequantization (#18190)
* [PT FE]: fix constant folding for dequantization * add test
This commit is contained in:
parent
d13adf7ae8
commit
df0bd18ed2
@ -16,6 +16,7 @@
|
||||
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
|
||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||
#include "transformations/control_flow/unroll_if.hpp"
|
||||
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
|
||||
#include "transformations/op_conversions/convert_convertlike.hpp"
|
||||
#include "transforms.hpp"
|
||||
#include "transforms/append_list_unpack_replacer.hpp"
|
||||
@ -35,6 +36,7 @@
|
||||
#include "transforms/prim_list_unpack_replacer.hpp"
|
||||
#include "transforms/rfftn_complex_replacer.hpp"
|
||||
#include "transforms/string_equality_replacer.hpp"
|
||||
#include "transforms/tuple_unpack_replacer.hpp"
|
||||
#include "translate_session.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -155,6 +157,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::pass::ConvertConvertLike>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
|
||||
|
||||
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
|
||||
element::TypeVector{element::u8, element::i8, element::u4, element::i4});
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
manager.register_pass<ov::pass::PushConstantToSubgraph>();
|
||||
manager.register_pass<ov::pass::UnrollIf>();
|
||||
@ -172,6 +176,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "tuple_unpack_replacer.hpp"
|
||||
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
|
||||
auto tuple_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||
|
||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||
auto tuple_unpack = cast_fw_node(m.get_match_root(), "prim::TupleUnpack");
|
||||
if (!tuple_unpack)
|
||||
return false;
|
||||
OutputVector outputs;
|
||||
auto input_node = tuple_unpack->get_input_node_shared_ptr(0);
|
||||
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
|
||||
if (!tuple_construct) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& input : input_node->inputs()) {
|
||||
const auto& out = input.get_source_output();
|
||||
outputs.push_back(out);
|
||||
}
|
||||
replace_node(tuple_unpack, outputs);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(tuple_unpack,
|
||||
"ov::frontend::pytorch::pass::PrimTupleUnpackReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class PrimTupleUnpackReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::PrimTupleUnpackReplacer");
|
||||
PrimTupleUnpackReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -54,4 +54,30 @@ class TestTupleConstruct(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct(self, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTupleConstructTupleUnpack(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class prim_tuple_construct_tuple_unpack(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x1, x2, x3, x4, x5 = self.prepare_input(x)
|
||||
return x1, x2, x3, x4, x5
|
||||
|
||||
def prepare_input(self, x):
|
||||
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
|
||||
|
||||
|
||||
ref_net = None
|
||||
|
||||
return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
|
Loading…
Reference in New Issue
Block a user