[LPT] Add NormalizeDequantization function in NetworkHelper (#3458)

* [LPT] Add NormalizeDequantization function in NetworkHelper.

* [LPT] Handling subtract constant index in makeDequantization

* [LPT] Extend Add and Multiply transformations with normalizeDequantization.

* [LPT] Add/Subtract simplify normalizeDequantization call

* [LPT] normalizeDequantization: usage replace_node instead of copy assignment

* [LPT] Update lpt paths

* [LPT] normalizeDequantization completion + refactoring

Co-authored-by: Aleksandr Pertovsky <aleksandr.pertovsky@intel.com>
This commit is contained in:
Edward Shogulin 2021-01-28 11:30:52 +03:00 committed by GitHub
parent 885a493336
commit 46f0775c09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 280 additions and 4 deletions

View File

@ -116,6 +116,8 @@ public:
static FakeQuantizeDequantization getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex = 0ul, const bool inPlace = false);
static FakeQuantizeDequantization normalizeDequantization(FakeQuantizeDequantization dequantization);
static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);
class InsertDequantizationResult {

View File

@ -94,6 +94,9 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
return false;
}
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, 1));
std::shared_ptr<Node> addNode = separateInStandaloneBranch(op);
std::shared_ptr<opset1::Add> add = as_type_ptr<opset1::Add>(addNode);

View File

@ -30,6 +30,9 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
return false;
}
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, 1));
multiply = separateInStandaloneBranch(multiply);
auto newMultiply = multiply;

View File

@ -807,6 +807,29 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt
return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
}
FakeQuantizeDequantization NetworkHelper::normalizeDequantization(FakeQuantizeDequantization dequantization) {
if (dequantization.empty()) {
return dequantization;
}
if (dequantization.multiply != nullptr && as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0))) {
std::shared_ptr<Node> leftParent = dequantization.multiply->get_input_node_shared_ptr(0);
std::shared_ptr<Node> rightParent = dequantization.multiply->get_input_node_shared_ptr(1);
std::shared_ptr<opset1::Multiply> normalized_multiply = as_type_ptr<opset1::Multiply>(
dequantization.multiply->clone_with_new_inputs({rightParent, leftParent}));
replace_node(dequantization.multiply, normalized_multiply);
dequantization.multiply = normalized_multiply;
}
if (dequantization.subtract != nullptr && as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(0))) {
std::shared_ptr<Node> leftParent = dequantization.subtract->get_input_node_shared_ptr(0);
std::shared_ptr<Node> rightParent = dequantization.subtract->get_input_node_shared_ptr(1);
std::shared_ptr<opset1::Subtract> normalized_subtract = as_type_ptr<opset1::Subtract>(
dequantization.subtract->clone_with_new_inputs({rightParent, leftParent}));
replace_node(dequantization.subtract, normalized_subtract);
dequantization.subtract = normalized_subtract;
}
return dequantization;
}
FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();

View File

