[LPT] Checks to not transform layers with incorrect zero points (#4764)

* [LPT] Checks to not transform layers with incorrect zero points

* [LPT] Fold not transformed weights

* [LPT] Minor fixes; review from #5313
This commit is contained in:
Vladimir Zinoviev 2021-05-04 16:02:27 +03:00 committed by GitHub
parent 895b605c06
commit 866515184c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 25 deletions

View File

@ -21,6 +21,7 @@
#include "transformations/utils/utils.hpp"
#include "common/fake_quantize_dequantization.hpp"
#include "common/ie_lpt_exception.hpp"
#include "layer_transformation.hpp"
namespace ngraph {
namespace pass {
@ -177,6 +178,7 @@ public:
static FakeQuantizeDequantizationValues createEmptyValues(const FakeQuantizeDequantization& dequantization);
static bool isZeroConst(const std::shared_ptr<Node>& node);
static bool checkZeroPoint(const std::shared_ptr<Node>& node, const DataPrecision& dataPrecision = DataPrecision());
static std::shared_ptr<Node> toScalarIfPossible(std::shared_ptr<Node> node);

View File

@ -17,6 +17,7 @@ class TRANSFORMATIONS_API WeightableLayerTransformation : public LayerTransforma
public:
WeightableLayerTransformation(const Params& params);
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const;
bool isQuantized(std::shared_ptr<Node> layer, bool reshapeIsRequired) const noexcept;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

View File

@ -36,28 +36,17 @@ bool ConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) const n
return WeightableLayerTransformation::isQuantized(layer, false);
}
bool ConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
auto convolution = m.get_match_root();
if (!WeightableLayerTransformation::canBeTransformed(context, convolution)) {
return false;
}
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(convolution);
if (!canSubtractBeHandled(convolution, dequantization)) {
return false;
}
if ((!supportAsymmetricQuantization) && getDataPrecisionOnWeights(convolution).hasZeroPoint) {
return false;
}
if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
if (!canConvolutionBeTransformed(context, convolution)) {
return false;
}
convolution = NetworkHelper::separateInStandaloneBranch(convolution);
dequantization = NetworkHelper::getDequantization(convolution);
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(convolution);
{
std::shared_ptr<opset1::Subtract> subtract;
@ -177,7 +166,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(convolution->input_value(1).get_node_shared_ptr());
const auto dequantization = reshapeFromWeights == nullptr ?
dequantization = reshapeFromWeights == nullptr ?
NetworkHelper::getDequantization(convolution, 1ul) :
NetworkHelper::getDequantization(reshapeFromWeights);
assert(!dequantization.empty());
@ -292,7 +281,6 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
}
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -18,12 +18,6 @@ GroupConvolutionTransformation::GroupConvolutionTransformation(const Params& par
}
void GroupConvolutionTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
// question to nGraph: why it doesn't work
// addPattern(
// pass,
// context,
// make_op_pattern<opset1::GroupConvolution>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>()}));
addSingleNodePattern<opset1::GroupConvolution>(pass, context);
}

View File

@ -188,6 +188,10 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
return false;
}
}
if (!NetworkHelper::checkZeroPoint(dequantization1.subtract)) {
return false;
}
}
const auto dequantization2 = NetworkHelper::getDequantization(layer, 1);

View File

