[LPT] Concat*Transformation: handled FQ with unexpected quantization levels (#5111)

* [LPT] TRANSFORMATION_EXCEPTION segfault fix

* [LPT] Concat*Transformation: handled FQ with unexpected quantization levels

* [TESTS] Concat*Transformation: added test-cases with unexpected quant levels
This commit is contained in:
Vladislav Golubev 2021-04-06 12:13:47 +03:00 committed by GitHub
parent f438a3a321
commit a478475386
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 5 deletions

View File

@ -23,6 +23,10 @@ class TRANSFORMATIONS_API Exception : std::exception {
std::shared_ptr<std::ostringstream> buffer;
mutable std::string buffer_str;
public:
Exception() {
buffer = std::make_shared<std::ostringstream>();
}
template <typename T>
Exception& operator<< (const T& x) {
*buffer << x;

View File

@ -46,6 +46,10 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
// precisions can be different
ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
if (!NetworkHelper::isQuantizeSupported(fq)) {
return false;
}
DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
if (dataPrecision.precision == ngraph::element::undefined) {
return false;

View File

@ -64,6 +64,10 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
{
for (auto quantizationLayer : subgraph.quantizationLayers) {
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this());
if (!NetworkHelper::isQuantizeSupported(fq)) {
return false;
}
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
if (dataPrecision.precision == ngraph::element::undefined) {

View File

@ -610,6 +610,52 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::f32,
{ {element::f32}, {}, { 0.01f } },
}
},
// unexpected quantization levels, concat
{
LayerTransformation::createParamsU8I8(),
false,
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{}
},
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
ngraph::element::f32,
{},
}
},
// unexpected quantization levels, concat multi channels
{
LayerTransformation::createParamsU8I8(),
true,
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{}
},
{
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
ngraph::element::f32,
{},
}
}
};

View File

@ -36,7 +36,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
}
},
// FQ with unexpected quantizationLevels
{
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
},
};
const std::vector<ngraph::Shape> shapes = {

View File

@ -31,7 +31,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
}
},
// FQ with unexpected quantizationLevels
{
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatTransformation,

View File

@ -72,9 +72,15 @@ void ConcatTransformation::validate() {
const auto transformed = transformNGraph(params, getLowPrecisionTransformationsNGraph(params));
const auto output = transformed->get_output_op(0);
const auto scaleShift = output->get_input_node_shared_ptr(0);
const std::string typeName = scaleShift->get_type_name();
ASSERT_EQ("ScaleShiftIE", typeName);
const auto previousLayer = output->get_input_node_shared_ptr(0);
const std::string typeName = previousLayer->get_type_name();
if (testValues.fqOnData1.quantizationLevel != 256ul ||
testValues.fqOnData2.quantizationLevel != 256ul) {
ASSERT_EQ("Concat", typeName);
} else {
ASSERT_EQ("ScaleShiftIE", typeName);
}
}
TEST_P(ConcatTransformation, CompareWithRefImpl) {