[LPT] MoveFakeQuantize: Q/DQ pattern identification generalization (#18945)

* [LPT] MoveFakeQuantize Q/D pattern dequantization generalization

* [LPT] MoveFakeQuantize Q/D pattern dequantization generalization: quantize op convert
This commit is contained in:
Edward Shogulin 2023-08-17 17:07:51 +01:00 committed by GitHub
parent dea2310153
commit 318009f8d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -107,7 +107,8 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
return false; return false;
} }
const bool q_dq = is_type<opset1::Convert>(convert_q); const auto& dequantization = NetworkHelper::getDequantizationBelow(convert_q, true);
std::vector<std::shared_ptr<ngraph::Node>> newNodes; std::vector<std::shared_ptr<ngraph::Node>> newNodes;
for (size_t i = 0; i < concat->get_input_size(); ++i) { for (size_t i = 0; i < concat->get_input_size(); ++i) {
ov::Output<ov::Node> parent_output; ov::Output<ov::Node> parent_output;
@ -133,7 +134,7 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
ngraph::copy_runtime_info(fq, new_fq); ngraph::copy_runtime_info(fq, new_fq);
new_fq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1)); new_fq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1));
if (q_dq) { if (!dequantization.empty()) {
auto new_convert_q = convert_q->clone_with_new_inputs({new_fq}); auto new_convert_q = convert_q->clone_with_new_inputs({new_fq});
ngraph::copy_runtime_info(convert_q, new_convert_q); ngraph::copy_runtime_info(convert_q, new_convert_q);
new_convert_q->set_friendly_name(convert_q->get_friendly_name() + "_" + std::to_string(i + 1)); new_convert_q->set_friendly_name(convert_q->get_friendly_name() + "_" + std::to_string(i + 1));
@ -146,9 +147,8 @@ bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern
auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end())); auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end()));
newConcat->set_friendly_name(concat->get_friendly_name()); newConcat->set_friendly_name(concat->get_friendly_name());
NetworkHelper::copyInfo(concat, newConcat); NetworkHelper::copyInfo(concat, newConcat);
if (q_dq) { if (!dequantization.empty()) {
auto dq = NetworkHelper::getDequantizationBelow(convert_q); moveDequantizationBefore(context, newConcat, dequantization, false);
moveDequantizationBefore(context, newConcat, dq, false);
return true; return true;
} }
replace_node(fq, newConcat); replace_node(fq, newConcat);