LPT: element-wise ops fuse extension to support updatePrecisions (#14141)

This commit is contained in:
Edward Shogulin
2022-11-25 09:49:03 +00:00
committed by GitHub
parent 04df4195a0
commit cdb711a069
5 changed files with 103 additions and 16 deletions

View File

@@ -32,7 +32,8 @@ public:
static std::shared_ptr<opset1::FakeQuantize> fuseElementwise(
TransformationContext& context,
MatcherPass* matcherPass,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize,
const bool updatePrecisions);
};
} // namespace low_precision

View File

@@ -100,7 +100,7 @@ bool AssignAndReadValueTransformation::transform(TransformationContext& context,
return true;
}
FakeQuantizeTransformation::fuseElementwise(context, this, fakeQuantize);
FakeQuantizeTransformation::fuseElementwise(context, this, fakeQuantize, updatePrecisions);
return true;
}

View File

@@ -42,7 +42,7 @@ bool FakeQuantizeTransformation::transform(TransformationContext& context, ngrap
bool wasHandled = false;
std::shared_ptr<opset1::FakeQuantize> fakeQuantize = layer;
do {
fakeQuantize = fuseElementwise(context, this, fakeQuantize);
fakeQuantize = fuseElementwise(context, this, fakeQuantize, updatePrecisions);
wasHandled = wasHandled || (fakeQuantize != nullptr);
} while (fakeQuantize != nullptr);
@@ -88,6 +88,41 @@ static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>
return ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
}
bool all_precisions_equal(const std::shared_ptr<Node>& node) {
const auto& inputs = node->inputs();
const auto first_input_precision = inputs.empty() ? element::undefined : inputs[0].get_element_type();
if (!inputs.empty()) {
const auto first_input_precision = inputs[0].get_element_type();
if (std::any_of(
inputs.begin(),
inputs.end(),
[first_input_precision](const ov::Input<ov::Node>& input) {
return input.get_element_type() != first_input_precision;
})) {
return false;
}
}
const auto& outputs = node->outputs();
if (!outputs.empty()) {
const auto first_output_precision = outputs[0].get_element_type();
if ((first_input_precision != element::undefined) && (first_input_precision != first_output_precision)) {
return false;
}
if (std::any_of(
outputs.begin(),
outputs.end(),
[first_output_precision](const ov::Output<ov::Node>& output) {
return output.get_element_type() != first_output_precision;
})) {
return false;
}
}
return true;
}
} // namespace fq
bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
@@ -121,9 +156,14 @@ bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& e
std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
TransformationContext& context,
MatcherPass* matcherPass,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) {
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize,
const bool updatePrecisions) {
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
if (!updatePrecisions && !fq::all_precisions_equal(eltwise)) {
return nullptr;
}
std::shared_ptr<Node> inputLowConst_f32 = foldConvert(fakeQuantize->input_value(1), element::f32);
std::shared_ptr<Node> inputHighConst_f32 = foldConvert(fakeQuantize->input_value(2), element::f32);

View File

@@ -123,6 +123,50 @@ TEST_P(FuseFakeQuantizeTransformation, CompareFunctions) {
}
const std::vector<FuseFakeQuantizeTransformationTestValues> testValues = {
// Convert: U8 -> FP32, updatePrecisions = true
{
{1, 3, 16, 16},
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
{
element::f32,
{},
element::u8,
{ {element::f32}, {}, {} },
element::f32,
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } }
},
{
element::f32,
{},
element::u8,
{{}, {}, {}},
element::f32,
element::f32,
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 255.f }, element::u8 }
}
},
// Convert: U8 -> FP32, updatePrecisions = false
{
{1, 3, 16, 16},
TestTransformationParams(false, {ngraph::element::u8}, {ngraph::element::i8}),
{
element::f32,
{},
element::u8,
{ {element::f32}, {}, {} },
element::f32,
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } }
},
{
element::f32,
{},
element::u8,
{{element::f32}, {}, {}},
element::f32,
element::f32,
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 255.f }, element::f32 }
}
},
// 1) Multiply
{
{1, 3, 16, 16},

View File

@@ -66,19 +66,21 @@ std::shared_ptr<ngraph::Function> FuseFakeQuantizeFunction::getOriginal(
const std::shared_ptr<Node> lastDequantization = makeDequantization(parent, dequantization);
std::shared_ptr<Node> lastNode;
auto fqOnDataCopy = fqOnData;
fqOnDataCopy.outputHighValues = {255.f};
fqOnDataCopy.outputPrecision = fqOnData.outputPrecision == element::undefined ? ngraph::element::u8 : fqOnData.outputPrecision;
if (fqOnData.outputLowValues == std::vector<float>{0.f} &&
fqOnData.outputHighValues == std::vector<float>{2.55f}) {
auto fqOnDataCopy = fqOnData;
fqOnDataCopy.outputHighValues = {255.f};
fqOnDataCopy.outputPrecision = ngraph::element::u8;
lastNode = makeFakeQuantizeTypeRelaxed(lastDequantization, precisionFqOnData, fqOnDataCopy);
lastNode = makeDequantization(lastNode, { {element::f32}, {}, {{0.01f}, precisionFqOnData} });
} else {
throw std::runtime_error("Unknown parameter on output intervals!");
}
std::shared_ptr<Node> lastNode = makeFakeQuantizeTypeRelaxed(lastDequantization, precisionFqOnData, fqOnDataCopy);
lastNode = makeDequantization(
lastNode,
{
lastNode->output(0).get_element_type() != element::f32 ?
DequantizationOperations::Convert{element::f32} :
DequantizationOperations::Convert{},
{},
{{0.01f},
precisionFqOnData}
});
lastNode->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(lastNode) };