[LPT] integration: issue #42391 & issue #43001 (cherry-pick to master) (#3255)

* [LPT] NetworkHelper::roundWithTolerance: removed tolerance & rename to round
[LPT] NetworkHelper::round functional tests
[LPT] ieFuncTests: updated some test-cases

* [LPT] Subtract is not used

* [LPT] AddTransformation: zero handling

* [LPT] AddTransformation test
This commit is contained in:
Edward Shogulin 2020-11-23 17:19:06 +03:00 committed by GitHub
parent d02223c796
commit b5d7f236f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 343 additions and 114 deletions

View File

@ -17,6 +17,7 @@ public:
~AddTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
} // namespace low_precision

View File

@ -26,6 +26,7 @@ public:
std::shared_ptr<ngraph::opset1::Multiply> multiply);
bool empty() const;
bool multiplyHasZero() const;
bool isShared() const;
bool isLowPrecision() const;
static bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise);

View File

@ -81,7 +81,7 @@ public:
// Optimizes the series of multiplies after a given output port
static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
static std::shared_ptr<opset1::Constant> roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance = 0.1);
static std::shared_ptr<opset1::Constant> round(std::shared_ptr<Node> node, element::Type target_type);
static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,

View File

@ -199,6 +199,20 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
return true;
}
bool AddTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
const FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, 0ul);
if (dequantization1.multiplyHasZero()) {
return false;
}
const FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, 1ul);
if (dequantization2.multiplyHasZero()) {
return false;
}
return EltwiseBaseTransformation::canBeTransformed(context, layer);
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -42,7 +42,8 @@ bool ClampTransformation::transform(TransformationContext& context, ngraph::patt
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(clamp);
const bool moveSubtract = subWithTheSameValues(dequantization.subtract);
if (!moveSubtract && !canSubtractBeHandled(clamp, dequantization)) {
// issue #43136
if (!moveSubtract && (dequantization.subtract != nullptr)) {
return false;
}
const auto newClamp = as_type_ptr<opset1::Clamp>(moveDequantizationAfter(context, clamp, dequantization, false, moveSubtract));

View File

@ -30,6 +30,23 @@ bool FakeQuantizeDequantization::empty() const {
return (convert == nullptr) && (subtract == nullptr) && (multiply == nullptr);
}
bool FakeQuantizeDequantization::multiplyHasZero() const {
if (multiply == nullptr) {
return false;
}
std::shared_ptr<opset1::Constant> multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
if (multiplyConstant == nullptr) {
multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(0));
}
if (multiplyConstant == nullptr) {
return false;
}
auto const values = multiplyConstant->cast_vector<float>();
return std::any_of(values.begin(), values.end(), [](const float value) { return value == 0.f; });
}
bool FakeQuantizeDequantization::isShared() const {
if ((convert != nullptr) && (convert->get_output_target_inputs(0).size() > 1ul)) {
return true;

View File

@ -33,6 +33,7 @@ bool GroupConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) co
bool GroupConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
auto convolution = m.get_match_root();
if (!GroupConvolutionTransformation::canBeTransformed(context, convolution)) {
return false;
}

View File

@ -138,9 +138,7 @@ bool LayerTransformation::canSubtractBeHandled(const std::shared_ptr<Node>& op,
return false;
}
std::shared_ptr<Node> zeroPoint = dequantization.subtract->input_value(1).get_node_shared_ptr();
auto convertedZeroPoint = NetworkHelper::roundWithTolerance(zeroPoint, operationType);
return convertedZeroPoint->output(0).get_element_type() == operationType;
return true;
}
#ifdef LPT_PRINT_DEQUANTIZATION_INFO

View File

@ -41,7 +41,7 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
return false;
}
if (!canSubtractBeHandled(operation)) {
if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
return false;
}

View File

