From a89c4cfc3f58bfb61ee883bb49130322b21bd794 Mon Sep 17 00:00:00 2001 From: Egor Duplenskii Date: Wed, 3 Aug 2022 18:36:01 +0200 Subject: [PATCH] [LPT] Correct a check for whether model is quantized (#12364) Look inside subgraph operations, such as TensorIterator, Loop, If, etc --- .../src/low_precision.cpp | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index 507879aa90f..8cdb7c0ba94 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -13,6 +13,8 @@ #include #include #include +#include "ngraph/op/util/multi_subgraph_base.hpp" + #include #include @@ -260,31 +262,39 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p bool ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(const std::shared_ptr& function) { std::set> handledNodes; std::deque> nodes; - for (auto result : function->get_results()) { + for (const auto result : function->get_results()) { nodes.push_front(result); } while (!nodes.empty()) { - auto node = nodes.front(); + const auto node = nodes.front(); nodes.pop_front(); for (size_t i = 0; i < node->inputs().size(); ++i) { - auto parent = node->get_input_node_shared_ptr(i); + const auto parent = node->get_input_node_shared_ptr(i); if (handledNodes.find(parent) != handledNodes.end()) { continue; } - const std::shared_ptr fakeQuantize = ov::as_type_ptr(parent); - if ((fakeQuantize != nullptr) && - QuantizationDetails::outputLayoutIsSupported(fakeQuantize, true) && - QuantizationDetails::isSupportedLevel(fakeQuantize->get_levels())) { - return true; + if (const auto fakeQuantize = ov::as_type_ptr(parent)) { + if (QuantizationDetails::outputLayoutIsSupported(fakeQuantize, true) && + QuantizationDetails::isSupportedLevel(fakeQuantize->get_levels())) { + return true; + } + } else if (const auto multiSubGraph = ov::as_type_ptr(parent)) { + // Look inside subraph operations, such as TensorIterator, Loop, If, etc + for (int i = 0; i < multiSubGraph->get_internal_subgraphs_size(); i++) { + if (isFunctionQuantized(multiSubGraph->get_function(i))) { + return true; + } + } } nodes.push_front(parent); handledNodes.insert(parent); } } + return false; }