From 5be402750ad458a390c4dc225d906648a4b7c9ac Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 22 Feb 2022 02:02:11 +0300 Subject: [PATCH] [LPT] FuseConvert transformation extension (#10558) * [LPT] FuseConvert transformation extension * [LPT] Tests * [LPT] Cleanup & tests refactoring --- .../include/low_precision/fake_quantize.hpp | 1 - .../fake_quantize_decomposition.hpp | 1 - .../low_precision/fuse_fake_quantize.hpp | 30 --- .../src/fake_quantize.cpp | 3 +- .../src/fuse_convert.cpp | 8 +- .../src/fuse_fake_quantize.cpp | 193 ------------------ .../src/low_precision.cpp | 1 - .../fuse_convert_transformation.cpp | 59 +++++- ...ntize_with_multi_inputs_transformation.cpp | 4 +- ...sformations_after_split_transformation.cpp | 1 - .../common/dequantization_operations.hpp | 12 ++ .../common/fake_quantize_on_data.hpp | 6 + .../fuse_convert_function.hpp | 1 + .../src/fuse_convert_function.cpp | 12 +- 14 files changed, 87 insertions(+), 245 deletions(-) delete mode 100644 src/common/low_precision_transformations/include/low_precision/fuse_fake_quantize.hpp delete mode 100644 src/common/low_precision_transformations/src/fuse_fake_quantize.cpp diff --git a/src/common/low_precision_transformations/include/low_precision/fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/fake_quantize.hpp index 1df89215758..e04626f057c 100644 --- a/src/common/low_precision_transformations/include/low_precision/fake_quantize.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fake_quantize.hpp @@ -7,7 +7,6 @@ #include #include #include "layer_transformation.hpp" -#include "low_precision/fuse_fake_quantize.hpp" namespace ngraph { namespace pass { diff --git a/src/common/low_precision_transformations/include/low_precision/fake_quantize_decomposition.hpp b/src/common/low_precision_transformations/include/low_precision/fake_quantize_decomposition.hpp index bf6bdbce879..171e2515a75 100644 --- a/src/common/low_precision_transformations/include/low_precision/fake_quantize_decomposition.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fake_quantize_decomposition.hpp @@ -7,7 +7,6 @@ #include #include #include "layer_transformation.hpp" -#include "low_precision/fuse_fake_quantize.hpp" namespace ngraph { namespace pass { diff --git a/src/common/low_precision_transformations/include/low_precision/fuse_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/fuse_fake_quantize.hpp deleted file mode 100644 index fc5aa7ce130..00000000000 --- a/src/common/low_precision_transformations/include/low_precision/fuse_fake_quantize.hpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include "low_precision/layer_transformation.hpp" - -namespace ngraph { -namespace pass { -namespace low_precision { - -class LP_TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransformation { -public: - NGRAPH_RTTI_DECLARATION; - FuseFakeQuantizeTransformation(const Params& params); - bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; - bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; - -private: - std::shared_ptr handle( - TransformationContext& context, - const std::shared_ptr& fakeQuantize) const; -}; - -} // namespace low_precision -} // namespace pass -} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/fake_quantize.cpp b/src/common/low_precision_transformations/src/fake_quantize.cpp index 25787b894c9..72628c3b999 100644 --- a/src/common/low_precision_transformations/src/fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fake_quantize.cpp @@ -175,7 +175,8 @@ std::shared_ptr FakeQuantizeTransformation::fuseElementwis return nullptr; } - const auto data = fq::getDataNode(eltwise); + // issue #79980 + const auto data = eltwise->get_input_size() == 1ul ? eltwise->get_input_node_shared_ptr(0) : fq::getDataNode(eltwise); const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise); const auto newFakeQuantize = ov::as_type_ptr(fakeQuantize->clone_with_new_inputs({ diff --git a/src/common/low_precision_transformations/src/fuse_convert.cpp b/src/common/low_precision_transformations/src/fuse_convert.cpp index a6b0c713981..003dc5098f2 100644 --- a/src/common/low_precision_transformations/src/fuse_convert.cpp +++ b/src/common/low_precision_transformations/src/fuse_convert.cpp @@ -23,8 +23,14 @@ FuseConvertTransformation::FuseConvertTransformation(const Params& params) : Lay auto multiply = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); auto subtract = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); auto add = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); + auto fakeQuantize = pattern::wrap_type({ + pattern::wrap_type({pattern::wrap_type()}), + pattern::any_input(), + pattern::any_input(), + pattern::any_input(), + pattern::any_input()}); auto matcher = std::make_shared( - std::make_shared(OutputVector{ multiply, subtract, add }), + std::make_shared(OutputVector{ multiply, subtract, add, fakeQuantize }), "FuseConvertTransformation"); ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { diff --git a/src/common/low_precision_transformations/src/fuse_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_fake_quantize.cpp deleted file mode 100644 index fa9ba5b1c27..00000000000 --- a/src/common/low_precision_transformations/src/fuse_fake_quantize.cpp +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "low_precision/fuse_fake_quantize.hpp" -#include -#include -#include -#include "low_precision/common/ie_lpt_exception.hpp" -#include "low_precision/network_helper.hpp" - -namespace ngraph { -namespace pass { -namespace low_precision { - -NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::FuseFakeQuantizeTransformation, "FuseFakeQuantizeTransformation", 0); - -FuseFakeQuantizeTransformation::FuseFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) { - auto matcher = pattern::wrap_type(); - - ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { - auto op = m.get_match_root(); - if (transformation_callback(op)) { - return false; - } - return transform(*context, m); - }; - - auto m = std::make_shared(matcher, "FuseFakeQuantizeTransformation"); - this->register_matcher(m, callback); -} - -bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) { - auto fakeQuantize = ov::as_type_ptr(m.get_match_root()); - if (!fakeQuantize) - return false; - - do { - fakeQuantize = handle(context, fakeQuantize); - } while (fakeQuantize != nullptr); - return true; -} - -namespace fuse_fq { -namespace { - -std::shared_ptr updateShape(std::shared_ptr op, const PartialShape& targetPShape) { - assert(targetPShape.is_static()); - assert(op->get_output_partial_shape(0).is_static()); - const Shape targetShape = targetPShape.to_shape(); - const Shape shape = op->get_output_shape(0); - - if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) { - op = fold( - op, - std::make_shared(ngraph::element::i32, Shape{ 1 }, std::vector({ 0ul }))); - } - return op; -} - -std::shared_ptr getDataNode(const std::shared_ptr& eltwise) { - if (!ov::is_type(eltwise->get_input_node_shared_ptr(0))) { - return eltwise->get_input_node_shared_ptr(0); - } - - if (!ov::is_type(eltwise->get_input_node_shared_ptr(1))) { - return eltwise->get_input_node_shared_ptr(1); - } - - return nullptr; -} - -std::shared_ptr getConstant(const std::shared_ptr& eltwise) { - if (eltwise->get_input_size() != 2) { - return nullptr; - } - - std::shared_ptr constant = ov::as_type_ptr(eltwise->get_input_node_shared_ptr(1)); - if (constant != nullptr) { - return constant; - } - - return ov::as_type_ptr(eltwise->get_input_node_shared_ptr(0)); -} - -bool eltwiseWithConstant(const std::shared_ptr& eltwise) { - std::shared_ptr constant = getConstant(eltwise); - if (constant == nullptr) { - return false; - } - - Shape shape = constant->get_shape(); - if ((!shape.empty()) && (shape_size(shape) != 1ul)) { - const auto eltwisePShape = eltwise->get_output_partial_shape(0); - if (eltwisePShape.rank().is_dynamic()) { - return false; - } - - const size_t eltwiseOutRank = eltwisePShape.rank().get_length(); - if ((eltwiseOutRank - shape.size()) > 1) { - return false; - } - - if ((eltwiseOutRank - shape.size()) == 1ul) { - shape.insert(shape.begin(), 1ul); - } - - for (size_t i = 2ul; i < shape.size(); ++i) { - if (shape[i] != 1ul) { - return false; - } - } - } - - return getDataNode(eltwise) != nullptr; -} - -} // namespace -} // namespace fuse_fq - -std::shared_ptr FuseFakeQuantizeTransformation::handle( - TransformationContext& context, - const std::shared_ptr& fakeQuantize) const { - const std::shared_ptr eltwise = fakeQuantize->get_input_node_shared_ptr(0); - - std::shared_ptr inputLowConst = fakeQuantize->get_input_node_shared_ptr(1); - std::shared_ptr inputHightConst = fakeQuantize->get_input_node_shared_ptr(2); - - std::shared_ptr constant = fuse_fq::getConstant(eltwise); - if (ov::is_type(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) { - const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ? - constant : - foldConvert(constant, eltwise->get_output_element_type(0)); - - inputLowConst = fuse_fq::updateShape(fold(inputLowConst, value), fakeQuantize->get_output_partial_shape(0)); - inputHightConst = fuse_fq::updateShape(fold(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); - } else if (ov::is_type(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) { - const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ? - constant : - foldConvert(constant, eltwise->get_output_element_type(0)); - - inputLowConst = fuse_fq::updateShape(fold(inputLowConst, value), fakeQuantize->get_output_partial_shape(0)); - inputHightConst = fuse_fq::updateShape(fold(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); - } else if (ov::is_type(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) { - const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ? - constant : - foldConvert(constant, eltwise->get_output_element_type(0)); - - inputLowConst = fuse_fq::updateShape(fold(inputLowConst, value), fakeQuantize->get_output_partial_shape(0)); - inputHightConst = fuse_fq::updateShape(fold(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); - } else if (ov::is_type(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) { - if (ov::is_type(fuse_fq::getDataNode(eltwise)) || - ov::is_type(fuse_fq::getDataNode(eltwise))) { - return nullptr; - } - - const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ? - constant : - foldConvert(constant, eltwise->get_output_element_type(0)); - - inputLowConst = fuse_fq::updateShape(fold(inputLowConst, value), fakeQuantize->get_output_partial_shape(0)); - inputHightConst = fuse_fq::updateShape(fold(inputHightConst, value), fakeQuantize->get_output_partial_shape(0)); - } else if (ov::is_type(eltwise)) { - // issue #40611 - if ((eltwise->get_input_element_type(0) == element::i32) && (eltwise->get_output_element_type(0) == element::f32)) { - return nullptr; - } - } else { - return nullptr; - } - - const auto data = fuse_fq::getDataNode(eltwise); - const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise); - - std::shared_ptr newFakeQuantize = ov::as_type_ptr(fakeQuantize->clone_with_new_inputs({ - data->output(outputIdx), - inputLowConst, - inputHightConst, - fakeQuantize->input_value(3), - fakeQuantize->input_value(4) })); - - replace_node(fakeQuantize, newFakeQuantize); - NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize); - return newFakeQuantize; -} - -bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { - return false; -} - -} // namespace low_precision -} // namespace pass -} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index 038003bfa8c..e91373b1e0f 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -74,7 +74,6 @@ #include "low_precision/convert.hpp" #include "low_precision/fold_fake_quantize.hpp" #include "low_precision/fuse_convert.hpp" -#include "low_precision/fuse_fake_quantize.hpp" #include "low_precision/fuse_subtract_to_fake_quantize.hpp" #include "low_precision/fuse_multiply_to_fake_quantize.hpp" #include "low_precision/multiply_to_group_convolution.hpp" diff --git a/src/tests/functional/inference_engine/lp_transformations/fuse_convert_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/fuse_convert_transformation.cpp index 4d766f589c1..972ac81c8e4 100644 --- a/src/tests/functional/inference_engine/lp_transformations/fuse_convert_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/fuse_convert_transformation.cpp @@ -30,12 +30,14 @@ public: public: ngraph::element::Type inputPrecision; ngraph::builder::subgraph::DequantizationOperations dequantization; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize; }; class Expected { public: ngraph::element::Type inputPrecision; ngraph::builder::subgraph::DequantizationOperations dequantization; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize; }; bool constInput; @@ -58,6 +60,7 @@ public: inputShape, testValues.actual.inputPrecision, testValues.actual.dequantization, + testValues.actual.fakeQuantize, testValues.constInput); SimpleLowPrecisionTransformer transformer; @@ -68,6 +71,7 @@ public: inputShape, testValues.expected.inputPrecision, testValues.expected.dequantization, + testValues.expected.fakeQuantize, testValues.constInput); } @@ -77,9 +81,13 @@ public: std::ostringstream result; result << - inputShape << "_" << - testValues.actual.inputPrecision << "_" << - testValues.actual.dequantization << "_" << + "IS_" << inputShape << "_" << + "AIP_" << testValues.actual.inputPrecision << "_" << + "ADEQ_" << testValues.actual.dequantization << "_" << + "AFQ_" << testValues.actual.fakeQuantize << "_" << + "EIP_" << testValues.expected.inputPrecision << "_" << + "EDEQ_" << testValues.expected.dequantization << "_" << + "EFQ_" << testValues.expected.fakeQuantize << "_" << testValues.constInput; return result.str(); } @@ -111,7 +119,8 @@ const std::vector testValues = { { ngraph::element::f32 }, {1.f}, {0.45f} - } + }, + {} }, { ngraph::element::u8, @@ -119,7 +128,8 @@ const std::vector testValues = { {}, DequantizationOperations::Subtract({1.f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32), {0.45f} - } + }, + {} } }, // fuse to multiply @@ -132,7 +142,8 @@ const std::vector testValues = { { ngraph::element::f32 }, {}, {0.45f} - } + }, + {} }, { ngraph::element::u8, @@ -140,7 +151,8 @@ const std::vector testValues = { {}, {}, DequantizationOperations::Multiply({0.45f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32) - } + }, + {} } }, // Convert with unexpected precision @@ -149,11 +161,13 @@ const std::vector testValues = { LayerTransformation::createParamsU8I8(), { ngraph::element::f32, - {{ ngraph::element::i32 }, {}, {3.f}} + {{ ngraph::element::i32 }, {}, {3.f}}, + {} }, { ngraph::element::f32, - {{ ngraph::element::i32 }, {}, {3.f}} + {{ ngraph::element::i32 }, {}, {3.f}}, + {} } }, }; @@ -173,6 +187,27 @@ const std::vector inputShapes = { }; const std::vector testValuesWithConstant = { + // Constant + // | + // Convert Const Const Const Const + // \ \ | / / + // \ \ | / / + // FakeQuantize + // + { + true, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {}, {}}, + { 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32} + }, + { + ngraph::element::f32, + {}, + { 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32} + } + }, // fuse to const { true, @@ -183,7 +218,8 @@ const std::vector testValuesWithConstant = { ngraph::element::f32 }, {1.f}, {0.45f} - } + }, + {} }, { ngraph::element::f32, @@ -191,7 +227,8 @@ const std::vector testValuesWithConstant = {}, {1.f}, {0.45f} - } + }, + {} } }, }; diff --git a/src/tests/functional/inference_engine/lp_transformations/fuse_fake_quantize_with_multi_inputs_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/fuse_fake_quantize_with_multi_inputs_transformation.cpp index 0df3320562a..2f949a05c1c 100644 --- a/src/tests/functional/inference_engine/lp_transformations/fuse_fake_quantize_with_multi_inputs_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/fuse_fake_quantize_with_multi_inputs_transformation.cpp @@ -12,7 +12,7 @@ #include #include -#include +#include #include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" #include "lpt_ngraph_functions/common/dequantization_operations.hpp" @@ -62,7 +62,7 @@ public: testValues.actual.fakeQuantizeOnData); SimpleLowPrecisionTransformer transformer; - transformer.add(testValues.params); + transformer.add(testValues.params); transformer.transform(actualFunction); referenceFunction = ngraph::builder::subgraph::FuseFakeQuantizeFunction::get( diff --git a/src/tests/functional/inference_engine/lp_transformations/transformations_after_split_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/transformations_after_split_transformation.cpp index 89ce7756452..d054d36f912 100644 --- a/src/tests/functional/inference_engine/lp_transformations/transformations_after_split_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/transformations_after_split_transformation.cpp @@ -35,7 +35,6 @@ // cleanup transformations #include "low_precision/fuse_convert.hpp" -#include "low_precision/fuse_fake_quantize.hpp" #include "low_precision/fuse_subtract_to_fake_quantize.hpp" #include "low_precision/fuse_multiply_to_fake_quantize.hpp" #include "low_precision/multiply_to_group_convolution.hpp" diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp index bfd8c44eadd..0930f48337d 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp @@ -117,10 +117,16 @@ public: }; inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Convert& convert) { + if (convert.empty()) { + return out << "{}"; + } return out << "_" << (convert.outPrecision != element::undefined ? convert.outPrecision.get_type_name() : ""); } inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Subtract& subtract) { + if (subtract.empty()) { + return out << "{}"; + } return out << "_" << subtract.values << "_" << subtract.outPrecision << "_" << @@ -132,6 +138,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation } inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Multiply& multiply) { + if (multiply.empty()) { + return out << "{}"; + } return out << "_" << multiply.values << "_" << multiply.outPrecision << "_" << @@ -142,6 +151,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation } inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations& data) { + if (data.empty()) { + return out << "{}"; + } return out << "_" << data.convert << "_" << data.subtract << "_" << data.multiply; } diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/fake_quantize_on_data.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/fake_quantize_on_data.hpp index ce0a816b90d..6612c978a15 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/fake_quantize_on_data.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/fake_quantize_on_data.hpp @@ -54,6 +54,9 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector& valu } inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& data) { + if (data.empty()) { + return out << "{}"; + } return out << "_" << data.quantizationLevel << data.constantShape << "_" << data.inputLowValues << "_" << data.inputHighValues << "_" << data.outputLowValues << "_" << data.outputHighValues << "_" << (data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name()); @@ -89,6 +92,9 @@ public: }; inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) { + if (data.empty()) { + return out << "{}"; + } return out << "_" << data.quantizationLevel << (data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" << data.inputLowValues << "_" << data.inputHighValues << "_" << diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/fuse_convert_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/fuse_convert_function.hpp index 793948fd991..ecb203d6672 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/fuse_convert_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/fuse_convert_function.hpp @@ -20,6 +20,7 @@ public: const ngraph::PartialShape& inputShape, const ngraph::element::Type inputPrecision, const ngraph::builder::subgraph::DequantizationOperations& dequantization, + const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize, const bool constInput); static std::shared_ptr getWithFQ( diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/fuse_convert_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/fuse_convert_function.cpp index 48bce4c1bad..6f20e1bea9d 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/fuse_convert_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/fuse_convert_function.cpp @@ -16,6 +16,7 @@ std::shared_ptr FuseConvertFunction::get( const ngraph::PartialShape& inputShape, const ngraph::element::Type inputPrecision, const ngraph::builder::subgraph::DequantizationOperations& dequantization, + const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize, const bool constInput) { std::shared_ptr parent; std::shared_ptr input; @@ -28,14 +29,19 @@ std::shared_ptr FuseConvertFunction::get( parent = input; } - const std::shared_ptr dequantizationOp = makeDequantization(parent, dequantization); - dequantizationOp->set_friendly_name("output"); + parent = makeDequantization(parent, dequantization); + + if (!fakeQuantize.empty()) { + parent = makeFakeQuantize(parent, fakeQuantize.outputPrecision, fakeQuantize); + } + + parent->set_friendly_name("output"); auto parameters = constInput ? ngraph::ParameterVector{}: ngraph::ParameterVector{ input }; - ngraph::ResultVector results{ std::make_shared(dequantizationOp) }; + ngraph::ResultVector results{std::make_shared(parent)}; return std::make_shared(results, parameters, "FuseConvertFunction"); }