[LPT] Handle empty dequantization in MultiplyToGroupConvolution (#3818)
Add const
This commit is contained in:
parent
f88840d500
commit
885a493336
@ -123,6 +123,12 @@ bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const Transforma
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dequantization = NetworkHelper::getDequantization(operation, inputIndex);
|
||||
|
||||
if (dequantization.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Shape outShape = operation->get_output_shape(0);
|
||||
if (outShape[1] % groupSize != 0) {
|
||||
return false;
|
||||
@ -135,15 +141,12 @@ bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const Transforma
|
||||
}
|
||||
|
||||
if (updatePrecisions) {
|
||||
auto dequantization = NetworkHelper::getDequantization(operation, inputIndex);
|
||||
const element::Type parentPrecision = dequantization.data.get_element_type();
|
||||
if (std::find(precisionsOnActivations.begin(), precisionsOnActivations.end(), parentPrecision) == precisionsOnActivations.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -42,6 +42,7 @@ public:
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
bool transformed;
|
||||
bool haveMultiplyWithNoConstBeforeDequantization;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
@ -69,8 +70,8 @@ public:
|
||||
actualFunction = ngraph::builder::subgraph::MultiplyToGroupConvolutionFunction::getOriginal(
|
||||
testValues.inputShape,
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization);
|
||||
|
||||
testValues.actual.dequantization,
|
||||
testValues.haveMultiplyWithNoConstBeforeDequantization);
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation, ngraph::opset1::Multiply>(testValues.params);
|
||||
transformer.transform(actualFunction);
|
||||
@ -86,7 +87,8 @@ public:
|
||||
referenceFunction = ngraph::builder::subgraph::MultiplyToGroupConvolutionFunction::getOriginal(
|
||||
testValues.inputShape,
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization);
|
||||
testValues.actual.dequantization,
|
||||
testValues.haveMultiplyWithNoConstBeforeDequantization);
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,6 +99,8 @@ public:
|
||||
result <<
|
||||
testValues.inputShape << "_" <<
|
||||
testValues.actual.precisionBeforeDequantization << "_" <<
|
||||
testValues.transformed << "_" <<
|
||||
testValues.haveMultiplyWithNoConstBeforeDequantization << "_" <<
|
||||
testValues.actual.dequantization;
|
||||
return result.str();
|
||||
}
|
||||
@ -108,6 +112,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -132,6 +137,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -156,6 +162,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -180,6 +187,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -204,6 +212,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
false,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{
|
||||
@ -219,6 +228,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 1, 2, 2 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -234,6 +244,7 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
ngraph::Shape{ 1, 4, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
false,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
@ -244,6 +255,30 @@ const std::vector<MultiplyToGroupConvolutionTransformationTestValues> testValues
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::Shape{ 1, 4, 1, 1 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
false,
|
||||
true,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{},
|
||||
{},
|
||||
{{0.45f, 0.82f, 0.71f, 0.37f}}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::i8, ngraph::Shape{4, 1, 1, 1, 1}, std::vector<float>{1.f, 1.f, 1.f, 1.f}),
|
||||
nullptr,
|
||||
{
|
||||
{},
|
||||
{},
|
||||
{{0.45f, 0.82f, 0.71f, 0.37f}}
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
TEST_P(MultiplyToGroupConvolutionTransformation, CompareFunctions) {
|
||||
|
@ -20,7 +20,8 @@ public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type& precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization);
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const bool haveMultiplyWithNoConstBeforeDequantization);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
|
@ -15,16 +15,28 @@ namespace subgraph {
|
||||
std::shared_ptr<ngraph::Function> MultiplyToGroupConvolutionFunction::getOriginal(
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type& precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
|
||||
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const bool haveMultiplyWithNoConstBeforeDequantization) {
|
||||
std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
precisionBeforeDequantization,
|
||||
ngraph::Shape(inputShape));
|
||||
|
||||
const auto dequantizationOp = makeDequantization(input, dequantization);
|
||||
std::shared_ptr<ngraph::op::Op> parent = input;
|
||||
std::shared_ptr<ngraph::op::Parameter> secondInput;
|
||||
if (haveMultiplyWithNoConstBeforeDequantization) {
|
||||
secondInput = std::make_shared<ngraph::opset1::Parameter>(
|
||||
precisionBeforeDequantization,
|
||||
ngraph::Shape(inputShape));
|
||||
parent = std::make_shared<ngraph::opset1::Multiply>(input, secondInput);
|
||||
}
|
||||
const auto dequantizationOp = makeDequantization(parent, dequantization);
|
||||
dequantizationOp->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOp) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MultiplyToGroupConvolutionFunction");
|
||||
ngraph::ParameterVector params{input};
|
||||
if (haveMultiplyWithNoConstBeforeDequantization) {
|
||||
params.push_back(secondInput);
|
||||
}
|
||||
return std::make_shared<ngraph::Function>(results, params, "MultiplyToGroupConvolutionFunction");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> MultiplyToGroupConvolutionFunction::getOriginal(
|
||||
|
Loading…
Reference in New Issue
Block a user