[LPT] ReshapeTransformation: non-constant reshape pattern supported in case of per-tensor dequantization (#8095)
This commit is contained in:
parent
b754013aec
commit
64a0e3dbd0
@ -12,6 +12,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
|
||||
#include "low_precision/common/ie_lpt_exception.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
@ -23,13 +24,29 @@ namespace low_precision {
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::ReshapeTransformation, "ReshapeTransformation", 0);
|
||||
|
||||
ReshapeTransformation::ReshapeTransformation(const Params& params) : LayerTransformation(params) {
|
||||
auto matcher = pattern::wrap_type<opset1::Reshape>({ pattern::wrap_type<opset1::Multiply>(), pattern::wrap_type<opset1::Constant>() });
|
||||
auto input = pattern::any_input();
|
||||
auto mul_const_m = pattern::wrap_type<opset1::Constant>();
|
||||
auto mul_m = pattern::wrap_type<opset1::Multiply>({ input, mul_const_m });
|
||||
auto reshape_pattern_const = pattern::wrap_type<opset1::Constant>();
|
||||
auto reshape_pattern_nonconst = pattern::any_input();
|
||||
auto reshape_pattern = std::make_shared<pattern::op::Or>(OutputVector{ reshape_pattern_const, reshape_pattern_nonconst });
|
||||
auto matcher = pattern::wrap_type<opset1::Reshape>({ mul_m, reshape_pattern });
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
|
||||
auto op = m.get_match_root();
|
||||
if (transformation_callback(op)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// we can propagate only per-tensor dq through reshape with non-const reshape_pattern
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
if (pattern_map.count(reshape_pattern_nonconst)) {
|
||||
const auto mul_const = as_type_ptr<opset1::Constant>(pattern_map.at(mul_const_m).get_node_shared_ptr());
|
||||
if (!mul_const || ngraph::shape_size(mul_const->get_shape()) != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return transform(*context, m);
|
||||
};
|
||||
|
||||
|
@ -42,7 +42,7 @@ public:
|
||||
};
|
||||
|
||||
ngraph::PartialShape inputShape;
|
||||
std::vector<int> reshapeConstValues;
|
||||
std::vector<int> reshapeConstValues; // if empty then create shapeOf
|
||||
TestTransformationParams params;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
@ -962,6 +962,38 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
|
||||
{}
|
||||
}
|
||||
},
|
||||
// U8: non-const reshape pattern and per-tensor dequantization
|
||||
{
|
||||
{ -1, -1, -1, -1 },
|
||||
{},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
}
|
||||
},
|
||||
// U8: non-const reshape pattern and per-channel dequantization
|
||||
{
|
||||
{ -1, 3, -1, -1 },
|
||||
{},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}},
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -20,10 +20,14 @@ std::shared_ptr<ngraph::Function> ReshapeFunction::getOriginal(
|
||||
|
||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
|
||||
|
||||
const std::shared_ptr<Node> reshape = std::make_shared<ngraph::opset1::Reshape>(
|
||||
dequantizationOp,
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ reshapeConstValues.size() }, reshapeConstValues),
|
||||
true);
|
||||
std::shared_ptr<Node> reshape_pattern;
|
||||
if (!reshapeConstValues.empty()) {
|
||||
reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues);
|
||||
} else {
|
||||
reshape_pattern = std::make_shared<opset1::ShapeOf>(dequantizationOp);
|
||||
}
|
||||
|
||||
const auto reshape = std::make_shared<ngraph::opset1::Reshape>(dequantizationOp, reshape_pattern, true);
|
||||
reshape->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape) };
|
||||
@ -61,11 +65,15 @@ std::shared_ptr<ngraph::Function> ReshapeFunction::getReference(
|
||||
|
||||
const std::shared_ptr<Node> quantizationOpBefore = makeDequantization(input, dequantizationBefore);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::Constant> reshapeConstant = std::make_shared<ngraph::opset1::Constant>(
|
||||
ngraph::element::i64,
|
||||
ngraph::Shape{ reshapeConstValues.size() },
|
||||
reshapeConstValues);
|
||||
const std::shared_ptr<ngraph::opset1::Reshape> reshape = std::make_shared<ngraph::opset1::Reshape>(quantizationOpBefore, reshapeConstant, true);
|
||||
std::shared_ptr<Node> reshape_pattern;
|
||||
if (!reshapeConstValues.empty()) {
|
||||
reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues);
|
||||
} else {
|
||||
reshape_pattern = makeDequantization(quantizationOpBefore, dequantizationAfter);
|
||||
reshape_pattern = std::make_shared<opset1::ShapeOf>(reshape_pattern);
|
||||
}
|
||||
|
||||
const auto reshape = std::make_shared<opset1::Reshape>(quantizationOpBefore, reshape_pattern, true);
|
||||
if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) {
|
||||
THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user