@ -19,6 +19,7 @@
#include <ngraph/rt_info.hpp>
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/common/dequantization_op.hpp"
#include "low_precision/layer_transformation.hpp"
namespace ngraph {
namespace pass {
@ -1540,6 +1541,62 @@ std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> no
return NetworkHelper::toScalar(constant);
}
bool NetworkHelper::checkZeroPoint(const std::shared_ptr<Node>& node, const DataPrecision& dataPrecision) {
if (!node) {
return true;
}
float min, max;
if (is_type<opset1::Subtract>(node)) {
const auto parent = node->get_input_node_shared_ptr(0);
const auto intNode = is_type<opset1::Convert>(parent) ? parent : node;
const auto intType = intNode->get_input_element_type(0);
if (intType == element::u8 || intType == element::i8) {
min = DataPrecision::getMinValue(intType, 256) - 0.5f;
max = DataPrecision::getMaxValue(intType, 256) + 0.5f;
} else {
return false;
}
auto subtract1input = node->get_input_node_shared_ptr(1);
if (is_type<opset1::Convert>(subtract1input)) {
return true;
}
auto subtractConst = as_type_ptr<opset1::Constant>(subtract1input);
if (!subtractConst) {
subtractConst = as_type_ptr<opset1::Constant>(node->get_input_node_shared_ptr(1)->get_input_node_shared_ptr(0));
if (subtractConst == nullptr) {
return false;
}
}
const auto subtractValues = subtractConst->cast_vector<float>();
if (std::any_of(subtractValues.begin(), subtractValues.end(), [min, max] (const float& val) {
return (val < min) || (val > max); })) {
return false;
}
} else if (is_type<opset1::FakeQuantize>(node)) {
if (!dataPrecision.hasZeroPoint) {
return true;
}
min = dataPrecision.min - 0.5f;
max = dataPrecision.max + 0.5f;
const auto quantizationDetails = QuantizationDetails::getDetails(as_type_ptr<opset1::FakeQuantize>(node));
for (size_t i = 0; i < quantizationDetails.outputIntervalsCount; ++i) {
float shift;
if (quantizationDetails.outputHighValues[i] != quantizationDetails.outputLowValues[i]) {
shift = (dataPrecision.min * quantizationDetails.outputHighValues[i] -
dataPrecision.max * quantizationDetails.outputLowValues[i]) /
(quantizationDetails.outputHighValues[i] - quantizationDetails.outputLowValues[i]);
} else {
shift = 0.f;
}
if (shift < min || shift > max) {
return false;
}
}
}
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -16,6 +16,61 @@ namespace low_precision {
WeightableLayerTransformation::WeightableLayerTransformation(const Params& params) : LayerTransformation(params) {}
bool WeightableLayerTransformation::canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!WeightableLayerTransformation::canBeTransformed(context, layer)) {
return false;
}
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(layer);
if (!canSubtractBeHandled(layer, dequantization)) {
return false;
}
if (updatePrecisions && !NetworkHelper::checkZeroPoint(dequantization.subtract)) {
return false;
}
if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
return false;
}
std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(layer->get_input_node_shared_ptr(1));
dequantization = reshapeFromWeights == nullptr ?
NetworkHelper::getDequantization(layer, 1ul) :
NetworkHelper::getDequantization(reshapeFromWeights);
if (dequantization.empty()) {
const auto fqOnWeights = getFakeQuantizeOnWeights(layer);
const auto dataPrecision = getDataPrecisionOnWeights(layer);
if ((!supportAsymmetricQuantization) && dataPrecision.hasZeroPoint) {
return false;
}
if (!NetworkHelper::checkZeroPoint(fqOnWeights, dataPrecision)) {
const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(fqOnWeights);
if (as_type_ptr<opset1::Constant>(resultConstant)) {
replace_node(fqOnWeights, resultConstant);
}
return false;
}
} else {
if (!NetworkHelper::checkZeroPoint(dequantization.subtract)) {
const auto resultDequantization = NetworkHelper::foldDequantization(dequantization.multiply, 0, true);
if (resultDequantization.empty() && reshapeFromWeights) {
const auto foldedReshape = fold<opset1::Reshape>(
reshapeFromWeights->get_input_node_shared_ptr(0),
reshapeFromWeights->get_input_node_shared_ptr(1),
reshapeFromWeights->get_special_zero());
if (is_type<opset1::Constant>(foldedReshape)) {
replace_node(reshapeFromWeights, foldedReshape);
}
}
return false;
}
}
return true;
}
bool WeightableLayerTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!LayerTransformation::canBeTransformed(context, layer)) {
return false;

View File

@ -404,7 +404,83 @@ const std::vector<ConvolutionQDqTransformationTestValues> testValues = {
ngraph::element::f32,
{{}, {}, {{ 0.0006f }, ngraph::element::f32, { 1, 1, 1, 1 }}}
}
}
},
// incorrect zero point on activations [not transformed]
{
LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {1000.f}, element::f32, {}, false },
{ {0.02f}, element::f32, {}, false }
},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false },
{ {0.03f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8},
{},
ngraph::element::f32,
{}
},
// ExpectedValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {1000.f}, element::f32, {}, false },
{ {0.02f}, element::f32, {}, false }
},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false },
{ {0.03f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8},
{},
ngraph::element::f32,
{}
}
},
// incorrect zero point on weights [not transformed, weights folded]
{
LayerTransformation::createParamsU8I8().setSupportAsymmetricQuantization(true),
// ActualValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{
{ ngraph::element::f32, false },
{ {1000.f}, element::f32, {}, false },
{ {0.03f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8},
{},
ngraph::element::f32,
{}
},
// ExpectedValues
{
ngraph::element::u8,
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{},
{ std::vector<float>{ -29.94f }, ngraph::element::f32},
{},
ngraph::element::f32,
{}
}
},
};
INSTANTIATE_TEST_CASE_P(

View File

@ -382,6 +382,46 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
{{}, {}, {{ 0.0002f }, ngraph::element::f32, { 1, 1, 1, 1 }}}
}
},
// incorrect zero point on activations [not transformed]
{
LayerTransformation::createParamsU8I8(),
// ActualValues
{
ngraph::element::u8,
{{element::f32}, { 1000.f }, { {0.02f}, element::f32 }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
},
// ExpectedValues
{
ngraph::element::u8,
{{element::f32}, { 1000.f }, { {0.02f}, element::f32 }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } },
ngraph::element::f32,
{}
}
},
// incorrect zero point on weights [not transformed, weights folded]
{
LayerTransformation::createParamsU8I8(),
// ActualValues
{
ngraph::element::u8,
{{element::f32}, {}, { {0.02f}, element::f32 }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 0.f }),
{ 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { 5.f }, { 6.f } }
},
// ExpectedValues
{
ngraph::element::u8,
{{element::f32}, {}, { {0.02f}, element::f32 }},
op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 5.f }),
{},
ngraph::element::f32,
{}
}
},
};
INSTANTIATE_TEST_CASE_P(

View File

@ -244,7 +244,9 @@ std::shared_ptr<ngraph::Function> ConvolutionFunction::getReference(
const auto convertedWeights = convertedOutput[0].get_node_shared_ptr();
std::shared_ptr<ngraph::Node> onWeights = fakeQuantizeOnWeights.empty() ?
std::dynamic_pointer_cast<ngraph::Node>(weights) :
(weights->get_output_element_type(0).is_real() ?
convertedWeights :
std::dynamic_pointer_cast<ngraph::Node>(weights)) :
ngraph::builder::makeFakeQuantize(
convertedWeights->output(0),
netPrecision,