LPT: element-wise ops fuse extension to support updatePrecisions (#14141)
This commit is contained in:
@@ -32,7 +32,8 @@ public:
|
||||
static std::shared_ptr<opset1::FakeQuantize> fuseElementwise(
|
||||
TransformationContext& context,
|
||||
MatcherPass* matcherPass,
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize,
|
||||
const bool updatePrecisions);
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
|
||||
@@ -100,7 +100,7 @@ bool AssignAndReadValueTransformation::transform(TransformationContext& context,
|
||||
return true;
|
||||
}
|
||||
|
||||
FakeQuantizeTransformation::fuseElementwise(context, this, fakeQuantize);
|
||||
FakeQuantizeTransformation::fuseElementwise(context, this, fakeQuantize, updatePrecisions);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ bool FakeQuantizeTransformation::transform(TransformationContext& context, ngrap
|
||||
bool wasHandled = false;
|
||||
std::shared_ptr<opset1::FakeQuantize> fakeQuantize = layer;
|
||||
do {
|
||||
fakeQuantize = fuseElementwise(context, this, fakeQuantize);
|
||||
fakeQuantize = fuseElementwise(context, this, fakeQuantize, updatePrecisions);
|
||||
wasHandled = wasHandled || (fakeQuantize != nullptr);
|
||||
} while (fakeQuantize != nullptr);
|
||||
|
||||
@@ -88,6 +88,41 @@ static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>
|
||||
return ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
|
||||
}
|
||||
|
||||
bool all_precisions_equal(const std::shared_ptr<Node>& node) {
|
||||
const auto& inputs = node->inputs();
|
||||
const auto first_input_precision = inputs.empty() ? element::undefined : inputs[0].get_element_type();
|
||||
if (!inputs.empty()) {
|
||||
const auto first_input_precision = inputs[0].get_element_type();
|
||||
if (std::any_of(
|
||||
inputs.begin(),
|
||||
inputs.end(),
|
||||
[first_input_precision](const ov::Input<ov::Node>& input) {
|
||||
return input.get_element_type() != first_input_precision;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& outputs = node->outputs();
|
||||
if (!outputs.empty()) {
|
||||
const auto first_output_precision = outputs[0].get_element_type();
|
||||
if ((first_input_precision != element::undefined) && (first_input_precision != first_output_precision)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (std::any_of(
|
||||
outputs.begin(),
|
||||
outputs.end(),
|
||||
[first_output_precision](const ov::Output<ov::Node>& output) {
|
||||
return output.get_element_type() != first_output_precision;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace fq
|
||||
|
||||
bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
|
||||
@@ -121,9 +156,14 @@ bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& e
|
||||
std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
|
||||
TransformationContext& context,
|
||||
MatcherPass* matcherPass,
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) {
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize,
|
||||
const bool updatePrecisions) {
|
||||
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
|
||||
|
||||
if (!updatePrecisions && !fq::all_precisions_equal(eltwise)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> inputLowConst_f32 = foldConvert(fakeQuantize->input_value(1), element::f32);
|
||||
std::shared_ptr<Node> inputHighConst_f32 = foldConvert(fakeQuantize->input_value(2), element::f32);
|
||||
|
||||
|
||||
@@ -123,6 +123,50 @@ TEST_P(FuseFakeQuantizeTransformation, CompareFunctions) {
|
||||
}
|
||||
|
||||
const std::vector<FuseFakeQuantizeTransformationTestValues> testValues = {
|
||||
// Convert: U8 -> FP32, updatePrecisions = true
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
|
||||
{
|
||||
element::f32,
|
||||
{},
|
||||
element::u8,
|
||||
{ {element::f32}, {}, {} },
|
||||
element::f32,
|
||||
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } }
|
||||
},
|
||||
{
|
||||
element::f32,
|
||||
{},
|
||||
element::u8,
|
||||
{{}, {}, {}},
|
||||
element::f32,
|
||||
element::f32,
|
||||
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 255.f }, element::u8 }
|
||||
}
|
||||
},
|
||||
// Convert: U8 -> FP32, updatePrecisions = false
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
TestTransformationParams(false, {ngraph::element::u8}, {ngraph::element::i8}),
|
||||
{
|
||||
element::f32,
|
||||
{},
|
||||
element::u8,
|
||||
{ {element::f32}, {}, {} },
|
||||
element::f32,
|
||||
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } }
|
||||
},
|
||||
{
|
||||
element::f32,
|
||||
{},
|
||||
element::u8,
|
||||
{{element::f32}, {}, {}},
|
||||
element::f32,
|
||||
element::f32,
|
||||
{ 256ul, {}, { 0.f }, { 2.55f }, { 0.f }, { 255.f }, element::f32 }
|
||||
}
|
||||
},
|
||||
// 1) Multiply
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
|
||||
@@ -66,19 +66,21 @@ std::shared_ptr<ngraph::Function> FuseFakeQuantizeFunction::getOriginal(
|
||||
|
||||
const std::shared_ptr<Node> lastDequantization = makeDequantization(parent, dequantization);
|
||||
|
||||
std::shared_ptr<Node> lastNode;
|
||||
auto fqOnDataCopy = fqOnData;
|
||||
fqOnDataCopy.outputHighValues = {255.f};
|
||||
fqOnDataCopy.outputPrecision = fqOnData.outputPrecision == element::undefined ? ngraph::element::u8 : fqOnData.outputPrecision;
|
||||
|
||||
if (fqOnData.outputLowValues == std::vector<float>{0.f} &&
|
||||
fqOnData.outputHighValues == std::vector<float>{2.55f}) {
|
||||
auto fqOnDataCopy = fqOnData;
|
||||
fqOnDataCopy.outputHighValues = {255.f};
|
||||
fqOnDataCopy.outputPrecision = ngraph::element::u8;
|
||||
lastNode = makeFakeQuantizeTypeRelaxed(lastDequantization, precisionFqOnData, fqOnDataCopy);
|
||||
lastNode = makeDequantization(lastNode, { {element::f32}, {}, {{0.01f}, precisionFqOnData} });
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Unknown parameter on output intervals!");
|
||||
}
|
||||
std::shared_ptr<Node> lastNode = makeFakeQuantizeTypeRelaxed(lastDequantization, precisionFqOnData, fqOnDataCopy);
|
||||
lastNode = makeDequantization(
|
||||
lastNode,
|
||||
{
|
||||
lastNode->output(0).get_element_type() != element::f32 ?
|
||||
DequantizationOperations::Convert{element::f32} :
|
||||
DequantizationOperations::Convert{},
|
||||
{},
|
||||
{{0.01f},
|
||||
precisionFqOnData}
|
||||
});
|
||||
lastNode->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(lastNode) };
|
||||
|
||||
Reference in New Issue
Block a user