[LPT] ConvolutionTransformation with asymmetric quantization after Split fix (#5895)

This commit is contained in:
Vladislav Golubev 2021-06-01 16:58:33 +03:00 committed by GitHub
parent 28c10b1727
commit 22881aec5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 6 deletions

View File

@ -90,14 +90,14 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
broadcastShape[1] = subtract->get_output_shape(0)[1];
std::shared_ptr<Node> newShift = fold<opset1::Broadcast>(
subtract->input_value(1).get_node_shared_ptr(),
subtract->input_value(1),
std::make_shared<opset1::Constant>(
element::i64,
Shape{ length },
broadcastShape));
const auto newSubtract = as_type_ptr<opset1::Subtract>(subtract->clone_with_new_inputs({
subtract->input_value(0).get_node_shared_ptr(),
subtract->input_value(0),
newShift }));
NetworkHelper::copyInfo(subtract, newSubtract);
replace_node(subtract, newSubtract);
@ -176,7 +176,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
if (is_type<opset1::Convert>(convolution->get_input_node_ptr(0))) {
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_node_ptr(0)->get_input_source_output(0),
convolution->get_input_node_shared_ptr(1) });
convolution->input_value(1)});
replace_node(convolution, newConvolution);
convolution = newConvolution;
}
@ -249,7 +249,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
zeroPointShape[0] = weightsShape[0];
auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->get_input_node_shared_ptr(1),
subtractFromWeights->input_value(1),
std::make_shared<opset1::Constant>(element::i32, Shape{ zeroPointShape.size() }, zeroPointShape));
NetworkHelper::copyInfo(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
@ -266,7 +266,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_source_output(0),
childNode.get() == convolution.get() ?
convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0) :
convolution->get_input_node_ptr(1)->input_value(0) :
childNode->copy_with_new_inputs({convertFromWeights->input_value(0), childNode->input_value(1)})});
replace_node(convolution, newConvolution);
convolution = newConvolution;

View File

@ -69,7 +69,7 @@ SimpleLowPrecisionTransformer getTransformerWithTransformationByName(
transformer.add<ClampTransformation, opset1::Clamp>(params);
return transformer;
}
if (name == "ConvolutionTransformation") {
if (name == "ConvolutionTransformation" || name == "AsymmetricConvolutionTransformation") {
transformer.add<ConvolutionTransformation, opset1::Convolution>(params);
return transformer;
}
@ -190,6 +190,7 @@ const std::vector<std::string> transformationNames = {
"AvgPoolTransformation",
"ClampTransformation",
"ConvolutionTransformation",
"AsymmetricConvolutionTransformation",
"DepthToSpaceTransformation",
"FakeQuantizeTransformation",
"InterpolateTransformation",

View File

@ -78,6 +78,18 @@ std::shared_ptr<Node> TransformationsAfterSplitFunction::getLayerByTransformatio
CoordinateDiff{ 0, 0 },
Strides{ 1, 1 });
}
if (transformationName == "AsymmetricConvolutionTransformation") {
const auto dequantizationOnData = makeDequantization(parent, { {element::f32}, { 128.f }, { 0.1f } });
const auto weights = opset1::Constant::create(element::i8, Shape{ 3, 3, 1, 1 }, { 2 });
const auto dequantizationOnWeights = makeDequantization(weights, { {element::f32}, {}, {0.3f} });
return std::make_shared<opset1::Convolution>(
dequantizationOnData,
dequantizationOnWeights,
Strides{ 1, 1 },
CoordinateDiff{ 0, 0 },
CoordinateDiff{ 0, 0 },
Strides{ 1, 1 });
}
if (transformationName == "DepthToSpaceTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<opset1::DepthToSpace>(dequantization, opset1::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 3);