[LPT] Concat precision selection fix (#6069)

This commit is contained in:
Edward Shogulin 2021-06-08 15:29:30 +03:00 committed by GitHub
parent a7a9364b41
commit 9e34622ac1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 542 additions and 24 deletions

View File

@ -43,19 +43,21 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
return false; return false;
} }
// precisions can be different // Concat operations precision is defined:
// 1. consumers after Concat
// 2. FakeQuantize precisions without zero point
ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0]; ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this()); std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
if (!NetworkHelper::isQuantizeSupported(fq)) { if (!NetworkHelper::isQuantizeSupported(fq)) {
return false; return false;
} }
DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
std::vector<element::Type> concatParentsChildrensPrecisions = precisionsOnActivations; if (dataPrecision.precision == ngraph::element::undefined) {
fillAvailablePrecisions(subgraph.quantizationLayers[0], concatParentsChildrensPrecisions);
if (concatParentsChildrensPrecisions.empty()) {
return false; return false;
} }
std::vector<element::Type> concatChildrenPrecisions = precisionsOnActivations;
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) { for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(subgraph.quantizationLayers[i]); fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(subgraph.quantizationLayers[i]);
if (fq == nullptr) { if (fq == nullptr) {
@ -72,20 +74,28 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
if (quantizationDetails.inputHighValues.size() != 1ul) { if (quantizationDetails.inputHighValues.size() != 1ul) {
return false; return false;
} }
std::vector<element::Type> fqChildrensPrecisions = precisionsOnActivations;
fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrensPrecisions);
concatParentsChildrensPrecisions = NetworkHelper::precisionIntersection(concatParentsChildrensPrecisions, fqChildrensPrecisions);
if (concatParentsChildrensPrecisions.empty()) { // define concatenation operation consumers precisions
std::vector<element::Type> fqChildrenPrecisions = precisionsOnActivations;
fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrenPrecisions);
concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions);
if (concatChildrenPrecisions.empty()) {
return false; return false;
} }
// define FakeQuantize precisions without zero point
const DataPrecision dataPrecision2 = getDataPrecision(subgraph.quantizationLayers[i]->shared_from_this(), quantizationDetails, false);
if (dataPrecision2.precision == ngraph::element::undefined) {
return false;
}
if (dataPrecision.precision != dataPrecision2.precision) {
dataPrecision = dataPrecision.precision.is_signed() ? dataPrecision : dataPrecision2;
}
} }
DataPrecision dataPrecision; if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) {
if (std::find(concatParentsChildrensPrecisions.begin(), concatParentsChildrensPrecisions.end(), element::i8) != concatParentsChildrensPrecisions.end()) { dataPrecision = DataPrecision(concatChildrenPrecisions[0]);
dataPrecision = DataPrecision(element::i8);
} else {
dataPrecision = DataPrecision(concatParentsChildrensPrecisions[0]);
} }
std::vector<QuantizationDetails> quantizationLayersDetails; std::vector<QuantizationDetails> quantizationLayersDetails;

View File

