[CPU] Refactoring. Avoid using align arg when appending post ops (#9225)

Always align legacy scale shift post ops
This commit is contained in:
Egor Duplensky 2021-12-20 10:23:32 +03:00 committed by GitHub
parent fab4448ebd
commit abee3ea4d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 21 additions and 27 deletions

View File

@ -1102,7 +1102,7 @@ Layout MKLDNNNode::getWeightsLayoutByDims(SizeVector dims, bool isGrouped) {
}
}
void MKLDNNNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, int align) {
void MKLDNNNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims) {
IE_THROW() << "Fusing of " << this->getType() << " operation is not implemented";
}

View File

@ -602,7 +602,7 @@ protected:
* Seed node should call this routine and pass its post operations list as parameter.
* @param ops List of fused post operations
*/
virtual void appendPostOps(mkldnn::post_ops& ops, const VectorDims& postOpDims, int align = -1);
virtual void appendPostOps(mkldnn::post_ops& ops, const VectorDims& postOpDims);
virtual void appendBinPostOps(mkldnn::post_ops& ops, const VectorDims& postOpDims, std::vector<MKLDNNMemoryPtr>& binaryPostOpsMem);
virtual std::shared_ptr<mkldnn::primitive_attr> initPrimitiveAttr() { return nullptr; }

View File

@ -1132,8 +1132,7 @@ void MKLDNNBinaryConvolutionNode::setPostOps(mkldnn::primitive_attr &attr) {
ops.append_sum(1.0);
} else {
// TODO [DS]: change to shape from memory
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, getOutputShapeAtPort(0).getStaticDims(), align);
eltwiseNode->appendPostOps(ops, getOutputShapeAtPort(0).getStaticDims());
}
continue;
}

View File

@ -352,8 +352,7 @@ void MKLDNNConvolutionNode::setPostOps(mkldnn::primitive_attr &attr, const Vecto
ops.append_sum(1.0, MKLDNNExtensionUtils::IEPrecisionToDataType(eltwisePrecision));
} else {
if (useLegacyPostOps || eltwiseNode->getMKLDNNAlgorithm() != mkldnn::algorithm::undef) {
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, dims, align);
eltwiseNode->appendPostOps(ops, dims);
} else {
eltwiseNode->appendBinPostOps(ops, getBinPostOpShape(), binaryPostOpsArgs);
}

View File

@ -365,9 +365,8 @@ void MKLDNNDeconvolutionNode::setPostOps(mkldnn::primitive_attr &attr, const Vec
for (auto &node : fusedWith) {
if (auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get())) {
// TODO [DS]: change to shape from memory
constexpr int align = 16;
// use legacy depthwise since backprop convolution does not support binary post ops
eltwiseNode->appendPostOps(ops, dims, align);
eltwiseNode->appendPostOps(ops, dims);
continue;
}
if (auto* fakeQuantizeNode = dynamic_cast<MKLDNNFakeQuantizeNode *>(node.get())) {

View File

@ -1744,7 +1744,7 @@ void MKLDNNEltwiseNode::fuseInto(MKLDNNNodePtr& parentNode) {
MKLDNNNode::fuseInto(parentNode);
}
void MKLDNNEltwiseNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, int align) {
void MKLDNNEltwiseNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims) {
const std::string errorPrefix = "Appending Eltwise node with name '" + getName() + "' ";
if (getMKLDNNAlgorithm() != mkldnn::algorithm::undef) {
@ -1775,11 +1775,11 @@ void MKLDNNEltwiseNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &p
}
} else {
const size_t chIdx = postOpDims.size() > 1 ? getFusingAxis() : 0;
constexpr int align = 16; // always align for legacy scale/shift post ops
scalesBuffer = makeAlignedBuffer(postOpDims[chIdx], scales, align);
if (getAlgorithm() != EltwisePrelu) {
shiftsBuffer = makeAlignedBuffer(postOpDims[chIdx], shifts, align);
}
/* @todo legacy depthwise post ops are kept for now
* for performance reasons
*/

View File

@ -75,7 +75,7 @@ public:
bool created() const override;
bool canBeInPlace() const override;
bool canFuse(const MKLDNNNodePtr& node) const override;
void appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, int align = -1) override;
void appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims) override;
void appendBinPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, std::vector<MKLDNNMemoryPtr>& binaryPostOpsMem) override;
void fuseInto(MKLDNNNodePtr& parentNode) override;
InferenceEngine::Precision getRuntimePrecision() const override;