@ -321,52 +321,15 @@ std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter
return nullptr;
}
std::shared_ptr<opset1::Constant> NetworkHelper::roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance) {
auto constant = as_type_ptr<opset1::Constant>(node);
std::shared_ptr<opset1::Constant> NetworkHelper::round(std::shared_ptr<Node> node, element::Type target_type) {
const auto constant = as_type_ptr<opset1::Constant>(node);
assert(constant);
auto values = constant->cast_vector<float>();
auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(constant, target_type));
auto castedValues = castedConstant->cast_vector<float>();
const auto castedConstant = as_type_ptr<ngraph::opset1::Constant>(fold<op::v0::Convert>(
fold<ngraph::op::v5::Round>(constant->output(0), ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO),
target_type));
// TODO: implement with constant folding when ReduceAnd constant folding is ready
if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
return castedConstant;
}
auto round = [](
const std::shared_ptr<opset1::Constant>& constant,
element::Type target_type,
float tolerance,
std::vector<float>& values,
float increaseValue) -> std::shared_ptr<opset1::Constant> {
const auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(
fold<opset1::Add>(constant, std::make_shared<opset1::Constant>(constant->get_output_element_type(0), Shape{ 1 }, increaseValue)),
target_type));
const auto castedValues = castedConstant->cast_vector<float>();
if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
return castedConstant;
}
return nullptr;
};
castedConstant = round(constant, target_type, tolerance, values, 0.5f);
if (castedConstant != nullptr) {
return castedConstant;
}
castedConstant = round(constant, target_type, tolerance, values, -0.5f);
if (castedConstant != nullptr) {
return castedConstant;
}
castedConstant = round(constant, target_type, tolerance, values, 1.f);
if (castedConstant != nullptr) {
return castedConstant;
}
return constant;
return castedConstant;
}
std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
@ -889,16 +852,13 @@ std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Su
auto data = convertOnSubtract->input_value(0);
auto shift = subtract->input_value(1).get_node_shared_ptr();
auto roundedShift = NetworkHelper::roundWithTolerance(shift, convertInputType);
auto roundedShift = NetworkHelper::round(shift, convertInputType);
std::shared_ptr<Node> replacement;
if (roundedShift->get_element_type() == convertInputType) {
// Propagate convertInputType down
replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
NetworkHelper::copyInfo(subtract, replacement);
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
replace_node(subtract, replacement);
}
// Propagate convertInputType down
const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
NetworkHelper::copyInfo(subtract, replacement);
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
replace_node(subtract, replacement);
// We lose the tail conversion here; not needed if the next node is a TypeRelaxed
// TODO: check cases when Convert should be preserved
@ -992,7 +952,8 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
optimizeSubtract(dequantization.subtract);
// issue #43088
// NetworkHelper::optimizeElementwise(dequantization.subtract);
}
return InsertDequantizationResult(newOperation, parent);

View File

@ -40,7 +40,7 @@ bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& co
return false;
}
if (!canSubtractBeHandled(operation)) {
if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
return false;
}

View File

@ -40,7 +40,7 @@ bool PReluTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const n
bool PReluTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
if (dequantization.empty()) {
if (dequantization.empty() || (dequantization.subtract != nullptr)) {
return false;
}

View File

@ -48,11 +48,7 @@ bool ReluTransformation::canBeTransformed(const TransformationContext& context,
}
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
if (dequantization.empty()) {
return false;
}
if (!canSubtractBeHandled(op, dequantization)) {
if (dequantization.empty() || (dequantization.subtract != nullptr)) {
return false;
}

View File

@ -72,12 +72,13 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
}
if (dequantization.convert != nullptr) {
std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeSubtract(subtract);
newSubtract->set_output_type(0, originalPrecision, newSubtract->get_output_partial_shape(0));
// issue #43088
// std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeElementwise(subtract);
subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
replace_node(newSubtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
newSubtract->get_input_node_shared_ptr(0),
newSubtract->get_input_node_shared_ptr(1)));
replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
subtract->get_input_node_shared_ptr(0),
subtract->get_input_node_shared_ptr(1)));
}
return true;
}

View File

@ -147,6 +147,54 @@ TEST_P(AddTransformation, CompareFunctions) {
}
const std::vector<AddTransformationTestValues> addTransformationTestValues = {
// Multiply with zero on the first branch
{
ngraph::element::f32,
ngraph::Shape{1, 4, 16, 16},
false,
-1,
LayerTransformation::createParamsU8I8(),
{
ngraph::element::f32,
{ },
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
{ }
},
{
ngraph::element::f32,
{ },
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
{ },
{ }
},
""
},
// Multiply with zero on the second branch
{
ngraph::element::f32,
ngraph::Shape{1, 4, 16, 16},
false,
-1,
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
ngraph::element::f32,
{ },
{ }
},
{
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
ngraph::element::f32,
{ },
{ },
{ }
},
""
},
// U8
{
ngraph::element::f32,

View File

@ -331,9 +331,13 @@ const std::vector<ClampTransformationTestValues> testValues = {
},
{
ngraph::element::u8,
{{}, {{ 128.f, 0.f, 128.f }, ngraph::element::f32}, {}},
{
{ngraph::element::f32},
{{ 128.f, 0.f, 128.f }},
{{ 3.f, 3.f, 3.f }}
},
ngraph::element::f32,
{{}, {}, {{3.f, 3.f, 3.f}}}
{{}, {}, {}}
}
},
// U8 without asymmetric quantization

View File

