[LPT] FuseConvert transformation extension (#10558)

* [LPT] FuseConvert transformation extension

* [LPT] Tests

* [LPT] Cleanup & tests refactoring
This commit is contained in:
Edward Shogulin 2022-02-22 02:02:11 +03:00 committed by GitHub
parent d7ad1bd9cd
commit 5be402750a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 87 additions and 245 deletions

View File

@ -7,7 +7,6 @@
#include <memory>
#include <ngraph/ngraph.hpp>
#include "layer_transformation.hpp"
#include "low_precision/fuse_fake_quantize.hpp"
namespace ngraph {
namespace pass {

View File

@ -7,7 +7,6 @@
#include <memory>
#include <ngraph/ngraph.hpp>
#include "layer_transformation.hpp"
#include "low_precision/fuse_fake_quantize.hpp"
namespace ngraph {
namespace pass {

View File

@ -1,30 +0,0 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/ngraph.hpp>
#include "low_precision/layer_transformation.hpp"
namespace ngraph {
namespace pass {
namespace low_precision {
class LP_TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
FuseFakeQuantizeTransformation(const Params& params);
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
private:
std::shared_ptr<opset1::FakeQuantize> handle(
TransformationContext& context,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const;
};
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -175,7 +175,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
return nullptr;
}
const auto data = fq::getDataNode(eltwise);
// issue #79980
const auto data = eltwise->get_input_size() == 1ul ? eltwise->get_input_node_shared_ptr(0) : fq::getDataNode(eltwise);
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({

View File

@ -23,8 +23,14 @@ FuseConvertTransformation::FuseConvertTransformation(const Params& params) : Lay
auto multiply = pattern::wrap_type<opset1::Multiply>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
auto subtract = pattern::wrap_type<opset1::Subtract>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
auto add = pattern::wrap_type<opset1::Add>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
auto fakeQuantize = pattern::wrap_type<opset1::FakeQuantize>({
pattern::wrap_type<opset1::Convert>({pattern::wrap_type<opset1::Constant>()}),
pattern::any_input(),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()});
auto matcher = std::make_shared<ngraph::pattern::Matcher>(
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add }),
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add, fakeQuantize }),
"FuseConvertTransformation");
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {

View File

@ -1,193 +0,0 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/fuse_fake_quantize.hpp"
#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/network_helper.hpp"
namespace ngraph {
namespace pass {
namespace low_precision {
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::FuseFakeQuantizeTransformation, "FuseFakeQuantizeTransformation", 0);
FuseFakeQuantizeTransformation::FuseFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
auto matcher = pattern::wrap_type<opset1::FakeQuantize>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, "FuseFakeQuantizeTransformation");
this->register_matcher(m, callback);
}
bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
auto fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
if (!fakeQuantize)
return false;
do {
fakeQuantize = handle(context, fakeQuantize);
} while (fakeQuantize != nullptr);
return true;
}
namespace fuse_fq {
namespace {
std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const PartialShape& targetPShape) {
assert(targetPShape.is_static());
assert(op->get_output_partial_shape(0).is_static());
const Shape targetShape = targetPShape.to_shape();
const Shape shape = op->get_output_shape(0);
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
op = fold<opset1::Unsqueeze>(
op,
std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
}
return op;
}
std::shared_ptr<Node> getDataNode(const std::shared_ptr<Node>& eltwise) {
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
return eltwise->get_input_node_shared_ptr(0);
}
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
return eltwise->get_input_node_shared_ptr(1);
}
return nullptr;
}
std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
if (eltwise->get_input_size() != 2) {
return nullptr;
}
std::shared_ptr<opset1::Constant> constant = ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
if (constant != nullptr) {
return constant;
}
return ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
}
bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
if (constant == nullptr) {
return false;
}
Shape shape = constant->get_shape();
if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
const auto eltwisePShape = eltwise->get_output_partial_shape(0);
if (eltwisePShape.rank().is_dynamic()) {
return false;
}
const size_t eltwiseOutRank = eltwisePShape.rank().get_length();
if ((eltwiseOutRank - shape.size()) > 1) {
return false;
}
if ((eltwiseOutRank - shape.size()) == 1ul) {
shape.insert(shape.begin(), 1ul);
}
for (size_t i = 2ul; i < shape.size(); ++i) {
if (shape[i] != 1ul) {
return false;
}
}
}
return getDataNode(eltwise) != nullptr;
}
} // namespace
} // namespace fuse_fq
std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
TransformationContext& context,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
if (ov::is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
foldConvert(constant, eltwise->get_output_element_type(0));
inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
foldConvert(constant, eltwise->get_output_element_type(0));
inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
foldConvert(constant, eltwise->get_output_element_type(0));
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
if (ov::is_type<opset1::Convolution>(fuse_fq::getDataNode(eltwise)) ||
ov::is_type<opset1::GroupConvolution>(fuse_fq::getDataNode(eltwise))) {
return nullptr;
}
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
foldConvert(constant, eltwise->get_output_element_type(0));
inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
} else if (ov::is_type<opset1::Convert>(eltwise)) {
// issue #40611
if ((eltwise->get_input_element_type(0) == element::i32) && (eltwise->get_output_element_type(0) == element::f32)) {
return nullptr;
}
} else {
return nullptr;
}
const auto data = fuse_fq::getDataNode(eltwise);
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
data->output(outputIdx),
inputLowConst,
inputHightConst,
fakeQuantize->input_value(3),
fakeQuantize->input_value(4) }));
replace_node(fakeQuantize, newFakeQuantize);
NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
return newFakeQuantize;
}
bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return false;
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -74,7 +74,6 @@
#include "low_precision/convert.hpp"
#include "low_precision/fold_fake_quantize.hpp"
#include "low_precision/fuse_convert.hpp"
#include "low_precision/fuse_fake_quantize.hpp"
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"

