[LPT] ReshapeTransformation: non-constant reshape pattern supported in case of per-tensor dequantization (#8095)

This commit is contained in:
Vladislav Golubev 2021-10-25 22:27:16 +03:00 committed by GitHub
parent b754013aec
commit 64a0e3dbd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 12 deletions

View File

@ -12,6 +12,7 @@
#include <vector>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/network_helper.hpp"
@ -23,13 +24,29 @@ namespace low_precision {
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::ReshapeTransformation, "ReshapeTransformation", 0);
ReshapeTransformation::ReshapeTransformation(const Params& params) : LayerTransformation(params) {
auto matcher = pattern::wrap_type<opset1::Reshape>({ pattern::wrap_type<opset1::Multiply>(), pattern::wrap_type<opset1::Constant>() });
auto input = pattern::any_input();
auto mul_const_m = pattern::wrap_type<opset1::Constant>();
auto mul_m = pattern::wrap_type<opset1::Multiply>({ input, mul_const_m });
auto reshape_pattern_const = pattern::wrap_type<opset1::Constant>();
auto reshape_pattern_nonconst = pattern::any_input();
auto reshape_pattern = std::make_shared<pattern::op::Or>(OutputVector{ reshape_pattern_const, reshape_pattern_nonconst });
auto matcher = pattern::wrap_type<opset1::Reshape>({ mul_m, reshape_pattern });
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
// we can propagate only per-tensor dq through reshape with non-const reshape_pattern
const auto& pattern_map = m.get_pattern_value_map();
if (pattern_map.count(reshape_pattern_nonconst)) {
const auto mul_const = as_type_ptr<opset1::Constant>(pattern_map.at(mul_const_m).get_node_shared_ptr());
if (!mul_const || ngraph::shape_size(mul_const->get_shape()) != 1) {
return false;
}
}
return transform(*context, m);
};

View File

@ -42,7 +42,7 @@ public:
};
ngraph::PartialShape inputShape;
std::vector<int> reshapeConstValues;
std::vector<int> reshapeConstValues; // if empty then create shapeOf
TestTransformationParams params;
Actual actual;
Expected expected;
@ -962,6 +962,38 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
{}
}
},
// U8: non-const reshape pattern and per-tensor dequantization
{
{ -1, -1, -1, -1 },
{},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
}
},
// U8: non-const reshape pattern and per-channel dequantization
{
{ -1, 3, -1, -1 },
{},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f, 124.f, 120.f}}, {{0.1f, 1.f, 10.f}}},
ngraph::element::f32,
{}
}
},
};
INSTANTIATE_TEST_SUITE_P(

View File

@ -20,10 +20,14 @@ std::shared_ptr<ngraph::Function> ReshapeFunction::getOriginal(
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
const std::shared_ptr<Node> reshape = std::make_shared<ngraph::opset1::Reshape>(
dequantizationOp,
std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ reshapeConstValues.size() }, reshapeConstValues),
true);
std::shared_ptr<Node> reshape_pattern;
if (!reshapeConstValues.empty()) {
reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues);
} else {
reshape_pattern = std::make_shared<opset1::ShapeOf>(dequantizationOp);
}
const auto reshape = std::make_shared<ngraph::opset1::Reshape>(dequantizationOp, reshape_pattern, true);
reshape->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(reshape) };
@ -61,11 +65,15 @@ std::shared_ptr<ngraph::Function> ReshapeFunction::getReference(
const std::shared_ptr<Node> quantizationOpBefore = makeDequantization(input, dequantizationBefore);
const std::shared_ptr<ngraph::opset1::Constant> reshapeConstant = std::make_shared<ngraph::opset1::Constant>(
ngraph::element::i64,
ngraph::Shape{ reshapeConstValues.size() },
reshapeConstValues);
const std::shared_ptr<ngraph::opset1::Reshape> reshape = std::make_shared<ngraph::opset1::Reshape>(quantizationOpBefore, reshapeConstant, true);
std::shared_ptr<Node> reshape_pattern;
if (!reshapeConstValues.empty()) {
reshape_pattern = opset1::Constant::create(element::i64, Shape{ reshapeConstValues.size() }, reshapeConstValues);
} else {
reshape_pattern = makeDequantization(quantizationOpBefore, dequantizationAfter);
reshape_pattern = std::make_shared<opset1::ShapeOf>(reshape_pattern);
}
const auto reshape = std::make_shared<opset1::Reshape>(quantizationOpBefore, reshape_pattern, true);
if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) {
THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation";
}