@ -154,7 +154,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
// ActualValues
{
ngraph::element::f32,
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
{{}, { 128.f }, { 0.02f }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
},
@ -214,7 +214,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
// ActualValues
{
ngraph::element::f32,
{{ngraph::element::f32}, {}, { 0.02f }},
{{}, {}, { 0.02f }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
},

View File

@ -165,7 +165,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
// ActualValues
{
ngraph::element::f32,
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
{{}, { 128.f }, { 0.02f }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
},
@ -329,7 +329,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
// ActualValues
{
ngraph::element::f32,
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
{{}, { 128.f }, { 0.02f }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
},

View File

@ -218,12 +218,12 @@ std::vector<MatMullTransformationTestValues> testValues = {
},
{
ngraph::element::u8,
{ ngraph::element::f32, { 127.5f }, { 0.02f } },
{ {}, {{128.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
ngraph::element::i8,
{ ngraph::element::f32, {}, { 0.03f } },
{ },
ngraph::element::f32,
ngraph::element::f32,
{},
{ {}, {}, { 0.0006f } },
}
},
// U8 + FP32

View File

@ -129,7 +129,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
},
{
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
ngraph::element::f32,
{ {}, {}, { 10.f } },
},
@ -159,7 +159,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
},
{
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
ngraph::element::f32,
{ {}, {}, { 10.f } },
},
@ -189,7 +189,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
},
{
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
ngraph::element::f32,
{ {}, {}, { 10.f } },
},
@ -219,7 +219,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
},
{
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
ngraph::element::f32,
{ {}, {}, { 10.f } },
},
@ -234,12 +234,12 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
},
{
{ {}, { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
{ {ngraph::element::f32}, { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
ngraph::element::f32,
{ {}, {}, { { 10.f, 10.f, 10.f } } },
},
},
// per-channel quantizations with the same values
// per-channel quantizations with different values
{
ngraph::element::u8,
LayerTransformation::createParamsU8I8(),
@ -249,7 +249,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
{ {ngraph::element::f32}, { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
},
{
{ {}, { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
{ {ngraph::element::f32}, { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
ngraph::element::f32,
{ {}, {}, { { 10.f, 12.f, 16.f } } },
},

View File

@ -91,6 +91,7 @@ public:
std::ostringstream result;
result <<
toString(testValues.params) << "_" <<
testValues.inputShape << "_" <<
testValues.reductionAxes << "_" <<
testValues.normalizeVariance << "_" <<
@ -145,9 +146,9 @@ const std::vector<MVNTransformationTestValues> testValues = {
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {127.f}, {}},
{{ngraph::element::f32}, {127.f}, {0.45f}},
ngraph::element::f32,
{{}, {}, {1.f}}
{{}, {}, {}}
}
},
{
@ -163,7 +164,7 @@ const std::vector<MVNTransformationTestValues> testValues = {
ngraph::element::u8,
{{ngraph::element::f32}, {12.5f}, {0.45f}},
ngraph::element::f32,
{}
{{}, {}, {}}
}
},
{

View File

@ -53,7 +53,7 @@ public:
low_precision::LayerTransformation::Params(params.transformationParams));
transform.transform(actualFunction);
referenceFunction = (!params.transformationParams.supportAsymmetricQuantization) && (!params.expected.subtractValues.empty()) ?
referenceFunction = !params.expected.subtractValues.empty() ?
ngraph::builder::subgraph::NormalizeL2Function::getOriginal(
precision,
shape,

View File

@ -137,9 +137,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
},
{
ngraph::element::u8,
{{}, { {128}, ngraph::element::f32 }, {}},
{{ngraph::element::f32}, { 128 }, {0.1f}},
ngraph::element::f32,
{{}, {}, {0.1f}}
{{}, {}, {}}
}
},
// I8: with positive subtract value
@ -152,24 +152,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
},
{
ngraph::element::i8,
{{}, { {127}, ngraph::element::f32 }, {}},
{{ngraph::element::f32}, { 127 }, {0.1f}},
ngraph::element::f32,
{{}, {}, {0.1f}}
}
},
// U8: with negative subtract value: Convert is still here
{
ngraph::Shape({ 1, 3, 16, 16 }),
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, { -128 }, {0.1f}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, { {-128}, ngraph::element::f32 }, {}},
ngraph::element::f32,
{{}, {}, {0.1f}}
{{}, {}, {}}
}
},
};

View File

@ -73,6 +73,7 @@ public:
std::ostringstream result;
result <<
toString(testValues.params) << "_" <<
testValues.shape << "_" <<
testValues.actual.precisionBeforeDequantization << "_" <<
testValues.actual.dequantization << "_" <<
@ -166,9 +167,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
},
{
ngraph::element::u8,
{{}, { {128}, ngraph::element::f32, {}, false }, {}},
{{ngraph::element::f32}, { 128 }, {0.1f}},
ngraph::element::f32,
{{}, {}, {0.1f}}
{{}, {}, {}}
}
},
// I8: with subtract value
@ -181,9 +182,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
},
{
ngraph::element::i8,
{{}, { {127}, ngraph::element::f32, {}, false }, {}},
{{ngraph::element::f32}, { 127 }, {0.1f}},
ngraph::element::f32,
{{}, {}, {0.1f}}
{{}, {}, {}}
}
},
// I8: with subtract value

