[LPT] MatMulTransformation: support Quantize/Dequantize on weights (#4206)
* [LPT] NetworkHelper::isConstantPath functional tests * [LPT] matMul 3D: support Q/DQ on weights
This commit is contained in:
committed by
GitHub
parent
3f6dbb8a00
commit
06145e20fc
@@ -17,14 +17,16 @@ using namespace ngraph::pass;
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
|
||||
std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
|
||||
std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(m.get_match_root());
|
||||
if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
matMul = as_type_ptr<ngraph::opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));
|
||||
matMul = as_type_ptr<opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));
|
||||
|
||||
const auto dequantization1 = NetworkHelper::getDequantization(matMul, 0);
|
||||
auto dequantization2 = NetworkHelper::getDequantization(matMul, 1);
|
||||
|
||||
FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
|
||||
if (dequantization2.empty()) {
|
||||
const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
|
||||
as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
|
||||
@@ -40,21 +42,19 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
|
||||
dataPrecision.hasZeroPoint,
|
||||
updatePrecisions);
|
||||
|
||||
dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
|
||||
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);
|
||||
|
||||
if (dequantization2.subtract != nullptr) {
|
||||
NetworkHelper::optimizeSubtract(dequantization2.subtract);
|
||||
dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
|
||||
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
|
||||
}
|
||||
|
||||
const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
|
||||
const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<op::TypeRelaxed<opset1::MatMul>>(
|
||||
std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
|
||||
ngraph::op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
|
||||
op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
|
||||
op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
|
||||
matMul->get_transpose_a(),
|
||||
matMul->get_transpose_b());
|
||||
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
|
||||
@@ -64,15 +64,15 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
|
||||
|
||||
// dequantization with subtract on activations & constant weights
|
||||
if (dequantization1.subtract) {
|
||||
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtract->get_input_node_shared_ptr(1))) ?
|
||||
ngraph::Shape(dequantization1.subtract->get_shape().size(), 1) :
|
||||
dequantization1.subtract->get_input_node_shared_ptr(1)->get_shape();
|
||||
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtractConstant)) ?
|
||||
Shape(dequantization1.subtract->get_shape().size(), 1) :
|
||||
dequantization1.subtractConstant->get_shape();
|
||||
const size_t lastIdx = matMul->get_transpose_a() ? broadcastShape.size() - 2 : broadcastShape.size() - 1;
|
||||
broadcastShape[lastIdx] = dequantization1.subtract->get_shape()[lastIdx];
|
||||
|
||||
// broadcasted sub const to form [1, ..., 1, Y]
|
||||
const auto broadcastedConst = fold<opset1::Broadcast>(
|
||||
dequantization1.subtract->get_input_node_shared_ptr(1),
|
||||
dequantization1.subtractConstant,
|
||||
opset1::Constant::create(ngraph::element::i32, { broadcastShape.size() }, broadcastShape));
|
||||
|
||||
// multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z]
|
||||
@@ -84,7 +84,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
|
||||
|
||||
const auto newSubtract = std::make_shared<DequantizationSubtract>(newMatMul, newSubConst);
|
||||
newSubtract->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationSubtract");
|
||||
ngraph::copy_runtime_info({ newSubtract, matMul }, newSubtract);
|
||||
copy_runtime_info({ newSubtract, matMul }, newSubtract);
|
||||
|
||||
parent = newSubtract;
|
||||
}
|
||||
@@ -100,17 +100,12 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
|
||||
std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));
|
||||
|
||||
auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
|
||||
std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
|
||||
std::shared_ptr<Node> transposedConstant = fold<opset1::Transpose>(node, order);
|
||||
return transposedConstant;
|
||||
};
|
||||
|
||||
const auto mulConst1 = matMul->get_transpose_a() ?
|
||||
transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
|
||||
dequantization1.multiply->get_input_node_shared_ptr(1);
|
||||
|
||||
auto mulConst2 = matMul->get_transpose_b() ?
|
||||
transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
|
||||
dequantization2.multiply->get_input_node_shared_ptr(1);
|
||||
const auto mulConst1 = matMul->get_transpose_a() ? transpose(dequantization1.multiplyConstant) : dequantization1.multiplyConstant;
|
||||
auto mulConst2 = matMul->get_transpose_b() ? transpose(dequantization2.multiplyConstant) : dequantization2.multiplyConstant;
|
||||
|
||||
if (NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(mulConst2))) {
|
||||
mulConst2 = NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(mulConst2));
|
||||
@@ -125,16 +120,16 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
|
||||
|
||||
mulConst2 = fold<opset1::Unsqueeze>(
|
||||
mulConst2,
|
||||
op::Constant::create(ngraph::element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
|
||||
op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
|
||||
}
|
||||
}
|
||||
|
||||
const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<ngraph::opset1::Multiply>(mulConst1, mulConst2));
|
||||
const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<opset1::Multiply>(mulConst1, mulConst2));
|
||||
const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(parent, newMulConst);
|
||||
newMultiply->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationMultiply");
|
||||
|
||||
replace_node(matMul, newMultiply);
|
||||
ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);
|
||||
copy_runtime_info({ newMultiply, matMul }, newMultiply);
|
||||
|
||||
updateOutput(context, newMultiply, matMul);
|
||||
|
||||
@@ -145,12 +140,12 @@ void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationC
|
||||
addPattern(
|
||||
pass,
|
||||
context,
|
||||
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
|
||||
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Multiply>() }));
|
||||
|
||||
addPattern(
|
||||
pass,
|
||||
context,
|
||||
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
|
||||
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>() }));
|
||||
}
|
||||
|
||||
bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
|
||||
@@ -167,15 +162,14 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
|
||||
const auto dequantization1 = NetworkHelper::getDequantization(layer, 0);
|
||||
if (!dequantization1.empty()) {
|
||||
if (updatePrecisions && !dequantization1.isLowPrecision()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto mulConst = as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1));
|
||||
if (!NetworkHelper::isScalarLike(mulConst)) {
|
||||
const auto constantShape = mulConst->get_shape();
|
||||
if (!NetworkHelper::isScalarLike(dequantization1.multiplyConstant)) {
|
||||
const auto constantShape = dequantization1.multiplyConstant->get_shape();
|
||||
const auto mulShape = dequantization1.multiply->get_shape();
|
||||
const size_t columnsIdx = matMul->get_transpose_a() ? mulShape.size() - 2ul : mulShape.size() - 1ul;
|
||||
|
||||
@@ -186,15 +180,21 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
|
||||
}
|
||||
}
|
||||
|
||||
const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
|
||||
const auto dequantization2 = NetworkHelper::getDequantization(layer, 1);
|
||||
if (!dequantization2.empty()) {
|
||||
if ((updatePrecisions && !dequantization2.isLowPrecision()) || (dequantization2.subtract)) {
|
||||
if ((updatePrecisions && !dequantization2.isLowPrecision())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto mulConst = as_type_ptr<opset1::Constant>(dequantization2.multiply->get_input_node_shared_ptr(1));
|
||||
if (!NetworkHelper::isScalarLike(mulConst)) {
|
||||
const auto constantShape = mulConst->get_shape();
|
||||
if (dequantization2.subtract) {
|
||||
const auto roundedConst = NetworkHelper::round(dequantization2.subtractConstant, dequantization2.data.get_element_type());
|
||||
if (!NetworkHelper::isZeroConst(roundedConst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!NetworkHelper::isScalarLike(dequantization2.multiplyConstant)) {
|
||||
const auto constantShape = dequantization2.multiplyConstant->get_shape();
|
||||
const auto mulShape = dequantization2.multiply->get_shape();
|
||||
const size_t rowsIdx = matMul->get_transpose_b() ? mulShape.size() - 1ul : mulShape.size() - 2ul;
|
||||
|
||||
@@ -229,7 +229,7 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
|
||||
}
|
||||
}
|
||||
|
||||
if (fakeQuantize == nullptr && dequantization1.subtract) {
|
||||
if ((!NetworkHelper::isConstantPath(layer->get_input_node_shared_ptr(1))) && (dequantization1.subtract)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
TEST(LPT, isConstantPathFQAfterInputTransformation) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto fqOnActivations = makeFakeQuantize(input, ngraph::element::f32,
|
||||
FakeQuantizeOnData{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(fqOnActivations);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathFQAfterWeightsTransformation) {
|
||||
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
|
||||
const auto fqOnWeights = makeFakeQuantize(weights, ngraph::element::f32,
|
||||
FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(fqOnWeights);
|
||||
|
||||
ASSERT_EQ(true, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathDqAfterInputTransformation) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto dqOnActivations = makeDequantization(input, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(dqOnActivations);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathDqAfterWeightsTransformation) {
|
||||
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
|
||||
const auto dqOnWeights = makeDequantization(weights, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(dqOnWeights);
|
||||
|
||||
ASSERT_EQ(true, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathTwoInputsTransformation) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(matmul);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathTwoConsantsTransformation) {
|
||||
const auto constant1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
|
||||
const auto constant2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f });
|
||||
const auto dq1 = makeDequantization(constant1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto dq2 = makeDequantization(constant2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto eltwise = std::make_shared<ngraph::opset1::Add>(dq1, dq2);
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(eltwise);
|
||||
|
||||
ASSERT_EQ(true, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathMatMulParentFQTransformation) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
|
||||
const auto fqAfterMatMul = makeFakeQuantize(matmul, ngraph::element::f32,
|
||||
FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(fqAfterMatMul);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathMatMulParentDqTransformation) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} });
|
||||
const auto matmul = std::make_shared<ngraph::opset1::MatMul>(dq1, dq2);
|
||||
const auto dqAfterMatMul = makeDequantization(matmul, DequantizationOperations{ {}, {}, {0.1f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterMatMul);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathConvParentDqTransformation) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 72, 16 });
|
||||
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 6, 3, 1, 1 }, { 1.f });
|
||||
const auto conv = std::make_shared<ngraph::opset1::Convolution>(
|
||||
input,
|
||||
weights,
|
||||
ngraph::Strides{ 1, 1 },
|
||||
ngraph::CoordinateDiff{ 0, 0 },
|
||||
ngraph::CoordinateDiff{ 0, 0 },
|
||||
ngraph::Strides{ 1, 1 });
|
||||
const auto dqAfterConv = makeDequantization(conv, DequantizationOperations{ {}, {}, {0.1f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterConv);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
|
||||
TEST(LPT, isConstantPathGroupConvParentDqTransformation) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 });
|
||||
const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 6, 3, 1, 1 }, { 1.f });
|
||||
const auto groupConv = std::make_shared<ngraph::opset1::GroupConvolution>(
|
||||
input,
|
||||
weights,
|
||||
ngraph::Strides{ 1, 1 },
|
||||
ngraph::CoordinateDiff{ 0, 0 },
|
||||
ngraph::CoordinateDiff{ 0, 0 },
|
||||
ngraph::Strides{ 1, 1 });
|
||||
const auto dqAfterGroupConv = makeDequantization(groupConv, DequantizationOperations{ {}, {}, {0.1f} });
|
||||
|
||||
const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterGroupConv);
|
||||
|
||||
ASSERT_EQ(false, result);
|
||||
}
|
||||
Reference in New Issue
Block a user