From 06145e20fc2e306018bc4fb3839810dd20b179c7 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 12 Feb 2021 13:02:07 +0300 Subject: [PATCH] [LPT] MatMulTransformation: support Quantize/Dequantize on weights (#4206) * [LPT] NetworkHelper::isConstantPath functional tests * [LPT] matMul 3D: support Q/DQ on weights --- .../src/mat_mul.cpp | 76 +++++----- .../is_constant_path_transformation.cpp | 141 ++++++++++++++++++ 2 files changed, 179 insertions(+), 38 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/lp_transformations/is_constant_path_transformation.cpp diff --git a/inference-engine/src/low_precision_transformations/src/mat_mul.cpp b/inference-engine/src/low_precision_transformations/src/mat_mul.cpp index 212a8e8d11a..9526ee5831d 100644 --- a/inference-engine/src/low_precision_transformations/src/mat_mul.cpp +++ b/inference-engine/src/low_precision_transformations/src/mat_mul.cpp @@ -17,14 +17,16 @@ using namespace ngraph::pass; using namespace ngraph::pass::low_precision; bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const { - std::shared_ptr matMul = as_type_ptr(m.get_match_root()); + std::shared_ptr matMul = as_type_ptr(m.get_match_root()); if ((matMul == nullptr) || !canBeTransformed(context, matMul)) { return false; } - matMul = as_type_ptr(NetworkHelper::separateInStandaloneBranch(matMul)); + matMul = as_type_ptr(NetworkHelper::separateInStandaloneBranch(matMul)); + + const auto dequantization1 = NetworkHelper::getDequantization(matMul, 0); + auto dequantization2 = NetworkHelper::getDequantization(matMul, 1); - FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); if (dequantization2.empty()) { const std::shared_ptr fakeQuantize = as_type_ptr(dequantization2.data.get_node_shared_ptr()); @@ -40,21 +42,19 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat dataPrecision.hasZeroPoint, updatePrecisions); - dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); + dequantization2 = NetworkHelper::getDequantization(matMul, 1); } } - const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0); - if (dequantization2.subtract != nullptr) { NetworkHelper::optimizeSubtract(dequantization2.subtract); - dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); + dequantization2 = NetworkHelper::getDequantization(matMul, 1); } - const std::shared_ptr newMatMul = std::make_shared>( + const std::shared_ptr newMatMul = std::make_shared>( std::vector({ element::f32, element::f32 }), std::vector({}), - ngraph::op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(), - ngraph::op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(), + op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(), + op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(), matMul->get_transpose_a(), matMul->get_transpose_b()); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0)); @@ -64,15 +64,15 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat // dequantization with subtract on activations & constant weights if (dequantization1.subtract) { - auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr(dequantization1.subtract->get_input_node_shared_ptr(1))) ? - ngraph::Shape(dequantization1.subtract->get_shape().size(), 1) : - dequantization1.subtract->get_input_node_shared_ptr(1)->get_shape(); + auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr(dequantization1.subtractConstant)) ? + Shape(dequantization1.subtract->get_shape().size(), 1) : + dequantization1.subtractConstant->get_shape(); const size_t lastIdx = matMul->get_transpose_a() ? broadcastShape.size() - 2 : broadcastShape.size() - 1; broadcastShape[lastIdx] = dequantization1.subtract->get_shape()[lastIdx]; // broadcasted sub const to form [1, ..., 1, Y] const auto broadcastedConst = fold( - dequantization1.subtract->get_input_node_shared_ptr(1), + dequantization1.subtractConstant, opset1::Constant::create(ngraph::element::i32, { broadcastShape.size() }, broadcastShape)); // multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z] @@ -84,7 +84,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat const auto newSubtract = std::make_shared(newMatMul, newSubConst); newSubtract->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationSubtract"); - ngraph::copy_runtime_info({ newSubtract, matMul }, newSubtract); + copy_runtime_info({ newSubtract, matMul }, newSubtract); parent = newSubtract; } @@ -100,17 +100,12 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2)); auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant); - std::shared_ptr transposedConstant = fold(node, order); + std::shared_ptr transposedConstant = fold(node, order); return transposedConstant; }; - const auto mulConst1 = matMul->get_transpose_a() ? - transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) : - dequantization1.multiply->get_input_node_shared_ptr(1); - - auto mulConst2 = matMul->get_transpose_b() ? - transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) : - dequantization2.multiply->get_input_node_shared_ptr(1); + const auto mulConst1 = matMul->get_transpose_a() ? transpose(dequantization1.multiplyConstant) : dequantization1.multiplyConstant; + auto mulConst2 = matMul->get_transpose_b() ? transpose(dequantization2.multiplyConstant) : dequantization2.multiplyConstant; if (NetworkHelper::isScalarLike(as_type_ptr(mulConst2))) { mulConst2 = NetworkHelper::toScalar(as_type_ptr(mulConst2)); @@ -125,16 +120,16 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat mulConst2 = fold( mulConst2, - op::Constant::create(ngraph::element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape)); + op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape)); } } - const auto newMulConst = NetworkHelper::toScalarIfPossible(fold(mulConst1, mulConst2)); + const auto newMulConst = NetworkHelper::toScalarIfPossible(fold(mulConst1, mulConst2)); const std::shared_ptr newMultiply = std::make_shared(parent, newMulConst); newMultiply->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationMultiply"); replace_node(matMul, newMultiply); - ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply); + copy_runtime_info({ newMultiply, matMul }, newMultiply); updateOutput(context, newMultiply, matMul); @@ -145,12 +140,12 @@ void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationC addPattern( pass, context, - make_op_pattern({ make_op_label(), make_op_label() })); + make_op_pattern({ make_op_label(), make_op_label() })); addPattern( pass, context, - make_op_pattern({ make_op_label(), make_op_label() })); + make_op_pattern({ make_op_label(), make_op_label() })); } bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { @@ -167,15 +162,14 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context return false; } - const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer); + const auto dequantization1 = NetworkHelper::getDequantization(layer, 0); if (!dequantization1.empty()) { if (updatePrecisions && !dequantization1.isLowPrecision()) { return false; } - const auto mulConst = as_type_ptr(dequantization1.multiply->get_input_node_shared_ptr(1)); - if (!NetworkHelper::isScalarLike(mulConst)) { - const auto constantShape = mulConst->get_shape(); + if (!NetworkHelper::isScalarLike(dequantization1.multiplyConstant)) { + const auto constantShape = dequantization1.multiplyConstant->get_shape(); const auto mulShape = dequantization1.multiply->get_shape(); const size_t columnsIdx = matMul->get_transpose_a() ? mulShape.size() - 2ul : mulShape.size() - 1ul; @@ -186,15 +180,21 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context } } - const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1); + const auto dequantization2 = NetworkHelper::getDequantization(layer, 1); if (!dequantization2.empty()) { - if ((updatePrecisions && !dequantization2.isLowPrecision()) || (dequantization2.subtract)) { + if ((updatePrecisions && !dequantization2.isLowPrecision())) { return false; } - const auto mulConst = as_type_ptr(dequantization2.multiply->get_input_node_shared_ptr(1)); - if (!NetworkHelper::isScalarLike(mulConst)) { - const auto constantShape = mulConst->get_shape(); + if (dequantization2.subtract) { + const auto roundedConst = NetworkHelper::round(dequantization2.subtractConstant, dequantization2.data.get_element_type()); + if (!NetworkHelper::isZeroConst(roundedConst)) { + return false; + } + } + + if (!NetworkHelper::isScalarLike(dequantization2.multiplyConstant)) { + const auto constantShape = dequantization2.multiplyConstant->get_shape(); const auto mulShape = dequantization2.multiply->get_shape(); const size_t rowsIdx = matMul->get_transpose_b() ? mulShape.size() - 1ul : mulShape.size() - 2ul; @@ -229,7 +229,7 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context } } - if (fakeQuantize == nullptr && dequantization1.subtract) { + if ((!NetworkHelper::isConstantPath(layer->get_input_node_shared_ptr(1))) && (dequantization1.subtract)) { return false; } diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/is_constant_path_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/is_constant_path_transformation.cpp new file mode 100644 index 00000000000..791838ee173 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/is_constant_path_transformation.cpp @@ -0,0 +1,141 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph_functions/subgraph_builders.hpp" +#include "low_precision/network_helper.hpp" + +#include "lpt_ngraph_functions/common/builders.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp" + +using namespace testing; +using namespace ngraph::pass; +using namespace ngraph::builder::subgraph; + +TEST(LPT, isConstantPathFQAfterInputTransformation) { + const auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto fqOnActivations = makeFakeQuantize(input, ngraph::element::f32, + FakeQuantizeOnData{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(fqOnActivations); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathFQAfterWeightsTransformation) { + const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f }); + const auto fqOnWeights = makeFakeQuantize(weights, ngraph::element::f32, + FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(fqOnWeights); + + ASSERT_EQ(true, result); +} + +TEST(LPT, isConstantPathDqAfterInputTransformation) { + const auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto dqOnActivations = makeDequantization(input, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(dqOnActivations); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathDqAfterWeightsTransformation) { + const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f }); + const auto dqOnWeights = makeDequantization(weights, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(dqOnWeights); + + ASSERT_EQ(true, result); +} + +TEST(LPT, isConstantPathTwoInputsTransformation) { + const auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto input2 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto matmul = std::make_shared(dq1, dq2); + + const bool result = low_precision::NetworkHelper::isConstantPath(matmul); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathTwoConsantsTransformation) { + const auto constant1 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f }); + const auto constant2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 1, 1 }, { 1.f }); + const auto dq1 = makeDequantization(constant1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto dq2 = makeDequantization(constant2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto eltwise = std::make_shared(dq1, dq2); + + const bool result = low_precision::NetworkHelper::isConstantPath(eltwise); + + ASSERT_EQ(true, result); +} + +TEST(LPT, isConstantPathMatMulParentFQTransformation) { + const auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto input2 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto matmul = std::make_shared(dq1, dq2); + const auto fqAfterMatMul = makeFakeQuantize(matmul, ngraph::element::f32, + FakeQuantizeOnWeights{ 255ul, {}, {0.f}, {254.f}, {-1.27f}, {1.27f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(fqAfterMatMul); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathMatMulParentDqTransformation) { + const auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto input2 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto dq1 = makeDequantization(input1, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto dq2 = makeDequantization(input2, DequantizationOperations{ ngraph::element::f32, {128.f}, {0.1f} }); + const auto matmul = std::make_shared(dq1, dq2); + const auto dqAfterMatMul = makeDequantization(matmul, DequantizationOperations{ {}, {}, {0.1f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterMatMul); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathConvParentDqTransformation) { + const auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 72, 16 }); + const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 6, 3, 1, 1 }, { 1.f }); + const auto conv = std::make_shared( + input, + weights, + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + const auto dqAfterConv = makeDequantization(conv, DequantizationOperations{ {}, {}, {0.1f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterConv); + + ASSERT_EQ(false, result); +} + +TEST(LPT, isConstantPathGroupConvParentDqTransformation) { + const auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 16, 16 }); + const auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 6, 3, 1, 1 }, { 1.f }); + const auto groupConv = std::make_shared( + input, + weights, + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + const auto dqAfterGroupConv = makeDequantization(groupConv, DequantizationOperations{ {}, {}, {0.1f} }); + + const bool result = low_precision::NetworkHelper::isConstantPath(dqAfterGroupConv); + + ASSERT_EQ(false, result); +}