View File

@ -0,0 +1,111 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <gtest/gtest.h>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ngraph_functions/low_precision_transformations/round_function.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
#include "low_precision/network_helper.hpp"
namespace {
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
class RoundTestValues {
public:
ngraph::element::Type inputPrecision;
ngraph::Shape inputShape;
ngraph::builder::subgraph::DequantizationOperations actualDequantization;
ngraph::builder::subgraph::DequantizationOperations referenceDequantization;
};
class RoundTransformation : public LayerTransformation, public testing::WithParamInterface<RoundTestValues> {
public:
void SetUp() override {
const auto testValues = this->GetParam();
actualFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getOriginal(
testValues.inputPrecision,
testValues.inputShape,
testValues.actualDequantization);
const auto lastNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(lastNode);
const auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
const auto roundedConst = ngraph::pass::low_precision::NetworkHelper::round(
subtractConstant,
testValues.inputPrecision);
if (roundedConst->get_element_type() == testValues.inputPrecision) {
const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(dequantization.data, roundedConst);
ngraph::pass::low_precision::NetworkHelper::copyInfo(dequantization.subtract, replacement);
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, dequantization.convert->get_element_type());
replace_node(dequantization.subtract, replacement);
}
referenceFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getReference(
testValues.inputPrecision,
testValues.inputShape,
testValues.referenceDequantization);
}
static std::string getTestCaseName(testing::TestParamInfo<RoundTestValues> obj) {
const auto testValues = obj.param;
std::ostringstream result;
result << testValues.inputPrecision << "_"
<< testValues.actualDequantization << "_"
<< testValues.referenceDequantization;
return result.str();
}
};
std::vector<RoundTestValues> testValues = {
{
ngraph::element::u8,
ngraph::Shape{ 1, 3, 16, 16 },
{ { ngraph::element::f32 }, { 125.5f }, { 0.1f } },
{ {}, { { 126.f }, ngraph::element::f32 }, { 0.1f } }
},
{
ngraph::element::u8,
ngraph::Shape{ 1, 3, 16, 16 },
{ { ngraph::element::f32 }, { { 128.3f, 64.5f, 31.7f } }, { { 0.1f, 0.1f, 0.1f } } },
{ {}, { { 128.f, 65.f, 32.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
},
{
ngraph::element::i8,
ngraph::Shape{ 1, 3, 16, 16 },
{ { ngraph::element::f32 }, { 126.6f }, { 0.1f } },
{ {}, { { 127.f }, ngraph::element::f32 }, { 0.1f } }
},
{
ngraph::element::i8,
ngraph::Shape{ 1, 3, 16, 16 },
{ { ngraph::element::f32 }, { { 126.5f, 32.25f, -127.5f } }, { { 0.1f, 0.1f, 0.1f } } },
{ {}, { { 127.f, 32.f, -128.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
},
};
TEST_P(RoundTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, true);
ASSERT_TRUE(res.first) << res.second;
}
INSTANTIATE_TEST_CASE_P(
LPT,
RoundTransformation,
::testing::ValuesIn(testValues),
RoundTransformation::getTestCaseName);
} // namespace

View File

@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/ngraph.hpp>
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
class RoundWithToleranceFunction {
public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization);
static std::shared_ptr<ngraph::Function> getReference(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,56 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/low_precision_transformations/round_function.hpp"
#include <ngraph/opsets/opset1.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
using namespace ngraph::pass::low_precision;
namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getOriginal(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
input->set_friendly_name("input");
const auto deq = makeDequantization(input, dequantization);
deq->set_friendly_name("output");
const auto result = std::make_shared<ngraph::opset1::Result>(deq);
result->set_friendly_name("result");
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{ result },
ngraph::ParameterVector{ input },
"RoundWithToleranceFunction");
}
std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getReference(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
input->set_friendly_name("input");
const auto deq = makeDequantization(input, dequantization);
deq->set_friendly_name("output");
const auto result = std::make_shared<ngraph::opset1::Result>(deq);
result->set_friendly_name("result");
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{ result },
ngraph::ParameterVector{ input },
"RoundWithToleranceFunction");
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph