[LPT] Add NormalizeDequantization function in NetworkHelper (#3458)
* [LPT] Add NormalizeDequantization function in NetworkHelper. * [LPT] Handling subtract constant index in makeDequantization * [LPT] Extend Add and Multiply transformations with normalizeDequantization. * [LPT] Add/Subtract simplify normalizeDequantization call * [LPT] normalizeDequantization: usage replace_node instead of copy assignment * [LPT] Update lpt paths * [LPT] normalizeDequantization completion + refactoring Co-authored-by: Aleksandr Pertovsky <aleksandr.pertovsky@intel.com>
This commit is contained in:
parent
885a493336
commit
46f0775c09
@ -116,6 +116,8 @@ public:
|
||||
|
||||
static FakeQuantizeDequantization getDequantization(const std::shared_ptr<Node> node, const size_t parentIndex = 0ul, const bool inPlace = false);
|
||||
|
||||
static FakeQuantizeDequantization normalizeDequantization(FakeQuantizeDequantization dequantization);
|
||||
|
||||
static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);
|
||||
|
||||
class InsertDequantizationResult {
|
||||
|
@ -94,6 +94,9 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
|
||||
return false;
|
||||
}
|
||||
|
||||
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, 0));
|
||||
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, 1));
|
||||
|
||||
std::shared_ptr<Node> addNode = separateInStandaloneBranch(op);
|
||||
std::shared_ptr<opset1::Add> add = as_type_ptr<opset1::Add>(addNode);
|
||||
|
||||
|
@ -30,6 +30,9 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
|
||||
return false;
|
||||
}
|
||||
|
||||
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, 0));
|
||||
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, 1));
|
||||
|
||||
multiply = separateInStandaloneBranch(multiply);
|
||||
auto newMultiply = multiply;
|
||||
|
||||
|
@ -807,6 +807,29 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt
|
||||
return FakeQuantizeDequantization(dataNode, convert, subtract, multiply);
|
||||
}
|
||||
|
||||
FakeQuantizeDequantization NetworkHelper::normalizeDequantization(FakeQuantizeDequantization dequantization) {
|
||||
if (dequantization.empty()) {
|
||||
return dequantization;
|
||||
}
|
||||
if (dequantization.multiply != nullptr && as_type_ptr<ngraph::opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0))) {
|
||||
std::shared_ptr<Node> leftParent = dequantization.multiply->get_input_node_shared_ptr(0);
|
||||
std::shared_ptr<Node> rightParent = dequantization.multiply->get_input_node_shared_ptr(1);
|
||||
std::shared_ptr<opset1::Multiply> normalized_multiply = as_type_ptr<opset1::Multiply>(
|
||||
dequantization.multiply->clone_with_new_inputs({rightParent, leftParent}));
|
||||
replace_node(dequantization.multiply, normalized_multiply);
|
||||
dequantization.multiply = normalized_multiply;
|
||||
}
|
||||
if (dequantization.subtract != nullptr && as_type_ptr<ngraph::opset1::Constant>(dequantization.subtract->get_input_node_shared_ptr(0))) {
|
||||
std::shared_ptr<Node> leftParent = dequantization.subtract->get_input_node_shared_ptr(0);
|
||||
std::shared_ptr<Node> rightParent = dequantization.subtract->get_input_node_shared_ptr(1);
|
||||
std::shared_ptr<opset1::Subtract> normalized_subtract = as_type_ptr<opset1::Subtract>(
|
||||
dequantization.subtract->clone_with_new_inputs({rightParent, leftParent}));
|
||||
replace_node(dequantization.subtract, normalized_subtract);
|
||||
dequantization.subtract = normalized_subtract;
|
||||
}
|
||||
return dequantization;
|
||||
}
|
||||
|
||||
FakeQuantizeDequantizationValues NetworkHelper::createEmptyValues(const FakeQuantizeDequantization& dequantization) {
|
||||
std::shared_ptr<Node> parent = dequantization.convert ? dequantization.convert : dequantization.data.get_node_shared_ptr();
|
||||
|
||||
|
@ -0,0 +1,170 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <low_precision/network_helper.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
#include "lpt_ngraph_functions/normalize_dequantization_function.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
|
||||
class NormalizeDequantizationTestValues {
|
||||
public:
|
||||
class Actual {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
};
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
ngraph::Shape inputShape;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
|
||||
class NormalizeDequantizationTransformation : public LayerTransformation, public testing::WithParamInterface<NormalizeDequantizationTestValues> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const NormalizeDequantizationTestValues testValues = GetParam();
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::NormalizeDequantizationFunction::getOriginal(
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.inputShape,
|
||||
testValues.actual.dequantization);
|
||||
|
||||
const auto targetNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
|
||||
const auto dequantization = low_precision::NetworkHelper::getDequantization(targetNode);
|
||||
low_precision::NetworkHelper::normalizeDequantization(dequantization);
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::NormalizeDequantizationFunction::getOriginal(
|
||||
testValues.expected.precisionBeforeDequantization,
|
||||
testValues.inputShape,
|
||||
testValues.expected.dequantization);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<NormalizeDequantizationTestValues> obj) {
|
||||
const NormalizeDequantizationTestValues testValues = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
testValues.inputShape << "_" <<
|
||||
testValues.actual.precisionBeforeDequantization << "_" <<
|
||||
testValues.actual.dequantization << "_" <<
|
||||
testValues.expected.dequantization << "_";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(NormalizeDequantizationTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
const std::vector<NormalizeDequantizationTestValues> testValues = {
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ 1, 3, 16, 16 },
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{
|
||||
{},
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 },
|
||||
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 }
|
||||
},
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{
|
||||
{},
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1},
|
||||
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ 1, 3, 16, 16 },
|
||||
{
|
||||
ngraph::element::i32,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
|
||||
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 }
|
||||
},
|
||||
},
|
||||
{
|
||||
ngraph::element::i32,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
|
||||
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ 1, 3, 16, 16 },
|
||||
{
|
||||
ngraph::element::u32,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 0 },
|
||||
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 }
|
||||
},
|
||||
},
|
||||
{
|
||||
ngraph::element::u32,
|
||||
{
|
||||
{ {ngraph::element::f32} },
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
|
||||
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(true),
|
||||
{ 1, 3, 16, 16 },
|
||||
{
|
||||
ngraph::element::u32,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
|
||||
{ {10.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 }
|
||||
},
|
||||
},
|
||||
{
|
||||
ngraph::element::u32,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{ {7.f}, ngraph::element::f32, { 1, 3, 16, 16 }, true, 1 },
|
||||
{{10.0f}, ngraph::element::f32, {1, 3, 16, 16}, true, 1 }
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_LPT,
|
||||
NormalizeDequantizationTransformation,
|
||||
::testing::ValuesIn(testValues),
|
||||
NormalizeDequantizationTransformation::getTestCaseName);
|
@ -97,9 +97,11 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
|
||||
data.subtract.values << "_" <<
|
||||
data.subtract.constantShape << "_" <<
|
||||
data.subtract.outPrecision << "_" <<
|
||||
data.subtract.constantIndex << "_" <<
|
||||
data.multiply.values << "_" <<
|
||||
data.multiply.constantShape << "_" <<
|
||||
data.multiply.outPrecision;
|
||||
data.multiply.outPrecision << "_" <<
|
||||
data.multiply.constantIndex;
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
|
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class NormalizeDequantizationFunction {
|
||||
public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -51,17 +51,20 @@ std::shared_ptr<Node> makeDequantization(
|
||||
shape,
|
||||
dequantizationOperations.subtract.values);
|
||||
|
||||
Output<Node> leftBranchParent = dequantizationOperations.subtract.constantIndex == 1 ? parent : subtractConst;
|
||||
Output<Node> rightBranchParent = dequantizationOperations.subtract.constantIndex == 1 ? subtractConst : parent;
|
||||
|
||||
if (((dequantizationOperations.subtract.outPrecision == element::undefined) ||
|
||||
(dequantizationOperations.subtract.outPrecision == parent.get_element_type())) &&
|
||||
((dequantizationOperations.subtract.constantPrecision == element::undefined) ||
|
||||
(dequantizationOperations.subtract.constantPrecision == parent.get_element_type()))) {
|
||||
subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(parent, subtractConst);
|
||||
subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(leftBranchParent, rightBranchParent);
|
||||
} else {
|
||||
subtract = std::make_shared<op::TypeRelaxed<ngraph::pass::low_precision::DequantizationSubtract>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(subtractConst, element::f32).get());
|
||||
ngraph::op::TemporaryReplaceOutputType(leftBranchParent, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(rightBranchParent, element::f32).get());
|
||||
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
|
||||
}
|
||||
if (!dequantizationOperations.subtract.addDequantizationAttribute) {
|
||||
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "lpt_ngraph_functions/normalize_dequantization_function.hpp"
|
||||
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
#include "ngraph_ops/type_relaxed.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
std::shared_ptr<ngraph::Function> NormalizeDequantizationFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
|
||||
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
|
||||
|
||||
const auto deq = makeDequantization(input, dequantization);
|
||||
|
||||
const auto op = ngraph::opset1::MaxPool(
|
||||
deq,
|
||||
Strides{ 1, 1 },
|
||||
Shape{ 1, 1 },
|
||||
Shape{ 0, 0 },
|
||||
Shape{ 2, 2 },
|
||||
op::RoundingType::FLOOR);
|
||||
const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
|
||||
op,
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{});
|
||||
auto& rtInfo = targetOp->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
|
||||
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(targetOp) },
|
||||
ngraph::ParameterVector{ input },
|
||||
"NormalizeDequantizationFunction");
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user