[LPT] Concat precision selection fix (#6069)
This commit is contained in:
parent
a7a9364b41
commit
9e34622ac1
@ -43,19 +43,21 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// precisions can be different
|
// Concat operations precision is defined:
|
||||||
|
// 1. consumers after Concat
|
||||||
|
// 2. FakeQuantize precisions without zero point
|
||||||
ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
|
ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
|
||||||
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
|
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
|
||||||
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
|
||||||
std::vector<element::Type> concatParentsChildrensPrecisions = precisionsOnActivations;
|
if (dataPrecision.precision == ngraph::element::undefined) {
|
||||||
fillAvailablePrecisions(subgraph.quantizationLayers[0], concatParentsChildrensPrecisions);
|
|
||||||
if (concatParentsChildrensPrecisions.empty()) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<element::Type> concatChildrenPrecisions = precisionsOnActivations;
|
||||||
|
|
||||||
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
|
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
|
||||||
fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(subgraph.quantizationLayers[i]);
|
fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(subgraph.quantizationLayers[i]);
|
||||||
if (fq == nullptr) {
|
if (fq == nullptr) {
|
||||||
@ -72,20 +74,28 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
|||||||
if (quantizationDetails.inputHighValues.size() != 1ul) {
|
if (quantizationDetails.inputHighValues.size() != 1ul) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::vector<element::Type> fqChildrensPrecisions = precisionsOnActivations;
|
|
||||||
fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrensPrecisions);
|
|
||||||
concatParentsChildrensPrecisions = NetworkHelper::precisionIntersection(concatParentsChildrensPrecisions, fqChildrensPrecisions);
|
|
||||||
|
|
||||||
if (concatParentsChildrensPrecisions.empty()) {
|
// define concatenation operation consumers precisions
|
||||||
|
std::vector<element::Type> fqChildrenPrecisions = precisionsOnActivations;
|
||||||
|
fillAvailablePrecisions(subgraph.quantizationLayers[i], fqChildrenPrecisions);
|
||||||
|
concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions);
|
||||||
|
if (concatChildrenPrecisions.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// define FakeQuantize precisions without zero point
|
||||||
|
const DataPrecision dataPrecision2 = getDataPrecision(subgraph.quantizationLayers[i]->shared_from_this(), quantizationDetails, false);
|
||||||
|
if (dataPrecision2.precision == ngraph::element::undefined) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
DataPrecision dataPrecision;
|
if (dataPrecision.precision != dataPrecision2.precision) {
|
||||||
if (std::find(concatParentsChildrensPrecisions.begin(), concatParentsChildrensPrecisions.end(), element::i8) != concatParentsChildrensPrecisions.end()) {
|
dataPrecision = dataPrecision.precision.is_signed() ? dataPrecision : dataPrecision2;
|
||||||
dataPrecision = DataPrecision(element::i8);
|
}
|
||||||
} else {
|
}
|
||||||
dataPrecision = DataPrecision(concatParentsChildrensPrecisions[0]);
|
|
||||||
|
if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) {
|
||||||
|
dataPrecision = DataPrecision(concatChildrenPrecisions[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QuantizationDetails> quantizationLayersDetails;
|
std::vector<QuantizationDetails> quantizationLayersDetails;
|
||||||
|
@ -64,14 +64,23 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
|
|||||||
|
|
||||||
DataPrecision dataPrecision;
|
DataPrecision dataPrecision;
|
||||||
{
|
{
|
||||||
|
std::vector<element::Type> concatChildrenPrecisions = precisionsOnActivations;
|
||||||
for (auto quantizationLayer : subgraph.quantizationLayers) {
|
for (auto quantizationLayer : subgraph.quantizationLayers) {
|
||||||
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this());
|
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this());
|
||||||
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
|
// define concatenation operation consumers precisions
|
||||||
|
std::vector<element::Type> fqChildrenPrecisions = precisionsOnActivations;
|
||||||
|
fillAvailablePrecisions(quantizationLayer, fqChildrenPrecisions);
|
||||||
|
concatChildrenPrecisions = NetworkHelper::precisionIntersection(concatChildrenPrecisions, fqChildrenPrecisions);
|
||||||
|
if (concatChildrenPrecisions.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// define FakeQuantize precisions without zero point
|
||||||
|
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
|
||||||
if (dataPrecision.precision == ngraph::element::undefined) {
|
if (dataPrecision.precision == ngraph::element::undefined) {
|
||||||
dataPrecision = tmp;
|
dataPrecision = tmp;
|
||||||
continue;
|
continue;
|
||||||
@ -81,6 +90,10 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
|
|||||||
dataPrecision = tmp;
|
dataPrecision = tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (std::find(concatChildrenPrecisions.begin(), concatChildrenPrecisions.end(), dataPrecision.precision) == concatChildrenPrecisions.end()) {
|
||||||
|
dataPrecision = DataPrecision(concatChildrenPrecisions[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
|
for (size_t i = 0; i < subgraph.quantizationLayers.size(); ++i) {
|
||||||
|
@ -0,0 +1,317 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "layer_transformation.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <low_precision/transformer.hpp>
|
||||||
|
#include <low_precision/avg_pool.hpp>
|
||||||
|
#include <low_precision/concat.hpp>
|
||||||
|
#include <low_precision/concat_multi_channels.hpp>
|
||||||
|
#include <low_precision/max_pool.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "lpt_ngraph_functions/concat_function.hpp"
|
||||||
|
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||||
|
#include "simple_low_precision_transformer.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace ngraph::pass;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ConcatTransformationActualValues {
|
||||||
|
public:
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
|
||||||
|
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConcatTransformationResultValues {
|
||||||
|
public:
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
|
||||||
|
ngraph::element::Type precisionBeforeOp;
|
||||||
|
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1;
|
||||||
|
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2;
|
||||||
|
ngraph::element::Type precisionAfterOperation;
|
||||||
|
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1;
|
||||||
|
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
|
||||||
|
return out << "_" <<
|
||||||
|
values.fakeQuantize1 << "_" <<
|
||||||
|
values.fakeQuantize2 << "_" <<
|
||||||
|
values.dequantizationAfter1 << "_" <<
|
||||||
|
values.dequantizationAfter2;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConcatTransformationTestValues {
|
||||||
|
public:
|
||||||
|
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||||
|
bool multiChannels;
|
||||||
|
ConcatTransformationActualValues actual;
|
||||||
|
ConcatTransformationResultValues result;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
|
||||||
|
return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef std::tuple <
|
||||||
|
ngraph::element::Type,
|
||||||
|
ngraph::Shape,
|
||||||
|
ConcatTransformationTestValues
|
||||||
|
> ConcatTransformationParams;
|
||||||
|
|
||||||
|
class ConcatWithIntermediatePrecisionSelectionTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
const ngraph::element::Type precision = std::get<0>(GetParam());
|
||||||
|
const ngraph::Shape shape = std::get<1>(GetParam());
|
||||||
|
ConcatTransformationTestValues testValues = std::get<2>(GetParam());
|
||||||
|
|
||||||
|
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithIntermediateAvgPool(
|
||||||
|
precision,
|
||||||
|
shape,
|
||||||
|
testValues.actual.fakeQuantize1,
|
||||||
|
testValues.actual.fakeQuantize2);
|
||||||
|
|
||||||
|
SimpleLowPrecisionTransformer transform;
|
||||||
|
if (testValues.multiChannels) {
|
||||||
|
transform.addBranchSpecific<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
|
||||||
|
} else {
|
||||||
|
transform.addBranchSpecific<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
|
||||||
|
}
|
||||||
|
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
|
||||||
|
transform.add<ngraph::pass::low_precision::AvgPoolTransformation, ngraph::opset1::AvgPool>(testValues.params);
|
||||||
|
transform.transform(actualFunction);
|
||||||
|
|
||||||
|
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithIntermediateAvgPool(
|
||||||
|
precision,
|
||||||
|
shape,
|
||||||
|
testValues.result.fakeQuantize1,
|
||||||
|
testValues.result.fakeQuantize2,
|
||||||
|
testValues.result.precisionBeforeOp,
|
||||||
|
testValues.result.dequantizationBefore1,
|
||||||
|
testValues.result.dequantizationBefore2,
|
||||||
|
testValues.result.precisionAfterOperation,
|
||||||
|
testValues.result.dequantizationAfter1,
|
||||||
|
testValues.result.dequantizationAfter2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
|
||||||
|
const ngraph::element::Type precision = std::get<0>(obj.param);
|
||||||
|
const ngraph::Shape shape = std::get<1>(obj.param);
|
||||||
|
const ConcatTransformationTestValues testValues = std::get<2>(obj.param);
|
||||||
|
|
||||||
|
std::ostringstream result;
|
||||||
|
result <<
|
||||||
|
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
|
||||||
|
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
|
||||||
|
testValues.actual << "_" <<
|
||||||
|
testValues.result << "_";
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(ConcatWithIntermediatePrecisionSelectionTransformation, CompareFunctions) {
|
||||||
|
actualFunction->validate_nodes_and_infer_types();
|
||||||
|
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<ngraph::element::Type> precisions = {
|
||||||
|
ngraph::element::f32,
|
||||||
|
// ngraph::element::f16
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<ConcatTransformationTestValues> testValues = {
|
||||||
|
// Concat: FakeQuantize operations with signed intervals but consumer requires U8
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8(),
|
||||||
|
false,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {64.f}, {192.f} },
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{ ngraph::element::f32, { 128.f }, { 0.01f } },
|
||||||
|
{ {}, { 128.f }, { 0.01f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Concat: FakeQuantize operations with unsigned intervals but consumer requires I8
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsI8I8(),
|
||||||
|
false,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { -0.f} },
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{ ngraph::element::f32, { -128.f }, { 0.01f } },
|
||||||
|
{ {}, { -128.f }, { 0.01f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// ConcatMultichannel: FakeQuantize operations with signed intervals but consumer requires U8
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8(),
|
||||||
|
true,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {0.f}, { 255.f} },
|
||||||
|
ngraph::element::u8,
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{ ngraph::element::f32, { 128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
|
||||||
|
{ {}, { 128.f }, { 0.005f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// ConcatMultichannel: FakeQuantize operations with unsigned intervals but consumer requires I8
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsI8I8(),
|
||||||
|
true,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-128.f}, {127.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {-128.f}, { 127.f} },
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{ ngraph::element::f32, { -128.f }, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
|
||||||
|
{ {}, { -128.f }, { 0.005f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Concat: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8AndI8(),
|
||||||
|
false,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, { 128.f} },
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{ ngraph::element::f32, {}, { 0.01f } },
|
||||||
|
{ {}, {}, { 0.01f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Concat: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8AndI8(),
|
||||||
|
false,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-64.f}, {64.f} },
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{ ngraph::element::f32, {}, { 0.01f } },
|
||||||
|
{ {}, {}, { 0.01f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// ConcatMultichannel: FakeQuantize operations with unsigned intervals, no consumer limitations: FQ were decomposed to U8 precision
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8AndI8(),
|
||||||
|
true,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {255.f} },
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::u8,
|
||||||
|
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
|
||||||
|
{ {}, {}, { 0.005f } }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// ConcatMultichannel: FakeQuantize operations with signed intervals, no consumer limitations: FQ were decomposed to I8 precision
|
||||||
|
{
|
||||||
|
LayerTransformation::createParamsU8I8AndI8(),
|
||||||
|
true,
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-1.28f / 2.f}, {1.27f / 2.f} }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f} },
|
||||||
|
{ 256ul, ngraph::Shape({}), {-1.28f / 2.f}, {1.27f / 2.f}, {-128.f}, {127.f} },
|
||||||
|
ngraph::element::i8,
|
||||||
|
{{}, {}, {}},
|
||||||
|
{{}, {}, {}},
|
||||||
|
ngraph::element::i8,
|
||||||
|
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
|
||||||
|
{ {}, {}, { 0.005f } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<ngraph::Shape> shapes = {
|
||||||
|
{ 1, 3, 9, 9 },
|
||||||
|
{ 4, 3, 9, 9 }
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
smoke_LPT,
|
||||||
|
ConcatWithIntermediatePrecisionSelectionTransformation,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(precisions),
|
||||||
|
::testing::ValuesIn(shapes),
|
||||||
|
::testing::ValuesIn(testValues)),
|
||||||
|
ConcatWithIntermediatePrecisionSelectionTransformation::getTestCaseName);
|
||||||
|
} // namespace
|
@ -49,19 +49,41 @@ bool SimpleLowPrecisionTransformer::isPrecisionPreserved(const std::shared_ptr<n
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ngraph::Function>& function) {
|
void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ngraph::Function>& function) {
|
||||||
|
// initialization
|
||||||
|
for (auto it : branchSpecificTransformations) {
|
||||||
|
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
|
||||||
|
transformation->setParamsManager(this);
|
||||||
|
transformation->setLayerTransformationsManager(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto it : transformations) {
|
||||||
|
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
|
||||||
|
transformation->setParamsManager(this);
|
||||||
|
transformation->setLayerTransformationsManager(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
// transformation
|
||||||
{
|
{
|
||||||
ngraph::pass::low_precision::TypeRelaxedReplacer pass;
|
ngraph::pass::low_precision::TypeRelaxedReplacer pass;
|
||||||
pass.run_on_function(function);
|
pass.run_on_function(function);
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::pass::low_precision::TransformationContext context(function);
|
ngraph::pass::low_precision::TransformationContext context(function);
|
||||||
|
{
|
||||||
GraphRewrite pass;
|
GraphRewrite pass;
|
||||||
for (auto it : transformations) {
|
for (auto it : branchSpecificTransformations) {
|
||||||
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
|
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
|
||||||
|
|
||||||
transformation->setParamsManager(this);
|
|
||||||
transformation->setLayerTransformationsManager(this);
|
|
||||||
transformation->registerMatcherIn(pass, context);
|
transformation->registerMatcherIn(pass, context);
|
||||||
}
|
}
|
||||||
pass.run_on_function(function);
|
pass.run_on_function(function);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
GraphRewrite pass;
|
||||||
|
for (auto it : transformations) {
|
||||||
|
ngraph::pass::low_precision::LayerTransformationPtr transformation = it.second;
|
||||||
|
transformation->registerMatcherIn(pass, context);
|
||||||
|
}
|
||||||
|
pass.run_on_function(function);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -28,9 +28,22 @@ public:
|
|||||||
bool isQuantized(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
|
bool isQuantized(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
|
||||||
bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
|
bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& layer) const noexcept override;
|
||||||
|
|
||||||
|
template <class T, class Operation>
|
||||||
|
ngraph::pass::low_precision::LayerTransformationPtr addBranchSpecific(const ngraph::pass::low_precision::LayerTransformation::Params& params) {
|
||||||
|
const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>();
|
||||||
|
|
||||||
|
const auto it = branchSpecificTransformations.find(typeName);
|
||||||
|
if (it != branchSpecificTransformations.end()) {
|
||||||
|
branchSpecificTransformations.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transformation = std::make_shared<T>(params);
|
||||||
|
branchSpecificTransformations.emplace(typeName, transformation);
|
||||||
|
return transformation;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T, class Operation>
|
template <class T, class Operation>
|
||||||
ngraph::pass::low_precision::LayerTransformationPtr add(const ngraph::pass::low_precision::LayerTransformation::Params& params) {
|
ngraph::pass::low_precision::LayerTransformationPtr add(const ngraph::pass::low_precision::LayerTransformation::Params& params) {
|
||||||
// const std::string typeName = typeid(ngraph::op::TypeRelaxed<Operation>).name();
|
|
||||||
const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>();
|
const std::string typeName = ngraph::pass::low_precision::LowPrecisionTransformations::getType<Operation>();
|
||||||
|
|
||||||
const auto it = transformations.find(typeName);
|
const auto it = transformations.find(typeName);
|
||||||
@ -46,5 +59,6 @@ public:
|
|||||||
void transform(std::shared_ptr<ngraph::Function>& function);
|
void transform(std::shared_ptr<ngraph::Function>& function);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> branchSpecificTransformations;
|
||||||
std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> transformations;
|
std::map<std::string, ngraph::pass::low_precision::LayerTransformationPtr> transformations;
|
||||||
};
|
};
|
||||||
|
@ -51,6 +51,12 @@ public:
|
|||||||
const FakeQuantizeOnData& fqOnData1,
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
const FakeQuantizeOnData& fqOnData2);
|
const FakeQuantizeOnData& fqOnData2);
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Function> getOriginalWithIntermediateAvgPool(
|
||||||
|
const ngraph::element::Type precision,
|
||||||
|
const ngraph::Shape& inputShape,
|
||||||
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
|
const FakeQuantizeOnData& fqOnData2);
|
||||||
|
|
||||||
static std::shared_ptr<ngraph::Function> getOriginalWithSplitedIntermediate(
|
static std::shared_ptr<ngraph::Function> getOriginalWithSplitedIntermediate(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::Shape& inputShape,
|
const ngraph::Shape& inputShape,
|
||||||
@ -134,6 +140,7 @@ public:
|
|||||||
const std::string& neighborType,
|
const std::string& neighborType,
|
||||||
const std::string& additionalLayer);
|
const std::string& additionalLayer);
|
||||||
|
|
||||||
|
// TODO: refactor: dequantizationBefore2 <=> dequantizationOperations2
|
||||||
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediate(
|
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediate(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::Shape& inputShape,
|
const ngraph::Shape& inputShape,
|
||||||
@ -142,6 +149,18 @@ public:
|
|||||||
const FakeQuantizeOnData& fqOnData2,
|
const FakeQuantizeOnData& fqOnData2,
|
||||||
const ngraph::element::Type precisionBeforeOp,
|
const ngraph::element::Type precisionBeforeOp,
|
||||||
const DequantizationOperations& dequantizationBefore1,
|
const DequantizationOperations& dequantizationBefore1,
|
||||||
|
const DequantizationOperations& dequantizationOperations2,
|
||||||
|
const ngraph::element::Type precisionAfterOperation,
|
||||||
|
const DequantizationOperations& dequantizationOperations1,
|
||||||
|
const DequantizationOperations& dequantizationBefore2);
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediateAvgPool(
|
||||||
|
const ngraph::element::Type precision,
|
||||||
|
const ngraph::Shape& inputShape,
|
||||||
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
|
const FakeQuantizeOnData& fqOnData2,
|
||||||
|
const ngraph::element::Type precisionBeforeOp,
|
||||||
|
const DequantizationOperations& dequantizationBefore1,
|
||||||
const DequantizationOperations& dequantizationBefore2,
|
const DequantizationOperations& dequantizationBefore2,
|
||||||
const ngraph::element::Type precisionAfterOperation,
|
const ngraph::element::Type precisionAfterOperation,
|
||||||
const DequantizationOperations& dequantizationOperations1,
|
const DequantizationOperations& dequantizationOperations1,
|
||||||
|
@ -272,6 +272,58 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediate(
|
|||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediateAvgPool(
|
||||||
|
const ngraph::element::Type precision,
|
||||||
|
const ngraph::Shape& inputShape,
|
||||||
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
|
const FakeQuantizeOnData& fqOnData2) {
|
||||||
|
const std::vector<size_t> inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2 };
|
||||||
|
|
||||||
|
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape1));
|
||||||
|
input1->set_friendly_name("input1");
|
||||||
|
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
|
||||||
|
fakeQuantize1->set_friendly_name("fakeQuantize1");
|
||||||
|
|
||||||
|
const std::vector<size_t> inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] };
|
||||||
|
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
|
||||||
|
input2->set_friendly_name("input2");
|
||||||
|
|
||||||
|
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
|
||||||
|
fakeQuantize2->set_friendly_name("fakeQuantize2");
|
||||||
|
|
||||||
|
std::shared_ptr<Node> intermediateOp = makeMaxPool(fakeQuantize2->output(0), { 3, 3 });
|
||||||
|
intermediateOp->set_friendly_name("intermediate");
|
||||||
|
|
||||||
|
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
|
||||||
|
ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, 1);
|
||||||
|
concat->set_friendly_name("concat");
|
||||||
|
|
||||||
|
auto& rtInfo = concat->get_rt_info();
|
||||||
|
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
|
||||||
|
|
||||||
|
std::shared_ptr<Node> parent2 = std::make_shared<ngraph::opset1::AvgPool>(
|
||||||
|
intermediateOp,
|
||||||
|
Strides{ 1, 1 },
|
||||||
|
Shape{ 1, 1 },
|
||||||
|
Shape{ 0, 0 },
|
||||||
|
Shape{ 2, 2 },
|
||||||
|
true,
|
||||||
|
op::RoundingType::FLOOR);
|
||||||
|
parent2->set_friendly_name("avgPool");
|
||||||
|
|
||||||
|
ngraph::ResultVector results {
|
||||||
|
std::make_shared<ngraph::opset1::Result>(concat),
|
||||||
|
std::make_shared<ngraph::opset1::Result>(parent2)
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||||
|
results,
|
||||||
|
ngraph::ParameterVector{ input1, input2 },
|
||||||
|
"ConcatWithIntermediateTransformation");
|
||||||
|
|
||||||
|
return function;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermediate(
|
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermediate(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::Shape& inputShape,
|
const ngraph::Shape& inputShape,
|
||||||
@ -1056,6 +1108,77 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediate(
|
|||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediateAvgPool(
|
||||||
|
const ngraph::element::Type precision,
|
||||||
|
const ngraph::Shape& inputShape,
|
||||||
|
const FakeQuantizeOnData& fqOnData1,
|
||||||
|
const FakeQuantizeOnData& fqOnData2,
|
||||||
|
const ngraph::element::Type precisionBeforeOp,
|
||||||
|
const DequantizationOperations& dequantizationBefore1,
|
||||||
|
const DequantizationOperations& dequantizationBefore2,
|
||||||
|
const ngraph::element::Type precisionAfterOperation,
|
||||||
|
const DequantizationOperations& dequantizationAfter1,
|
||||||
|
const DequantizationOperations& dequantizationAfter2) {
|
||||||
|
const std::vector<size_t> inputShape1 = { inputShape[0], inputShape[1], inputShape[2] - 2, inputShape[3] - 2};
|
||||||
|
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape1));
|
||||||
|
input1->set_friendly_name("input1");
|
||||||
|
|
||||||
|
const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1);
|
||||||
|
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize1, precisionBeforeOp);
|
||||||
|
fakeQuantize1->set_friendly_name("fakeQuantize1");
|
||||||
|
const auto deqBefore1 = makeDequantization(fakeQuantize1, dequantizationBefore1);
|
||||||
|
|
||||||
|
const std::vector<size_t> inputShape2 = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] };
|
||||||
|
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
|
||||||
|
input2->set_friendly_name("input2");
|
||||||
|
|
||||||
|
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
|
||||||
|
low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionBeforeOp);
|
||||||
|
fakeQuantize2->set_friendly_name("fakeQuantize2");
|
||||||
|
const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore2);
|
||||||
|
|
||||||
|
std::shared_ptr<Node> intermediateOp = makeMaxPool(deqBefore2, { 3, 3 });
|
||||||
|
intermediateOp->set_friendly_name("intermediate");
|
||||||
|
|
||||||
|
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
|
||||||
|
ngraph::OutputVector { deqBefore1, intermediateOp },
|
||||||
|
1);
|
||||||
|
concat->set_friendly_name("concat");
|
||||||
|
low_precision::NetworkHelper::setOutDataPrecision(concat, precisionAfterOperation);
|
||||||
|
|
||||||
|
auto& rtInfo = concat->get_rt_info();
|
||||||
|
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
|
||||||
|
|
||||||
|
const std::shared_ptr<ngraph::Node> parent1 = makeDequantization(concat, dequantizationAfter1);
|
||||||
|
parent1->set_friendly_name("concat");
|
||||||
|
|
||||||
|
std::shared_ptr<Node> parent2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::AvgPool>>(
|
||||||
|
std::vector<ngraph::element::Type>{ element::f32, element::f32 },
|
||||||
|
std::vector<ngraph::element::Type>{ element::f32 },
|
||||||
|
ngraph::op::TemporaryReplaceOutputType(intermediateOp, element::f32).get(),
|
||||||
|
Strides{ 1, 1 },
|
||||||
|
Shape{ 1, 1 },
|
||||||
|
Shape{ 0, 0 },
|
||||||
|
Shape{ 2, 2 },
|
||||||
|
true,
|
||||||
|
op::RoundingType::FLOOR);
|
||||||
|
parent2->set_friendly_name("avgPool");
|
||||||
|
|
||||||
|
parent2 = makeDequantization(parent2, dequantizationAfter2);
|
||||||
|
|
||||||
|
ngraph::ResultVector results {
|
||||||
|
std::make_shared<ngraph::opset1::Result>(parent1),
|
||||||
|
std::make_shared<ngraph::opset1::Result>(parent2)
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||||
|
results,
|
||||||
|
ngraph::ParameterVector{ input1, input2 },
|
||||||
|
"ConcatWithIntermediateTransformation");
|
||||||
|
|
||||||
|
return function;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedIntermediate(
|
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedIntermediate(
|
||||||
const ngraph::element::Type precision,
|
const ngraph::element::Type precision,
|
||||||
const ngraph::Shape& inputShape,
|
const ngraph::Shape& inputShape,
|
||||||
|
Loading…
Reference in New Issue
Block a user