Cleanup & refactoring (#15588)

This commit is contained in:
Vladislav Golubev 2023-02-13 10:50:46 +01:00 committed by GitHub
parent 1c20005b2f
commit b8a7b3bb43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 81 deletions

View File

@ -132,15 +132,6 @@ public:
const ngraph::Output<ngraph::Node>& parent,
const ngraph::Output<ngraph::Node>& subtract_constant);
static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision,
float min,
float max,
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision = element::f32);
static bool areQuantizeAndDequantizeSupportedForSubtract(const std::shared_ptr<const ngraph::Node>& node,
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support);
@ -259,7 +250,7 @@ public:
static ov::Output<ov::Node> getSingleConsumerConstant(const ov::Output<ov::Node>& output);
static bool checkConstantOnInf(const std::shared_ptr<Node> constant_node);
static bool checkConstantNotInf(const std::shared_ptr<Node> constant_node);
private:
static std::shared_ptr<Node> foldFakeQuantize(

View File

@ -189,8 +189,8 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
auto newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues);
// Transformation can't be applied if new full path values brake accuracy because of Inf values
if (!NetworkHelper::checkConstantOnInf(newSubtractFullPathValues) ||
!NetworkHelper::checkConstantOnInf(newMultiplyFullPathValues)) {
if (!NetworkHelper::checkConstantNotInf(newSubtractFullPathValues) ||
!NetworkHelper::checkConstantNotInf(newMultiplyFullPathValues)) {
return false;
}

View File

@ -179,8 +179,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
inputLowConst_f32 = fold<opset1::Divide>(inputLowConst_f32, value);
inputHighConst_f32 = fold<opset1::Divide>(inputHighConst_f32, value);
if (!NetworkHelper::checkConstantOnInf(inputLowConst_f32) ||
!NetworkHelper::checkConstantOnInf(inputHighConst_f32)) {
if (!NetworkHelper::checkConstantNotInf(inputLowConst_f32) ||
!NetworkHelper::checkConstantNotInf(inputHighConst_f32)) {
return nullptr;
}

View File

@ -1200,72 +1200,6 @@ std::shared_ptr<ov::Node> NetworkHelper::makeDequantizationSubtract(
: std::make_shared<opset1::Subtract>(parent, subtract_constant);
}
FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision,
float min,
float max,
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision) {
const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
auto newMin = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
auto newMax = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
auto outputLow = fq->input_value(3);
auto outputHigh = fq->input_value(4);
// TODO: threshold values have to used here to avoid shifts
const std::shared_ptr<opset1::Constant> scale = ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
fold<opset1::Subtract>(outputHigh, outputLow),
fold<opset1::Subtract>(newMax->output(0), newMin->output(0))), deqPrecision));
assert(scale != nullptr);
std::shared_ptr<opset1::Constant> shift = hasZeroPoint ?
ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
fold<opset1::Subtract>(fold<opset1::Multiply>(newMin->output(0), outputHigh), fold<opset1::Multiply>(newMax->output(0), outputLow)),
fold<opset1::Subtract>(outputHigh, outputLow)), deqPrecision)) :
nullptr;
assert((!hasZeroPoint) || (hasZeroPoint && shift != nullptr));
if (shift != nullptr) {
std::shared_ptr<opset1::Constant> shiftConst = ov::as_type_ptr<opset1::Constant>(shift);
if (isScalarLike(shiftConst)) {
auto scalar = toScalar(shiftConst);
if (ov::op::util::constantIsEqualTo(scalar, 0)) {
shift = nullptr;
}
}
}
const auto input = std::make_shared<ngraph::opset1::Parameter>(
updatePrecision ? precision : fq->get_output_element_type(0),
fq->get_output_partial_shape(0));
std::shared_ptr<ngraph::Node> parent = input;
std::shared_ptr<ngraph::opset1::Convert> convert;
if (updatePrecision || (parent->output(0).get_element_type() != deqPrecision)) {
convert = std::make_shared<opset1::Convert>(parent, deqPrecision);
parent = convert;
} else {
convert = nullptr;
}
std::shared_ptr<ngraph::opset1::Subtract> subtract;
if (shift != nullptr) {
subtract = std::make_shared<ov::op::TypeRelaxed<opset1::Subtract>>(parent, shift);
subtract->set_output_type(0, deqPrecision, subtract->get_output_partial_shape(0));
parent = subtract;
} else {
subtract = nullptr;
}
const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<opset1::Multiply>(parent, scale);
multiply->set_output_type(0, fq->get_output_element_type(0), multiply->get_output_partial_shape(0));
return FakeQuantizeDequantization(fq, convert, subtract, nullptr, shift, multiply, scale);
}
bool NetworkHelper::areQuantizeAndDequantizeSupportedForSubtract(const std::shared_ptr<const ngraph::Node>& node,
const std::vector<ngraph::element::Type>& defaultPrecisions) {
if (!ov::is_type<opset1::Subtract>(node)) {
@ -2011,7 +1945,7 @@ ov::Output<ov::Node> NetworkHelper::getSingleConsumerConstant(const ov::Output<o
: node->clone_with_new_inputs(node->input_values())->output(0);
}
bool NetworkHelper::checkConstantOnInf(const std::shared_ptr<Node> constant_node) {
bool NetworkHelper::checkConstantNotInf(const std::shared_ptr<Node> constant_node) {
const auto constant = ov::as_type_ptr<opset1::Constant>(constant_node);
if (constant == nullptr)
return false;