@ -64,14 +64,23 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
DataPrecision dataPrecision; DataPrecision dataPrecision;
{ {
std::vector<element::Type> concatChildrenPrecisions = precisionsOnActivations;
for (auto quantizationLayer : subgraph.quantizationLayers) { for (auto quantizationLayer : subgraph.quantizationLayers) {
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this()); std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this());
if (!NetworkHelper::isQuantizeSupported(fq)) { if (!NetworkHelper::isQuantizeSupported(fq)) {
return false; return false;
} }
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false); // define concatenation operation consumers precisions
std::vector<element::Type> fqChildrenPrecisions = precisionsOnActivations;
fillAvailablePrecisions(quantizationLayer, fqChildrenPrecisions);
concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions);
if (concatChildrenPrecisions.empty()) {
return false;
}
// define FakeQuantize precisions without zero point
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
if (dataPrecision.precision == ngraph::element::undefined) { if (dataPrecision.precision == ngraph::element::undefined) {
dataPrecision = tmp; dataPrecision = tmp;
continue; continue;
@ -81,6 +90,10 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
dataPrecision = tmp; dataPrecision = tmp;
} }
} }
if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) {
dataPrecision = DataPrecision(concatChildrenPrecisions[0]);
}
} }
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) { for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {

View File

@ -0,0 +1,317 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <low_precision/transformer.hpp>
#include <low_precision/avg_pool.hpp>
#include <low_precision/concat.hpp>
#include <low_precision/concat_multi_channels.hpp>
#include <low_precision/max_pool.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "lpt_ngraph_functions/concat_function.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "simple_low_precision_transformer.hpp"
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
namespace {
class ConcatTransformationActualValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
}
class ConcatTransformationResultValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
ngraph::element::Type precisionBeforeOp;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2;
ngraph::element::Type precisionAfterOperation;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
return out << "_" <<
values.fakeQuantize1 << "_" <<
values.fakeQuantize2 << "_" <<
values.dequantizationAfter1 << "_" <<
values.dequantizationAfter2;
}
class ConcatTransformationTestValues {
public:
ngraph::pass::low_precision::LayerTransformation::Params params;
bool multiChannels;
ConcatTransformationActualValues actual;
ConcatTransformationResultValues result;
};
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
}
typedef std::tuple <
ngraph::element::Type,
ngraph::Shape,
ConcatTransformationTestValues
> ConcatTransformationParams;
class ConcatWithIntermediatePrecisionSelectionTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
const ngraph::Shape shape = std::get<1>(GetParam());
ConcatTransformationTestValues testValues = std::get<2>(GetParam());
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithIntermediateAvgPool(
precision,
shape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2);
SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) {
transform.addBranchSpecific<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
} else {
transform.addBranchSpecific<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
}
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.add<ngraph::pass::low_precision::AvgPoolTransformation, ngraph::opset1::AvgPool>(testValues.params);
transform.transform(actualFunction);
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithIntermediateAvgPool(
precision,
shape,
testValues.result.fakeQuantize1,
testValues.result.fakeQuantize2,
testValues.result.precisionBeforeOp,
testValues.result.dequantizationBefore1,
testValues.result.dequantizationBefore2,
testValues.result.precisionAfterOperation,
testValues.result.dequantizationAfter1,
testValues.result.dequantizationAfter2);
}
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const ngraph::Shape shape = std::get<1>(obj.param);
const ConcatTransformationTestValues testValues = std::get<2>(obj.param);
std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
}
};
TEST_P(ConcatWithIntermediatePrecisionSelectionTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
// ngraph::element::f16
};
const std::vector<ConcatTransformationTestValues> testValues = {
// Concat: FakeQuantize operations with signed intervals but consumer requires U8
{
LayerTransformation::createParamsU8I8(),
false,
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {64.f}, {192.f} },
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, { 128.f }, { 0.01f } },
{ {}, { 128.f }, { 0.01f } }
}
},
// Concat: FakeQuantize operations with unsigned intervals but consumer requires I8
{
LayerTransformation::createParamsI8I8(),
false,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { -0.f} },
ngraph::element::i8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::i8,
{ ngraph::element::f32, { -128.f }, { 0.01f } },
{ {}, { -128.f }, { 0.01f } }
}
},
// ConcatMultichannel: FakeQuantize operations with signed intervals but consumer requires U8
{
LayerTransformation::createParamsU8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {0.f}, { 255.f} },
ngraph::element::u8,
{},
{},
ngraph::element::u8,
{ ngraph::element::f32, { 128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ {}, { 128.f }, { 0.005f } }
}
},
// ConcatMultichannel: FakeQuantize operations with unsigned intervals but consumer requires I8
{
LayerTransformation::createParamsI8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { 127.f} },
ngraph::element::i8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::i8,
{ ngraph::element::f32, { -128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ {}, { -128.f }, { 0.005f } }
}
},
// Concat: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision
{
LayerTransformation::createParamsU8I8AndI8(),
false,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, { 128.f} },
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, {}, { 0.01f } },
{ {}, {}, { 0.01f } }
}
},
// Concat: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision
{
LayerTransformation::createParamsU8I8AndI8(),
false,
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-64.f}, {64.f} },
ngraph::element::i8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::i8,
{ ngraph::element::f32, {}, { 0.01f } },
{ {}, {}, { 0.01f } }
}
},
// ConcatMultichannel: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision
{
LayerTransformation::createParamsU8I8AndI8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} },
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ {}, {}, { 0.005f } }
}
},
// ConcatMultichannel: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision
{
LayerTransformation::createParamsU8I8AndI8(),
true,
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} },
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-128.f}, {127.f} },
ngraph::element::i8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::i8,
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ {}, {}, { 0.005f } }
}
}
};
const std::vector<ngraph::Shape> shapes = {
{ 1, 3, 9, 9 },
{ 4, 3, 9, 9 }
};
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
ConcatWithIntermediatePrecisionSelectionTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
ConcatWithIntermediatePrecisionSelectionTransformation::getTestCaseName);
} // namespace

