diff --git a/inference-engine/src/low_precision_transformations/src/concat.cpp b/inference-engine/src/low_precision_transformations/src/concat.cpp index 4988e29b1e2..f6d860ed172 100644 --- a/inference-engine/src/low_precision_transformations/src/concat.cpp +++ b/inference-engine/src/low_precision_transformations/src/concat.cpp @@ -43,19 +43,21 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat return false; } - // precisions can be different + // Concat operations precision is defined: + // 1. consumers after Concat + // 2. FakeQuantize precisions without zero point ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0]; std::shared_ptr fq = ngraph::as_type_ptr(quantizationLayer.shared_from_this()); if (!NetworkHelper::isQuantizeSupported(fq)) { return false; } - - std::vector concatParentsChildrensPrecisions = precisionsOnActivations; - fillAvailablePrecisions(subgraph.quantizationLayers[0], concatParentsChildrensPrecisions); - if (concatParentsChildrensPrecisions.empty()) { + DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false); + if (dataPrecision.precision == ngraph::element::undefined) { return false; } + std::vector concatChildrenPrecisions = precisionsOnActivations; + for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) { fq = ngraph::as_type_ptr(subgraph.quantizationLayers[i]); if (fq == nullptr) { @@ -72,20 +74,28 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat if (quantizationDetails.inputHighValues.size() != 1ul) { return false; } - std::vector fqChildrensPrecisions = precisionsOnActivations; - fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrensPrecisions); - concatParentsChildrensPrecisions = NetworkHelper::precisionIntersection(concatParentsChildrensPrecisions, fqChildrensPrecisions); - if (concatParentsChildrensPrecisions.empty()) { + // define concatenation operation consumers precisions + std::vector fqChildrenPrecisions = precisionsOnActivations; + fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrenPrecisions); + concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions); + if (concatChildrenPrecisions.empty()) { return false; } + + // define FakeQuantize precisions without zero point + const DataPrecision dataPrecision2 = getDataPrecision(subgraph.quantizationLayers[i]->shared_from_this(), quantizationDetails, false); + if (dataPrecision2.precision == ngraph::element::undefined) { + return false; + } + + if (dataPrecision.precision != dataPrecision2.precision) { + dataPrecision = dataPrecision.precision.is_signed() ? dataPrecision : dataPrecision2; + } } - DataPrecision dataPrecision; - if (std::find(concatParentsChildrensPrecisions.begin(), concatParentsChildrensPrecisions.end(), element::i8) != concatParentsChildrensPrecisions.end()) { - dataPrecision = DataPrecision(element::i8); - } else { - dataPrecision = DataPrecision(concatParentsChildrensPrecisions[0]); + if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) { + dataPrecision = DataPrecision(concatChildrenPrecisions[0]); } std::vector quantizationLayersDetails; diff --git a/inference-engine/src/low_precision_transformations/src/concat_multi_channels.cpp b/inference-engine/src/low_precision_transformations/src/concat_multi_channels.cpp index dc81d51cd71..e36c2b5aa74 100644 --- a/inference-engine/src/low_precision_transformations/src/concat_multi_channels.cpp +++ b/inference-engine/src/low_precision_transformations/src/concat_multi_channels.cpp @@ -64,14 +64,23 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context DataPrecision dataPrecision; { + std::vector concatChildrenPrecisions = precisionsOnActivations; for (auto quantizationLayer : subgraph.quantizationLayers) { std::shared_ptr fq = ngraph::as_type_ptr(quantizationLayer->shared_from_this()); if (!NetworkHelper::isQuantizeSupported(fq)) { return false; } - const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false); + // define concatenation operation consumers precisions + std::vector fqChildrenPrecisions = precisionsOnActivations; + fillAvailablePrecisions(quantizationLayer, fqChildrenPrecisions); + concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions); + if (concatChildrenPrecisions.empty()) { + return false; + } + // define FakeQuantize precisions without zero point + const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false); if (dataPrecision.precision == ngraph::element::undefined) { dataPrecision = tmp; continue; @@ -81,6 +90,10 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context dataPrecision = tmp; } } + + if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) { + dataPrecision = DataPrecision(concatChildrenPrecisions[0]); + } } for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) { diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_intermediate_precision_selection_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_intermediate_precision_selection_transformation.cpp new file mode 100644 index 00000000000..0d6b29d5fe5 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_intermediate_precision_selection_transformation.cpp @@ -0,0 +1,317 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "layer_transformation.hpp" + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "lpt_ngraph_functions/concat_function.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "simple_low_precision_transformer.hpp" + +using namespace testing; +using namespace ngraph; +using namespace ngraph::pass; + +namespace { + +class ConcatTransformationActualValues { +public: + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) { + return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2; +} + +class ConcatTransformationResultValues { +public: + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; + ngraph::element::Type precisionBeforeOp; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + ngraph::element::Type precisionAfterOperation; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) { + return out << "_" << + values.fakeQuantize1 << "_" << + values.fakeQuantize2 << "_" << + values.dequantizationAfter1 << "_" << + values.dequantizationAfter2; +} + +class ConcatTransformationTestValues { +public: + ngraph::pass::low_precision::LayerTransformation::Params params; + bool multiChannels; + ConcatTransformationActualValues actual; + ConcatTransformationResultValues result; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) { + return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result; +} + +typedef std::tuple < + ngraph::element::Type, + ngraph::Shape, + ConcatTransformationTestValues +> ConcatTransformationParams; + +class ConcatWithIntermediatePrecisionSelectionTransformation : public LayerTransformation, public testing::WithParamInterface { +public: + void SetUp() override { + const ngraph::element::Type precision = std::get<0>(GetParam()); + const ngraph::Shape shape = std::get<1>(GetParam()); + ConcatTransformationTestValues testValues = std::get<2>(GetParam()); + + actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithIntermediateAvgPool( + precision, + shape, + testValues.actual.fakeQuantize1, + testValues.actual.fakeQuantize2); + + SimpleLowPrecisionTransformer transform; + if (testValues.multiChannels) { + transform.addBranchSpecific(testValues.params); + } else { + transform.addBranchSpecific(testValues.params); + } + transform.add(testValues.params); + transform.add(testValues.params); + transform.transform(actualFunction); + + referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithIntermediateAvgPool( + precision, + shape, + testValues.result.fakeQuantize1, + testValues.result.fakeQuantize2, + testValues.result.precisionBeforeOp, + testValues.result.dequantizationBefore1, + testValues.result.dequantizationBefore2, + testValues.result.precisionAfterOperation, + testValues.result.dequantizationAfter1, + testValues.result.dequantizationAfter2); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ngraph::element::Type precision = std::get<0>(obj.param); + const ngraph::Shape shape = std::get<1>(obj.param); + const ConcatTransformationTestValues testValues = std::get<2>(obj.param); + + std::ostringstream result; + result << + LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << + (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << + testValues.actual << "_" << + testValues.result << "_"; + return result.str(); + } +}; + +TEST_P(ConcatWithIntermediatePrecisionSelectionTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(referenceFunction, actualFunction, true, false, true); + ASSERT_TRUE(res.first) << res.second; +} + +const std::vector precisions = { + ngraph::element::f32, + // ngraph::element::f16 +}; + +const std::vector testValues = { + // Concat: FakeQuantize operations with signed intervals but consumer requires U8 + { + LayerTransformation::createParamsU8I8(), + false, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {64.f}, {192.f} }, + ngraph::element::u8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::u8, + { ngraph::element::f32, { 128.f }, { 0.01f } }, + { {}, { 128.f }, { 0.01f } } + } + }, + + // Concat: FakeQuantize operations with unsigned intervals but consumer requires I8 + { + LayerTransformation::createParamsI8I8(), + false, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { -0.f} }, + ngraph::element::i8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::i8, + { ngraph::element::f32, { -128.f }, { 0.01f } }, + { {}, { -128.f }, { 0.01f } } + } + }, + + // ConcatMultichannel: FakeQuantize operations with signed intervals but consumer requires U8 + { + LayerTransformation::createParamsU8I8(), + true, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {0.f}, { 255.f} }, + ngraph::element::u8, + {}, + {}, + ngraph::element::u8, + { ngraph::element::f32, { 128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, + { {}, { 128.f }, { 0.005f } } + } + }, + + // ConcatMultichannel: FakeQuantize operations with unsigned intervals but consumer requires I8 + { + LayerTransformation::createParamsI8I8(), + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { 127.f} }, + ngraph::element::i8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::i8, + { ngraph::element::f32, { -128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, + { {}, { -128.f }, { 0.005f } } + } + }, + + // Concat: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision + { + LayerTransformation::createParamsU8I8AndI8(), + false, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, { 128.f} }, + ngraph::element::u8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::u8, + { ngraph::element::f32, {}, { 0.01f } }, + { {}, {}, { 0.01f } } + } + }, + + // Concat: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision + { + LayerTransformation::createParamsU8I8AndI8(), + false, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-64.f}, {64.f} }, + ngraph::element::i8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::i8, + { ngraph::element::f32, {}, { 0.01f } }, + { {}, {}, { 0.01f } } + } + }, + + // ConcatMultichannel: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision + { + LayerTransformation::createParamsU8I8AndI8(), + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} }, + ngraph::element::u8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::u8, + { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, + { {}, {}, { 0.005f } } + } + }, + + // ConcatMultichannel: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision + { + LayerTransformation::createParamsU8I8AndI8(), + true, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} } + }, + { + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} }, + { 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-128.f}, {127.f} }, + ngraph::element::i8, + {{}, {}, {}}, + {{}, {}, {}}, + ngraph::element::i8, + { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, + { {}, {}, { 0.005f } } + } + } +}; + +const std::vector shapes = { + { 1, 3, 9, 9 }, + { 4, 3, 9, 9 } +}; + +INSTANTIATE_TEST_CASE_P( + smoke_LPT, + ConcatWithIntermediatePrecisionSelectionTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(shapes), + ::testing::ValuesIn(testValues)), + ConcatWithIntermediatePrecisionSelectionTransformation::getTestCaseName); +} // namespace diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.cpp index 8ee17c8e39b..3c48d56be5b 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.cpp @@ -49,19 +49,41 @@ bool SimpleLowPrecisionTransformer::isPrecisionPreserved(const std::shared_ptr& function) { + // initialization + for (auto it : branchSpecificTransformations) { + ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; + transformation->setParamsManager(this); + transformation->setLayerTransformationsManager(this); + } + + for (auto it : transformations) { + ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; + transformation->setParamsManager(this); + transformation->setLayerTransformationsManager(this); + } + + // transformation { ngraph::pass::low_precision::TypeRelaxedReplacer pass; pass.run_on_function(function); } ngraph::pass::low_precision::TransformationContext context(function); - GraphRewrite pass; - for (auto it : transformations) { - ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; - - transformation->setParamsManager(this); - transformation->setLayerTransformationsManager(this); - transformation->registerMatcherIn(pass, context); + { + GraphRewrite pass; + for (auto it : branchSpecificTransformations) { + ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; + transformation->registerMatcherIn(pass, context); + } + pass.run_on_function(function); + } + + { + GraphRewrite pass; + for (auto it : transformations) { + ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; + transformation->registerMatcherIn(pass, context); + } + pass.run_on_function(function); } - pass.run_on_function(function); } diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.hpp b/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.hpp index b4bf3a9c978..c9582adf0f0 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.hpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/simple_low_precision_transformer.hpp @@ -28,9 +28,22 @@ public: bool isQuantized(const std::shared_ptr& layer) const noexcept override; bool isPrecisionPreserved(const std::shared_ptr& layer) const noexcept override; + template + ngraph::pass::low_precision::LayerTransformationPtr addBranchSpecific(const ngraph::pass::low_precision::LayerTransformation::Params& params) { + const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType(); + + const auto it = branchSpecificTransformations.find(typeName); + if (it != branchSpecificTransformations.end()) { + branchSpecificTransformations.erase(it); + } + + auto transformation = std::make_shared(params); + branchSpecificTransformations.emplace(typeName, transformation); + return transformation; + } + template ngraph::pass::low_precision::LayerTransformationPtr add(const ngraph::pass::low_precision::LayerTransformation::Params& params) { - // const std::string typeName = typeid(ngraph::op::TypeRelaxed).name(); const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType(); const auto it = transformations.find(typeName); @@ -46,5 +59,6 @@ public: void transform(std::shared_ptr& function); private: + std::map branchSpecificTransformations; std::map transformations; }; diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp index c0c1686ca55..f70f653efe2 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp @@ -51,6 +51,12 @@ public: const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData2); + static std::shared_ptr getOriginalWithIntermediateAvgPool( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const FakeQuantizeOnData& fqOnData1, + const FakeQuantizeOnData& fqOnData2); + static std::shared_ptr getOriginalWithSplitedIntermediate( const ngraph::element::Type precision, const ngraph::Shape& inputShape, @@ -134,6 +140,7 @@ public: const std::string& neighborType, const std::string& additionalLayer); + // TODO: refactor: dequantizationBefore2 <=> dequantizationOperations2 static std::shared_ptr getReferenceWithIntermediate( const ngraph::element::Type precision, const ngraph::Shape& inputShape, @@ -142,6 +149,18 @@ public: const FakeQuantizeOnData& fqOnData2, const ngraph::element::Type precisionBeforeOp, const DequantizationOperations& dequantizationBefore1, + const DequantizationOperations& dequantizationOperations2, + const ngraph::element::Type precisionAfterOperation, + const DequantizationOperations& dequantizationOperations1, + const DequantizationOperations& dequantizationBefore2); + + static std::shared_ptr getReferenceWithIntermediateAvgPool( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const FakeQuantizeOnData& fqOnData1, + const FakeQuantizeOnData& fqOnData2, + const ngraph::element::Type precisionBeforeOp, + const DequantizationOperations& dequantizationBefore1, const DequantizationOperations& dequantizationBefore2, const ngraph::element::Type precisionAfterOperation, const DequantizationOperations& dequantizationOperations1, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp index 15108abb73e..37387977eb7 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp @@ -272,6 +272,58 @@ std::shared_ptr ConcatFunction::getOriginalWithIntermediate( return function; } +std::shared_ptr ConcatFunction::getOriginalWithIntermediateAvgPool( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const FakeQuantizeOnData& fqOnData1, + const FakeQuantizeOnData& fqOnData2) { + const std::vector inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2 }; + + const auto input1 = std::make_shared(precision, ngraph::Shape(inputShape1)); + input1->set_friendly_name("input1"); + const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1); + fakeQuantize1->set_friendly_name("fakeQuantize1"); + + const std::vector inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] }; + const auto input2 = std::make_shared(precision, ngraph::Shape(inputShape2)); + input2->set_friendly_name("input2"); + + const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2); + fakeQuantize2->set_friendly_name("fakeQuantize2"); + + std::shared_ptr intermediateOp = makeMaxPool(fakeQuantize2->output(0), { 3, 3 }); + intermediateOp->set_friendly_name("intermediate"); + + const std::shared_ptr concat = std::make_shared( + ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, 1); + concat->set_friendly_name("concat"); + + auto& rtInfo = concat->get_rt_info(); + rtInfo["Variant::std::string"] = std::make_shared>("concat"); + + std::shared_ptr parent2 = std::make_shared( + intermediateOp, + Strides{ 1, 1 }, + Shape{ 1, 1 }, + Shape{ 0, 0 }, + Shape{ 2, 2 }, + true, + op::RoundingType::FLOOR); + parent2->set_friendly_name("avgPool"); + + ngraph::ResultVector results { + std::make_shared(concat), + std::make_shared(parent2) + }; + + std::shared_ptr function = std::make_shared( + results, + ngraph::ParameterVector{ input1, input2 }, + "ConcatWithIntermediateTransformation"); + + return function; +} + std::shared_ptr ConcatFunction::getOriginalWithSplitedIntermediate( const ngraph::element::Type precision, const ngraph::Shape& inputShape, @@ -1056,6 +1108,77 @@ std::shared_ptr ConcatFunction::getReferenceWithIntermediate( return function; } +std::shared_ptr ConcatFunction::getReferenceWithIntermediateAvgPool( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const FakeQuantizeOnData& fqOnData1, + const FakeQuantizeOnData& fqOnData2, + const ngraph::element::Type precisionBeforeOp, + const DequantizationOperations& dequantizationBefore1, + const DequantizationOperations& dequantizationBefore2, + const ngraph::element::Type precisionAfterOperation, + const DequantizationOperations& dequantizationAfter1, + const DequantizationOperations& dequantizationAfter2) { + const std::vector inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2}; + const auto input1 = std::make_shared(precision, ngraph::Shape(inputShape1)); + input1->set_friendly_name("input1"); + + const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1); + low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize1, precisionBeforeOp); + fakeQuantize1->set_friendly_name("fakeQuantize1"); + const auto deqBefore1 = makeDequantization(fakeQuantize1, dequantizationBefore1); + + const std::vector inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] }; + const auto input2 = std::make_shared(precision, ngraph::Shape(inputShape2)); + input2->set_friendly_name("input2"); + + const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2); + low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionBeforeOp); + fakeQuantize2->set_friendly_name("fakeQuantize2"); + const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore2); + + std::shared_ptr intermediateOp = makeMaxPool(deqBefore2, { 3, 3 }); + intermediateOp->set_friendly_name("intermediate"); + + const std::shared_ptr concat = std::make_shared( + ngraph::OutputVector { deqBefore1, intermediateOp }, + 1); + concat->set_friendly_name("concat"); + low_precision::NetworkHelper::setOutDataPrecision(concat, precisionAfterOperation); + + auto& rtInfo = concat->get_rt_info(); + rtInfo["Variant::std::string"] = std::make_shared>("concat"); + + const std::shared_ptr parent1 = makeDequantization(concat, dequantizationAfter1); + parent1->set_friendly_name("concat"); + + std::shared_ptr parent2 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(intermediateOp, element::f32).get(), + Strides{ 1, 1 }, + Shape{ 1, 1 }, + Shape{ 0, 0 }, + Shape{ 2, 2 }, + true, + op::RoundingType::FLOOR); + parent2->set_friendly_name("avgPool"); + + parent2 = makeDequantization(parent2, dequantizationAfter2); + + ngraph::ResultVector results { + std::make_shared(parent1), + std::make_shared(parent2) + }; + + std::shared_ptr function = std::make_shared( + results, + ngraph::ParameterVector{ input1, input2 }, + "ConcatWithIntermediateTransformation"); + + return function; +} + std::shared_ptr ConcatFunction::getReferenceWithSplitedIntermediate( const ngraph::element::Type precision, const ngraph::Shape& inputShape,