Synchronous inference using stream executor affinity and parallelism constraints (#528)

* Syncronous inference using stream executor affinity and parallelizm constraints

* Fixed review coments
This commit is contained in:
Anton Pankratv
2021-03-09 20:32:14 +03:00
committed by GitHub
parent b7471be5fb
commit a5e2497788
8 changed files with 86 additions and 48 deletions

View File

@@ -8,14 +8,10 @@
MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& inferRequest,
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor)
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) {
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) {
static_cast<MKLDNNInferRequest*>(inferRequest.get())->SetAsyncRequest(this);
}
void MKLDNNPlugin::MKLDNNAsyncInferRequest::Infer_ThreadUnsafe() {
InferUsingAsync();
}
MKLDNNPlugin::MKLDNNAsyncInferRequest::~MKLDNNAsyncInferRequest() {
StopAndWait();
}

View File

@@ -16,9 +16,6 @@ public:
MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor);
void Infer_ThreadUnsafe() override;
~MKLDNNAsyncInferRequest() override;
};

View File

@@ -42,7 +42,8 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network,
InferenceEngine::ExecutableNetworkThreadSafeDefault{nullptr, nullptr},
extensionManager(extMgr),
_cfg{cfg},
_name{network.getName()} {
_name{network.getName()},
_numaNodesWeights(numaNodesWeights) {
OV_ITT_TASK_CHAIN(taskChain, MKLDNNPlugin::itt::domains::MKLDNN_LT, "MKLDNNExecNetwork", "cloneNet");
// we are cloning network if we have statistics and we can transform network.
@@ -239,33 +240,25 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network,
_callbackExecutor = _taskExecutor;
}
_graphs = decltype(_graphs) {[&] {
// TODO: Remove `cloneNet` to `localNetwork` when `MKLDNNGraph::CreateGraph`
// is fixed and does not change content of network passed (CVS-26420)
auto localNetwork = cloneNetwork(_clonedNetwork);
auto graph = std::make_shared<MKLDNNGraph>();
{
std::unique_lock<std::mutex> lock{_cfgMutex};
graph->setConfig(_cfg);
int streams = std::max(1, _cfg.streamExecutorConfig._streams);
std::vector<Task> tasks; tasks.resize(streams);
_graphs.resize(streams);
if (_cfg.streamExecutorConfig._streams != 0) {
for (auto&& task : tasks) {
task = [this] {
MKLDNNExecNetwork::GetGraph();
};
}
int numaNode = 0;
auto* streamExecutor = dynamic_cast<InferenceEngine::IStreamsExecutor*>(_taskExecutor.get());
if (nullptr != streamExecutor) {
numaNode = streamExecutor->GetNumaNodeId();
}
graph->CreateGraph(localNetwork, extensionManager, numaNodesWeights[numaNode]);
return graph;
}};
_taskExecutor->runAndWait({std::thread::hardware_concurrency(), [this] {_graphs.local();}});
_taskExecutor->runAndWait(tasks);
} else {
MKLDNNExecNetwork::GetGraph();
}
// Save all MemoryLayer data tensors. Will use insight about mechanics
// of MemoryLayer implementation. It uses output edge of MemoryLayer
// producer as storage for tensor to keep it between infer calls.
if (_graphs.size() == 1) {
for (auto &node : _graphs.begin()->get()->GetNodes()) {
for (auto &node : GetGraph()._graph.GetNodes()) {
if (node->getType() == MemoryInput) {
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
auto state_store = memoryNode->getStore();
@@ -282,13 +275,51 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network,
}
}
MKLDNNExecNetwork::Graph::Lock MKLDNNExecNetwork::GetGraph() {
int streamId = 0;
int numaNodeId = 0;
auto streamsExecutor = dynamic_cast<InferenceEngine::IStreamsExecutor*>(_taskExecutor.get());
if (nullptr != streamsExecutor) {
streamId = streamsExecutor->GetStreamId();
numaNodeId = streamsExecutor->GetNumaNodeId();
}
auto graphLock = Graph::Lock(_graphs[streamId % _graphs.size()]);
if (!graphLock._graph.IsReady()) {
std::exception_ptr exception;
auto makeGraph = [&] {
try {
auto localNetwork = cloneNetwork(_clonedNetwork);
{
std::lock_guard<std::mutex> lock{_cfgMutex};
graphLock._graph.setConfig(_cfg);
}
graphLock._graph.CreateGraph(localNetwork, extensionManager, _numaNodesWeights[numaNodeId]);
} catch(...) {
exception = std::current_exception();
}
};
if (nullptr != streamsExecutor) {
streamsExecutor->Execute(makeGraph);
} else {
makeGraph();
}
if (exception) {
std::rethrow_exception(exception);
}
}
return graphLock;
}
void MKLDNNExecNetwork::setProperty(const std::map<std::string, std::string> &properties) {
{
std::lock_guard<std::mutex> lock{_cfgMutex};
_cfg.readProperties(properties);
}
for (auto g : _graphs) {
g->setProperty(properties);
for (auto& g : _graphs) {
auto graphLock = Graph::Lock(g);
if (graphLock._graph.IsReady()) {
graphLock._graph.setProperty(properties);
}
}
}
@@ -300,16 +331,16 @@ InferenceEngine::CNNNetwork MKLDNNExecNetwork::GetExecGraphInfo() {
if (_graphs.size() == 0)
THROW_IE_EXCEPTION << "No graph was found";
return _graphs.begin()->get()->dump();
return GetGraph()._graph.dump();
}
Parameter MKLDNNExecNetwork::GetConfig(const std::string &name) const {
if (_graphs.size() == 0)
THROW_IE_EXCEPTION << "No graph was found";
Config engConfig = _graphs.begin()->get()->getProperty();
auto it = engConfig._config.find(name);
if (it != engConfig._config.end()) {
return it->second;
Config engConfig = const_cast<MKLDNNExecNetwork*>(this)->GetGraph()._graph.getProperty();
auto option = engConfig._config.find(name);
if (option != engConfig._config.end()) {
return option->second;
} else {
THROW_IE_EXCEPTION << "Unsupported ExecutableNetwork config key: " << name;
}
@@ -320,7 +351,8 @@ InferenceEngine::Parameter MKLDNNExecNetwork::GetMetric(const std::string &name)
THROW_IE_EXCEPTION << "No graph was found";
if (name == METRIC_KEY(NETWORK_NAME)) {
IE_SET_METRIC_RETURN(NETWORK_NAME, _graphs.begin()->get()->GetName());
IE_SET_METRIC_RETURN(NETWORK_NAME,
const_cast<MKLDNNExecNetwork*>(this)->GetGraph()._graph.dump().getName());
} else if (name == METRIC_KEY(SUPPORTED_METRICS)) {
std::vector<std::string> metrics;
metrics.push_back(METRIC_KEY(NETWORK_NAME));
@@ -330,12 +362,12 @@ InferenceEngine::Parameter MKLDNNExecNetwork::GetMetric(const std::string &name)
IE_SET_METRIC_RETURN(SUPPORTED_METRICS, metrics);
} else if (name == METRIC_KEY(SUPPORTED_CONFIG_KEYS)) {
std::vector<std::string> configKeys;
for (auto && key : _graphs.begin()->get()->getProperty()._config) {
for (auto && key : const_cast<MKLDNNExecNetwork*>(this)->GetGraph()._graph.getProperty()._config) {
configKeys.push_back(key.first);
}
IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, configKeys);
} else if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) {
Config engConfig = _graphs.begin()->get()->getProperty();
Config engConfig = const_cast<MKLDNNExecNetwork*>(this)->GetGraph()._graph.getProperty();
auto option = engConfig._config.find(CONFIG_KEY(CPU_THROUGHPUT_STREAMS));
IE_ASSERT(option != engConfig._config.end());
auto streams = std::stoi(option->second);

View File

@@ -45,8 +45,6 @@ public:
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
InferenceEngine::ThreadLocal<MKLDNNGraph::Ptr> _graphs;
protected:
friend class MKLDNNInferRequest;
MKLDNNExtensionManager::Ptr extensionManager;
@@ -56,7 +54,22 @@ protected:
Config _cfg;
std::atomic_int _numRequests = {0};
std::string _name;
struct Graph : public MKLDNNGraph {
std::mutex _mutex;
struct Lock : public std::unique_lock<std::mutex> {
explicit Lock(Graph& graph) : std::unique_lock<std::mutex>(graph._mutex), _graph(graph) {}
Graph& _graph;
};
};
// WARNING: Do not use _graphs directly.
std::deque<Graph> _graphs;
NumaNodesWeights& _numaNodesWeights;
/* WARNING: Use GetGraph() function to get access to graph in current stream.
* NOTE: Main thread is interpreted as master thread of external stream so use this function to get access to graphs
* even from main thread
*/
Graph::Lock GetGraph();
bool CanProcessDynBatch(const InferenceEngine::CNNNetwork &network) const;
};

View File

@@ -992,7 +992,7 @@ void MKLDNNGraph::setProperty(const std::map<std::string, std::string>& properti
config.readProperties(properties);
}
Config MKLDNNGraph::getProperty() {
Config MKLDNNGraph::getProperty() const {
return config;
}

View File

@@ -42,7 +42,7 @@ public:
void setConfig(const Config &cfg);
void setProperty(const std::map<std::string, std::string> &properties);
Config getProperty();
Config getProperty() const;
void getInputBlobs(InferenceEngine::BlobMap &in_map);
void getOutputBlobs(InferenceEngine::BlobMap &out_map);

View File

@@ -30,7 +30,7 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData
if (execNetwork->_graphs.size() == 0)
THROW_IE_EXCEPTION << "No graph was found";
graph = execNetwork->_graphs.begin()->get();
graph = &(execNetwork->GetGraph()._graph);
for (const auto& it : _networkInputs) {
MKLDNNInferRequest::GetBlob(it.first);
}
@@ -182,8 +182,8 @@ void MKLDNNPlugin::MKLDNNInferRequest::PullStates() {
void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() {
using namespace openvino::itt;
OV_ITT_SCOPED_TASK(itt::domains::MKLDNNPlugin, profilingTask);
graph = execNetwork->_graphs.local().get();
auto graphLock = execNetwork->GetGraph();
graph = &(graphLock._graph);
ThrowIfCanceled();

View File

@@ -16,7 +16,7 @@ using namespace mkldnn;
class MKLDNNTestExecNetwork: public MKLDNNPlugin::MKLDNNExecNetwork {
public:
MKLDNNPlugin::MKLDNNGraph& getGraph() {
return *(_graphs.begin()->get());
return _graphs.front();
}
};