[CPU] Optimize TBB usage in the parallel dynamic shapes processing (#16517)

This commit is contained in:
Maksim Kutakov
2023-04-18 22:25:03 +02:00
committed by GitHub
parent d4ac0b0e79
commit 531b5a3657

View File

@@ -52,8 +52,8 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
# include <tbb/task_group.h>
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO)
# include <tbb/task.h>
#endif
using namespace dnnl;
@@ -1056,6 +1056,202 @@ void Graph::InferStatic(InferRequestBase* request) {
}
}
namespace {
class IUpdateNodes {
public:
virtual void run(size_t stopIndx) = 0;
virtual ~IUpdateNodes() = default;
};
class UpdateNodesSeq : public IUpdateNodes {
public:
explicit UpdateNodesSeq(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
void run(size_t stopIndx) override {
for (; prepareCounter < stopIndx; ++prepareCounter) {
const auto& node = m_executableGraphNodes[prepareCounter];
if (node->isDynamicNode()) {
node->updateShapes();
node->updateDynamicParams();
}
}
}
private:
size_t prepareCounter = 0;
std::vector<NodePtr>& m_executableGraphNodes;
};
#if (OV_THREAD == OV_THREAD_SEQ)
using UpdateNodes = UpdateNodesSeq;
#endif
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO || OV_THREAD == OV_THREAD_OMP)
class UpdateNodesBase : public IUpdateNodes {
public:
explicit UpdateNodesBase(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
void updateShapes(size_t node_indx, size_t stop_indx) {
try {
for (size_t i = node_indx; i < stop_indx; i++) {
const auto& node = m_executableGraphNodes[i];
if (node->isDynamicNode()) {
node->updateShapes();
}
m_prepareCounter.store(i, std::memory_order::memory_order_release);
}
}
catch(...) {
m_completion.store(true, std::memory_order::memory_order_relaxed);
throw;
}
m_prepareCounter.store(stop_indx, std::memory_order::memory_order_release);
m_completion.store(true, std::memory_order::memory_order_relaxed);
}
void updateDynParams(size_t node_indx, size_t /*unused*/) {
size_t local_counter = node_indx;
while (true) {
bool completion = m_completion.load(std::memory_order::memory_order_relaxed);
size_t prepareCounter = m_prepareCounter.load(std::memory_order::memory_order_acquire);
if (completion && local_counter == prepareCounter) {
break;
}
while (local_counter < prepareCounter) {
const auto& node = m_executableGraphNodes[local_counter++];
if (node->isDynamicNode()) {
node->updateDynamicParams();
}
}
}
}
protected:
std::atomic<size_t> m_prepareCounter{0};
std::atomic<bool> m_completion{false};
std::vector<NodePtr>& m_executableGraphNodes;
};
#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO)
#if (TBB_VERSION_MAJOR > 2020)
template <typename Body>
class AsyncTask : public tbb::detail::d1::task {
public:
AsyncTask(Body& body, tbb::detail::d1::wait_context& wait, size_t node_indx, size_t stop_indx) :
m_body(body), m_wait(wait), m_node_indx(node_indx), m_stop_indx(stop_indx) {}
task* execute(tbb::detail::d1::execution_data&) override {
m_body(m_node_indx, m_stop_indx);
m_wait.release();
return nullptr;
}
task* cancel(tbb::detail::d1::execution_data&) override {
m_wait.release();
return nullptr;
}
private:
Body& m_body;
tbb::detail::d1::wait_context& m_wait;
size_t m_node_indx;
size_t m_stop_indx;
};
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();
tbb::detail::d1::wait_context wait_ctx(2);
auto task1 = [this](size_t start, size_t stop) {
this->updateShapes(start, stop);
};
AsyncTask<decltype(task1)> t1(task1, wait_ctx, startCounter, stopIndx);
auto task2 = [this](size_t start, size_t stop) {
this->updateDynParams(start, stop);
};
AsyncTask<decltype(task2)> t2(task2, wait_ctx, startCounter, stopIndx);
tbb::detail::d1::spawn(t2, ctx, /* always submit the task to a thread that occupies the first slot */ 1);
tbb::detail::d1::execute_and_wait(t1, ctx, wait_ctx, ctx);
}
private:
tbb::task_group_context ctx;
};
#else
template <typename Body>
class AsyncTask : public tbb::task {
public:
AsyncTask(Body& body, size_t node_indx, size_t stop_indx) : m_body(body), m_node_indx(node_indx), m_stop_indx(stop_indx) {}
task* execute() override {
m_body(m_node_indx, m_stop_indx);
return nullptr;
}
private:
Body& m_body;
size_t m_node_indx;
size_t m_stop_indx;
};
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();
tbb::task& root = *new(tbb::task::allocate_root()) tbb::empty_task;
root.set_ref_count(3); // two for children and one preserved
auto task1 = [this](size_t start, size_t stop) {
this->updateShapes(start, stop);
};
AsyncTask<decltype(task1)>& a = *new (root.allocate_child()) AsyncTask<decltype(task1)>(task1, startCounter, stopIndx);
auto task2 = [this](size_t start, size_t stop) {
this->updateDynParams(start, stop);
};
AsyncTask<decltype(task2)>& b = *new (root.allocate_child()) AsyncTask<decltype(task2)>(task2, startCounter, stopIndx);
b.set_affinity(2); // slot 1 plus 1
tbb::task::spawn(b);
root.spawn_and_wait_for_all(a);
}
};
#endif
#endif
#if (OV_THREAD == OV_THREAD_OMP)
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();
#pragma omp parallel
#pragma omp single
{
#pragma omp task
{
updateDynParams(startCounter, stopIndx);
}
#pragma omp task
{
updateShapes(startCounter, stopIndx);
}
#pragma omp taskwait
}
}
};
#endif
#endif
} // namespace
void Graph::InferDynamic(InferRequestBase* request) {
dnnl::stream stream(getEngine());
@@ -1068,72 +1264,16 @@ void Graph::InferDynamic(InferRequestBase* request) {
}
syncIndsWorkSet.insert(executableGraphNodes.size());
std::function<void(size_t)> updateNodes;
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
std::atomic<size_t> prepareCounter(0);
std::vector<std::atomic<uint8_t>> waveFrontCount(executableGraphNodes.size());
waveFrontCount.front().store(1);
for (size_t i = 1; i < waveFrontCount.size(); ++i) {
waveFrontCount[i].store(2);
std::unique_ptr<IUpdateNodes> updateNodes{};
if (parallel_get_max_threads() > 1) {
updateNodes.reset(new UpdateNodes(executableGraphNodes));
} else {
updateNodes.reset(new UpdateNodesSeq(executableGraphNodes));
}
tbb::task_group tg;
std::function<void(size_t, size_t)> updateShapes;
std::function<void(size_t, size_t)> updateDynParams;
updateShapes = [&](size_t node_indx, size_t stop_indx) {
prepareCounter.store(node_indx);
if (node_indx >= stop_indx) {
return;
}
const auto& node = executableGraphNodes[node_indx];
if (node->isDynamicNode()) {
node->updateShapes();
}
if (--waveFrontCount[node_indx] == 0) {
tg.run([=, &updateDynParams](){ updateDynParams(node_indx, stop_indx); });
}
updateShapes(node_indx + 1, stop_indx);
};
updateDynParams = [&](size_t node_indx, size_t stop_indx) {
if (node_indx >= stop_indx) {
prepareCounter.store(node_indx);
return;
}
const auto& node = executableGraphNodes[node_indx];
if (node->isDynamicNode()) {
node->updateDynamicParams();
}
if (node_indx + 1 < waveFrontCount.size() && --waveFrontCount[node_indx + 1] == 0) {
tg.run([=, &updateDynParams](){ updateDynParams(node_indx + 1, stop_indx); });
}
};
updateNodes = [&](size_t stopIndx) {
auto startCounter = prepareCounter.load();
tg.run([=, &updateShapes](){ updateShapes(startCounter, stopIndx); });
tg.wait();
};
#else
size_t prepareCounter = 0;
updateNodes = [&](size_t stopIndx) {
for (; prepareCounter < stopIndx; ++prepareCounter) {
const auto& node = executableGraphNodes[prepareCounter];
if (node->isDynamicNode()) {
node->updateShapes();
node->updateDynamicParams();
}
}
};
#endif
size_t inferCounter = 0;
for (auto stopIndx : syncIndsWorkSet) {
updateNodes(stopIndx);
updateNodes->run(stopIndx);
for (; inferCounter < stopIndx; ++inferCounter) {
auto& node = executableGraphNodes[inferCounter];
VERBOSE(node, getConfig().debugCaps.verbose);