Optimize memory depdendency analysis (Constant memory does not use pool : No need to add constant nodes to deps) (#11861)

This commit is contained in:
Taylor Yeonbok Lee 2022-06-14 13:46:24 +09:00 committed by GitHub
parent cae0c924b6
commit c73201c9e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 20 deletions

View File

@ -204,7 +204,7 @@ private:
bool _is_primary_stream;
bool _reset_arguments;
std::map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
std::unordered_map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
std::vector<std::shared_ptr<primitive_inst>> _inputs;
std::vector<std::shared_ptr<primitive_inst>> _outputs;
std::list<std::shared_ptr<primitive_inst>> _exec_order;

View File

@ -78,14 +78,19 @@ void oooq_memory_dependencies::run(program& p) {
// First create transitive closure of the graph,
// giving us mapping of node to set of all users that can be reached from this node.
auto& processing_order = p.get_processing_order();
// maps program nodes to bimap vector ids
auto user_map = std::map<program_node*, unsigned int>();
unsigned int processing_order_idx = 0;
for (auto node : processing_order) {
user_map[node] = processing_order_idx++;
std::list<program_node*> processing_order_except_const;
for (auto n : processing_order) {
if (!n->is_type<data>()) {
processing_order_except_const.push_back(n);
}
}
// maps program nodes to bimap vector ids
auto user_map = std::unordered_map<program_node*, unsigned int>();
unsigned int processing_order_idx = 0;
for (auto node : processing_order_except_const) {
user_map[node] = processing_order_idx++;
}
unsigned int num_nodes = static_cast<unsigned int>(user_map.size());
// full cross ref [node<->node] bitmap.
@ -118,9 +123,7 @@ void oooq_memory_dependencies::run(program& p) {
for (unsigned int n = 0; n < num_nodes; n++) {
auto& users = user_bitmap[n];
// iterate over all users
for (unsigned int user_id = 0; user_id < num_nodes; user_id++) {
// if we have this user set, then add its sub-users to the map
for (unsigned int user_id = n + 1; user_id < num_nodes; user_id++) {
if (users.is_set(user_id)) {
changed |= users._or(user_bitmap[user_id]);
}
@ -134,13 +137,15 @@ void oooq_memory_dependencies::run(program& p) {
};
unsigned int A = 0;
auto itr_A = processing_order.begin();
auto itr_A = processing_order_except_const.begin();
while (itr_A != processing_order.end()) {
while (itr_A != processing_order_except_const.end()) {
if (suspect_nodes.is_set(A)) {
std::vector<std::pair<program_node*, unsigned int>> deps;
for (const auto& dep : (*itr_A)->get_dependencies()) {
deps.emplace_back(dep, user_map.at(dep));
if (!dep->is_type<data>()) {
deps.emplace_back(dep, user_map.at(dep));
}
}
std::sort(deps.begin(), deps.end(),
@ -161,7 +166,7 @@ void oooq_memory_dependencies::run(program& p) {
}
unsigned int B = ++A;
auto itr_B = ++itr_A;
while (itr_B != processing_order.end()) {
while (itr_B != processing_order_except_const.end()) {
if (!are_connected(A, B)) {
add_memory_dependency(*itr_A, *itr_B);
add_memory_dependency(*itr_B, *itr_A);

View File

@ -27,6 +27,8 @@ void skipped_branch_memory_dependencies::run(program& p) {
while (itrB != processing_order.end()) {
auto& nodeB = *itrB;
auto itrA = ++itrB;
if (nodeB->is_constant())
continue;
if (nodeB->get_users().size() == 0)
continue;

View File

@ -144,7 +144,7 @@ public:
void allocate_internal_buffers();
static memory::ptr allocate_output(engine& engine, memory_pool& pool,
const program_node& _node, bool is_internal);
const program_node& _node, uint32_t net_id, bool is_internal);
std::vector<memory::cptr> get_intermediates_memories() const { return _intermediates_memory; }

View File

@ -263,12 +263,12 @@ void primitive_inst::allocate_internal_buffers(void) {
_intermediates_memory.push_back(engine.allocate_memory(layout, allocation_type::usm_host));
}
}
memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node,
memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node, uint32_t net_id,
bool is_internal) {
auto get_memory_from_pool = [&](engine& _engine, const layout& layout, const primitive_id id, std::set<primitive_id> dependencies,
allocation_type type, bool reusable) {
if (_engine.configuration().use_memory_pool)
return pool.get_memory(layout, id, 0, dependencies, type, reusable);
return pool.get_memory(layout, id, net_id, dependencies, type, reusable);
return pool.get_memory(layout, type);
};
@ -332,7 +332,7 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
}
}
memory::ptr primitive_inst::allocate_output() {
return allocate_output(get_network().get_engine(), _network.get_memory_pool(), _node, _network.is_internal());
return allocate_output(get_network().get_engine(), _network.get_memory_pool(), _node, get_network_id(), _network.is_internal());
}
std::vector<std::shared_ptr<primitive_inst>> primitive_inst::build_exec_deps(

View File

@ -579,7 +579,7 @@ void program::post_optimize_graph(bool is_internal) {
}
if (options.get<build_option_type::optimize_data>()->enabled())
apply_opt_pass<remove_redundant_reorders>(lo, false, true, true); // pass to remove output reorders while all others graph optimizations were done
apply_opt_pass<remove_redundant_reorders>(lo, false, true, true); // pass to remove output reorders while all others graph optimizations were done
// update loop input/output primitive mappings
apply_opt_pass<update_loop_primitive_map>();
@ -1544,7 +1544,7 @@ std::pair<int64_t, int64_t> program::get_estimated_device_mem_usage() {
} else if (node->is_type<mutable_data>() && node->get_dependencies().empty()) {
continue;
} else {
allocated_mem_ptrs.insert(primitive_inst::allocate_output(engine, pool, *node, false));
allocated_mem_ptrs.insert(primitive_inst::allocate_output(engine, pool, *node, 0, false));
}
}