diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.cpp index 79c2b045d2b..e03d9255b9a 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.cpp @@ -8,14 +8,10 @@ MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& inferRequest, const InferenceEngine::ITaskExecutor::Ptr& taskExecutor, const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor) - : InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) { + : InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) { static_cast(inferRequest.get())->SetAsyncRequest(this); } -void MKLDNNPlugin::MKLDNNAsyncInferRequest::Infer_ThreadUnsafe() { - InferUsingAsync(); -} - MKLDNNPlugin::MKLDNNAsyncInferRequest::~MKLDNNAsyncInferRequest() { StopAndWait(); } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.h b/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.h index 386c53dea0f..b06210627a9 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_async_infer_request.h @@ -16,9 +16,6 @@ public: MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest, const InferenceEngine::ITaskExecutor::Ptr &taskExecutor, const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor); - - void Infer_ThreadUnsafe() override; - ~MKLDNNAsyncInferRequest() override; }; diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp index efa8170f98e..72ba9a1eceb 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp @@ -42,7 +42,8 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network, InferenceEngine::ExecutableNetworkThreadSafeDefault{nullptr, nullptr}, extensionManager(extMgr), _cfg{cfg}, - _name{network.getName()} { + _name{network.getName()}, + _numaNodesWeights(numaNodesWeights) { OV_ITT_TASK_CHAIN(taskChain, MKLDNNPlugin::itt::domains::MKLDNN_LT, "MKLDNNExecNetwork", "cloneNet"); // we are cloning network if we have statistics and we can transform network. @@ -239,33 +240,25 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network, _callbackExecutor = _taskExecutor; } - _graphs = decltype(_graphs) {[&] { - // TODO: Remove `cloneNet` to `localNetwork` when `MKLDNNGraph::CreateGraph` - // is fixed and does not change content of network passed (CVS-26420) - auto localNetwork = cloneNetwork(_clonedNetwork); - - auto graph = std::make_shared(); - { - std::unique_lock lock{_cfgMutex}; - graph->setConfig(_cfg); + int streams = std::max(1, _cfg.streamExecutorConfig._streams); + std::vector tasks; tasks.resize(streams); + _graphs.resize(streams); + if (_cfg.streamExecutorConfig._streams != 0) { + for (auto&& task : tasks) { + task = [this] { + MKLDNNExecNetwork::GetGraph(); + }; } - int numaNode = 0; - auto* streamExecutor = dynamic_cast(_taskExecutor.get()); - if (nullptr != streamExecutor) { - numaNode = streamExecutor->GetNumaNodeId(); - } - - graph->CreateGraph(localNetwork, extensionManager, numaNodesWeights[numaNode]); - return graph; - }}; - - _taskExecutor->runAndWait({std::thread::hardware_concurrency(), [this] {_graphs.local();}}); + _taskExecutor->runAndWait(tasks); + } else { + MKLDNNExecNetwork::GetGraph(); + } // Save all MemoryLayer data tensors. Will use insight about mechanics // of MemoryLayer implementation. It uses output edge of MemoryLayer // producer as storage for tensor to keep it between infer calls. if (_graphs.size() == 1) { - for (auto &node : _graphs.begin()->get()->GetNodes()) { + for (auto &node : GetGraph()._graph.GetNodes()) { if (node->getType() == MemoryInput) { auto memoryNode = dynamic_cast(node.get()); auto state_store = memoryNode->getStore(); @@ -282,13 +275,51 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network, } } +MKLDNNExecNetwork::Graph::Lock MKLDNNExecNetwork::GetGraph() { + int streamId = 0; + int numaNodeId = 0; + auto streamsExecutor = dynamic_cast(_taskExecutor.get()); + if (nullptr != streamsExecutor) { + streamId = streamsExecutor->GetStreamId(); + numaNodeId = streamsExecutor->GetNumaNodeId(); + } + auto graphLock = Graph::Lock(_graphs[streamId % _graphs.size()]); + if (!graphLock._graph.IsReady()) { + std::exception_ptr exception; + auto makeGraph = [&] { + try { + auto localNetwork = cloneNetwork(_clonedNetwork); + { + std::lock_guard lock{_cfgMutex}; + graphLock._graph.setConfig(_cfg); + } + graphLock._graph.CreateGraph(localNetwork, extensionManager, _numaNodesWeights[numaNodeId]); + } catch(...) { + exception = std::current_exception(); + } + }; + if (nullptr != streamsExecutor) { + streamsExecutor->Execute(makeGraph); + } else { + makeGraph(); + } + if (exception) { + std::rethrow_exception(exception); + } + } + return graphLock; +} + void MKLDNNExecNetwork::setProperty(const std::map &properties) { { std::lock_guard lock{_cfgMutex}; _cfg.readProperties(properties); } - for (auto g : _graphs) { - g->setProperty(properties); + for (auto& g : _graphs) { + auto graphLock = Graph::Lock(g); + if (graphLock._graph.IsReady()) { + graphLock._graph.setProperty(properties); + } } } @@ -300,16 +331,16 @@ InferenceEngine::CNNNetwork MKLDNNExecNetwork::GetExecGraphInfo() { if (_graphs.size() == 0) THROW_IE_EXCEPTION << "No graph was found"; - return _graphs.begin()->get()->dump(); + return GetGraph()._graph.dump(); } Parameter MKLDNNExecNetwork::GetConfig(const std::string &name) const { if (_graphs.size() == 0) THROW_IE_EXCEPTION << "No graph was found"; - Config engConfig = _graphs.begin()->get()->getProperty(); - auto it = engConfig._config.find(name); - if (it != engConfig._config.end()) { - return it->second; + Config engConfig = const_cast(this)->GetGraph()._graph.getProperty(); + auto option = engConfig._config.find(name); + if (option != engConfig._config.end()) { + return option->second; } else { THROW_IE_EXCEPTION << "Unsupported ExecutableNetwork config key: " << name; } @@ -320,7 +351,8 @@ InferenceEngine::Parameter MKLDNNExecNetwork::GetMetric(const std::string &name) THROW_IE_EXCEPTION << "No graph was found"; if (name == METRIC_KEY(NETWORK_NAME)) { - IE_SET_METRIC_RETURN(NETWORK_NAME, _graphs.begin()->get()->GetName()); + IE_SET_METRIC_RETURN(NETWORK_NAME, + const_cast(this)->GetGraph()._graph.dump().getName()); } else if (name == METRIC_KEY(SUPPORTED_METRICS)) { std::vector metrics; metrics.push_back(METRIC_KEY(NETWORK_NAME)); @@ -330,12 +362,12 @@ InferenceEngine::Parameter MKLDNNExecNetwork::GetMetric(const std::string &name) IE_SET_METRIC_RETURN(SUPPORTED_METRICS, metrics); } else if (name == METRIC_KEY(SUPPORTED_CONFIG_KEYS)) { std::vector configKeys; - for (auto && key : _graphs.begin()->get()->getProperty()._config) { + for (auto && key : const_cast(this)->GetGraph()._graph.getProperty()._config) { configKeys.push_back(key.first); } IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, configKeys); } else if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) { - Config engConfig = _graphs.begin()->get()->getProperty(); + Config engConfig = const_cast(this)->GetGraph()._graph.getProperty(); auto option = engConfig._config.find(CONFIG_KEY(CPU_THROUGHPUT_STREAMS)); IE_ASSERT(option != engConfig._config.end()); auto streams = std::stoi(option->second); diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h index 6066aee4505..bef52bcf218 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h @@ -45,8 +45,6 @@ public: INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") std::vector QueryState() override; - InferenceEngine::ThreadLocal _graphs; - protected: friend class MKLDNNInferRequest; MKLDNNExtensionManager::Ptr extensionManager; @@ -56,7 +54,22 @@ protected: Config _cfg; std::atomic_int _numRequests = {0}; std::string _name; + struct Graph : public MKLDNNGraph { + std::mutex _mutex; + struct Lock : public std::unique_lock { + explicit Lock(Graph& graph) : std::unique_lock(graph._mutex), _graph(graph) {} + Graph& _graph; + }; + }; + // WARNING: Do not use _graphs directly. + std::deque _graphs; + NumaNodesWeights& _numaNodesWeights; + /* WARNING: Use GetGraph() function to get access to graph in current stream. + * NOTE: Main thread is interpreted as master thread of external stream so use this function to get access to graphs + * even from main thread + */ + Graph::Lock GetGraph(); bool CanProcessDynBatch(const InferenceEngine::CNNNetwork &network) const; }; diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp index df071706ce4..eb8ba8f6987 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp @@ -992,7 +992,7 @@ void MKLDNNGraph::setProperty(const std::map& properti config.readProperties(properties); } -Config MKLDNNGraph::getProperty() { +Config MKLDNNGraph::getProperty() const { return config; } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph.h b/inference-engine/src/mkldnn_plugin/mkldnn_graph.h index 7834d08a7d3..f47a643a288 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph.h @@ -42,7 +42,7 @@ public: void setConfig(const Config &cfg); void setProperty(const std::map &properties); - Config getProperty(); + Config getProperty() const; void getInputBlobs(InferenceEngine::BlobMap &in_map); void getOutputBlobs(InferenceEngine::BlobMap &out_map); diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp index c5892675597..4f1c14b311f 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp @@ -30,7 +30,7 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData if (execNetwork->_graphs.size() == 0) THROW_IE_EXCEPTION << "No graph was found"; - graph = execNetwork->_graphs.begin()->get(); + graph = &(execNetwork->GetGraph()._graph); for (const auto& it : _networkInputs) { MKLDNNInferRequest::GetBlob(it.first); } @@ -182,8 +182,8 @@ void MKLDNNPlugin::MKLDNNInferRequest::PullStates() { void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() { using namespace openvino::itt; OV_ITT_SCOPED_TASK(itt::domains::MKLDNNPlugin, profilingTask); - - graph = execNetwork->_graphs.local().get(); + auto graphLock = execNetwork->GetGraph(); + graph = &(graphLock._graph); ThrowIfCanceled(); diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_leaks_test.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_leaks_test.cpp index a45d7e8751e..9a8b136be51 100644 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_leaks_test.cpp +++ b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_leaks_test.cpp @@ -16,7 +16,7 @@ using namespace mkldnn; class MKLDNNTestExecNetwork: public MKLDNNPlugin::MKLDNNExecNetwork { public: MKLDNNPlugin::MKLDNNGraph& getGraph() { - return *(_graphs.begin()->get()); + return _graphs.front(); } };