diff --git a/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp b/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp index f7162d8c6fd..e08c27770d2 100644 --- a/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp +++ b/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp @@ -333,6 +333,13 @@ protected: const bool updatePrecision, const bool moveSubtract = true) const; + std::shared_ptr moveDequantizationBefore( + TransformationContext& context, + const std::shared_ptr& operation, + const FakeQuantizeDequantization& dequantization, + const bool updatePrecision, + const bool moveSubtract = true) const; + void updateOutput( TransformationContext &context, std::shared_ptr lastNode, diff --git a/src/common/low_precision_transformations/include/low_precision/move_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/move_fake_quantize.hpp index 4e0e8054e55..487f9bfa0bc 100644 --- a/src/common/low_precision_transformations/include/low_precision/move_fake_quantize.hpp +++ b/src/common/low_precision_transformations/include/low_precision/move_fake_quantize.hpp @@ -17,6 +17,7 @@ public: NGRAPH_RTTI_DECLARATION; MoveFakeQuantize(const Params& params = Params()); bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; }; diff --git a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp index 569cd79d506..16f23558785 100644 --- a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp +++ b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp @@ -173,6 +173,16 @@ public: const bool updatePrecision, const bool moveSubtract); + static InsertDequantizationResult moveDequantizationBefore( + const std::shared_ptr& operation, + const FakeQuantizeDequantization& dequantization, + const bool updatePrecision, + const bool moveSubtract); + + static std::vector>> split_consts_before_concat( + const std::shared_ptr concat, + const std::vector> currConstants); + static bool checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr& constant); static size_t getChildInputIndex(const std::shared_ptr& parent, const std::shared_ptr& child); diff --git a/src/common/low_precision_transformations/src/layer_transformation.cpp b/src/common/low_precision_transformations/src/layer_transformation.cpp index 1ba9700255e..6326c990ee3 100644 --- a/src/common/low_precision_transformations/src/layer_transformation.cpp +++ b/src/common/low_precision_transformations/src/layer_transformation.cpp @@ -387,6 +387,17 @@ std::shared_ptr LayerTransformation::moveDequantizationAfter( return result.newOperation; } +std::shared_ptr LayerTransformation::moveDequantizationBefore( + TransformationContext& context, + const std::shared_ptr& operation, + const FakeQuantizeDequantization& dequantization, + const bool updatePrecision, + const bool moveSubtract) const { + const auto result = ngraph::pass::low_precision::NetworkHelper::moveDequantizationBefore(operation, dequantization, updatePrecision, moveSubtract); + updateOutput(context, result.newOperation, result.lastDequantization); + return result.newOperation; +} + void LayerTransformation::updateOutput( TransformationContext &context, std::shared_ptr lastNode, diff --git a/src/common/low_precision_transformations/src/move_fake_quantize.cpp b/src/common/low_precision_transformations/src/move_fake_quantize.cpp index bc28f3737f2..caaaf14204b 100644 --- a/src/common/low_precision_transformations/src/move_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/move_fake_quantize.cpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include #include @@ -39,26 +39,12 @@ MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(p output_low, output_high }); - ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto op = m.get_match_root(); if (transformation_callback(op)) { return false; } - // workaround: only per-tensor quantization is allowed - const auto& pattern_map = m.get_pattern_value_map(); - const auto is_scalar = [&](const std::shared_ptr& wrapped_constant) { - return NetworkHelper::isScalarLike( - as_type_ptr(pattern_map.at(wrapped_constant).get_node_shared_ptr())); - }; - - if (!is_scalar(input_low) || - !is_scalar(input_high) || - !is_scalar(output_low) || - !is_scalar(output_high)) { - return false; - } - return transform(*context, m); }; @@ -70,49 +56,111 @@ MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(p bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { auto fq = m.get_match_root(); + if (!canBeTransformed(context, fq)) { + return false; + } auto operation = fq->get_input_node_shared_ptr(0); std::shared_ptr concat; - bool only_concat = true; - std::string fq_original_name = fq->get_friendly_name(), operation_original_name; + bool without_operation = true; + std::string fq_original_name = fq->get_friendly_name(), + operation_original_name, + convert_q_original_name; if (is_type(operation)) { concat = operation; } else { operation_original_name = operation->get_friendly_name(); concat = operation->get_input_node_shared_ptr(0); - only_concat = false; + without_operation = false; } if (!ConcatTransformation::isQuantizedStatic(concat)) { return false; } - std::vector> fqs; - size_t input_size = concat->get_input_size(); - for (size_t i{ 0 }; i < input_size; ++i) { + auto convert_q = (*fq->output(0).get_target_inputs().begin()).get_node()->shared_from_this(); + bool q_dq = is_type(convert_q); + std::vector> currConstants(4); + bool multi_chanels = false; + const auto number_of_concat_inputs = concat->get_input_size(); + const auto concatNode = as_type_ptr(concat); + const auto concat_axis = concatNode->get_concatenation_axis(); + for (size_t i = 0; i < 4; i++) { + currConstants[i] = as_type_ptr(fq->get_input_node_shared_ptr(i + 1)); + if (!multi_chanels && currConstants[i]->get_shape().size() > 1 && currConstants[i]->get_shape()[concat_axis] != 1) { + multi_chanels = true; + } + } + std::vector>> newConstants; + if (multi_chanels) { + newConstants = NetworkHelper::split_consts_before_concat(concat, currConstants); + } + std::vector> newNodes; + for (size_t i{ 0 }; i < number_of_concat_inputs; ++i) { std::shared_ptr fq_input; - if (only_concat) { + if (without_operation) { fq_input = concat->get_input_node_shared_ptr(i); } else { auto input = concat->get_input_node_shared_ptr(i); fq_input = operation->clone_with_new_inputs({ input }); fq_input->set_friendly_name(operation_original_name + "_" + std::to_string(i + 1)); } - auto newFq = fq->clone_with_new_inputs({ fq_input, - fq->get_input_node_shared_ptr(1)->clone_with_new_inputs({}), - fq->get_input_node_shared_ptr(2)->clone_with_new_inputs({}), - fq->get_input_node_shared_ptr(3)->clone_with_new_inputs({}), - fq->get_input_node_shared_ptr(4)->clone_with_new_inputs({}) }); + std::shared_ptr newFq; + if (multi_chanels) { + newFq = fq->clone_with_new_inputs({ fq_input, + newConstants[0][newConstants[0].size() == 1 ? 0 : i], + newConstants[1][newConstants[1].size() == 1 ? 0 : i], + newConstants[2][newConstants[2].size() == 1 ? 0 : i], + newConstants[3][newConstants[3].size() == 1 ? 0 : i] }); + } else { + newFq = fq->clone_with_new_inputs({ fq_input, + fq->get_input_node_ptr(1)->clone_with_new_inputs({}), + fq->get_input_node_ptr(2)->clone_with_new_inputs({}), + fq->get_input_node_ptr(3)->clone_with_new_inputs({}), + fq->get_input_node_ptr(4)->clone_with_new_inputs({}) }); + } + ngraph::copy_runtime_info(fq, newFq); newFq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1)); - fqs.push_back(newFq); + if (q_dq) { + auto newConvert_q = convert_q->clone_with_new_inputs({ newFq }); + ngraph::copy_runtime_info(convert_q, newConvert_q); + newConvert_q->set_friendly_name(convert_q->get_friendly_name() + "_" + std::to_string(i + 1)); + newNodes.push_back(newConvert_q); + } else { + newNodes.push_back(newFq); + } } - ngraph::copy_runtime_info(fq, fqs); - auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(fqs.begin(), fqs.end())); + auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end())); + newConcat->set_friendly_name(concat->get_friendly_name()); - replace_node(fq, newConcat); NetworkHelper::copyInfo(concat, newConcat); + if (q_dq) { + auto dq = NetworkHelper::getDequantizationBelow(convert_q); + moveDequantizationBefore(context, newConcat, dq, false); + return true; + } + replace_node(fq, newConcat); updateOutput(context, newConcat, fq); return true; } -bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr layer) const noexcept { +bool MoveFakeQuantize::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { + auto operation = layer->get_input_node_shared_ptr(0); + std::shared_ptr concat; + if (is_type(operation)) { + concat = operation; + } else { + concat = operation->get_input_node_shared_ptr(0); + } + if (!ConcatTransformation::isQuantizedStatic(concat)) { + return false; + } + auto convert_q = (*layer->output(0).get_target_inputs().begin()).get_node()->shared_from_this(); + bool q_dq = is_type(convert_q); + if (q_dq && (convert_q->get_output_size() != 1 || layer->get_output_size() != 1)) { + return false; + } + return true; +} + +bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr) const noexcept { return true; } diff --git a/src/common/low_precision_transformations/src/network_helper.cpp b/src/common/low_precision_transformations/src/network_helper.cpp index 492de2f0e47..a74be7f6c48 100644 --- a/src/common/low_precision_transformations/src/network_helper.cpp +++ b/src/common/low_precision_transformations/src/network_helper.cpp @@ -1662,6 +1662,144 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter return InsertDequantizationResult(newOperation, parent); } +NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefore( + const std::shared_ptr& operation, + const FakeQuantizeDequantization& dequantization, + const bool updatePrecision, + const bool moveSubtract) { + assert( + (NetworkHelper::getDequantizationBelow(operation).subtractConstant == nullptr) || + (NetworkHelper::getDequantizationBelow(operation).subtractConstant.get() == dequantization.subtractConstant.get())); + + assert( + (NetworkHelper::getDequantizationBelow(operation).multiplyConstant == nullptr) || + (NetworkHelper::getDequantizationBelow(operation).multiplyConstant.get() == dequantization.multiplyConstant.get())); + std::vector>> multiplyConstants, subtractConstants; + if (is_type(operation)) { + const auto concatNode = as_type_ptr(operation); + auto axis = concatNode->get_concatenation_axis(); + if (dequantization.multiply && dequantization.multiplyConstant->get_shape().size() > 1 && dequantization.multiplyConstant->get_shape()[axis] != 1) { + multiplyConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.multiplyConstant }); + } + if (dequantization.subtract && dequantization.subtractConstant->get_shape().size() > 1 && dequantization.subtractConstant->get_shape()[axis] != 1) { + subtractConstants = NetworkHelper::split_consts_before_concat(operation, { dequantization.subtractConstant }); + } + } + std::vector> newNodes; + for (size_t i = 0; i < operation->get_input_size(); ++i) { + auto parent = operation->get_input_node_shared_ptr(i); + const element::Type deqPrecision = dequantization.multiplyConstant->get_element_type(); + const bool shouldConvert = (operation->get_output_element_type(0) != deqPrecision); + if (shouldConvert) { + const auto convertOutputPrecision = dequantization.convert != nullptr ? + dequantization.convert->get_output_element_type(0) : + deqPrecision; + parent = std::make_shared(parent, convertOutputPrecision); + parent->set_friendly_name(dequantization.convert->get_friendly_name() + "_" + std::to_string(i + 1)); + ngraph::copy_runtime_info(dequantization.convert, parent); + } + if (moveSubtract && (dequantization.subtract != nullptr)) { + if (dequantization.subtractConvert == nullptr) { + const element::Type parentPrecision = parent->get_output_element_type(0); + if (parentPrecision.bitwidth() < dequantization.subtractConstant->get_element_type().bitwidth()) { + THROW_IE_LPT_EXCEPTION(*parent) << + "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision << + ", subtract dequantization constant " << dequantization.subtractConstant->get_friendly_name() << ":" << + dequantization.subtractConstant->get_element_type(); + } + auto subtractConstant = subtractConstants.size() ? subtractConstants[0][i] : dequantization.subtractConstant; + parent = std::make_shared>( + std::vector{element::f32, element::f32}, std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType( + subtractConstant->output(0).get_element_type() == parentPrecision ? + subtractConstant : + foldConvert(subtractConstant, parentPrecision), element::f32).get()); + parent->set_friendly_name(dequantization.subtract->get_friendly_name() + "_" + std::to_string(i + 1)); + } else { + parent = std::make_shared(parent, dequantization.subtractConvert); + } + ngraph::copy_runtime_info(dequantization.subtract, parent); + } + + if (dequantization.multiply != nullptr) { + auto multiplyConstant = multiplyConstants.size() ? multiplyConstants[0][i] : dequantization.multiplyConstant; + const element::Type parentPrecision = parent->get_output_element_type(0); + if (parentPrecision.bitwidth() < multiplyConstant->get_element_type().bitwidth()) { + THROW_IE_LPT_EXCEPTION(*parent) << + "unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision << + ", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->get_element_type(); + } + + parent = std::make_shared>( + opset1::Multiply(parent, + multiplyConstant->output(0).get_element_type() == parentPrecision ? + multiplyConstant : + foldConvert(multiplyConstant->output(0), parentPrecision)), + dequantization.multiply->get_output_element_type(0)); + ngraph::copy_runtime_info(dequantization.multiply, parent); + parent->set_friendly_name(dequantization.multiply->get_friendly_name() + "_" + std::to_string(i + 1)); + } + if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) { + // issue #43088 + // NetworkHelper::optimizeElementwise(dequantization.subtract); + } + newNodes.push_back(parent); + } + auto newOperation = operation->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end())); + NetworkHelper::copyInfo(operation, newOperation); + replace_node(dequantization.multiply, newOperation); + + auto op = std::dynamic_pointer_cast(newOperation); + if (op != nullptr) { + if (updatePrecision) { + op->set_overridden_output_type(newOperation->get_input_element_type(0)); + } else if (dequantization.multiply) { + op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type()); + } else if (dequantization.subtract) { + op->set_overridden_output_type(dequantization.subtractConstant->get_element_type()); + } + std::dynamic_pointer_cast(newOperation)->validate_and_infer_types(); + } + return InsertDequantizationResult(newOperation, dequantization.multiply); +} + +std::vector>> NetworkHelper::split_consts_before_concat(const std::shared_ptr concat, + const std::vector> currConstants) { + std::vector>> newConstants(currConstants.size()); + auto number_of_concat_inputs = concat->get_input_size(); + const auto concatNode = as_type_ptr(concat); + const auto concat_axis = concatNode->get_concatenation_axis(); + std::vector shape_axis(number_of_concat_inputs); + for (size_t i{ 0 }; i < number_of_concat_inputs; ++i) { + auto shape = concat->get_input_shape(i); + shape_axis[i] = shape[concat_axis]; + } + for (size_t i = 0; i < currConstants.size(); ++i) { + std::vector> newConstant; + if (currConstants[i]->output(0).get_shape()[concat_axis] == 1) { + newConstant.push_back(currConstants[i]); + newConstants[i] = newConstant; + continue; + } + auto split = std::make_shared(currConstants[i], + opset1::Constant::create(element::i64, Shape{}, { concat_axis }), + opset1::Constant::create(element::i64, Shape{ number_of_concat_inputs }, shape_axis)); + OutputVector outputResults(split->get_output_size()); + auto foldResult = split->constant_fold(outputResults, split->input_values()); + if (!foldResult) { + // handle potential constant fold issue here + } + for (auto outputResult : outputResults) { + auto constant = as_type_ptr(outputResult.get_node_shared_ptr()); + newConstant.push_back(constant); + } + + newConstants[i] = newConstant; + } + return newConstants; +} + bool NetworkHelper::checkConstantValuePrecision(const element::Type expectedPrecision, const std::shared_ptr& constant) { if (expectedPrecision.is_signed()) { return true; diff --git a/src/tests/functional/inference_engine/lp_transformations/move_fake_quantize_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/move_fake_quantize_transformation.cpp index 6f3e9806f21..ab10755c402 100644 --- a/src/tests/functional/inference_engine/lp_transformations/move_fake_quantize_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/move_fake_quantize_transformation.cpp @@ -37,12 +37,10 @@ namespace { class MoveFakeQuantizeTransformationActualValues { public: - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + size_t number_of_operations; + std::vector fakeQuantizeBefore; + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; std::string operation; ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter; @@ -51,12 +49,9 @@ public: inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeTransformationActualValues& values) { return out << "_" << - values.fakeQuantizeBefore1 << "_" << - values.convertBefore1.outPrecision << "_" << - values.dequantizationBefore1 << "_" << - values.fakeQuantizeBefore2 << "_" << - values.convertBefore2.outPrecision << "_" << - values.dequantizationBefore2 << "_" << + values.number_of_operations << "_" << + values.convertBefore.outPrecision << "_" << + values.dequantizationBefore << "_" << values.operation << "_" << values.fakeQuantizeAfter << "_" << values.convertAfter.outPrecision << "_" << @@ -65,33 +60,25 @@ inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeTransfo class MoveFakeQuantizeTransformationResultValues { public: - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + size_t number_of_operations; + std::vector fakeQuantizeBefore; + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; std::string operation; ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter; ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; ngraph::element::Type precisionAfterOperation; - ngraph::builder::subgraph::DequantizationOperations dequantizationAfterNotFQ; }; inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeTransformationResultValues& values) { return out << "_" << - values.fakeQuantizeBefore1 << "_" << - values.convertBefore1.outPrecision << "_" << - values.dequantizationBefore1 << "_" << - values.fakeQuantizeBefore2 << "_" << - values.convertBefore2.outPrecision << "_" << - values.dequantizationBefore2 << "_" << + values.convertBefore.outPrecision << "_" << + values.dequantizationBefore << "_" << values.operation << "_" << values.fakeQuantizeAfter << "_" << values.convertAfter << "_" << - values.dequantizationAfter << "_" << - values.dequantizationAfterNotFQ; + values.dequantizationAfter; } class MoveFakeQuantizeTransformationTestValues { @@ -126,7 +113,7 @@ inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeTransfo typedef std::tuple < ngraph::element::Type, - ngraph::PartialShape, + std::vector, MoveFakeQuantizeTransformationTestValues > MoveFakeQuantizeTransformationParams; @@ -134,16 +121,13 @@ class MoveFakeQuantizeTransformation : public LayerTransformation, public testin public: void SetUp() override { const ngraph::element::Type precision = std::get<0>(GetParam()); - const ngraph::PartialShape shape = std::get<1>(GetParam()); + const std::vector shape = std::get<1>(GetParam()); + //const auto shape = std::get<1>(GetParam()); MoveFakeQuantizeTransformationTestValues testValues = std::get<2>(GetParam()); - // dequantization output precision depends on input precision // to avoid huge amount of tests cases let's define dequantization output precision as input precision - if (!testValues.actual.dequantizationBefore1.multiply.empty()) { - testValues.actual.dequantizationBefore1.multiply.outPrecision = precision; - } - if (!testValues.actual.dequantizationBefore2.multiply.empty()) { - testValues.actual.dequantizationBefore2.multiply.outPrecision = precision; + if (!testValues.actual.dequantizationBefore.multiply.empty()) { + testValues.actual.dequantizationBefore.multiply.outPrecision = precision; } IntervalsAlignmentSharedValue::Interval interval{ -1.28f, 2.55f }; @@ -151,12 +135,10 @@ public: actualFunction = ngraph::builder::subgraph::MoveFakeQuantize::get( precision, shape, - testValues.actual.fakeQuantizeBefore1, - testValues.actual.convertBefore1, - testValues.actual.dequantizationBefore1, - testValues.actual.fakeQuantizeBefore2, - testValues.actual.convertBefore2, - testValues.actual.dequantizationBefore2, + testValues.actual.number_of_operations, + testValues.actual.fakeQuantizeBefore, + testValues.actual.convertBefore, + testValues.actual.dequantizationBefore, testValues.actual.operation, testValues.actual.fakeQuantizeAfter, testValues.actual.convertAfter, @@ -167,7 +149,6 @@ public: QuantizationAlignmentAttribute(false) }, ngraph::element::undefined, - {}, testValues.axis); auto supportedPrecisionsOnActivation = std::vector({ ngraph::pass::low_precision::OperationPrecisionRestriction::create({{0, testValues.params.precisionsOnActivations}}) @@ -183,6 +164,7 @@ public: ov::pass::Manager manager; manager.register_pass(params); manager.run_passes(actualFunction); + // dequantization output precision depends on input precision // to avoid huge amount of tests cases let's define dequantization output precision as input precision if (!testValues.result.dequantizationAfter.multiply.empty()) { @@ -198,12 +180,10 @@ public: referenceFunction = ngraph::builder::subgraph::MoveFakeQuantize::get( precision, shape, - testValues.result.fakeQuantizeBefore1, - testValues.result.convertBefore1, - testValues.result.dequantizationBefore1, - testValues.result.fakeQuantizeBefore2, - testValues.result.convertBefore2, - testValues.result.dequantizationBefore2, + testValues.result.number_of_operations, + testValues.result.fakeQuantizeBefore, + testValues.result.convertBefore, + testValues.result.dequantizationBefore, testValues.result.operation, testValues.result.fakeQuantizeAfter, testValues.result.convertAfter, @@ -214,18 +194,16 @@ public: QuantizationAlignmentAttribute(false) }, testValues.result.precisionAfterOperation, - {}, testValues.axis); } - static std::string getTestCaseName(testing::TestParamInfo obj) { const ngraph::element::Type precision = std::get<0>(obj.param); - const ngraph::PartialShape shape = std::get<1>(obj.param); + const std::vector shape = std::get<1>(obj.param); const MoveFakeQuantizeTransformationTestValues testValues = std::get<2>(obj.param); std::ostringstream result; result << - LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << + LayerTransformation::getTestCaseNameByParams(precision, shape[0], testValues.params) << "_" << (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << "axis_" << testValues.axis << "_" << testValues.actual << "_" << @@ -236,7 +214,7 @@ public: TEST_P(MoveFakeQuantizeTransformation, CompareFunctions) { actualFunction->validate_nodes_and_infer_types(); - auto res = compare_functions(referenceFunction, actualFunction, true, true, true, true, true); + auto res = compare_functions(referenceFunction, actualFunction, true, true, true, true, false); ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; @@ -252,21 +230,19 @@ const std::vector precisions = { }; namespace testValues1 { -const std::vector shapes = { - { 1, 3, 9, 9 }, - { 4, 3, 9, 9 }, - { Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() } +const std::vector> shapes = { + {{ 1, 3, 9, 9 }}, + {{ 4, 3, 9, 9 }}, + {{ Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() }} }; const std::vector testValues = { - // U8: concat + // without operation { LayerTransformation::createParamsU8I8(), false, 1, { - {}, - {}, - {}, + 2, {}, {}, {}, @@ -276,28 +252,23 @@ const std::vector testValues = { {} }, { - { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {}, - { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + 2, + {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}}, {}, {}, "", {}, {}, {}, - }, - false, - false + } }, + // with ReLU { LayerTransformation::createParamsU8I8(), false, 1, { - {}, - {}, - {}, + 2, {}, {}, {}, @@ -307,28 +278,23 @@ const std::vector testValues = { {} }, { - { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {}, - { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + 2, + {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}}, {}, {}, "relu", {}, {}, {}, - }, - false, - false + } }, + // negative test { LayerTransformation::createParamsU8I8(), false, 0, { - {}, - {}, - {}, + 2, {}, {}, {}, @@ -338,9 +304,7 @@ const std::vector testValues = { {} }, { - {}, - {}, - {}, + 2, {}, {}, {}, @@ -348,9 +312,109 @@ const std::vector testValues = { { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, {}, {} - }, + } + }, + // Q/DQ + { + LayerTransformation::createParamsU8I8(), false, - false + 1, + { + 2, + {}, + {}, + {}, + "", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + { ngraph::element::u8 }, + { + { element::f32 }, + {}, + { 0.01f } + }, + }, + { + 2, + {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}}, + { ngraph::element::u8 }, + { + { element::f32 }, + {}, + { 0.01f } + }, + "", + {}, + {}, + {}, + } + }, + // Q/DQ with ReLU + { + LayerTransformation::createParamsU8I8(), + false, + 1, + { + 2, + {}, + {}, + {}, + "relu", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + { ngraph::element::u8 }, + { + { element::f32 }, + {}, + { 0.01f } + }, + }, + { + 2, + {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}}, + { ngraph::element::u8 }, + { + { element::f32 }, + {}, + { 0.01f } + }, + "relu", + {}, + {}, + {}, + } + }, + // Q/DQ with subtract + { + LayerTransformation::createParamsU8I8(), + false, + 1, + { + 2, + {}, + {}, + {}, + "", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + { ngraph::element::u8 }, + { + { element::f32 }, + { 0.01f }, + { 0.01f } + }, + }, + { + 2, + {{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}}, + { ngraph::element::u8 }, + { + { element::f32 }, + { 0.01f }, + { 0.01f } + }, + "", + {}, + {}, + {}, + } }, }; @@ -363,4 +427,66 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(testValues)), MoveFakeQuantizeTransformation::getTestCaseName); } // namespace testValues1 +namespace testValues2 { +const std::vector precisions = { +ngraph::element::f32, +ngraph::element::f16 +}; + +const std::vector> shapes = { + {{ 1, 1, 224, 224 }, { 1, 2, 224, 224 }}, + {{ 4, 1, 9, 9 }, { 4, 2, 9, 9 }} +}; +const std::vector testValues = { + // multi-chanels + { + LayerTransformation::createParamsU8I8(), + true, + 1, + { + 2, + {}, + {}, + {}, + "", + { + 256ul, + {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}}, + {-2.66068696975708f}, {2.6399004459381104f}, + {-31.695816040039062f, -35.69844055175781f, -49.126914978027344f}, + {277.8320007324219f, 267.07110595703125f, 254.99429321289062f} + }, + {}, + {} + }, + { + 2, + { + {256ul, + {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, + {-2.66068696975708f}, {2.6399004459381104f}, {-31.695816040039062f}, {277.8320007324219f}}, + {256ul, + {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 2, 1, 1}, {1, 2, 1, 1}}, + {-2.66068696975708f}, {2.6399004459381104f}, + {-35.69844055175781f, -49.126914978027344f}, + {267.07110595703125f, 254.99429321289062f}} + }, + {}, + {}, + "", + {}, + {}, + {}, + } + }, +}; +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MoveFakeQuantizeTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(shapes), + ::testing::ValuesIn(testValues)), + MoveFakeQuantizeTransformation::getTestCaseName); +} // namespace testValues2 } // namespace diff --git a/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp index 09ac0e229a1..2dc602fa656 100644 --- a/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp +++ b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp @@ -9,7 +9,6 @@ using namespace LayerTestsDefinitions; -namespace { const std::vector netPrecisions = { ngraph::element::f32, //ngraph::element::f16 @@ -19,12 +18,12 @@ const std::vector tras LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true) }; +namespace testValues1 { + const std::vector params = { - // without operation - { - {}, - {}, - {}, + // without operation + { + 3, {}, {}, {}, @@ -38,9 +37,7 @@ const std::vector pa }, // with ReLU operation { - {}, - {}, - {}, + 3, {}, {}, {}, @@ -52,27 +49,117 @@ const std::vector pa "U8", 1 }, - // negative axis + // Q/DQ { - {}, - {}, - {}, + 3, {}, {}, {}, "", - {256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}}, + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { 0.01f } + }, + "Concatenation", + "U8", + 1 + }, + // Q/DQ with ReLU + { + 3, + {}, + {}, + {}, + "relu", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { 0.01f } + }, + "Concatenation", + "U8", + 1 + }, + // multi-chanels + { + 3, + {}, + {}, + {}, + "relu", + { + 256ul, + {{1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + {2.55f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f, 2.55f / 5.f, 2.55f / 6.f}, + {-128.f, -128.f, -128.f, -128.f, -128.f, -128.f}, + {127.f, 127.f, 127.f, 127.f, 127.f, 127.f} + }, {}, {}, "Concatenation", - "FP32", - 0 - } + "I8", + 1 + }, + // Q/DQ with multi-channels multiply + { + 3, + {}, + {}, + {}, + "", + { + 256ul, + {{1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + {2.55f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f, 2.55f / 5.f, 2.55f / 6.f}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + {255.f, 255.f / 2.f, 255.f / 3.f, 255.f / 4.f, 255.f / 5.f, 255.f / 6.f}, + }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { {0.01f, 0.02f, 0.03f, 0.04f, 0.05f, 0.06f}, ngraph::element::f32, {1, 6, 1, 1} }, + }, + "Concatenation", + "U8", + 1 + }, + // Q/DQ with multi-channels subtract + { + 3, + {}, + {}, + {}, + "", + { + 256ul, + {{1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}, {1, 6, 1, 1}}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + {2.55f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f, 2.55f / 5.f, 2.55f / 6.f}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + {255.f, 255.f / 2.f, 255.f / 3.f, 255.f / 4.f, 255.f / 5.f, 255.f / 6.f}, + }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + { {-127.f, -127.f / 2.f, -127.f / 3.f, -127.f / 4.f, -127.f / 5.f, -127.f / 6.f}, ngraph::element::f32, {1, 6, 1, 1} }, + { 0.01f }, + }, + "Concatenation", + "U8", + 1 + }, }; -const std::vector shapes = { - { 1, 3, 16, 16 }, - { 4, 3, 16, 16 } +const std::vector> shapes = { + {{ 1, 1, 16, 16 }, { 1, 2, 16, 16 }, { 1, 3, 16, 16 }} }; INSTANTIATE_TEST_SUITE_P(smoke_LPT, MoveFakeQuantizeTransformation, @@ -83,4 +170,36 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, MoveFakeQuantizeTransformation, ::testing::ValuesIn(trasformationParamValues), ::testing::ValuesIn(params)), MoveFakeQuantizeTransformation::getTestCaseName); -} // namespace +} // namespace testValues1 + +namespace testValues2 { + + const std::vector params = { + // negative axis + { + 3, + {}, + {}, + {}, + "", + {256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}}, + {}, + {}, + "Concatenation", + "FP32", + -1 + }, + }; + const std::vector> shapes = { + {{ 1, 1, 16, 16 }} + }; + + INSTANTIATE_TEST_SUITE_P(smoke_LPT, MoveFakeQuantizeTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(shapes), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + MoveFakeQuantizeTransformation::getTestCaseName); +} // namespace testValues2 diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp index 86b44f3b248..da08a05e15d 100644 --- a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp @@ -15,72 +15,171 @@ const std::vector netPrecisions = { ngraph::element::f16 }; - const std::vector trasformationParamValues = { - LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams(), - }; +const std::vector trasformationParamValues = { + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams(), +}; - const std::vector params = { +const std::vector params = { // without operation { - {}, - {}, - {}, + 2, + {}, + {}, + {}, + "", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {}, + {}, + "Concat", + "U8", + 1, + }, + // with ReLU operation + { + 2, + {}, + {}, + {}, + "relu", + { 256ul, {}, { -12.7f }, { 12.7f }, { -12.7f }, { 12.7f }}, + {}, + {}, + "Concat", + "U8", + 1 + }, + // negative axis + { + 2, + {}, + {}, + {}, + "", + {256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}}, + {}, + {}, + "Concat", + "FP32", + 0 + }, + // Q/DQ + { + 2, + {}, + {}, + {}, + "", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { 0.01f } + }, + "Concat", + "U8", + 1 + }, + // Q/DQ with ReLU + { + 2, + {}, + {}, + {}, + "relu", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { 0.01f } + }, + "Concat", + "U8", + 1 + }, + // multi chanel + { + 3, + {}, + {}, + {}, + "relu", + { 256ul, + {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}}, + {-2.66068696975708f}, {2.6399004459381104f}, + {-31.695816040039062f, -35.69844055175781f, -49.126914978027344f}, + {277.8320007324219f, 267.07110595703125f, 254.99429321289062f} + }, + {}, + {}, + "Concat", + "U8", + 1 + }, + // Q/DQ with multi-channels + { + 3, + {}, + {}, + {}, + "", + { + 256ul, + {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}}, + {0.f, 0.f, 0.f}, + {2.55f, 2.55f, 2.55f}, + {0.f, 0.f, 0.f}, + {255.f, 255.f, 255.f} + }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + {}, + { {0.01f, 0.01f, 0.01f}, ngraph::element::f32, {1, 3, 1, 1} } + }, + "Concat", + "U8", + 1 + }, + // Q/DQ with multi-channels subtruct + { + 3, {}, {}, {}, "", - { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {}, - "Concat", - "U8", - 1, - }, - // with ReLU operation - { - {}, - {}, - {}, - {}, - {}, - {}, - "relu", - { 256ul, {}, { -12.7f }, { 12.7f }, { -12.7f }, { 12.7f }}, - {}, - {}, + { + 256ul, + {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}}, + {0.f, 0.f, 0.f}, + {2.55f, 2.55f, 2.55f}, + {0.f, 0.f, 0.f}, + {255.f, 255.f, 255.f} + }, + { ngraph::element::u8 }, + { + { ngraph::element::f32 }, + { {0.01f, 0.01f, 0.01f}, ngraph::element::f32, {1, 3, 1, 1} }, + { 0.01f } + }, "Concat", "U8", 1 }, - // negative axis - { - {}, - {}, - {}, - {}, - {}, - {}, - "", - {256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}}, - {}, - {}, - "Concat", - "FP32", - 0 - } - }; +}; - const std::vector shapes = { - { 1, 3, 16, 16 }, - { 4, 3, 16, 16 } - }; +const std::vector> shapes = { + {{ 1, 1, 16, 16 }}, + {{ 4, 1, 16, 16 }} +}; - INSTANTIATE_TEST_SUITE_P(smoke_LPT, MoveFakeQuantizeTransformation, - ::testing::Combine( - ::testing::ValuesIn(netPrecisions), - ::testing::ValuesIn(shapes), - ::testing::Values(CommonTestUtils::DEVICE_GPU), - ::testing::ValuesIn(trasformationParamValues), - ::testing::ValuesIn(params)), - MoveFakeQuantizeTransformation::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_LPT, MoveFakeQuantizeTransformation, +::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(shapes), + ::testing::Values(CommonTestUtils::DEVICE_GPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), +MoveFakeQuantizeTransformation::getTestCaseName); } // namespace diff --git a/src/tests/functional/plugin/shared/include/low_precision_transformations/move_fake_quantize_transformation.hpp b/src/tests/functional/plugin/shared/include/low_precision_transformations/move_fake_quantize_transformation.hpp index e53eef8b048..f6bd4b827ac 100644 --- a/src/tests/functional/plugin/shared/include/low_precision_transformations/move_fake_quantize_transformation.hpp +++ b/src/tests/functional/plugin/shared/include/low_precision_transformations/move_fake_quantize_transformation.hpp @@ -19,12 +19,10 @@ namespace LayerTestsDefinitions { class MoveFakeQuantizeTransformationParam { public: - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; - ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2; - ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + size_t number_of_operations; + std::vector fakeQuantizeBefore; + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; std::string operation; ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter; @@ -36,7 +34,7 @@ public: typedef std::tuple < ngraph::element::Type, - ngraph::Shape, + std::vector, std::string, ngraph::pass::low_precision::LayerTransformation::Params, MoveFakeQuantizeTransformationParam diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/move_fake_quantize_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/move_fake_quantize_transformation.cpp index d92181e34c2..fec0efdaa23 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/move_fake_quantize_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/move_fake_quantize_transformation.cpp @@ -20,21 +20,21 @@ namespace LayerTestsDefinitions { std::string MoveFakeQuantizeTransformation::getTestCaseName(testing::TestParamInfo obj) { ngraph::element::Type netPrecision; - ngraph::PartialShape inputShape; + std::vector inputShape; std::string targetDevice; ngraph::pass::low_precision::LayerTransformation::Params params; MoveFakeQuantizeTransformationParam param; std::tie(netPrecision, inputShape, targetDevice, params, param) = obj.param; std::ostringstream result; - result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << - param.operation << param.fakeQuantizeAfter; + result << getTestCaseNameByParams(netPrecision, inputShape[0], targetDevice, params) << + param.operation << param.fakeQuantizeAfter << param.dequantizationAfter; return result.str(); } void MoveFakeQuantizeTransformation::SetUp() { ngraph::element::Type netPrecision; - ngraph::PartialShape inputShape; + std::vector inputShape; ngraph::pass::low_precision::LayerTransformation::Params params; MoveFakeQuantizeTransformationParam param; std::tie(netPrecision, inputShape, targetDevice, params, param) = this->GetParam(); @@ -42,19 +42,16 @@ void MoveFakeQuantizeTransformation::SetUp() { function = ngraph::builder::subgraph::MoveFakeQuantize::get( netPrecision, inputShape, - param.fakeQuantizeBefore1, - param.convertBefore1, - param.dequantizationBefore1, - param.fakeQuantizeBefore2, - param.convertBefore2, - param.dequantizationBefore2, + param.number_of_operations, + param.fakeQuantizeBefore, + param.convertBefore, + param.dequantizationBefore, param.operation, param.fakeQuantizeAfter, param.convertAfter, param.dequantizationAfter, {}, {}, - {}, param.axis); } diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp index bdd2f6fbc8c..3e543b8f372 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp @@ -19,20 +19,17 @@ class MoveFakeQuantize { public: static std::shared_ptr get( const ngraph::element::Type inputPrecision, - const ngraph::PartialShape& inputShape, - const FakeQuantizeOnDataWithConstant& fqOnData1, - const DequantizationOperations::Convert& convert1, - const DequantizationOperations& dequantization1, - const FakeQuantizeOnDataWithConstant& fqOnData2, - const DequantizationOperations::Convert& convert2, - const DequantizationOperations& dequantization2, + const std::vector& inputShape, + const size_t number_of_operations, + const std::vector& fqBefore, + const DequantizationOperations::Convert& convertBefore, + const DequantizationOperations& dequantizationBefore, const std::string& operation, - const FakeQuantizeOnDataWithConstant& fqOnData3, - const DequantizationOperations::Convert& convert3, - const DequantizationOperations& dequantization3, + const FakeQuantizeOnDataWithConstant& fqOnDataAfter, + const DequantizationOperations::Convert& convertAfter, + const DequantizationOperations& dequantizationAfter, const std::vector& concatAttributes, const ngraph::element::Type precisionAfterOperation, - const DequantizationOperations& dequantizationAfter, const std::int64_t& axis); }; diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp index 6c3ea3a09b9..31b3854c4c2 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp @@ -21,83 +21,77 @@ using namespace ngraph::pass; std::shared_ptr MoveFakeQuantize::get( const ngraph::element::Type inputPrecision, - const ngraph::PartialShape& inputShape, - const FakeQuantizeOnDataWithConstant& fqOnData1, - const DequantizationOperations::Convert& convert1, - const DequantizationOperations& dequantization1, - const FakeQuantizeOnDataWithConstant& fqOnData2, - const DequantizationOperations::Convert& convert2, - const DequantizationOperations& dequantization2, + const std::vector& inputShape, + const size_t number_of_operations, + const std::vector& fqOnDataBefore, + const DequantizationOperations::Convert& convertBefore, + const DequantizationOperations& dequantizationBefore, const std::string& operation, - const FakeQuantizeOnDataWithConstant& fqOnData3, - const DequantizationOperations::Convert& convert3, - const DequantizationOperations& dequantization3, + const FakeQuantizeOnDataWithConstant& fqOnDataAfter, + const DequantizationOperations::Convert& convertAfter, + const DequantizationOperations& dequantizationAfter, const std::vector& concatAttributes, const ngraph::element::Type precisionAfterOperation, - const DequantizationOperations& dequantizationAfter, const std::int64_t& axis) { - const auto input1 = std::make_shared(inputPrecision, inputShape); - input1->set_friendly_name("input1"); - - const auto input2 = std::make_shared(inputPrecision, inputShape); - input2->set_friendly_name("input2"); - std::shared_ptr parent1 = input1, parent2 = input2; - if (!fqOnData1.empty()) { - if (operation == "relu") { - auto relu1 = std::make_shared(input1->output(0)); - parent1 = makeFakeQuantize(relu1, inputPrecision, fqOnData1); - } else { - parent1 = makeFakeQuantize(input1, inputPrecision, fqOnData1); + std::vector > inputs(number_of_operations); + std::vector > parents(number_of_operations); + for (size_t i = 0; i < number_of_operations; i++) { + auto ind = 0; + if (inputShape.size() != 1) { + ind = i; } - parent1->set_friendly_name("concat_fq1"); - if (!convert1.empty()) { - parent1 = std::make_shared(parent1, convert1.outPrecision); - } - if (!dequantization1.empty()) { - parent1 = makeDequantization(parent1, dequantization1); + inputs[i] = std::make_shared(inputPrecision, inputShape[ind]); + inputs[i]->set_friendly_name(std::string("input") + "_" + std::to_string(i + 1)); + parents[i] = inputs[i]; + } + if (!fqOnDataBefore.empty()) { + for (size_t i = 0; i < number_of_operations; i++) { + size_t ind = i; + if (fqOnDataBefore.size() == 1) { + ind = 0; + } + if (operation == "relu") { + auto relu = std::make_shared(parents[i]->output(0)); + parents[i] = makeFakeQuantize(relu, inputPrecision, fqOnDataBefore[ind]); + } else { + parents[i] = makeFakeQuantize(parents[i], inputPrecision, fqOnDataBefore[ind]); + } + parents[i]->set_friendly_name(std::string("concat_fq") + "_" + std::to_string(i + 1)); + if (!convertBefore.empty()) { + parents[i] = std::make_shared(parents[i], convertBefore.outPrecision); + } + if (!dequantizationBefore.empty()) { + parents[i] = makeDequantization(parents[i], dequantizationBefore); + } } } - if (!fqOnData2.empty()) { - if (operation == "relu") { - auto relu2 = std::make_shared(input2->output(0)); - parent2 = makeFakeQuantize(relu2, inputPrecision, fqOnData2); - } else { - parent2 = makeFakeQuantize(input1, inputPrecision, fqOnData2); - } - parent2->set_friendly_name("concat_fq2"); - if (!convert2.empty()) { - parent1 = std::make_shared(parent2, convert2.outPrecision); - } - if (!dequantization1.empty()) { - parent2 = makeDequantization(parent2, dequantization2); - } - } - const std::shared_ptr concat = std::make_shared(ngraph::OutputVector{ parent1, parent2 }, axis); + const std::shared_ptr concat = std::make_shared(ngraph::OutputVector(parents.begin(), parents.end()), axis); concat->set_friendly_name("concat"); std::shared_ptr parent = concat; - if (!dequantizationAfter.empty()) { - const auto lastDequantization = makeDequantization(concat, dequantizationAfter); - lastDequantization->set_friendly_name("multiply"); - parent = lastDequantization; - } addAttributes({ parent }, concatAttributes); - if (!fqOnData3.empty()) { - std::shared_ptr fq; + if (!fqOnDataAfter.empty()) { + std::shared_ptr fq; if (operation == "relu") { auto relu = std::make_shared(concat->output(0)); - fq = makeFakeQuantize(relu, inputPrecision, fqOnData3); + fq = makeFakeQuantize(relu, inputPrecision, fqOnDataAfter); } else { - fq = makeFakeQuantize(concat, inputPrecision, fqOnData3); + fq = makeFakeQuantize(concat, inputPrecision, fqOnDataAfter); } fq->set_friendly_name("fakeQuantizeAfter"); parent = fq; + if (!convertAfter.empty()) { + parent = std::make_shared(parent, convertAfter.outPrecision); + } + if (!dequantizationAfter.empty()) { + parent = makeDequantization(parent, dequantizationAfter); + } } parent->set_friendly_name("output"); ngraph::ResultVector results{ std::make_shared(parent) }; std::shared_ptr function = std::make_shared( results, - ngraph::ParameterVector{ input1, input2 }, + ngraph::ParameterVector(inputs.begin(), inputs.end()), "MoveFakeQuantize"); return function; }