* [LPT] NetworkHelper::roundWithTolerance: removed tolerance & rename to round [LPT] NetworkHelper::round functional tests [LPT] ieFuncTests: updated some test-cases * [LPT] Subtract is not used * [LPT] AddTransformation: zero handling * [LPT] AddTransformation test
This commit is contained in:
parent
d02223c796
commit
b5d7f236f4
@ -17,6 +17,7 @@ public:
|
||||
~AddTransformation() override {}
|
||||
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
|
||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
|
||||
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -26,6 +26,7 @@ public:
|
||||
std::shared_ptr<ngraph::opset1::Multiply> multiply);
|
||||
|
||||
bool empty() const;
|
||||
bool multiplyHasZero() const;
|
||||
bool isShared() const;
|
||||
bool isLowPrecision() const;
|
||||
static bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise);
|
||||
|
@ -81,7 +81,7 @@ public:
|
||||
// Optimizes the series of multiplies after a given output port
|
||||
static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
|
||||
|
||||
static std::shared_ptr<opset1::Constant> roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance = 0.1);
|
||||
static std::shared_ptr<opset1::Constant> round(std::shared_ptr<Node> node, element::Type target_type);
|
||||
|
||||
static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
|
||||
std::shared_ptr<opset1::FakeQuantize> fq,
|
||||
|
@ -199,6 +199,20 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AddTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
|
||||
const FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, 0ul);
|
||||
if (dequantization1.multiplyHasZero()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, 1ul);
|
||||
if (dequantization2.multiplyHasZero()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return EltwiseBaseTransformation::canBeTransformed(context, layer);
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -42,7 +42,8 @@ bool ClampTransformation::transform(TransformationContext& context, ngraph::patt
|
||||
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(clamp);
|
||||
|
||||
const bool moveSubtract = subWithTheSameValues(dequantization.subtract);
|
||||
if (!moveSubtract && !canSubtractBeHandled(clamp, dequantization)) {
|
||||
// issue #43136
|
||||
if (!moveSubtract && (dequantization.subtract != nullptr)) {
|
||||
return false;
|
||||
}
|
||||
const auto newClamp = as_type_ptr<opset1::Clamp>(moveDequantizationAfter(context, clamp, dequantization, false, moveSubtract));
|
||||
|
@ -30,6 +30,23 @@ bool FakeQuantizeDequantization::empty() const {
|
||||
return (convert == nullptr) && (subtract == nullptr) && (multiply == nullptr);
|
||||
}
|
||||
|
||||
bool FakeQuantizeDequantization::multiplyHasZero() const {
|
||||
if (multiply == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::Constant> multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
|
||||
if (multiplyConstant == nullptr) {
|
||||
multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(0));
|
||||
}
|
||||
if (multiplyConstant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const values = multiplyConstant->cast_vector<float>();
|
||||
return std::any_of(values.begin(), values.end(), [](const float value) { return value == 0.f; });
|
||||
}
|
||||
|
||||
bool FakeQuantizeDequantization::isShared() const {
|
||||
if ((convert != nullptr) && (convert->get_output_target_inputs(0).size() > 1ul)) {
|
||||
return true;
|
||||
|
@ -33,6 +33,7 @@ bool GroupConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) co
|
||||
|
||||
bool GroupConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
|
||||
auto convolution = m.get_match_root();
|
||||
|
||||
if (!GroupConvolutionTransformation::canBeTransformed(context, convolution)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -138,9 +138,7 @@ bool LayerTransformation::canSubtractBeHandled(const std::shared_ptr<Node>& op,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> zeroPoint = dequantization.subtract->input_value(1).get_node_shared_ptr();
|
||||
auto convertedZeroPoint = NetworkHelper::roundWithTolerance(zeroPoint, operationType);
|
||||
return convertedZeroPoint->output(0).get_element_type() == operationType;
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef LPT_PRINT_DEQUANTIZATION_INFO
|
||||
|
@ -41,7 +41,7 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!canSubtractBeHandled(operation)) {
|
||||
if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -321,52 +321,15 @@ std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::Constant> NetworkHelper::roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance) {
|
||||
auto constant = as_type_ptr<opset1::Constant>(node);
|
||||
std::shared_ptr<opset1::Constant> NetworkHelper::round(std::shared_ptr<Node> node, element::Type target_type) {
|
||||
const auto constant = as_type_ptr<opset1::Constant>(node);
|
||||
assert(constant);
|
||||
auto values = constant->cast_vector<float>();
|
||||
|
||||
auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(constant, target_type));
|
||||
auto castedValues = castedConstant->cast_vector<float>();
|
||||
const auto castedConstant = as_type_ptr<ngraph::opset1::Constant>(fold<op::v0::Convert>(
|
||||
fold<ngraph::op::v5::Round>(constant->output(0), ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO),
|
||||
target_type));
|
||||
|
||||
// TODO: implement with constant folding when ReduceAnd constant folding is ready
|
||||
if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
auto round = [](
|
||||
const std::shared_ptr<opset1::Constant>& constant,
|
||||
element::Type target_type,
|
||||
float tolerance,
|
||||
std::vector<float>& values,
|
||||
float increaseValue) -> std::shared_ptr<opset1::Constant> {
|
||||
const auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(
|
||||
fold<opset1::Add>(constant, std::make_shared<opset1::Constant>(constant->get_output_element_type(0), Shape{ 1 }, increaseValue)),
|
||||
target_type));
|
||||
const auto castedValues = castedConstant->cast_vector<float>();
|
||||
if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
castedConstant = round(constant, target_type, tolerance, values, 0.5f);
|
||||
if (castedConstant != nullptr) {
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
castedConstant = round(constant, target_type, tolerance, values, -0.5f);
|
||||
if (castedConstant != nullptr) {
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
castedConstant = round(constant, target_type, tolerance, values, 1.f);
|
||||
if (castedConstant != nullptr) {
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
return constant;
|
||||
return castedConstant;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
|
||||
@ -889,16 +852,13 @@ std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Su
|
||||
|
||||
auto data = convertOnSubtract->input_value(0);
|
||||
auto shift = subtract->input_value(1).get_node_shared_ptr();
|
||||
auto roundedShift = NetworkHelper::roundWithTolerance(shift, convertInputType);
|
||||
auto roundedShift = NetworkHelper::round(shift, convertInputType);
|
||||
|
||||
std::shared_ptr<Node> replacement;
|
||||
if (roundedShift->get_element_type() == convertInputType) {
|
||||
// Propagate convertInputType down
|
||||
replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
|
||||
NetworkHelper::copyInfo(subtract, replacement);
|
||||
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
|
||||
replace_node(subtract, replacement);
|
||||
}
|
||||
// Propagate convertInputType down
|
||||
const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
|
||||
NetworkHelper::copyInfo(subtract, replacement);
|
||||
NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
|
||||
replace_node(subtract, replacement);
|
||||
|
||||
// We lose the tail conversion here; not needed if the next node is a TypeRelaxed
|
||||
// TODO: check cases when Convert should be preserved
|
||||
@ -992,7 +952,8 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
|
||||
if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
|
||||
NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
|
||||
optimizeSubtract(dequantization.subtract);
|
||||
// issue #43088
|
||||
// NetworkHelper::optimizeElementwise(dequantization.subtract);
|
||||
}
|
||||
|
||||
return InsertDequantizationResult(newOperation, parent);
|
||||
|
@ -40,7 +40,7 @@ bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& co
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!canSubtractBeHandled(operation)) {
|
||||
if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ bool PReluTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const n
|
||||
|
||||
bool PReluTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
|
||||
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
|
||||
if (dequantization.empty()) {
|
||||
if (dequantization.empty() || (dequantization.subtract != nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -48,11 +48,7 @@ bool ReluTransformation::canBeTransformed(const TransformationContext& context,
|
||||
}
|
||||
|
||||
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
|
||||
if (dequantization.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!canSubtractBeHandled(op, dequantization)) {
|
||||
if (dequantization.empty() || (dequantization.subtract != nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -72,12 +72,13 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
|
||||
}
|
||||
|
||||
if (dequantization.convert != nullptr) {
|
||||
std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeSubtract(subtract);
|
||||
newSubtract->set_output_type(0, originalPrecision, newSubtract->get_output_partial_shape(0));
|
||||
// issue #43088
|
||||
// std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeElementwise(subtract);
|
||||
subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
|
||||
|
||||
replace_node(newSubtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
|
||||
newSubtract->get_input_node_shared_ptr(0),
|
||||
newSubtract->get_input_node_shared_ptr(1)));
|
||||
replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
|
||||
subtract->get_input_node_shared_ptr(0),
|
||||
subtract->get_input_node_shared_ptr(1)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -147,6 +147,54 @@ TEST_P(AddTransformation, CompareFunctions) {
|
||||
}
|
||||
|
||||
const std::vector<AddTransformationTestValues> addTransformationTestValues = {
|
||||
// Multiply with zero on the first branch
|
||||
{
|
||||
ngraph::element::f32,
|
||||
ngraph::Shape{1, 4, 16, 16},
|
||||
false,
|
||||
-1,
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{ },
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
|
||||
{ }
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{ },
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
|
||||
{ },
|
||||
{ }
|
||||
},
|
||||
""
|
||||
},
|
||||
// Multiply with zero on the second branch
|
||||
{
|
||||
ngraph::element::f32,
|
||||
ngraph::Shape{1, 4, 16, 16},
|
||||
false,
|
||||
-1,
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
|
||||
ngraph::element::f32,
|
||||
{ },
|
||||
{ }
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
|
||||
ngraph::element::f32,
|
||||
{ },
|
||||
{ },
|
||||
{ }
|
||||
},
|
||||
""
|
||||
},
|
||||
// U8
|
||||
{
|
||||
ngraph::element::f32,
|
||||
|
@ -331,9 +331,13 @@ const std::vector<ClampTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {{ 128.f, 0.f, 128.f }, ngraph::element::f32}, {}},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{ 128.f, 0.f, 128.f }},
|
||||
{{ 3.f, 3.f, 3.f }}
|
||||
},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {{3.f, 3.f, 3.f}}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
// U8 without asymmetric quantization
|
||||
|
@ -154,7 +154,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
|
||||
{{}, { 128.f }, { 0.02f }},
|
||||
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
|
||||
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
|
||||
},
|
||||
@ -214,7 +214,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ngraph::element::f32}, {}, { 0.02f }},
|
||||
{{}, {}, { 0.02f }},
|
||||
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
|
||||
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
|
||||
},
|
||||
|
@ -165,7 +165,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
|
||||
{{}, { 128.f }, { 0.02f }},
|
||||
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
|
||||
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
|
||||
},
|
||||
@ -329,7 +329,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.02f }},
|
||||
{{}, { 128.f }, { 0.02f }},
|
||||
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
|
||||
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
|
||||
},
|
||||
|
@ -218,12 +218,12 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 127.5f }, { 0.02f } },
|
||||
{ {}, {{128.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
|
||||
ngraph::element::i8,
|
||||
{ ngraph::element::f32, {}, { 0.03f } },
|
||||
{ },
|
||||
ngraph::element::f32,
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
{ {}, {}, { 0.0006f } },
|
||||
}
|
||||
},
|
||||
// U8 + FP32
|
||||
|
@ -129,7 +129,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 10.f } },
|
||||
},
|
||||
@ -159,7 +159,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 10.f } },
|
||||
},
|
||||
@ -189,7 +189,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 10.f } },
|
||||
},
|
||||
@ -219,7 +219,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f }, ngraph::element::f32, {}, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 10.f } },
|
||||
},
|
||||
@ -234,12 +234,12 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { { 10.f, 10.f, 10.f } } },
|
||||
},
|
||||
},
|
||||
// per-channel quantizations with the same values
|
||||
// per-channel quantizations with different values
|
||||
{
|
||||
ngraph::element::u8,
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
@ -249,7 +249,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
|
||||
{ {ngraph::element::f32}, { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
|
||||
},
|
||||
{
|
||||
{ {}, { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
|
||||
{ {ngraph::element::f32}, { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { { 10.f, 12.f, 16.f } } },
|
||||
},
|
||||
|
@ -91,6 +91,7 @@ public:
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
toString(testValues.params) << "_" <<
|
||||
testValues.inputShape << "_" <<
|
||||
testValues.reductionAxes << "_" <<
|
||||
testValues.normalizeVariance << "_" <<
|
||||
@ -145,9 +146,9 @@ const std::vector<MVNTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {127.f}, {}},
|
||||
{{ngraph::element::f32}, {127.f}, {0.45f}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {1.f}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -163,7 +164,7 @@ const std::vector<MVNTransformationTestValues> testValues = {
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {12.5f}, {0.45f}},
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -53,7 +53,7 @@ public:
|
||||
low_precision::LayerTransformation::Params(params.transformationParams));
|
||||
transform.transform(actualFunction);
|
||||
|
||||
referenceFunction = (!params.transformationParams.supportAsymmetricQuantization) && (!params.expected.subtractValues.empty()) ?
|
||||
referenceFunction = !params.expected.subtractValues.empty() ?
|
||||
ngraph::builder::subgraph::NormalizeL2Function::getOriginal(
|
||||
precision,
|
||||
shape,
|
||||
|
@ -137,9 +137,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, { {128}, ngraph::element::f32 }, {}},
|
||||
{{ngraph::element::f32}, { 128 }, {0.1f}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {0.1f}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
// I8: with positive subtract value
|
||||
@ -152,24 +152,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{}, { {127}, ngraph::element::f32 }, {}},
|
||||
{{ngraph::element::f32}, { 127 }, {0.1f}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {0.1f}}
|
||||
}
|
||||
},
|
||||
// U8: with negative subtract value: Convert is still here
|
||||
{
|
||||
ngraph::Shape({ 1, 3, 16, 16 }),
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { -128 }, {0.1f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { {-128}, ngraph::element::f32 }, {}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {0.1f}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
@ -73,6 +73,7 @@ public:
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
toString(testValues.params) << "_" <<
|
||||
testValues.shape << "_" <<
|
||||
testValues.actual.precisionBeforeDequantization << "_" <<
|
||||
testValues.actual.dequantization << "_" <<
|
||||
@ -166,9 +167,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, { {128}, ngraph::element::f32, {}, false }, {}},
|
||||
{{ngraph::element::f32}, { 128 }, {0.1f}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {0.1f}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
// I8: with subtract value
|
||||
@ -181,9 +182,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{}, { {127}, ngraph::element::f32, {}, false }, {}},
|
||||
{{ngraph::element::f32}, { 127 }, {0.1f}},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {0.1f}}
|
||||
{{}, {}, {}}
|
||||
}
|
||||
},
|
||||
// I8: with subtract value
|
||||
|
@ -0,0 +1,111 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/round_function.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
|
||||
namespace {
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::pass;
|
||||
|
||||
class RoundTestValues {
|
||||
public:
|
||||
ngraph::element::Type inputPrecision;
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::builder::subgraph::DequantizationOperations actualDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations referenceDequantization;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class RoundTransformation : public LayerTransformation, public testing::WithParamInterface<RoundTestValues> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const auto testValues = this->GetParam();
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getOriginal(
|
||||
testValues.inputPrecision,
|
||||
testValues.inputShape,
|
||||
testValues.actualDequantization);
|
||||
const auto lastNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
|
||||
const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(lastNode);
|
||||
const auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
|
||||
const auto roundedConst = ngraph::pass::low_precision::NetworkHelper::round(
|
||||
subtractConstant,
|
||||
testValues.inputPrecision);
|
||||
|
||||
if (roundedConst->get_element_type() == testValues.inputPrecision) {
|
||||
const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(dequantization.data, roundedConst);
|
||||
ngraph::pass::low_precision::NetworkHelper::copyInfo(dequantization.subtract, replacement);
|
||||
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, dequantization.convert->get_element_type());
|
||||
replace_node(dequantization.subtract, replacement);
|
||||
}
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getReference(
|
||||
testValues.inputPrecision,
|
||||
testValues.inputShape,
|
||||
testValues.referenceDequantization);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<RoundTestValues> obj) {
|
||||
const auto testValues = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << testValues.inputPrecision << "_"
|
||||
<< testValues.actualDequantization << "_"
|
||||
<< testValues.referenceDequantization;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<RoundTestValues> testValues = {
|
||||
{
|
||||
ngraph::element::u8,
|
||||
ngraph::Shape{ 1, 3, 16, 16 },
|
||||
{ { ngraph::element::f32 }, { 125.5f }, { 0.1f } },
|
||||
{ {}, { { 126.f }, ngraph::element::f32 }, { 0.1f } }
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
ngraph::Shape{ 1, 3, 16, 16 },
|
||||
{ { ngraph::element::f32 }, { { 128.3f, 64.5f, 31.7f } }, { { 0.1f, 0.1f, 0.1f } } },
|
||||
{ {}, { { 128.f, 65.f, 32.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
ngraph::Shape{ 1, 3, 16, 16 },
|
||||
{ { ngraph::element::f32 }, { 126.6f }, { 0.1f } },
|
||||
{ {}, { { 127.f }, ngraph::element::f32 }, { 0.1f } }
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
ngraph::Shape{ 1, 3, 16, 16 },
|
||||
{ { ngraph::element::f32 }, { { 126.5f, 32.25f, -127.5f } }, { { 0.1f, 0.1f, 0.1f } } },
|
||||
{ {}, { { 127.f, 32.f, -128.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
|
||||
},
|
||||
};
|
||||
|
||||
TEST_P(RoundTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
LPT,
|
||||
RoundTransformation,
|
||||
::testing::ValuesIn(testValues),
|
||||
RoundTransformation::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class RoundWithToleranceFunction {
|
||||
public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_functions/low_precision_transformations/round_function.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
|
||||
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
|
||||
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
|
||||
input->set_friendly_name("input");
|
||||
|
||||
const auto deq = makeDequantization(input, dequantization);
|
||||
deq->set_friendly_name("output");
|
||||
|
||||
const auto result = std::make_shared<ngraph::opset1::Result>(deq);
|
||||
result->set_friendly_name("result");
|
||||
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ result },
|
||||
ngraph::ParameterVector{ input },
|
||||
"RoundWithToleranceFunction");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getReference(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations dequantization) {
|
||||
const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
|
||||
input->set_friendly_name("input");
|
||||
|
||||
const auto deq = makeDequantization(input, dequantization);
|
||||
deq->set_friendly_name("output");
|
||||
|
||||
const auto result = std::make_shared<ngraph::opset1::Result>(deq);
|
||||
result->set_friendly_name("result");
|
||||
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ result },
|
||||
ngraph::ParameterVector{ input },
|
||||
"RoundWithToleranceFunction");
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user