View File

@ -30,12 +30,14 @@ public:
public:
ngraph::element::Type inputPrecision;
ngraph::builder::subgraph::DequantizationOperations dequantization;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
};
class Expected {
public:
ngraph::element::Type inputPrecision;
ngraph::builder::subgraph::DequantizationOperations dequantization;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
};
bool constInput;
@ -58,6 +60,7 @@ public:
inputShape,
testValues.actual.inputPrecision,
testValues.actual.dequantization,
testValues.actual.fakeQuantize,
testValues.constInput);
SimpleLowPrecisionTransformer transformer;
@ -68,6 +71,7 @@ public:
inputShape,
testValues.expected.inputPrecision,
testValues.expected.dequantization,
testValues.expected.fakeQuantize,
testValues.constInput);
}
@ -77,9 +81,13 @@ public:
std::ostringstream result;
result <<
inputShape << "_" <<
testValues.actual.inputPrecision << "_" <<
testValues.actual.dequantization << "_" <<
"IS_" << inputShape << "_" <<
"AIP_" << testValues.actual.inputPrecision << "_" <<
"ADEQ_" << testValues.actual.dequantization << "_" <<
"AFQ_" << testValues.actual.fakeQuantize << "_" <<
"EIP_" << testValues.expected.inputPrecision << "_" <<
"EDEQ_" << testValues.expected.dequantization << "_" <<
"EFQ_" << testValues.expected.fakeQuantize << "_" <<
testValues.constInput;
return result.str();
}
@ -111,7 +119,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
{ ngraph::element::f32 },
{1.f},
{0.45f}
}
},
{}
},
{
ngraph::element::u8,
@ -119,7 +128,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
{},
DequantizationOperations::Subtract({1.f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32),
{0.45f}
}
},
{}
}
},
// fuse to multiply
@ -132,7 +142,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
{ ngraph::element::f32 },
{},
{0.45f}
}
},
{}
},
{
ngraph::element::u8,
@ -140,7 +151,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
{},
{},
DequantizationOperations::Multiply({0.45f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32)
}
},
{}
}
},
// Convert with unexpected precision
@ -149,11 +161,13 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
LayerTransformation::createParamsU8I8(),
{
ngraph::element::f32,
{{ ngraph::element::i32 }, {}, {3.f}}
{{ ngraph::element::i32 }, {}, {3.f}},
{}
},
{
ngraph::element::f32,
{{ ngraph::element::i32 }, {}, {3.f}}
{{ ngraph::element::i32 }, {}, {3.f}},
{}
}
},
};
@ -173,6 +187,27 @@ const std::vector<ngraph::PartialShape> inputShapes = {
};
const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant = {
// Constant
// |
// Convert Const Const Const Const
// \ \ | / /
// \ \ | / /
// FakeQuantize
//
{
true,
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {}, {}},
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
},
{
ngraph::element::f32,
{},
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
}
},
// fuse to const
{
true,
@ -183,7 +218,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
{ ngraph::element::f32 },
{1.f},
{0.45f}
}
},
{}
},
{
ngraph::element::f32,
@ -191,7 +227,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
{},
{1.f},
{0.45f}
}
},
{}
}
},
};

