[LPT] Some LP Transformations improvements (#6434)
* [LPT] LayerTransformation::canBeTransformed: replaced legacy code * [LPT] NetworkHelper::moveDequantizationAfter refactoring * [LPT] ReshapeTransformation improvement * [LPT] Squeeze/UnsqueezeTransformation improvement
This commit is contained in:
parent
0834ae2e6d
commit
9d9a24c914
@ -1560,14 +1560,14 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
|||||||
if (updatePrecision) {
|
if (updatePrecision) {
|
||||||
op->set_overridden_output_type(newOperation->get_input_element_type(0));
|
op->set_overridden_output_type(newOperation->get_input_element_type(0));
|
||||||
} else if (dequantization.multiply) {
|
} else if (dequantization.multiply) {
|
||||||
op->set_overridden_output_type(dequantization.multiply->get_input_element_type(1));
|
op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type());
|
||||||
} else if (dequantization.subtract) {
|
} else if (dequantization.subtract) {
|
||||||
op->set_overridden_output_type(dequantization.subtract->get_input_element_type(1));
|
op->set_overridden_output_type(dequantization.subtractConstant->get_element_type());
|
||||||
}
|
}
|
||||||
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
|
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
const element::Type deqPrecision = dequantization.multiply->get_input_node_shared_ptr(1)->get_output_element_type(0);
|
const element::Type deqPrecision = dequantization.multiplyConstant->get_element_type();
|
||||||
const bool shouldConvert = (newOperation->get_output_element_type(0) != deqPrecision);
|
const bool shouldConvert = (newOperation->get_output_element_type(0) != deqPrecision);
|
||||||
|
|
||||||
auto parent = newOperation;
|
auto parent = newOperation;
|
||||||
@ -1582,11 +1582,11 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
|||||||
if (moveSubtract && (dequantization.subtract != nullptr)) {
|
if (moveSubtract && (dequantization.subtract != nullptr)) {
|
||||||
if (dequantization.subtractConvert == nullptr) {
|
if (dequantization.subtractConvert == nullptr) {
|
||||||
const element::Type parentPrecision = parent->get_output_element_type(0);
|
const element::Type parentPrecision = parent->get_output_element_type(0);
|
||||||
if (parentPrecision.bitwidth() < dequantization.subtractConstant->output(0).get_element_type().bitwidth()) {
|
if (parentPrecision.bitwidth() < dequantization.subtractConstant->get_element_type().bitwidth()) {
|
||||||
THROW_IE_LPT_EXCEPTION(*parent) <<
|
THROW_IE_LPT_EXCEPTION(*parent) <<
|
||||||
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
||||||
", subtract dequantization constant " << dequantization.subtractConstant->get_friendly_name() << ":" <<
|
", subtract dequantization constant " << dequantization.subtractConstant->get_friendly_name() << ":" <<
|
||||||
dequantization.subtractConstant->output(0).get_element_type();
|
dequantization.subtractConstant->get_element_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
parent = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
|
parent = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
|
||||||
@ -1604,12 +1604,12 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (dequantization.multiply != nullptr) {
|
if (dequantization.multiply != nullptr) {
|
||||||
auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
|
auto multiplyConstant = dequantization.multiplyConstant;
|
||||||
const element::Type parentPrecision = parent->get_output_element_type(0);
|
const element::Type parentPrecision = parent->get_output_element_type(0);
|
||||||
if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
|
if (parentPrecision.bitwidth() < multiplyConstant->get_element_type().bitwidth()) {
|
||||||
THROW_IE_LPT_EXCEPTION(*parent) <<
|
THROW_IE_LPT_EXCEPTION(*parent) <<
|
||||||
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
||||||
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
|
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->get_element_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
parent = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
|
parent = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
|
||||||
|
@ -42,9 +42,14 @@ bool SqueezeTransformation::transform(TransformationContext& context, ngraph::pa
|
|||||||
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
||||||
const ngraph::PartialShape& inputShape) {
|
const ngraph::PartialShape& inputShape) {
|
||||||
const size_t inputRankValue = inputShape.rank().get_length();
|
const size_t inputRankValue = inputShape.rank().get_length();
|
||||||
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
|
const auto constantShape = dequantizationOpConstant->get_shape();
|
||||||
|
if (shape_size(constantShape) == 1ul) {
|
||||||
|
return NetworkHelper::toScalar(dequantizationOpConstant);
|
||||||
|
}
|
||||||
|
if (constantShape.size() == inputRankValue) {
|
||||||
return as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->get_input_node_shared_ptr(1)));
|
return as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->get_input_node_shared_ptr(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return dequantizationOpConstant;
|
return dequantizationOpConstant;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -42,12 +42,19 @@ bool UnsqueezeTransformation::transform(TransformationContext& context, ngraph::
|
|||||||
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
||||||
const ngraph::PartialShape& inputShape) {
|
const ngraph::PartialShape& inputShape) {
|
||||||
const size_t inputRankValue = inputShape.rank().get_length();
|
const size_t inputRankValue = inputShape.rank().get_length();
|
||||||
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
|
const auto constantShape = dequantizationOpConstant->get_shape();
|
||||||
|
if (shape_size(constantShape) == 1ul) {
|
||||||
|
return NetworkHelper::toScalar(dequantizationOpConstant);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (constantShape.size() == inputRankValue) {
|
||||||
return as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->get_input_node_shared_ptr(1)));
|
return as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->get_input_node_shared_ptr(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return dequantizationOpConstant;
|
return dequantizationOpConstant;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
const std::shared_ptr<Node> unsqueeze = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
|
const std::shared_ptr<Node> unsqueeze = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
|
||||||
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(unsqueeze);
|
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(unsqueeze);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user