Cleanup & refactoring (#15588)
This commit is contained in:
parent
1c20005b2f
commit
b8a7b3bb43
@ -132,15 +132,6 @@ public:
|
||||
const ngraph::Output<ngraph::Node>& parent,
|
||||
const ngraph::Output<ngraph::Node>& subtract_constant);
|
||||
|
||||
static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
|
||||
std::shared_ptr<opset1::FakeQuantize> fq,
|
||||
element::Type precision,
|
||||
float min,
|
||||
float max,
|
||||
const bool hasZeroPoint,
|
||||
const bool updatePrecision,
|
||||
const element::Type deqPrecision = element::f32);
|
||||
|
||||
static bool areQuantizeAndDequantizeSupportedForSubtract(const std::shared_ptr<const ngraph::Node>& node,
|
||||
const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support);
|
||||
|
||||
@ -259,7 +250,7 @@ public:
|
||||
|
||||
static ov::Output<ov::Node> getSingleConsumerConstant(const ov::Output<ov::Node>& output);
|
||||
|
||||
static bool checkConstantOnInf(const std::shared_ptr<Node> constant_node);
|
||||
static bool checkConstantNotInf(const std::shared_ptr<Node> constant_node);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Node> foldFakeQuantize(
|
||||
|
@ -189,8 +189,8 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
|
||||
auto newMultiplyFullPathValues = fold<opset1::Divide>(multiplyFullPathValues, multiplyEmptyPathValues);
|
||||
|
||||
// Transformation can't be applied if new full path values brake accuracy because of Inf values
|
||||
if (!NetworkHelper::checkConstantOnInf(newSubtractFullPathValues) ||
|
||||
!NetworkHelper::checkConstantOnInf(newMultiplyFullPathValues)) {
|
||||
if (!NetworkHelper::checkConstantNotInf(newSubtractFullPathValues) ||
|
||||
!NetworkHelper::checkConstantNotInf(newMultiplyFullPathValues)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -179,8 +179,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
|
||||
inputLowConst_f32 = fold<opset1::Divide>(inputLowConst_f32, value);
|
||||
inputHighConst_f32 = fold<opset1::Divide>(inputHighConst_f32, value);
|
||||
if (!NetworkHelper::checkConstantOnInf(inputLowConst_f32) ||
|
||||
!NetworkHelper::checkConstantOnInf(inputHighConst_f32)) {
|
||||
if (!NetworkHelper::checkConstantNotInf(inputLowConst_f32) ||
|
||||
!NetworkHelper::checkConstantNotInf(inputHighConst_f32)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -1200,72 +1200,6 @@ std::shared_ptr<ov::Node> NetworkHelper::makeDequantizationSubtract(
|
||||
: std::make_shared<opset1::Subtract>(parent, subtract_constant);
|
||||
}
|
||||
|
||||
FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
|
||||
std::shared_ptr<opset1::FakeQuantize> fq,
|
||||
element::Type precision,
|
||||
float min,
|
||||
float max,
|
||||
const bool hasZeroPoint,
|
||||
const bool updatePrecision,
|
||||
const element::Type deqPrecision) {
|
||||
const ngraph::element::Type_t fqPrecision = fq->get_output_element_type(0);
|
||||
auto newMin = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, min);
|
||||
auto newMax = std::make_shared<opset1::Constant>(fqPrecision, Shape{}, max);
|
||||
|
||||
auto outputLow = fq->input_value(3);
|
||||
auto outputHigh = fq->input_value(4);
|
||||
|
||||
// TODO: threshold values have to used here to avoid shifts
|
||||
|
||||
const std::shared_ptr<opset1::Constant> scale = ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
|
||||
fold<opset1::Subtract>(outputHigh, outputLow),
|
||||
fold<opset1::Subtract>(newMax->output(0), newMin->output(0))), deqPrecision));
|
||||
assert(scale != nullptr);
|
||||
|
||||
std::shared_ptr<opset1::Constant> shift = hasZeroPoint ?
|
||||
ov::as_type_ptr<opset1::Constant>(foldConvert(fold<opset1::Divide>(
|
||||
fold<opset1::Subtract>(fold<opset1::Multiply>(newMin->output(0), outputHigh), fold<opset1::Multiply>(newMax->output(0), outputLow)),
|
||||
fold<opset1::Subtract>(outputHigh, outputLow)), deqPrecision)) :
|
||||
nullptr;
|
||||
assert((!hasZeroPoint) || (hasZeroPoint && shift != nullptr));
|
||||
|
||||
if (shift != nullptr) {
|
||||
std::shared_ptr<opset1::Constant> shiftConst = ov::as_type_ptr<opset1::Constant>(shift);
|
||||
if (isScalarLike(shiftConst)) {
|
||||
auto scalar = toScalar(shiftConst);
|
||||
if (ov::op::util::constantIsEqualTo(scalar, 0)) {
|
||||
shift = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
updatePrecision ? precision : fq->get_output_element_type(0),
|
||||
fq->get_output_partial_shape(0));
|
||||
std::shared_ptr<ngraph::Node> parent = input;
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Convert> convert;
|
||||
if (updatePrecision || (parent->output(0).get_element_type() != deqPrecision)) {
|
||||
convert = std::make_shared<opset1::Convert>(parent, deqPrecision);
|
||||
parent = convert;
|
||||
} else {
|
||||
convert = nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Subtract> subtract;
|
||||
if (shift != nullptr) {
|
||||
subtract = std::make_shared<ov::op::TypeRelaxed<opset1::Subtract>>(parent, shift);
|
||||
subtract->set_output_type(0, deqPrecision, subtract->get_output_partial_shape(0));
|
||||
parent = subtract;
|
||||
} else {
|
||||
subtract = nullptr;
|
||||
}
|
||||
const std::shared_ptr<ngraph::opset1::Multiply> multiply = std::make_shared<opset1::Multiply>(parent, scale);
|
||||
multiply->set_output_type(0, fq->get_output_element_type(0), multiply->get_output_partial_shape(0));
|
||||
|
||||
return FakeQuantizeDequantization(fq, convert, subtract, nullptr, shift, multiply, scale);
|
||||
}
|
||||
|
||||
bool NetworkHelper::areQuantizeAndDequantizeSupportedForSubtract(const std::shared_ptr<const ngraph::Node>& node,
|
||||
const std::vector<ngraph::element::Type>& defaultPrecisions) {
|
||||
if (!ov::is_type<opset1::Subtract>(node)) {
|
||||
@ -2011,7 +1945,7 @@ ov::Output<ov::Node> NetworkHelper::getSingleConsumerConstant(const ov::Output<o
|
||||
: node->clone_with_new_inputs(node->input_values())->output(0);
|
||||
}
|
||||
|
||||
bool NetworkHelper::checkConstantOnInf(const std::shared_ptr<Node> constant_node) {
|
||||
bool NetworkHelper::checkConstantNotInf(const std::shared_ptr<Node> constant_node) {
|
||||
const auto constant = ov::as_type_ptr<opset1::Constant>(constant_node);
|
||||
if (constant == nullptr)
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user