[LPT] Handle empty dequantization in MultiplyToGroupConvolution (#3818)

Add const
This commit is contained in:
Aleksandr Pertovsky 2021-01-28 11:23:30 +03:00 committed by GitHub
parent f88840d500
commit 885a493336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 12 deletions

View File

@ -123,6 +123,12 @@ bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const Transforma
return false;
}
const auto dequantization = NetworkHelper::getDequantization(operation, inputIndex);
if (dequantization.empty()) {
return false;
}
const Shape outShape = operation->get_output_shape(0);
if (outShape[1] % groupSize != 0) {
return false;
@ -135,15 +141,12 @@ bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const Transforma
}
if (updatePrecisions) {
auto dequantization = NetworkHelper::getDequantization(operation, inputIndex);
const element::Type parentPrecision = dequantization.data.get_element_type();
if (std::find(precisionsOnActivations.begin(), precisionsOnActivations.end(), parentPrecision) == precisionsOnActivations.end()) {
return false;
}
}
return true;
}

View File

@ -42,6 +42,7 @@ public:
ngraph::Shape inputShape;
ngraph::pass::low_precision::LayerTransformation::Params params;
bool transformed;
bool haveMultiplyWithNoConstBeforeDequantization;
Actual actual;
Expected expected;
};
@ -69,8 +70,8 @@ public:
actualFunction = ngraph::builder::subgraph::MultiplyToGroupConvolutionFunction::getOriginal(
testValues.inputShape,
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization);
testValues.actual.dequantization,
testValues.haveMultiplyWithNoConstBeforeDequantization);
SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation, ngraph::opset1::Multiply>(testValues.params);
transformer.transform(actualFunction);
@ -86,7 +87,8 @@ public:
referenceFunction = ngraph::builder::subgraph::MultiplyToGroupConvolutionFunction::getOriginal(
testValues.inputShape,
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization);
testValues.actual.dequantization,
testValues.haveMultiplyWithNoConstBeforeDequantization);
}
}
@ -97,6 +99,8 @@ public:
result <<
testValues.inputShape << "_" <<
testValues.actual.precisionBeforeDequantization << "_" <<
testValues.transformed << "_" <<
testValues.haveMultiplyWithNoConstBeforeDequantization << "_" <<
testValues.actual.dequantization;
return result.str();
}
@ -108,6 +112,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1, 1 },
LayerTransformation::createParamsU8I8(),
true,
false,
{
ngraph::element::u8,
{
@ -132,6 +137,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1, 1 },
LayerTransformation::createParamsU8I8(),
true,
false,
{
ngraph::element::u8,
{
@ -156,6 +162,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1, 1 },
LayerTransformation::createParamsU8I8(),
true,
false,
{
ngraph::element::u8,
{
@ -180,6 +187,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1, 1, 1 },
LayerTransformation::createParamsU8I8(),
true,
false,
{
ngraph::element::u8,
{
@ -204,6 +212,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1, 1 },
LayerTransformation::createParamsU8I8(),
false,
false,
{
ngraph::element::i8,
{
@ -219,6 +228,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 1, 2, 2 },
LayerTransformation::createParamsU8I8(),
false,
false,
{
ngraph::element::u8,
{
@ -234,6 +244,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
ngraph::Shape{ 1, 4, 1 },
LayerTransformation::createParamsU8I8(),
false,
false,
{
ngraph::element::u8,
{
@ -244,6 +255,30 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
},
{}
},
{
ngraph::Shape{ 1, 4, 1, 1 },
LayerTransformation::createParamsU8I8(),
false,
true,
{
ngraph::element::u8,
{
{},
{},
{{0.45f, 0.82f, 0.71f, 0.37f}}
}
},
{
ngraph::element::u8,
std::make_shared<ngraph::opset1::Constant>(ngraph::element::i8, ngraph::Shape{4, 1, 1, 1, 1}, std::vector<float>{1.f, 1.f, 1.f, 1.f}),
nullptr,
{
{},
{},
{{0.45f, 0.82f, 0.71f, 0.37f}}
}
}
},
};
TEST_P(MultiplyToGroupConvolutionTransformation, CompareFunctions) {

View File

@ -20,7 +20,8 @@ public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::Shape& inputShape,
const ngraph::element::Type& precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization);
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const bool haveMultiplyWithNoConstBeforeDequantization);
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,

View File

@ -15,16 +15,28 @@ namespace subgraph {
std::shared_ptr<ngraph::Function> MultiplyToGroupConvolutionFunction::getOriginal(
const ngraph::Shape& inputShape,
const ngraph::element::Type& precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const bool haveMultiplyWithNoConstBeforeDequantization) {
std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
precisionBeforeDequantization,
ngraph::Shape(inputShape));
const auto dequantizationOp = makeDequantization(input, dequantization);
std::shared_ptr<ngraph::op::Op> parent = input;
std::shared_ptr<ngraph::op::Parameter> secondInput;
if (haveMultiplyWithNoConstBeforeDequantization) {
secondInput = std::make_shared<ngraph::opset1::Parameter>(
precisionBeforeDequantization,
ngraph::Shape(inputShape));
parent = std::make_shared<ngraph::opset1::Multiply>(input, secondInput);
}
const auto dequantizationOp = makeDequantization(parent, dequantization);
dequantizationOp->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOp) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MultiplyToGroupConvolutionFunction");
ngraph::ParameterVector params{input};
if (haveMultiplyWithNoConstBeforeDequantization) {
params.push_back(secondInput);
}
return std::make_shared<ngraph::Function>(results, params, "MultiplyToGroupConvolutionFunction");
}
std::shared_ptr<ngraph::Function> MultiplyToGroupConvolutionFunction::getOriginal(