View File

@ -12,7 +12,7 @@
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <low_precision/fuse_fake_quantize.hpp>
#include <low_precision/fake_quantize.hpp>
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
@ -62,7 +62,7 @@ public:
testValues.actual.fakeQuantizeOnData);
SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::FuseFakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
transformer.add<ngraph::pass::low_precision::FakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
transformer.transform(actualFunction);
referenceFunction = ngraph::builder::subgraph::FuseFakeQuantizeFunction::get(

View File

@ -35,7 +35,6 @@
// cleanup transformations
#include "low_precision/fuse_convert.hpp"
#include "low_precision/fuse_fake_quantize.hpp"
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"

View File

@ -117,10 +117,16 @@ public:
};
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Convert& convert) {
if (convert.empty()) {
return out << "{}";
}
return out << "_" << (convert.outPrecision != element::undefined ? convert.outPrecision.get_type_name() : "");
}
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Subtract& subtract) {
if (subtract.empty()) {
return out << "{}";
}
return out << "_" <<
subtract.values << "_" <<
subtract.outPrecision << "_" <<
@ -132,6 +138,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
}
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Multiply& multiply) {
if (multiply.empty()) {
return out << "{}";
}
return out << "_" <<
multiply.values << "_" <<
multiply.outPrecision << "_" <<
@ -142,6 +151,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
}
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations& data) {
if (data.empty()) {
return out << "{}";
}
return out << "_" << data.convert << "_" << data.subtract << "_" << data.multiply;
}

View File

@ -54,6 +54,9 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector<float>& valu
}
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& data) {
if (data.empty()) {
return out << "{}";
}
return out << "_" << data.quantizationLevel << data.constantShape << "_" << data.inputLowValues << "_" << data.inputHighValues <<
"_" << data.outputLowValues << "_" << data.outputHighValues << "_" <<
(data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
@ -89,6 +92,9 @@ public:
};
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
if (data.empty()) {
return out << "{}";
}
return out << "_" << data.quantizationLevel <<
(data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" <<
data.inputLowValues << "_" << data.inputHighValues << "_" <<

View File

@ -20,6 +20,7 @@ public:
const ngraph::PartialShape& inputShape,
const ngraph::element::Type inputPrecision,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
const bool constInput);
static std::shared_ptr<ngraph::Function> getWithFQ(

View File

@ -16,6 +16,7 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
const ngraph::PartialShape& inputShape,
const ngraph::element::Type inputPrecision,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
const bool constInput) {
std::shared_ptr<Node> parent;
std::shared_ptr<op::Parameter> input;
@ -28,14 +29,19 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
parent = input;
}
const std::shared_ptr<Node> dequantizationOp = makeDequantization(parent, dequantization);
dequantizationOp->set_friendly_name("output");
parent = makeDequantization(parent, dequantization);
if (!fakeQuantize.empty()) {
parent = makeFakeQuantize(parent, fakeQuantize.outputPrecision, fakeQuantize);
}
parent->set_friendly_name("output");
auto parameters = constInput ?
ngraph::ParameterVector{}:
ngraph::ParameterVector{ input };
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOp) };
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(parent)};
return std::make_shared<ngraph::Function>(results, parameters, "FuseConvertFunction");
}