[CPU] Removed optimized out nodes from infer stage (#7440)
This commit is contained in:
parent
f8f6e57c39
commit
ae3e3af521
@ -344,7 +344,7 @@ void MKLDNNGraph::InitGraph() {
|
|||||||
graphNode->cleanup();
|
graphNode->cleanup();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
ExtractConstantNodes();
|
ExtractConstantAndExecutableNodes();
|
||||||
|
|
||||||
ExecuteConstantNodesOnly();
|
ExecuteConstantNodesOnly();
|
||||||
}
|
}
|
||||||
@ -389,13 +389,13 @@ void MKLDNNGraph::InitOptimalPrimitiveDescriptors() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNGraph::ExtractConstantNodes() {
|
void MKLDNNGraph::ExtractConstantAndExecutableNodes() {
|
||||||
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExtractConstantNodes");
|
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExtractConstantAndExecutableNodes");
|
||||||
for (auto& graphNode : graphNodes) {
|
for (const auto& graphNode : graphNodes) {
|
||||||
if (graphNode->isConstant())
|
if (graphNode->isConstant())
|
||||||
constantGraphNodes.emplace_back(graphNode);
|
constantGraphNodes.emplace_back(graphNode);
|
||||||
else
|
else if (graphNode->isExecutable())
|
||||||
mutableGraphNodes.emplace_back(graphNode);
|
executableGraphNodes.emplace_back(graphNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -827,7 +827,7 @@ void MKLDNNGraph::Infer(MKLDNNInferRequest* request, int batch) {
|
|||||||
|
|
||||||
mkldnn::stream stream(eng);
|
mkldnn::stream stream(eng);
|
||||||
|
|
||||||
for (const auto& node : mutableGraphNodes) {
|
for (const auto& node : executableGraphNodes) {
|
||||||
PERF(config.collectPerfCounters, node);
|
PERF(config.collectPerfCounters, node);
|
||||||
if (request)
|
if (request)
|
||||||
request->ThrowIfCanceled();
|
request->ThrowIfCanceled();
|
||||||
|
@ -235,7 +235,7 @@ protected:
|
|||||||
void Allocate();
|
void Allocate();
|
||||||
void AllocateWithReuse();
|
void AllocateWithReuse();
|
||||||
void CreatePrimitives();
|
void CreatePrimitives();
|
||||||
void ExtractConstantNodes();
|
void ExtractConstantAndExecutableNodes();
|
||||||
void ExecuteNode(const MKLDNNNodePtr& node, const mkldnn::stream& stream) const;
|
void ExecuteNode(const MKLDNNNodePtr& node, const mkldnn::stream& stream) const;
|
||||||
void ExecuteConstantNodesOnly() const;
|
void ExecuteConstantNodesOnly() const;
|
||||||
|
|
||||||
@ -247,10 +247,12 @@ private:
|
|||||||
// TODO: change std::map to std::unordered_map
|
// TODO: change std::map to std::unordered_map
|
||||||
std::map<std::string, MKLDNNNodePtr> inputNodesMap;
|
std::map<std::string, MKLDNNNodePtr> inputNodesMap;
|
||||||
std::map<std::string, MKLDNNNodePtr> outputNodesMap;
|
std::map<std::string, MKLDNNNodePtr> outputNodesMap;
|
||||||
|
|
||||||
// these node pointers (from graphNodes) are to avoid regular checking for
|
// these node pointers (from graphNodes) are to avoid regular checking for
|
||||||
// constant node in ExecuteConstantNodesOnly and Infer methods
|
// constantness of nodes in ExecuteConstantNodesOnly, Infer methods and calls of
|
||||||
|
// non-executable (optimized out) nodes, such as Input, Reshape, etc.
|
||||||
std::vector<MKLDNNNodePtr> constantGraphNodes;
|
std::vector<MKLDNNNodePtr> constantGraphNodes;
|
||||||
std::vector<MKLDNNNodePtr> mutableGraphNodes;
|
std::vector<MKLDNNNodePtr> executableGraphNodes;
|
||||||
|
|
||||||
void EnforceBF16();
|
void EnforceBF16();
|
||||||
};
|
};
|
||||||
|
@ -195,6 +195,11 @@ public:
|
|||||||
return engine;
|
return engine;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// must be called only after MKLDNNGraph::InitEdges()
|
||||||
|
virtual bool isExecutable() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool isConstant();
|
bool isConstant();
|
||||||
|
|
||||||
bool isInplace() const;
|
bool isInplace() const;
|
||||||
|
@ -27,6 +27,9 @@ public:
|
|||||||
bool isOptimized() const;
|
bool isOptimized() const;
|
||||||
|
|
||||||
InferenceEngine::Precision getRuntimePrecision() const override;
|
InferenceEngine::Precision getRuntimePrecision() const override;
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return !isOptimized();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t axis = 0;
|
size_t axis = 0;
|
||||||
|
@ -26,6 +26,9 @@ public:
|
|||||||
MKLDNNMemoryCPtr getMemoryPtr() const;
|
MKLDNNMemoryCPtr getMemoryPtr() const;
|
||||||
|
|
||||||
void executeDynamicImpl(mkldnn::stream strm) override {}
|
void executeDynamicImpl(mkldnn::stream strm) override {}
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<VectorDims> shapeInfer() const override {
|
std::vector<VectorDims> shapeInfer() const override {
|
||||||
return std::vector<VectorDims>();
|
return std::vector<VectorDims>();
|
||||||
|
@ -89,6 +89,9 @@ public:
|
|||||||
bool created() const override {
|
bool created() const override {
|
||||||
return getType() == MemoryInput;
|
return getType() == MemoryInput;
|
||||||
}
|
}
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
void execute(mkldnn::stream strm) override;
|
void execute(mkldnn::stream strm) override;
|
||||||
|
|
||||||
void createPrimitive() override;
|
void createPrimitive() override;
|
||||||
|
@ -25,6 +25,10 @@ public:
|
|||||||
bool created() const override;
|
bool created() const override;
|
||||||
const std::vector<impl_desc_type>& getPrimitivesPriority() override;
|
const std::vector<impl_desc_type>& getPrimitivesPriority() override;
|
||||||
|
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return !isOptimized;
|
||||||
|
}
|
||||||
|
|
||||||
void setDescs(const MemoryDesc& input, const MemoryDesc& output) {
|
void setDescs(const MemoryDesc& input, const MemoryDesc& output) {
|
||||||
this->input = input.clone();
|
this->input = input.clone();
|
||||||
inputShapes.clear();
|
inputShapes.clear();
|
||||||
|
@ -26,6 +26,9 @@ public:
|
|||||||
void initSupportedPrimitiveDescriptors() override;
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
void createPrimitive() override;
|
void createPrimitive() override;
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||||
};
|
};
|
||||||
|
@ -298,7 +298,7 @@ bool MKLDNNSplitNode::created() const {
|
|||||||
return getType() == Split;
|
return getType() == Split;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MKLDNNSplitNode::isOptimized() {
|
bool MKLDNNSplitNode::isOptimized() const {
|
||||||
return getSelectedPrimitiveDescriptor() && getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].inPlace >= 0;
|
return getSelectedPrimitiveDescriptor() && getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].inPlace >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,10 +22,13 @@ public:
|
|||||||
void execute(mkldnn::stream strm) override;
|
void execute(mkldnn::stream strm) override;
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
|
|
||||||
bool isOptimized();
|
bool isOptimized() const;
|
||||||
void initOptimalPrimitiveDescriptor() override;
|
void initOptimalPrimitiveDescriptor() override;
|
||||||
|
|
||||||
void setDynamicBatchLim(int lim) override;
|
void setDynamicBatchLim(int lim) override;
|
||||||
|
bool isExecutable() const override {
|
||||||
|
return !isOptimized();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void prepareOptimizedParams();
|
void prepareOptimizedParams();
|
||||||
|
Loading…
Reference in New Issue
Block a user