[CPU] Optimize TBB usage in the parallel dynamic shapes processing (#16517)
This commit is contained in:
@@ -52,8 +52,8 @@
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include <common/primitive_desc.hpp>
|
||||
#include <common/primitive_desc_iface.hpp>
|
||||
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
|
||||
# include <tbb/task_group.h>
|
||||
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO)
|
||||
# include <tbb/task.h>
|
||||
#endif
|
||||
|
||||
using namespace dnnl;
|
||||
@@ -1056,6 +1056,202 @@ void Graph::InferStatic(InferRequestBase* request) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class IUpdateNodes {
|
||||
public:
|
||||
virtual void run(size_t stopIndx) = 0;
|
||||
virtual ~IUpdateNodes() = default;
|
||||
};
|
||||
|
||||
class UpdateNodesSeq : public IUpdateNodes {
|
||||
public:
|
||||
explicit UpdateNodesSeq(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
|
||||
void run(size_t stopIndx) override {
|
||||
for (; prepareCounter < stopIndx; ++prepareCounter) {
|
||||
const auto& node = m_executableGraphNodes[prepareCounter];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateShapes();
|
||||
node->updateDynamicParams();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t prepareCounter = 0;
|
||||
std::vector<NodePtr>& m_executableGraphNodes;
|
||||
};
|
||||
|
||||
#if (OV_THREAD == OV_THREAD_SEQ)
|
||||
using UpdateNodes = UpdateNodesSeq;
|
||||
#endif
|
||||
|
||||
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO || OV_THREAD == OV_THREAD_OMP)
|
||||
class UpdateNodesBase : public IUpdateNodes {
|
||||
public:
|
||||
explicit UpdateNodesBase(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
|
||||
void updateShapes(size_t node_indx, size_t stop_indx) {
|
||||
try {
|
||||
for (size_t i = node_indx; i < stop_indx; i++) {
|
||||
const auto& node = m_executableGraphNodes[i];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateShapes();
|
||||
}
|
||||
m_prepareCounter.store(i, std::memory_order::memory_order_release);
|
||||
}
|
||||
}
|
||||
catch(...) {
|
||||
m_completion.store(true, std::memory_order::memory_order_relaxed);
|
||||
throw;
|
||||
}
|
||||
m_prepareCounter.store(stop_indx, std::memory_order::memory_order_release);
|
||||
m_completion.store(true, std::memory_order::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void updateDynParams(size_t node_indx, size_t /*unused*/) {
|
||||
size_t local_counter = node_indx;
|
||||
while (true) {
|
||||
bool completion = m_completion.load(std::memory_order::memory_order_relaxed);
|
||||
size_t prepareCounter = m_prepareCounter.load(std::memory_order::memory_order_acquire);
|
||||
if (completion && local_counter == prepareCounter) {
|
||||
break;
|
||||
}
|
||||
while (local_counter < prepareCounter) {
|
||||
const auto& node = m_executableGraphNodes[local_counter++];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateDynamicParams();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::atomic<size_t> m_prepareCounter{0};
|
||||
std::atomic<bool> m_completion{false};
|
||||
std::vector<NodePtr>& m_executableGraphNodes;
|
||||
};
|
||||
|
||||
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO)
|
||||
#if (TBB_VERSION_MAJOR > 2020)
|
||||
template <typename Body>
|
||||
class AsyncTask : public tbb::detail::d1::task {
|
||||
public:
|
||||
AsyncTask(Body& body, tbb::detail::d1::wait_context& wait, size_t node_indx, size_t stop_indx) :
|
||||
m_body(body), m_wait(wait), m_node_indx(node_indx), m_stop_indx(stop_indx) {}
|
||||
task* execute(tbb::detail::d1::execution_data&) override {
|
||||
m_body(m_node_indx, m_stop_indx);
|
||||
m_wait.release();
|
||||
return nullptr;
|
||||
}
|
||||
task* cancel(tbb::detail::d1::execution_data&) override {
|
||||
m_wait.release();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Body& m_body;
|
||||
tbb::detail::d1::wait_context& m_wait;
|
||||
size_t m_node_indx;
|
||||
size_t m_stop_indx;
|
||||
};
|
||||
|
||||
class UpdateNodes : public UpdateNodesBase {
|
||||
public:
|
||||
using UpdateNodesBase::UpdateNodesBase;
|
||||
void run(size_t stopIndx) override {
|
||||
m_completion.store(false);
|
||||
auto startCounter = m_prepareCounter.load();
|
||||
tbb::detail::d1::wait_context wait_ctx(2);
|
||||
|
||||
auto task1 = [this](size_t start, size_t stop) {
|
||||
this->updateShapes(start, stop);
|
||||
};
|
||||
AsyncTask<decltype(task1)> t1(task1, wait_ctx, startCounter, stopIndx);
|
||||
|
||||
auto task2 = [this](size_t start, size_t stop) {
|
||||
this->updateDynParams(start, stop);
|
||||
};
|
||||
AsyncTask<decltype(task2)> t2(task2, wait_ctx, startCounter, stopIndx);
|
||||
|
||||
tbb::detail::d1::spawn(t2, ctx, /* always submit the task to a thread that occupies the first slot */ 1);
|
||||
tbb::detail::d1::execute_and_wait(t1, ctx, wait_ctx, ctx);
|
||||
}
|
||||
|
||||
private:
|
||||
tbb::task_group_context ctx;
|
||||
};
|
||||
#else
|
||||
template <typename Body>
|
||||
class AsyncTask : public tbb::task {
|
||||
public:
|
||||
AsyncTask(Body& body, size_t node_indx, size_t stop_indx) : m_body(body), m_node_indx(node_indx), m_stop_indx(stop_indx) {}
|
||||
task* execute() override {
|
||||
m_body(m_node_indx, m_stop_indx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Body& m_body;
|
||||
size_t m_node_indx;
|
||||
size_t m_stop_indx;
|
||||
};
|
||||
|
||||
class UpdateNodes : public UpdateNodesBase {
|
||||
public:
|
||||
using UpdateNodesBase::UpdateNodesBase;
|
||||
void run(size_t stopIndx) override {
|
||||
m_completion.store(false);
|
||||
auto startCounter = m_prepareCounter.load();
|
||||
tbb::task& root = *new(tbb::task::allocate_root()) tbb::empty_task;
|
||||
root.set_ref_count(3); // two for children and one preserved
|
||||
|
||||
auto task1 = [this](size_t start, size_t stop) {
|
||||
this->updateShapes(start, stop);
|
||||
};
|
||||
AsyncTask<decltype(task1)>& a = *new (root.allocate_child()) AsyncTask<decltype(task1)>(task1, startCounter, stopIndx);
|
||||
|
||||
auto task2 = [this](size_t start, size_t stop) {
|
||||
this->updateDynParams(start, stop);
|
||||
};
|
||||
AsyncTask<decltype(task2)>& b = *new (root.allocate_child()) AsyncTask<decltype(task2)>(task2, startCounter, stopIndx);
|
||||
|
||||
b.set_affinity(2); // slot 1 plus 1
|
||||
tbb::task::spawn(b);
|
||||
root.spawn_and_wait_for_all(a);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if (OV_THREAD == OV_THREAD_OMP)
|
||||
class UpdateNodes : public UpdateNodesBase {
|
||||
public:
|
||||
using UpdateNodesBase::UpdateNodesBase;
|
||||
void run(size_t stopIndx) override {
|
||||
m_completion.store(false);
|
||||
auto startCounter = m_prepareCounter.load();
|
||||
|
||||
#pragma omp parallel
|
||||
#pragma omp single
|
||||
{
|
||||
#pragma omp task
|
||||
{
|
||||
updateDynParams(startCounter, stopIndx);
|
||||
}
|
||||
#pragma omp task
|
||||
{
|
||||
updateShapes(startCounter, stopIndx);
|
||||
}
|
||||
#pragma omp taskwait
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
|
||||
void Graph::InferDynamic(InferRequestBase* request) {
|
||||
dnnl::stream stream(getEngine());
|
||||
|
||||
@@ -1068,72 +1264,16 @@ void Graph::InferDynamic(InferRequestBase* request) {
|
||||
}
|
||||
syncIndsWorkSet.insert(executableGraphNodes.size());
|
||||
|
||||
std::function<void(size_t)> updateNodes;
|
||||
|
||||
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
|
||||
std::atomic<size_t> prepareCounter(0);
|
||||
std::vector<std::atomic<uint8_t>> waveFrontCount(executableGraphNodes.size());
|
||||
waveFrontCount.front().store(1);
|
||||
for (size_t i = 1; i < waveFrontCount.size(); ++i) {
|
||||
waveFrontCount[i].store(2);
|
||||
std::unique_ptr<IUpdateNodes> updateNodes{};
|
||||
if (parallel_get_max_threads() > 1) {
|
||||
updateNodes.reset(new UpdateNodes(executableGraphNodes));
|
||||
} else {
|
||||
updateNodes.reset(new UpdateNodesSeq(executableGraphNodes));
|
||||
}
|
||||
|
||||
tbb::task_group tg;
|
||||
std::function<void(size_t, size_t)> updateShapes;
|
||||
std::function<void(size_t, size_t)> updateDynParams;
|
||||
|
||||
updateShapes = [&](size_t node_indx, size_t stop_indx) {
|
||||
prepareCounter.store(node_indx);
|
||||
if (node_indx >= stop_indx) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& node = executableGraphNodes[node_indx];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateShapes();
|
||||
}
|
||||
if (--waveFrontCount[node_indx] == 0) {
|
||||
tg.run([=, &updateDynParams](){ updateDynParams(node_indx, stop_indx); });
|
||||
}
|
||||
updateShapes(node_indx + 1, stop_indx);
|
||||
};
|
||||
|
||||
updateDynParams = [&](size_t node_indx, size_t stop_indx) {
|
||||
if (node_indx >= stop_indx) {
|
||||
prepareCounter.store(node_indx);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& node = executableGraphNodes[node_indx];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateDynamicParams();
|
||||
}
|
||||
if (node_indx + 1 < waveFrontCount.size() && --waveFrontCount[node_indx + 1] == 0) {
|
||||
tg.run([=, &updateDynParams](){ updateDynParams(node_indx + 1, stop_indx); });
|
||||
}
|
||||
};
|
||||
|
||||
updateNodes = [&](size_t stopIndx) {
|
||||
auto startCounter = prepareCounter.load();
|
||||
tg.run([=, &updateShapes](){ updateShapes(startCounter, stopIndx); });
|
||||
tg.wait();
|
||||
};
|
||||
#else
|
||||
size_t prepareCounter = 0;
|
||||
updateNodes = [&](size_t stopIndx) {
|
||||
for (; prepareCounter < stopIndx; ++prepareCounter) {
|
||||
const auto& node = executableGraphNodes[prepareCounter];
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateShapes();
|
||||
node->updateDynamicParams();
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
size_t inferCounter = 0;
|
||||
|
||||
for (auto stopIndx : syncIndsWorkSet) {
|
||||
updateNodes(stopIndx);
|
||||
updateNodes->run(stopIndx);
|
||||
for (; inferCounter < stopIndx; ++inferCounter) {
|
||||
auto& node = executableGraphNodes[inferCounter];
|
||||
VERBOSE(node, getConfig().debugCaps.verbose);
|
||||
|
||||
Reference in New Issue
Block a user