View File

@ -49,19 +49,41 @@ bool SimpleLowPrecisionTransformer::isPrecisionPreserved(const std::shared_ptr<n
} }
void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ngraph::Function>& function) { void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ngraph::Function>& function) {
// initialization
for (auto it : branchSpecificTransformations) {
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
transformation->setParamsManager(this);
transformation->setLayerTransformationsManager(this);
}
for (auto it : transformations) {
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
transformation->setParamsManager(this);
transformation->setLayerTransformationsManager(this);
}
// transformation
{ {
ngraph::pass::low_precision::TypeRelaxedReplacer pass; ngraph::pass::low_precision::TypeRelaxedReplacer pass;
pass.run_on_function(function); pass.run_on_function(function);
} }
ngraph::pass::low_precision::TransformationContext context(function); ngraph::pass::low_precision::TransformationContext context(function);
GraphRewrite pass; {
for (auto it : transformations) { GraphRewrite pass;
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second; for (auto it : branchSpecificTransformations) {
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
transformation->setParamsManager(this); transformation->registerMatcherIn(pass, context);
transformation->setLayerTransformationsManager(this); }
transformation->registerMatcherIn(pass, context); pass.run_on_function(function);
}
{
GraphRewrite pass;
for (auto it : transformations) {
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
transformation->registerMatcherIn(pass, context);
}
pass.run_on_function(function);
} }
pass.run_on_function(function);
} }

View File

@ -28,9 +28,22 @@ public:
bool isQuantized(const std::shared_ptr<ngraph::Node>& layer) const noexcept override; bool isQuantized(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& layer) const noexcept override; bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
template <class T, class Operation>
ngraph::pass::low_precision::LayerTransformationPtr addBranchSpecific(const ngraph::pass::low_precision::LayerTransformation::Params& params) {
const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>();
const auto it = branchSpecificTransformations.find(typeName);
if (it != branchSpecificTransformations.end()) {
branchSpecificTransformations.erase(it);
}
auto transformation = std::make_shared<T>(params);
branchSpecificTransformations.emplace(typeName, transformation);
return transformation;
}
template <class T, class Operation> template <class T, class Operation>
ngraph::pass::low_precision::LayerTransformationPtr add(const ngraph::pass::low_precision::LayerTransformation::Params& params) { ngraph::pass::low_precision::LayerTransformationPtr add(const ngraph::pass::low_precision::LayerTransformation::Params& params) {
// const std::string typeName = typeid(ngraph::op::TypeRelaxed<Operation>).name();
const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>(); const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>();
const auto it = transformations.find(typeName); const auto it = transformations.find(typeName);
@ -46,5 +59,6 @@ public:
void transform(std::shared_ptr<ngraph::Function>& function); void transform(std::shared_ptr<ngraph::Function>& function);
private: private:
std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> branchSpecificTransformations;
std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> transformations; std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> transformations;
}; };

View File

