[CPU] Optimize fused FQ/Eltwise nodes to use less binary postOps (#13502)

This commit is contained in:
Tingqian Li
2022-11-16 05:26:02 +08:00
committed by GitHub
parent 8d3b382ca9
commit 80971ae8e8
31 changed files with 1507 additions and 497 deletions

View File

@@ -35,7 +35,7 @@ public:
bool isLowPrecision() const;
std::shared_ptr<Node> copyWithNewInput(const std::shared_ptr<Node>& input) const;
static bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise);
bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise) const;
static bool checkShape(const std::shared_ptr<ngraph::Node>& elementwise);
@@ -48,6 +48,7 @@ public:
const std::shared_ptr<ngraph::Node>& elementwise,
std::shared_ptr<ngraph::opset1::Constant>& constant);
size_t channelDimIndex;
Output<Node> data;
std::shared_ptr<opset1::Convert> convert;
std::shared_ptr<opset1::Subtract> subtract;

View File

@@ -43,8 +43,8 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(operation, defaultPrecisions, 0ul);
FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(operation, defaultPrecisions, 1ul);
if ((dequantization1.empty() || ((dequantization1.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization1.multiply))) &&
(dequantization2.empty() || ((dequantization2.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization2.multiply)))) {
if ((dequantization1.empty() || ((dequantization1.multiply != nullptr) && !dequantization1.checkElementwise(dequantization1.multiply))) &&
(dequantization2.empty() || ((dequantization2.multiply != nullptr) && !dequantization2.checkElementwise(dequantization2.multiply)))) {
return false;
}

View File

@@ -32,6 +32,20 @@ FakeQuantizeDequantization::FakeQuantizeDequantization(
subtractConstant(subtractConstant),
multiply(multiply),
multiplyConstant(multiplyConstant) {
// for most node with layout NC, NCHW, NCDWH, index of channel dimension is 1
channelDimIndex = 1ul;
const auto rank = data.get_partial_shape().rank();
if (rank.is_static()) {
std::string data_src_type = data.get_node()->get_type_name();
if (data_src_type == "MatMul" && data.get_index() == 0) {
// for MatMul, index of channel dimension is the last one
channelDimIndex = static_cast<size_t>(rank.get_length()) - 1;
} else if (rank.get_length() == 1) {
// special 1D case: C
channelDimIndex = 0ul;
}
}
}
bool FakeQuantizeDequantization::empty() const noexcept {
@@ -100,7 +114,8 @@ bool FakeQuantizeDequantization::checkShape(const std::shared_ptr<ngraph::Node>&
return true;
}
bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::Node>& dequantizationElementwise) {
// check if elementwise operation inside dequantization subgraph satisfy per-tensor/per-OC broadcast requirement
bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::Node>& dequantizationElementwise) const {
std::shared_ptr<ngraph::opset1::Convert> convert;
std::shared_ptr<ngraph::opset1::Constant> constant;
FakeQuantizeDequantization::fillDequantizationParams(dequantizationElementwise, convert, constant);
@@ -114,6 +129,7 @@ bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::
return false;
}
// scalar-like const tensor is broadcastable to any shape of data
if (ngraph::shape_size(constShape) == 1) {
return true;
}
@@ -123,40 +139,51 @@ bool FakeQuantizeDequantization::checkElementwise(const std::shared_ptr<ngraph::
return false;
}
const auto channelsDimension = partialShape[partialShape.size() > 1ul ? 1ul : 0ul];
auto dimc = channelDimIndex;
const auto channelsDimension = partialShape[dimc];
if (channelsDimension.is_dynamic()) {
return false;
}
const size_t channelsShapeVal = channelsDimension.get_length();
// special case: 1D const tensor is considered to be per-channel w/o comparing actual shapes using broadcast rules
// as long as the number of elements matches channel dimension.
if (constShape.size() == 1ul) {
return constShape[0] == channelsShapeVal;
}
auto checkConstShape = [&constShape, &channelsShapeVal] (const size_t chDimIdx) {
for (size_t i = 0ul; i < constShape.size(); ++i) {
auto curDim = constShape[i];
if (curDim == 1ul)
continue;
if (i == chDimIdx && curDim == channelsShapeVal)
continue;
return false;
}
return true;
};
const size_t rank = partialShape.rank().get_length();
if (constShape.size() == rank) {
if ((constShape[0] != 1ul) || (constShape[1] != channelsShapeVal)) {
// special case: ND const tensor with N matches data rank
// element-wise comparing works under any broadcast rules.
return checkConstShape(dimc);
} else if (constShape.size() < rank) {
// rank mismatch, we have to apply broadcast rules to align dimensions, all dequantization nodes are constructed
// by LPT itself, thus should has default NUMPY type
if (dequantizationElementwise->get_autob() != ov::op::AutoBroadcastType::NUMPY)
return false;
}
for (size_t i = 2ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else if (constShape.size() == (rank - 1)) {
if (constShape[0] != channelsShapeVal) {
// the prepended dimensions are all 1 and can be skipped;
// derive index of channel dimension in const tensor after right aligned
if (dimc < rank - constShape.size())
return false;
}
for (size_t i = 1ul; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
return false;
}
}
} else {
return false;
return checkConstShape(dimc - (rank - constShape.size()));
}
return true;
return false;
}
std::shared_ptr<Node> FakeQuantizeDequantization::copyWithNewInput(const std::shared_ptr<Node>& input) const {

View File

@@ -62,7 +62,7 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
const auto dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
if (!dequantization.empty()) {
auto perChannelQuantization = [](const PartialShape dataPShape, Shape constShape) {
auto perChannelQuantization = [](const PartialShape dataPShape, Shape constShape, size_t idxChannelDim) {
if (ngraph::shape_size(constShape) == 1ul) {
return true;
}
@@ -77,12 +77,16 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
constShape.insert(constShape.begin(), 1ul);
}
// special case: 1D const is assumed to imply per-channel
if (constShape.size() == 1)
return true;
if ((constShape.size() >= 2ul) && (constShape[0] != 1ul)) {
return false;
}
for (size_t i = 2; i < constShape.size(); ++i) {
if (constShape[i] != 1ul) {
for (size_t i = 0; i < constShape.size(); ++i) {
if ((constShape[i] != 1ul) && (i != idxChannelDim)) {
return false;
}
}
@@ -91,13 +95,15 @@ bool LayerTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& la
if ((dequantization.subtract != nullptr) && (!perChannelQuantization(
dequantization.subtract->get_output_partial_shape(0),
dequantization.subtractConstant->get_shape()))) {
dequantization.subtractConstant->get_shape(),
dequantization.channelDimIndex))) {
return false;
}
if ((dequantization.multiply != nullptr) && (!perChannelQuantization(
dequantization.multiply->get_output_partial_shape(0),
dequantization.multiplyConstant->get_shape()))) {
dequantization.multiplyConstant->get_shape(),
dequantization.channelDimIndex))) {
return false;
}
}

View File

@@ -82,28 +82,40 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
if (ov::is_type<op::MVN>(mvn)) {
reduction_axes = ov::as_type_ptr<op::MVN>(mvn)->get_reduction_axes();
} else {
reduction_axes = ov::as_type_ptr<opset1::Constant>(mvn->get_input_node_shared_ptr(1))->get_axis_set_val();
}
// MVN-6 allows negative values in reduction axes: [-r, r-1]
// given static rank of input data of MVN node, we can recover the exact axis number
auto axis_set = ov::as_type_ptr<opset1::Constant>(mvn->get_input_node_shared_ptr(1))->cast_vector<int64_t>();
if (reduction_axes.count(1) == 0) {
return true;
}
Dimension::value_type ndims = 0;
if (std::any_of(axis_set.begin(), axis_set.end(), [](int64_t v) { return v < 0; })) {
const auto rank = mvn->get_input_partial_shape(0).rank();
// we need ndims to deduce exact axis if there are negative values
if (rank.is_dynamic()) {
return false;
}
ndims = rank.get_length();
}
bool perTensor = true;
const auto rank = mvn->get_input_partial_shape(0).rank();
if (rank.is_dynamic()) {
return false;
}
for (int i = 2; i < rank.get_length(); ++i) {
if (reduction_axes.count(i) == 0) {
perTensor = false;
break;
for (auto& axis : axis_set) {
reduction_axes.insert(axis >= 0 ? axis : axis + ndims);
}
}
bool isScalarScales = NetworkHelper::isScalarLike(dequantization.multiplyConstant);
return perTensor && isScalarScales;
// scale-only-dequantization maybe per-channel or per-tensor
// and it can be pushed-through MVN node only if there is one consistent
// scale applied within a single normalization slice.
// per-tensor scale-only-dequantization can always satisfy that
if (NetworkHelper::isScalarLike(dequantization.multiplyConstant)) {
return true;
}
// per-channel scale-only-dequantization can be pushed through MVN only
// if the channel dimension is not among the reduction_axes (so a single
// scale is applied to the whole normalization slice)
if (reduction_axes.count(dequantization.channelDimIndex) == 0)
return true;
return false;
}
bool MVNTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) {

View File

@@ -71,7 +71,7 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
return false;
}
if ((dequantization.multiply != nullptr) && !FakeQuantizeDequantization::checkElementwise(dequantization.multiply)) {
if ((dequantization.multiply != nullptr) && !dequantization.checkElementwise(dequantization.multiply)) {
return false;
}