[LPT] Some LP Transformations improvements (#6434)
* [LPT] LayerTransformation::canBeTransformed: replaced legacy code * [LPT] NetworkHelper::moveDequantizationAfter refactoring * [LPT] ReshapeTransformation improvement * [LPT] Squeeze/UnsqueezeTransformation improvement
This commit is contained in:
parent
0834ae2e6d
commit
9d9a24c914
@ -1560,14 +1560,14 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
if (updatePrecision) {
|
||||
op->set_overridden_output_type(newOperation->get_input_element_type(0));
|
||||
} else if (dequantization.multiply) {
|
||||
op->set_overridden_output_type(dequantization.multiply->get_input_element_type(1));
|
||||
op->set_overridden_output_type(dequantization.multiplyConstant->get_element_type());
|
||||
} else if (dequantization.subtract) {
|
||||
op->set_overridden_output_type(dequantization.subtract->get_input_element_type(1));
|
||||
op->set_overridden_output_type(dequantization.subtractConstant->get_element_type());
|
||||
}
|
||||
std::dynamic_pointer_cast<ngraph::Node>(newOperation)->validate_and_infer_types();
|
||||
}
|
||||
|
||||
const element::Type deqPrecision = dequantization.multiply->get_input_node_shared_ptr(1)->get_output_element_type(0);
|
||||
const element::Type deqPrecision = dequantization.multiplyConstant->get_element_type();
|
||||
const bool shouldConvert = (newOperation->get_output_element_type(0) != deqPrecision);
|
||||
|
||||
auto parent = newOperation;
|
||||
@ -1582,11 +1582,11 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
if (moveSubtract && (dequantization.subtract != nullptr)) {
|
||||
if (dequantization.subtractConvert == nullptr) {
|
||||
const element::Type parentPrecision = parent->get_output_element_type(0);
|
||||
if (parentPrecision.bitwidth() < dequantization.subtractConstant->output(0).get_element_type().bitwidth()) {
|
||||
if (parentPrecision.bitwidth() < dequantization.subtractConstant->get_element_type().bitwidth()) {
|
||||
THROW_IE_LPT_EXCEPTION(*parent) <<
|
||||
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
||||
", subtract dequantization constant " << dequantization.subtractConstant->get_friendly_name() << ":" <<
|
||||
dequantization.subtractConstant->output(0).get_element_type();
|
||||
dequantization.subtractConstant->get_element_type();
|
||||
}
|
||||
|
||||
parent = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
|
||||
@ -1604,12 +1604,12 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
}
|
||||
|
||||
if (dequantization.multiply != nullptr) {
|
||||
auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
|
||||
auto multiplyConstant = dequantization.multiplyConstant;
|
||||
const element::Type parentPrecision = parent->get_output_element_type(0);
|
||||
if (parentPrecision.bitwidth() < multiplyConstant->output(0).get_element_type().bitwidth()) {
|
||||
if (parentPrecision.bitwidth() < multiplyConstant->get_element_type().bitwidth()) {
|
||||
THROW_IE_LPT_EXCEPTION(*parent) <<
|
||||
"unexpected precisions: on data " << parent->get_friendly_name() << ":" << parentPrecision <<
|
||||
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->output(0).get_element_type();
|
||||
", multiply dequantization constant " << multiplyConstant->get_friendly_name() << ":" << multiplyConstant->get_element_type();
|
||||
}
|
||||
|
||||
parent = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
|
||||
|
@ -42,9 +42,14 @@ bool SqueezeTransformation::transform(TransformationContext& context, ngraph::pa
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
||||
const ngraph::PartialShape& inputShape) {
|
||||
const size_t inputRankValue = inputShape.rank().get_length();
|
||||
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
|
||||
const auto constantShape = dequantizationOpConstant->get_shape();
|
||||
if (shape_size(constantShape) == 1ul) {
|
||||
return NetworkHelper::toScalar(dequantizationOpConstant);
|
||||
}
|
||||
if (constantShape.size() == inputRankValue) {
|
||||
return as_type_ptr<opset1::Constant>(fold<opset1::Squeeze>(dequantizationOpConstant, squeeze->get_input_node_shared_ptr(1)));
|
||||
}
|
||||
|
||||
return dequantizationOpConstant;
|
||||
};
|
||||
|
||||
|
@ -42,12 +42,19 @@ bool UnsqueezeTransformation::transform(TransformationContext& context, ngraph::
|
||||
const std::shared_ptr<ngraph::opset1::Constant>& dequantizationOpConstant,
|
||||
const ngraph::PartialShape& inputShape) {
|
||||
const size_t inputRankValue = inputShape.rank().get_length();
|
||||
if (dequantizationOpConstant->get_shape().size() == inputRankValue) {
|
||||
const auto constantShape = dequantizationOpConstant->get_shape();
|
||||
if (shape_size(constantShape) == 1ul) {
|
||||
return NetworkHelper::toScalar(dequantizationOpConstant);
|
||||
}
|
||||
|
||||
if (constantShape.size() == inputRankValue) {
|
||||
return as_type_ptr<opset1::Constant>(fold<opset1::Unsqueeze>(dequantizationOpConstant, unsqueeze->get_input_node_shared_ptr(1)));
|
||||
}
|
||||
|
||||
return dequantizationOpConstant;
|
||||
};
|
||||
|
||||
|
||||
const std::shared_ptr<Node> unsqueeze = NetworkHelper::separateInStandaloneBranch(m.get_match_root());
|
||||
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(unsqueeze);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user