[LPT] MatMulTransformation: support Quantize/Dequantize on weights (#4206)

* [LPT] NetworkHelper::isConstantPath functional tests

* [LPT] matMul 3D: support Q/DQ on weights
This commit is contained in:
Vladislav Golubev
2021-02-12 13:02:07 +03:00
committed by GitHub
parent 3f6dbb8a00
commit 06145e20fc
2 changed files with 179 additions and 38 deletions

View File

@@ -17,14 +17,16 @@ using namespace ngraph::pass;
using namespace ngraph::pass::low_precision;
bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(m.get_match_root());
if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
return false;
}
matMul = as_type_ptr<ngraph::opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));
matMul = as_type_ptr<opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));
const auto dequantization1 = NetworkHelper::getDequantization(matMul, 0);
auto dequantization2 = NetworkHelper::getDequantization(matMul, 1);
FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
if (dequantization2.empty()) {
const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
@@ -40,21 +42,19 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
dataPrecision.hasZeroPoint,
updatePrecisions);
dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
}
}
const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);
if (dequantization2.subtract != nullptr) {
NetworkHelper::optimizeSubtract(dequantization2.subtract);
dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
}
const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<op::TypeRelaxed<opset1::MatMul>>(
std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
ngraph::op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
matMul->get_transpose_a(),
matMul->get_transpose_b());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
@@ -64,15 +64,15 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
// dequantization with subtract on activations & constant weights
if (dequantization1.subtract) {
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtract->get_input_node_shared_ptr(1))) ?
ngraph::Shape(dequantization1.subtract->get_shape().size(), 1) :
dequantization1.subtract->get_input_node_shared_ptr(1)->get_shape();
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtractConstant)) ?
Shape(dequantization1.subtract->get_shape().size(), 1) :
dequantization1.subtractConstant->get_shape();
const size_t lastIdx = matMul->get_transpose_a() ? broadcastShape.size() - 2 : broadcastShape.size() - 1;
broadcastShape[lastIdx] = dequantization1.subtract->get_shape()[lastIdx];
// broadcasted sub const to form [1, ..., 1, Y]
const auto broadcastedConst = fold<opset1::Broadcast>(
dequantization1.subtract->get_input_node_shared_ptr(1),
dequantization1.subtractConstant,
opset1::Constant::create(ngraph::element::i32, { broadcastShape.size() }, broadcastShape));
// multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z]
@@ -84,7 +84,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
const auto newSubtract = std::make_shared<DequantizationSubtract>(newMatMul, newSubConst);
newSubtract->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationSubtract");
ngraph::copy_runtime_info({ newSubtract, matMul }, newSubtract);
copy_runtime_info({ newSubtract, matMul }, newSubtract);
parent = newSubtract;
}
@@ -100,17 +100,12 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));
auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
std::shared_ptr<Node> transposedConstant = fold<opset1::Transpose>(node, order);
return transposedConstant;
};
const auto mulConst1 = matMul->get_transpose_a() ?
transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
dequantization1.multiply->get_input_node_shared_ptr(1);
auto mulConst2 = matMul->get_transpose_b() ?
transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
dequantization2.multiply->get_input_node_shared_ptr(1);
const auto mulConst1 = matMul->get_transpose_a() ? transpose(dequantization1.multiplyConstant) : dequantization1.multiplyConstant;
auto mulConst2 = matMul->get_transpose_b() ? transpose(dequantization2.multiplyConstant) : dequantization2.multiplyConstant;
if (NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(mulConst2))) {
mulConst2 = NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(mulConst2));
@@ -125,16 +120,16 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
mulConst2 = fold<opset1::Unsqueeze>(
mulConst2,
op::Constant::create(ngraph::element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
}
}
const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<ngraph::opset1::Multiply>(mulConst1, mulConst2));
const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<opset1::Multiply>(mulConst1, mulConst2));
const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(parent, newMulConst);
newMultiply->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationMultiply");
replace_node(matMul, newMultiply);
ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);
copy_runtime_info({ newMultiply, matMul }, newMultiply);
updateOutput(context, newMultiply, matMul);
@@ -145,12 +140,12 @@ void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationC
addPattern(
pass,
context,
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Multiply>() }));
addPattern(
pass,
context,
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>() }));
}
bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
@@ -167,15 +162,14 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
return false;
}
const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
const auto dequantization1 = NetworkHelper::getDequantization(layer, 0);
if (!dequantization1.empty()) {
if (updatePrecisions && !dequantization1.isLowPrecision()) {
return false;
}
const auto mulConst = as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1));
if (!NetworkHelper::isScalarLike(mulConst)) {
const auto constantShape = mulConst->get_shape();
if (!NetworkHelper::isScalarLike(dequantization1.multiplyConstant)) {
const auto constantShape = dequantization1.multiplyConstant->get_shape();
const auto mulShape = dequantization1.multiply->get_shape();
const size_t columnsIdx = matMul->get_transpose_a() ? mulShape.size() - 2ul : mulShape.size() - 1ul;
@@ -186,15 +180,21 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
}
}
const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
const auto dequantization2 = NetworkHelper::getDequantization(layer, 1);
if (!dequantization2.empty()) {
if ((updatePrecisions && !dequantization2.isLowPrecision()) || (dequantization2.subtract)) {
if ((updatePrecisions && !dequantization2.isLowPrecision())) {
return false;
}
const auto mulConst = as_type_ptr<opset1::Constant>(dequantization2.multiply->get_input_node_shared_ptr(1));
if (!NetworkHelper::isScalarLike(mulConst)) {
const auto constantShape = mulConst->get_shape();
if (dequantization2.subtract) {
const auto roundedConst = NetworkHelper::round(dequantization2.subtractConstant, dequantization2.data.get_element_type());
if (!NetworkHelper::isZeroConst(roundedConst)) {
return false;
}
}
if (!NetworkHelper::isScalarLike(dequantization2.multiplyConstant)) {
const auto constantShape = dequantization2.multiplyConstant->get_shape();
const auto mulShape = dequantization2.multiply->get_shape();
const size_t rowsIdx = matMul->get_transpose_b() ? mulShape.size() - 1ul : mulShape.size() - 2ul;
@@ -229,7 +229,7 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
}
}
if (fakeQuantize == nullptr && dequantization1.subtract) {
if ((!NetworkHelper::isConstantPath(layer->get_input_node_shared_ptr(1))) && (dequantization1.subtract)) {
return false;
}

View File

@@ -0,0 +1,141 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <gtest/gtest.h>
#include "ngraph_functions/subgraph_builders.hpp"
#include "low_precision/network_helper.hpp"
#include "lpt_ngraph_functions/common/builders.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp"
using namespace testing;
using namespace ngraph::pass;
using namespace ngraph::builder::subgraph;
TEST(LPT, isConstantPathFQAfterInputTransformation) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto fqOnActivations = makeFakeQuantize(input, ngraph::element::f32,
FakeQuantizeOnData{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} });
const bool result = low_precision::NetworkHelper::isConstantPath(fqOnActivations);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathFQAfterWeightsTransformation) {
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
const auto fqOnWeights = makeFakeQuantize(weights, ngraph::element::f32,
FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} });
const bool result = low_precision::NetworkHelper::isConstantPath(fqOnWeights);
ASSERT_EQ(true, result);
}
TEST(LPT, isConstantPathDqAfterInputTransformation) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto dqOnActivations = makeDequantization(input, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const bool result = low_precision::NetworkHelper::isConstantPath(dqOnActivations);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathDqAfterWeightsTransformation) {
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
const auto dqOnWeights = makeDequantization(weights, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const bool result = low_precision::NetworkHelper::isConstantPath(dqOnWeights);
ASSERT_EQ(true, result);
}
TEST(LPT, isConstantPathTwoInputsTransformation) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
const bool result = low_precision::NetworkHelper::isConstantPath(matmul);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathTwoConsantsTransformation) {
const auto constant1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
const auto constant2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
const auto dq1 = makeDequantization(constant1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto dq2 = makeDequantization(constant2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto eltwise = std::make_shared<ngraph::opset1::Add>(dq1, dq2);
const bool result = low_precision::NetworkHelper::isConstantPath(eltwise);
ASSERT_EQ(true, result);
}
TEST(LPT, isConstantPathMatMulParentFQTransformation) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
const auto fqAfterMatMul = makeFakeQuantize(matmul, ngraph::element::f32,
FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} });
const bool result = low_precision::NetworkHelper::isConstantPath(fqAfterMatMul);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathMatMulParentDqTransformation) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
const auto dqAfterMatMul = makeDequantization(matmul, DequantizationOperations{ {}, {}, {0.1f} });
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterMatMul);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathConvParentDqTransformation) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 72, 16 });
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 6, 3, 1, 1 }, { 1.f });
const auto conv = std::make_shared<ngraph::opset1::Convolution>(
input,
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
const auto dqAfterConv = makeDequantization(conv, DequantizationOperations{ {}, {}, {0.1f} });
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterConv);
ASSERT_EQ(false, result);
}
TEST(LPT, isConstantPathGroupConvParentDqTransformation) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 6, 3, 1, 1 }, { 1.f });
const auto groupConv = std::make_shared<ngraph::opset1::GroupConvolution>(
input,
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
const auto dqAfterGroupConv = makeDequantization(groupConv, DequantizationOperations{ {}, {}, {0.1f} });
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterGroupConv);
ASSERT_EQ(false, result);
}