[CPU] Removed optimized out nodes from infer stage (#7440)

This commit is contained in:
Alexandra Sidorova 2021-09-24 09:31:41 +03:00 committed by GitHub
parent f8f6e57c39
commit ae3e3af521
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 38 additions and 12 deletions

View File

@ -344,7 +344,7 @@ void MKLDNNGraph::InitGraph() {
graphNode->cleanup(); graphNode->cleanup();
} }
#endif #endif
ExtractConstantNodes(); ExtractConstantAndExecutableNodes();
ExecuteConstantNodesOnly(); ExecuteConstantNodesOnly();
} }
@ -389,13 +389,13 @@ void MKLDNNGraph::InitOptimalPrimitiveDescriptors() {
} }
} }
void MKLDNNGraph::ExtractConstantNodes() { void MKLDNNGraph::ExtractConstantAndExecutableNodes() {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExtractConstantNodes"); OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExtractConstantAndExecutableNodes");
for (auto& graphNode : graphNodes) { for (const auto& graphNode : graphNodes) {
if (graphNode->isConstant()) if (graphNode->isConstant())
constantGraphNodes.emplace_back(graphNode); constantGraphNodes.emplace_back(graphNode);
else else if (graphNode->isExecutable())
mutableGraphNodes.emplace_back(graphNode); executableGraphNodes.emplace_back(graphNode);
} }
} }
@ -827,7 +827,7 @@ void MKLDNNGraph::Infer(MKLDNNInferRequest* request, int batch) {
mkldnn::stream stream(eng); mkldnn::stream stream(eng);
for (const auto& node : mutableGraphNodes) { for (const auto& node : executableGraphNodes) {
PERF(config.collectPerfCounters, node); PERF(config.collectPerfCounters, node);
if (request) if (request)
request->ThrowIfCanceled(); request->ThrowIfCanceled();

View File

@ -235,7 +235,7 @@ protected:
void Allocate(); void Allocate();
void AllocateWithReuse(); void AllocateWithReuse();
void CreatePrimitives(); void CreatePrimitives();
void ExtractConstantNodes(); void ExtractConstantAndExecutableNodes();
void ExecuteNode(const MKLDNNNodePtr& node, const mkldnn::stream& stream) const; void ExecuteNode(const MKLDNNNodePtr& node, const mkldnn::stream& stream) const;
void ExecuteConstantNodesOnly() const; void ExecuteConstantNodesOnly() const;
@ -247,10 +247,12 @@ private:
// TODO: change std::map to std::unordered_map // TODO: change std::map to std::unordered_map
std::map<std::string, MKLDNNNodePtr> inputNodesMap; std::map<std::string, MKLDNNNodePtr> inputNodesMap;
std::map<std::string, MKLDNNNodePtr> outputNodesMap; std::map<std::string, MKLDNNNodePtr> outputNodesMap;
// these node pointers (from graphNodes) are to avoid regular checking for // these node pointers (from graphNodes) are to avoid regular checking for
// constant node in ExecuteConstantNodesOnly and Infer methods // constantness of nodes in ExecuteConstantNodesOnly, Infer methods and calls of
// non-executable (optimized out) nodes, such as Input, Reshape, etc.
std::vector<MKLDNNNodePtr> constantGraphNodes; std::vector<MKLDNNNodePtr> constantGraphNodes;
std::vector<MKLDNNNodePtr> mutableGraphNodes; std::vector<MKLDNNNodePtr> executableGraphNodes;
void EnforceBF16(); void EnforceBF16();
}; };

View File

@ -195,6 +195,11 @@ public:
return engine; return engine;
} }
// must be called only after MKLDNNGraph::InitEdges()
virtual bool isExecutable() const {
return true;
}
bool isConstant(); bool isConstant();
bool isInplace() const; bool isInplace() const;

View File

@ -27,6 +27,9 @@ public:
bool isOptimized() const; bool isOptimized() const;
InferenceEngine::Precision getRuntimePrecision() const override; InferenceEngine::Precision getRuntimePrecision() const override;
bool isExecutable() const override {
return !isOptimized();
}
private: private:
size_t axis = 0; size_t axis = 0;

View File

@ -26,6 +26,9 @@ public:
MKLDNNMemoryCPtr getMemoryPtr() const; MKLDNNMemoryCPtr getMemoryPtr() const;
void executeDynamicImpl(mkldnn::stream strm) override {} void executeDynamicImpl(mkldnn::stream strm) override {}
bool isExecutable() const override {
return false;
}
std::vector<VectorDims> shapeInfer() const override { std::vector<VectorDims> shapeInfer() const override {
return std::vector<VectorDims>(); return std::vector<VectorDims>();

View File

@ -89,6 +89,9 @@ public:
bool created() const override { bool created() const override {
return getType() == MemoryInput; return getType() == MemoryInput;
} }
bool isExecutable() const override {
return true;
}
void execute(mkldnn::stream strm) override; void execute(mkldnn::stream strm) override;
void createPrimitive() override; void createPrimitive() override;

View File

@ -25,6 +25,10 @@ public:
bool created() const override; bool created() const override;
const std::vector<impl_desc_type>& getPrimitivesPriority() override; const std::vector<impl_desc_type>& getPrimitivesPriority() override;
bool isExecutable() const override {
return !isOptimized;
}
void setDescs(const MemoryDesc& input, const MemoryDesc& output) { void setDescs(const MemoryDesc& input, const MemoryDesc& output) {
this->input = input.clone(); this->input = input.clone();
inputShapes.clear(); inputShapes.clear();

View File

@ -26,6 +26,9 @@ public:
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override; void createPrimitive() override;
bool created() const override; bool created() const override;
bool isExecutable() const override {
return false;
}
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
}; };

View File

@ -298,7 +298,7 @@ bool MKLDNNSplitNode::created() const {
return getType() == Split; return getType() == Split;
} }
bool MKLDNNSplitNode::isOptimized() { bool MKLDNNSplitNode::isOptimized() const {
return getSelectedPrimitiveDescriptor() && getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].inPlace >= 0; return getSelectedPrimitiveDescriptor() && getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].inPlace >= 0;
} }

View File

@ -22,10 +22,13 @@ public:
void execute(mkldnn::stream strm) override; void execute(mkldnn::stream strm) override;
bool created() const override; bool created() const override;
bool isOptimized(); bool isOptimized() const;
void initOptimalPrimitiveDescriptor() override; void initOptimalPrimitiveDescriptor() override;
void setDynamicBatchLim(int lim) override; void setDynamicBatchLim(int lim) override;
bool isExecutable() const override {
return !isOptimized();
}
private: private:
void prepareOptimizedParams(); void prepareOptimizedParams();