@ -51,6 +51,12 @@ public:
const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2); const FakeQuantizeOnData& fqOnData2);
static std::shared_ptr<ngraph::Function> getOriginalWithIntermediateAvgPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2);
static std::shared_ptr<ngraph::Function> getOriginalWithSplitedIntermediate( static std::shared_ptr<ngraph::Function> getOriginalWithSplitedIntermediate(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::Shape& inputShape, const ngraph::Shape& inputShape,
@ -134,6 +140,7 @@ public:
const std::string& neighborType, const std::string& neighborType,
const std::string& additionalLayer); const std::string& additionalLayer);
// TODO: refactor: dequantizationBefore2 <=> dequantizationOperations2
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediate( static std::shared_ptr<ngraph::Function> getReferenceWithIntermediate(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::Shape& inputShape, const ngraph::Shape& inputShape,
@ -142,6 +149,18 @@ public:
const FakeQuantizeOnData& fqOnData2, const FakeQuantizeOnData& fqOnData2,
const ngraph::element::Type precisionBeforeOp, const ngraph::element::Type precisionBeforeOp,
const DequantizationOperations& dequantizationBefore1, const DequantizationOperations& dequantizationBefore1,
const DequantizationOperations& dequantizationOperations2,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationBefore2);
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediateAvgPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2,
const ngraph::element::Type precisionBeforeOp,
const DequantizationOperations& dequantizationBefore1,
const DequantizationOperations& dequantizationBefore2, const DequantizationOperations& dequantizationBefore2,
const ngraph::element::Type precisionAfterOperation, const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationOperations1, const DequantizationOperations& dequantizationOperations1,

View File

@ -272,6 +272,58 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediate(
return function; return function;
} }
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediateAvgPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2) {
const std::vector<size_t> inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2 };
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape1));
input1->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
fakeQuantize1->set_friendly_name("fakeQuantize1");
const std::vector<size_t> inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] };
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
fakeQuantize2->set_friendly_name("fakeQuantize2");
std::shared_ptr<Node> intermediateOp = makeMaxPool(fakeQuantize2->output(0), { 3, 3 });
intermediateOp->set_friendly_name("intermediate");
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, 1);
concat->set_friendly_name("concat");
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
std::shared_ptr<Node> parent2 = std::make_shared<ngraph::opset1::AvgPool>(
intermediateOp,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 },
true,
op::RoundingType::FLOOR);
parent2->set_friendly_name("avgPool");
ngraph::ResultVector results {
std::make_shared<ngraph::opset1::Result>(concat),
std::make_shared<ngraph::opset1::Result>(parent2)
};
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
"ConcatWithIntermediateTransformation");
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermediate( std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermediate(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::Shape& inputShape, const ngraph::Shape& inputShape,
@ -1056,6 +1108,77 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediate(
return function; return function;
} }
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediateAvgPool(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2,
const ngraph::element::Type precisionBeforeOp,
const DequantizationOperations& dequantizationBefore1,
const DequantizationOperations& dequantizationBefore2,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter1,
const DequantizationOperations& dequantizationAfter2) {
const std::vector<size_t> inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2};
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape1));
input1->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1);
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize1, precisionBeforeOp);
fakeQuantize1->set_friendly_name("fakeQuantize1");
const auto deqBefore1 = makeDequantization(fakeQuantize1, dequantizationBefore1);
const std::vector<size_t> inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] };
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionBeforeOp);
fakeQuantize2->set_friendly_name("fakeQuantize2");
const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore2);
std::shared_ptr<Node> intermediateOp = makeMaxPool(deqBefore2, { 3, 3 });
intermediateOp->set_friendly_name("intermediate");
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector { deqBefore1, intermediateOp },
1);
concat->set_friendly_name("concat");
low_precision::NetworkHelper::setOutDataPrecision(concat, precisionAfterOperation);
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
const std::shared_ptr<ngraph::Node> parent1 = makeDequantization(concat, dequantizationAfter1);
parent1->set_friendly_name("concat");
std::shared_ptr<Node> parent2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::AvgPool>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 },
std::vector<ngraph::element::Type>{ element::f32 },
ngraph::op::TemporaryReplaceOutputType(intermediateOp, element::f32).get(),
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 },
true,
op::RoundingType::FLOOR);
parent2->set_friendly_name("avgPool");
parent2 = makeDequantization(parent2, dequantizationAfter2);
ngraph::ResultVector results {
std::make_shared<ngraph::opset1::Result>(parent1),
std::make_shared<ngraph::opset1::Result>(parent2)
};
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
"ConcatWithIntermediateTransformation");
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedIntermediate( std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedIntermediate(
const ngraph::element::Type precision, const ngraph::element::Type precision,
const ngraph::Shape& inputShape, const ngraph::Shape& inputShape,