[LPT] Some LP Transformations improvements (#6434)

* [LPT] LayerTransformation::canBeTransformed: replaced legacy code

* [LPT] NetworkHelper::moveDequantizationAfter refactoring

* [LPT] ReshapeTransformation improvement

* [LPT] Squeeze/UnsqueezeTransformation improvement
This commit is contained in:
Vladislav Golubev 2021-08-11 10:10:56 +03:00 committed by GitHub
parent 0834ae2e6d
commit 9d9a24c914
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 10 deletions

View File

@ -1560,14 +1560,14 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
if (updatePrecision) {
op->set_overridden_output_type(newOperation->get_input_element_type(0));
} else if (dequantization.multiply) {
op->set_overridden_output_type(dequantization.multiply->get_input_element_type(1));
op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type());
} else if (dequantization.subtract) {
op->set_overridden_output_type(dequantization.subtract->get_input_element_type(1));
op->set_overridden_output_type(dequantization.subtractConstant->get_element_type());
}
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
}
const element::Type deqPrecision = dequantization.multiply->get_input_node_shared_ptr(1)->get_output_element_type(0);
const element::Type deqPrecision = dequantization.multiplyConstant->get_element_type();
const bool shouldConvert = (newOperation->get_output_element_type(0) != deqPrecision);
auto parent = newOperation;
@ -1582,11 +1582,11 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
if (moveSubtract && (dequantization.subtract != nullptr)) {
if (dequantization.subtractConvert == nullptr) {
const element::Type parentPrecision = parent->get_output_element_type(0);
if (parentPrecision.bitwidth() < dequantization.subtractConstant->output(0).get_element_type().bitwidth()) {
if (parentPrecision.bitwidth() < dequantization.subtractConstant->get_element_type().bitwidth()) {
THROW_IE_LPT_EXCEPTION(*parent) <<
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
", subtract dequantization constant " << dequantization.subtractConstant->get_friendly_name() << ":" <<
dequantization.subtractConstant->output(0).get_element_type();
dequantization.subtractConstant->get_element_type();
}
parent = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
@ -1604,12 +1604,12 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
}
if (dequantization.multiply != nullptr) {
auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
auto multiplyConstant = dequantization.multiplyConstant;
const element::Type parentPrecision = parent->get_output_element_type(0);
if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
if (parentPrecision.bitwidth() < multiplyConstant->get_element_type().bitwidth()) {
THROW_IE_LPT_EXCEPTION(*parent) <<
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->get_element_type();
}
parent = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(

View File

@ -42,9 +42,14 @@ bool SqueezeTransformation::transform(TransformationContext& context, ngraph::pa
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
const ngraph::PartialShape& inputShape) {
const size_t inputRankValue = inputShape.rank().get_length();
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
const auto constantShape = dequantizationOpConstant->get_shape();
if (shape_size(constantShape) == 1ul) {
return NetworkHelper::toScalar(dequantizationOpConstant);
}
if (constantShape.size() == inputRankValue) {
return as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->get_input_node_shared_ptr(1)));
}
return dequantizationOpConstant;
};

View File

@ -42,12 +42,19 @@ bool UnsqueezeTransformation::transform(TransformationContext& context, ngraph::
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
const ngraph::PartialShape& inputShape) {
const size_t inputRankValue = inputShape.rank().get_length();
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
const auto constantShape = dequantizationOpConstant->get_shape();
if (shape_size(constantShape) == 1ul) {
return NetworkHelper::toScalar(dequantizationOpConstant);
}
if (constantShape.size() == inputRankValue) {
return as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->get_input_node_shared_ptr(1)));
}
return dequantizationOpConstant;
};
const std::shared_ptr<Node> unsqueeze = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(unsqueeze);