@ -0,0 +1,170 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <low_precision/network_helper.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
#include "lpt_ngraph_functions/normalize_dequantization_function.hpp"
using namespace testing;
using namespace ngraph::pass;
class NormalizeDequantizationTestValues {
public:
class Actual {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization;
};
class Expected {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization;
};
ngraph::pass::low_precision::LayerTransformation::Params params;
ngraph::Shape inputShape;
Actual actual;
Expected expected;
};
class NormalizeDequantizationTransformation : public LayerTransformation, public testing::WithParamInterface<NormalizeDequantizationTestValues> {
public:
void SetUp() override {
const NormalizeDequantizationTestValues testValues = GetParam();
actualFunction = ngraph::builder::subgraph::NormalizeDequantizationFunction::getOriginal(
testValues.actual.precisionBeforeDequantization,
testValues.inputShape,
testValues.actual.dequantization);
const auto targetNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
const auto dequantization = low_precision::NetworkHelper::getDequantization(targetNode);
low_precision::NetworkHelper::normalizeDequantization(dequantization);
referenceFunction = ngraph::builder::subgraph::NormalizeDequantizationFunction::getOriginal(
testValues.expected.precisionBeforeDequantization,
testValues.inputShape,
testValues.expected.dequantization);
}
static std::string getTestCaseName(testing::TestParamInfo<NormalizeDequantizationTestValues> obj) {
const NormalizeDequantizationTestValues testValues = obj.param;
std::ostringstream result;
result <<
testValues.inputShape << "_" <<
testValues.actual.precisionBeforeDequantization << "_" <<
testValues.actual.dequantization << "_" <<
testValues.expected.dequantization << "_";
return result.str();
}
};
TEST_P(NormalizeDequantizationTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<NormalizeDequantizationTestValues> testValues = {
{
LayerTransformation::createParamsU8I8(),
{ 1, 3, 16, 16 },
{
ngraph::element::f32,
{
{},
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 },
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 }
},
},
{
ngraph::element::f32,
{
{},
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1},
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
}
},
},
{
LayerTransformation::createParamsU8I8(),
{ 1, 3, 16, 16 },
{
ngraph::element::i32,
{
{ngraph::element::f32},
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 }
},
},
{
ngraph::element::i32,
{
{ ngraph::element::f32 },
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
}
},
},
{
LayerTransformation::createParamsU8I8(),
{ 1, 3, 16, 16 },
{
ngraph::element::u32,
{
{ ngraph::element::f32 },
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 },
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 }
},
},
{
ngraph::element::u32,
{
{ {ngraph::element::f32} },
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
}
},
},
{
LayerTransformation::createParamsU8I8().setUpdatePrecisions(true),
{ 1, 3, 16, 16 },
{
ngraph::element::u32,
{
{ ngraph::element::f32 },
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 }
},
},
{
ngraph::element::u32,
{
{ ngraph::element::f32 },
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
}
},
},
};
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
NormalizeDequantizationTransformation,
::testing::ValuesIn(testValues),
NormalizeDequantizationTransformation::getTestCaseName);

View File

@ -97,9 +97,11 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
data.subtract.values << "_" <<
data.subtract.constantShape << "_" <<
data.subtract.outPrecision << "_" <<
data.subtract.constantIndex << "_" <<
data.multiply.values << "_" <<
data.multiply.constantShape << "_" <<
data.multiply.outPrecision;
data.multiply.outPrecision << "_" <<
data.multiply.constantIndex;
}
} // namespace subgraph

View File

@ -0,0 +1,25 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/ngraph.hpp>
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
class NormalizeDequantizationFunction {
public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -51,17 +51,20 @@ std::shared_ptr<Node> makeDequantization(
shape,
dequantizationOperations.subtract.values);
Output<Node> leftBranchParent = dequantizationOperations.subtract.constantIndex == 1 ? parent : subtractConst;
Output<Node> rightBranchParent = dequantizationOperations.subtract.constantIndex == 1 ? subtractConst : parent;
if (((dequantizationOperations.subtract.outPrecision == element::undefined) ||
(dequantizationOperations.subtract.outPrecision == parent.get_element_type())) &&
((dequantizationOperations.subtract.constantPrecision == element::undefined) ||
(dequantizationOperations.subtract.constantPrecision == parent.get_element_type()))) {
subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(leftBranchParent, rightBranchParent);
} else {
subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ element::f32 },
ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
ngraph::op::TemporaryReplaceOutputType(leftBranchParent, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(rightBranchParent, element::f32).get());
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
}
if (!dequantizationOperations.subtract.addDequantizationAttribute) {

View File

@ -0,0 +1,45 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "lpt_ngraph_functions/normalize_dequantization_function.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
#include "lpt_ngraph_functions/common/builders.hpp"
#include "ngraph_ops/type_relaxed.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<ngraph::Function> NormalizeDequantizationFunction::getOriginal(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
const auto deq = makeDequantization(input, dequantization);
const auto op = ngraph::opset1::MaxPool(
deq,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 },
op::RoundingType::FLOOR);
const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
op,
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{});
auto& rtInfo = targetOp->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(targetOp) },
ngraph::ParameterVector{ input },
"NormalizeDequantizationFunction");
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph