[LPT] Concat*Transformation: handled FQ with unexpected quantization levels (#5111)
* [LPT] TRANSFORMATION_EXCEPTION segfault fix * [LPT] Concat*Transformation: handled FQ with unexpected quantization levels * [TESTS] Concat*Transformation: added test-cases with unexpected quant levels
This commit is contained in:
parent
f438a3a321
commit
a478475386
@ -23,6 +23,10 @@ class TRANSFORMATIONS_API Exception : std::exception {
|
||||
std::shared_ptr<std::ostringstream> buffer;
|
||||
mutable std::string buffer_str;
|
||||
public:
|
||||
Exception() {
|
||||
buffer = std::make_shared<std::ostringstream>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Exception& operator<< (const T& x) {
|
||||
*buffer << x;
|
||||
|
@ -46,6 +46,10 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
// precisions can be different
|
||||
ngraph::Node& quantizationLayer = *subgraph.quantizationLayers[0];
|
||||
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer.shared_from_this());
|
||||
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DataPrecision dataPrecision = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
|
||||
if (dataPrecision.precision == ngraph::element::undefined) {
|
||||
return false;
|
||||
|
@ -64,6 +64,10 @@ bool ConcatMultiChannelsTransformation::transform(TransformationContext& context
|
||||
{
|
||||
for (auto quantizationLayer : subgraph.quantizationLayers) {
|
||||
std::shared_ptr<ngraph::opset1::FakeQuantize> fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(quantizationLayer->shared_from_this());
|
||||
if (!NetworkHelper::isQuantizeSupported(fq)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const DataPrecision tmp = getDataPrecision(fq, QuantizationDetails::getDetails(fq), false);
|
||||
|
||||
if (dataPrecision.precision == ngraph::element::undefined) {
|
||||
|
@ -610,6 +610,52 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
ngraph::element::f32,
|
||||
{ {element::f32}, {}, { 0.01f } },
|
||||
}
|
||||
},
|
||||
// unexpected quantization levels, concat
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
{
|
||||
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
}
|
||||
},
|
||||
// unexpected quantization levels, concat multi channels
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
{
|
||||
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 16ul, {}, {0.f}, {1.5f}, {0.f}, {15.f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -36,7 +36,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
|
||||
}
|
||||
},
|
||||
// FQ with unexpected quantizationLevels
|
||||
{
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
|
||||
},
|
||||
};
|
||||
|
||||
const std::vector<ngraph::Shape> shapes = {
|
||||
|
@ -31,7 +31,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
|
||||
}
|
||||
},
|
||||
// FQ with unexpected quantizationLevels
|
||||
{
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} },
|
||||
{ 16ul, ngraph::Shape({}), {0.f}, {15.f}, {0.f}, {1.5f} }
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_LPT, ConcatTransformation,
|
||||
|
@ -72,9 +72,15 @@ void ConcatTransformation::validate() {
|
||||
const auto transformed = transformNGraph(params, getLowPrecisionTransformationsNGraph(params));
|
||||
|
||||
const auto output = transformed->get_output_op(0);
|
||||
const auto scaleShift = output->get_input_node_shared_ptr(0);
|
||||
const std::string typeName = scaleShift->get_type_name();
|
||||
ASSERT_EQ("ScaleShiftIE", typeName);
|
||||
const auto previousLayer = output->get_input_node_shared_ptr(0);
|
||||
const std::string typeName = previousLayer->get_type_name();
|
||||
|
||||
if (testValues.fqOnData1.quantizationLevel != 256ul ||
|
||||
testValues.fqOnData2.quantizationLevel != 256ul) {
|
||||
ASSERT_EQ("Concat", typeName);
|
||||
} else {
|
||||
ASSERT_EQ("ScaleShiftIE", typeName);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ConcatTransformation, CompareWithRefImpl) {
|
||||
|
Loading…
Reference in New Issue
Block a user