View File

@ -1706,8 +1706,13 @@ void MKLDNNFakeQuantizeNode::initializePostOpData(const VectorDims &dims, const
isPostOpDataInitialized = true;
}
void MKLDNNFakeQuantizeNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, int align) {
initializePostOpData(postOpDims, align);
void MKLDNNFakeQuantizeNode::appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims) {
// MKLDNN quantization_injectors assumes that quantization data memory is always aligned on 16
// by length of AVX512 vector register which is also enough for AVX2 and SSE42 implementations.
// Otherwise it can lead to buffer over-read and performance penalties due to denormals.
const size_t bufferAlignment = 16;
initializePostOpData(postOpDims, bufferAlignment);
if (getAlgorithm() == FQBinarization) {
ops.append_binarization(mkldnn::algorithm::binarization_depthwise, (const float*)&binarizationThresholds[0], (const float*)&binarizationOutputMask[0]);

View File

@ -120,10 +120,7 @@ public:
InferenceEngine::Precision getInputPrecision() const { return inputPrecision; }
InferenceEngine::Precision getOutputPrecision() const { return outputPrecision; }
// MKLDNN quantization_injectors assumes that quantization data memory is always aligned on 16
// by length of AVX512 vector register which is also enough for AVX2 and SSE42 implementations.
// Otherwise it can lead to buffer over-read and performance penalties due to denormals.
void appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims = {}, int align = 16) override;
void appendPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims = {}) override;
void appendBinPostOps(mkldnn::post_ops& ops, const VectorDims &postOpDims, std::vector<MKLDNNMemoryPtr>& binaryPostOpsMem) override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;

View File

@ -198,9 +198,8 @@ void MKLDNNFullyConnectedNode::setPostOps(mkldnn::primitive_attr &attr, bool ini
if (auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get())) {
// TODO [DS]: change to shape from memory
constexpr int align = -1;
if (eltwiseNode->getMKLDNNAlgorithm() != mkldnn::algorithm::undef) {
eltwiseNode->appendPostOps(ops, getOutputShapeAtPort(0).getStaticDims(), align);
eltwiseNode->appendPostOps(ops, getOutputShapeAtPort(0).getStaticDims());
} else {
eltwiseNode->appendBinPostOps(ops, getBinPostOpShape(), binaryPostOpsArgs);
}

View File

@ -2102,8 +2102,7 @@ void MKLDNNInterpolateNode::setPostOps(mkldnn::primitive_attr &attr, const Vecto
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
if (eltwiseNode) {
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, dims, align);
eltwiseNode->appendPostOps(ops, dims);
continue;
}

View File

@ -891,8 +891,7 @@ void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
if (eltwiseNode) {
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, postOpDims, align);
eltwiseNode->appendPostOps(ops, postOpDims);
continue;
}
IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";

View File

@ -813,8 +813,7 @@ void MKLDNNNormalizeL2Node::setPostOps(mkldnn::primitive_attr& kernel_attrs, con
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
if (eltwiseNode) {
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, dims, align);
eltwiseNode->appendPostOps(ops, dims);
continue;
}

View File

@ -2779,8 +2779,7 @@ void MKLDNNReduceNode::setPostOps(mkldnn::primitive_attr &attr, const VectorDims
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
if (eltwiseNode) {
constexpr int align = 16;
eltwiseNode->appendPostOps(ops, postOpDims, align);
eltwiseNode->appendPostOps(ops, postOpDims);
continue;
}
IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";