[CPU] Optimize fused FQ/Eltwise nodes to use less binary postOps (#13502)
This commit is contained in:
@@ -35,7 +35,7 @@ public:
|
||||
bool isLowPrecision() const;
|
||||
std::shared_ptr<Node> copyWithNewInput(const std::shared_ptr<Node>& input) const;
|
||||
|
||||
static bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise);
|
||||
bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise) const;
|
||||
|
||||
static bool checkShape(const std::shared_ptr<ngraph::Node>& elementwise);
|
||||
|
||||
@@ -48,6 +48,7 @@ public:
|
||||
const std::shared_ptr<ngraph::Node>& elementwise,
|
||||
std::shared_ptr<ngraph::opset1::Constant>& constant);
|
||||
|
||||
size_t channelDimIndex;
|
||||
Output<Node> data;
|
||||
std::shared_ptr<opset1::Convert> convert;
|
||||
std::shared_ptr<opset1::Subtract> subtract;
|
||||
|
||||
@@ -43,8 +43,8 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
|
||||
|
||||
FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(operation, defaultPrecisions, 0ul);
|
||||
FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(operation, defaultPrecisions, 1ul);
|
||||
if ((dequantization1.empty() || ((dequantization1.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization1.multiply))) &&
|
||||
(dequantization2.empty() || ((dequantization2.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization2.multiply)))) {
|
||||
if ((dequantization1.empty() || ((dequantization1.multiply != nullptr) && !dequantization1.checkElementwise(dequantization1.multiply))) &&
|
||||
(dequantization2.empty() || ((dequantization2.multiply != nullptr) && !dequantization2.checkElementwise(dequantization2.multiply)))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,20 @@ FakeQuantizeDequantization::FakeQuantizeDequantization(
|
||||
subtractConstant(subtractConstant),
|
||||
multiply(multiply),
|
||||
multiplyConstant(multiplyConstant) {
|
||||
// for most node with layout NC, NCHW, NCDWH, index of channel dimension is 1
|
||||
channelDimIndex = 1ul;
|
||||
|
||||
const auto rank = data.get_partial_shape().rank();
|
||||
if (rank.is_static()) {
|
||||
std::string data_src_type = data.get_node()->get_type_name();
|
||||
if (data_src_type == "MatMul" && data.get_index() == 0) {
|
||||
// for MatMul, index of channel dimension is the last one
|
||||
channelDimIndex = static_cast<size_t>(rank.get_length()) - 1;
|
||||
} else if (rank.get_length() == 1) {
|
||||
// special 1D case: C
|
||||
channelDimIndex = 0ul;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool FakeQuantizeDequantization::empty() const noexcept {
|
||||
@@ -100,7 +114,8 @@ bool FakeQuantizeDequantization::checkShape(const std::shared_ptr<ngraph::Node>&
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::Node>& dequantizationElementwise) {
|
||||
// check if elementwise operation inside dequantization subgraph satisfy per-tensor/per-OC broadcast requirement
|
||||
bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::Node>& dequantizationElementwise) const {
|
||||
std::shared_ptr<ngraph::opset1::Convert> convert;
|
||||
std::shared_ptr<ngraph::opset1::Constant> constant;
|
||||
FakeQuantizeDequantization::fillDequantizationParams(dequantizationElementwise, convert, constant);
|
||||
@@ -114,6 +129,7 @@ bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::
|
||||
return false;
|
||||
}
|
||||
|
||||
// scalar-like const tensor is broadcastable to any shape of data
|
||||
if (ngraph::shape_size(constShape) == 1) {
|
||||
return true;
|
||||
}
|
||||
@@ -123,40 +139,51 @@ bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto channelsDimension = partialShape[partialShape.size() > 1ul ? 1ul : 0ul];
|
||||
auto dimc = channelDimIndex;
|
||||
|
||||
const auto channelsDimension = partialShape[dimc];
|
||||
if (channelsDimension.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t channelsShapeVal = channelsDimension.get_length();
|
||||
|
||||
// special case: 1D const tensor is considered to be per-channel w/o comparing actual shapes using broadcast rules
|
||||
// as long as the number of elements matches channel dimension.
|
||||
if (constShape.size() == 1ul) {
|
||||
return constShape[0] == channelsShapeVal;
|
||||
}
|
||||
|
||||
auto checkConstShape = [&constShape, &channelsShapeVal] (const size_t chDimIdx) {
|
||||
for (size_t i = 0ul; i < constShape.size(); ++i) {
|
||||
auto curDim = constShape[i];
|
||||
if (curDim == 1ul)
|
||||
continue;
|
||||
if (i == chDimIdx && curDim == channelsShapeVal)
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const size_t rank = partialShape.rank().get_length();
|
||||
if (constShape.size() == rank) {
|
||||
if ((constShape[0] != 1ul) || (constShape[1] != channelsShapeVal)) {
|
||||
// special case: ND const tensor with N matches data rank
|
||||
// element-wise comparing works under any broadcast rules.
|
||||
return checkConstShape(dimc);
|
||||
} else if (constShape.size() < rank) {
|
||||
// rank mismatch, we have to apply broadcast rules to align dimensions, all dequantization nodes are constructed
|
||||
// by LPT itself, thus should has default NUMPY type
|
||||
if (dequantizationElementwise->get_autob() != ov::op::AutoBroadcastType::NUMPY)
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 2ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (constShape.size() == (rank - 1)) {
|
||||
if (constShape[0] != channelsShapeVal) {
|
||||
// the prepended dimensions are all 1 and can be skipped;
|
||||
// derive index of channel dimension in const tensor after right aligned
|
||||
if (dimc < rank - constShape.size())
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 1ul; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
return checkConstShape(dimc - (rank - constShape.size()));
|
||||
}
|
||||
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> FakeQuantizeDequantization::copyWithNewInput(const std::shared_ptr<Node>& input) const {
|
||||
|
||||
@@ -62,7 +62,7 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
|
||||
|
||||
const auto dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
|
||||
if (!dequantization.empty()) {
|
||||
auto perChannelQuantization = [](const PartialShape dataPShape, Shape constShape) {
|
||||
auto perChannelQuantization = [](const PartialShape dataPShape, Shape constShape, size_t idxChannelDim) {
|
||||
if (ngraph::shape_size(constShape) == 1ul) {
|
||||
return true;
|
||||
}
|
||||
@@ -77,12 +77,16 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
|
||||
constShape.insert(constShape.begin(), 1ul);
|
||||
}
|
||||
|
||||
// special case: 1D const is assumed to imply per-channel
|
||||
if (constShape.size() == 1)
|
||||
return true;
|
||||
|
||||
if ((constShape.size() >= 2ul) && (constShape[0] != 1ul)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 2; i < constShape.size(); ++i) {
|
||||
if (constShape[i] != 1ul) {
|
||||
for (size_t i = 0; i < constShape.size(); ++i) {
|
||||
if ((constShape[i] != 1ul) && (i != idxChannelDim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -91,13 +95,15 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
|
||||
|
||||
if ((dequantization.subtract != nullptr) && (!perChannelQuantization(
|
||||
dequantization.subtract->get_output_partial_shape(0),
|
||||
dequantization.subtractConstant->get_shape()))) {
|
||||
dequantization.subtractConstant->get_shape(),
|
||||
dequantization.channelDimIndex))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((dequantization.multiply != nullptr) && (!perChannelQuantization(
|
||||
dequantization.multiply->get_output_partial_shape(0),
|
||||
dequantization.multiplyConstant->get_shape()))) {
|
||||
dequantization.multiplyConstant->get_shape(),
|
||||
dequantization.channelDimIndex))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,28 +82,40 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
|
||||
if (ov::is_type<op::MVN>(mvn)) {
|
||||
reduction_axes = ov::as_type_ptr<op::MVN>(mvn)->get_reduction_axes();
|
||||
} else {
|
||||
reduction_axes = ov::as_type_ptr<opset1::Constant>(mvn->get_input_node_shared_ptr(1))->get_axis_set_val();
|
||||
}
|
||||
// MVN-6 allows negative values in reduction axes: [-r, r-1]
|
||||
// given static rank of input data of MVN node, we can recover the exact axis number
|
||||
auto axis_set = ov::as_type_ptr<opset1::Constant>(mvn->get_input_node_shared_ptr(1))->cast_vector<int64_t>();
|
||||
|
||||
if (reduction_axes.count(1) == 0) {
|
||||
return true;
|
||||
}
|
||||
Dimension::value_type ndims = 0;
|
||||
if (std::any_of(axis_set.begin(), axis_set.end(), [](int64_t v) { return v < 0; })) {
|
||||
const auto rank = mvn->get_input_partial_shape(0).rank();
|
||||
// we need ndims to deduce exact axis if there are negative values
|
||||
if (rank.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
ndims = rank.get_length();
|
||||
}
|
||||
|
||||
bool perTensor = true;
|
||||
const auto rank = mvn->get_input_partial_shape(0).rank();
|
||||
if (rank.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 2; i < rank.get_length(); ++i) {
|
||||
if (reduction_axes.count(i) == 0) {
|
||||
perTensor = false;
|
||||
break;
|
||||
for (auto& axis : axis_set) {
|
||||
reduction_axes.insert(axis >= 0 ? axis : axis + ndims);
|
||||
}
|
||||
}
|
||||
|
||||
bool isScalarScales = NetworkHelper::isScalarLike(dequantization.multiplyConstant);
|
||||
return perTensor && isScalarScales;
|
||||
// scale-only-dequantization maybe per-channel or per-tensor
|
||||
// and it can be pushed-through MVN node only if there is one consistent
|
||||
// scale applied within a single normalization slice.
|
||||
// per-tensor scale-only-dequantization can always satisfy that
|
||||
if (NetworkHelper::isScalarLike(dequantization.multiplyConstant)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// per-channel scale-only-dequantization can be pushed through MVN only
|
||||
// if the channel dimension is not among the reduction_axes (so a single
|
||||
// scale is applied to the whole normalization slice)
|
||||
if (reduction_axes.count(dequantization.channelDimIndex) == 0)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MVNTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) {
|
||||
|
||||
@@ -71,7 +71,7 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((dequantization.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization.multiply)) {
|
||||
if ((dequantization.multiply != nullptr) && !dequantization.checkElementwise(dequantization.multiply)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user