Optimize memory depdendency analysis (Constant memory does not use pool : No need to add constant nodes to deps) (#11861)
This commit is contained in:
parent
cae0c924b6
commit
c73201c9e6
@ -204,7 +204,7 @@ private:
|
||||
bool _is_primary_stream;
|
||||
bool _reset_arguments;
|
||||
|
||||
std::map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
|
||||
std::unordered_map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
|
||||
std::vector<std::shared_ptr<primitive_inst>> _inputs;
|
||||
std::vector<std::shared_ptr<primitive_inst>> _outputs;
|
||||
std::list<std::shared_ptr<primitive_inst>> _exec_order;
|
||||
|
@ -78,14 +78,19 @@ void oooq_memory_dependencies::run(program& p) {
|
||||
// First create transitive closure of the graph,
|
||||
// giving us mapping of node to set of all users that can be reached from this node.
|
||||
auto& processing_order = p.get_processing_order();
|
||||
|
||||
// maps program nodes to bimap vector ids
|
||||
auto user_map = std::map<program_node*, unsigned int>();
|
||||
unsigned int processing_order_idx = 0;
|
||||
for (auto node : processing_order) {
|
||||
user_map[node] = processing_order_idx++;
|
||||
std::list<program_node*> processing_order_except_const;
|
||||
for (auto n : processing_order) {
|
||||
if (!n->is_type<data>()) {
|
||||
processing_order_except_const.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
// maps program nodes to bimap vector ids
|
||||
auto user_map = std::unordered_map<program_node*, unsigned int>();
|
||||
unsigned int processing_order_idx = 0;
|
||||
for (auto node : processing_order_except_const) {
|
||||
user_map[node] = processing_order_idx++;
|
||||
}
|
||||
unsigned int num_nodes = static_cast<unsigned int>(user_map.size());
|
||||
|
||||
// full cross ref [node<->node] bitmap.
|
||||
@ -118,9 +123,7 @@ void oooq_memory_dependencies::run(program& p) {
|
||||
for (unsigned int n = 0; n < num_nodes; n++) {
|
||||
auto& users = user_bitmap[n];
|
||||
|
||||
// iterate over all users
|
||||
for (unsigned int user_id = 0; user_id < num_nodes; user_id++) {
|
||||
// if we have this user set, then add its sub-users to the map
|
||||
for (unsigned int user_id = n + 1; user_id < num_nodes; user_id++) {
|
||||
if (users.is_set(user_id)) {
|
||||
changed |= users._or(user_bitmap[user_id]);
|
||||
}
|
||||
@ -134,13 +137,15 @@ void oooq_memory_dependencies::run(program& p) {
|
||||
};
|
||||
|
||||
unsigned int A = 0;
|
||||
auto itr_A = processing_order.begin();
|
||||
auto itr_A = processing_order_except_const.begin();
|
||||
|
||||
while (itr_A != processing_order.end()) {
|
||||
while (itr_A != processing_order_except_const.end()) {
|
||||
if (suspect_nodes.is_set(A)) {
|
||||
std::vector<std::pair<program_node*, unsigned int>> deps;
|
||||
for (const auto& dep : (*itr_A)->get_dependencies()) {
|
||||
deps.emplace_back(dep, user_map.at(dep));
|
||||
if (!dep->is_type<data>()) {
|
||||
deps.emplace_back(dep, user_map.at(dep));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(deps.begin(), deps.end(),
|
||||
@ -161,7 +166,7 @@ void oooq_memory_dependencies::run(program& p) {
|
||||
}
|
||||
unsigned int B = ++A;
|
||||
auto itr_B = ++itr_A;
|
||||
while (itr_B != processing_order.end()) {
|
||||
while (itr_B != processing_order_except_const.end()) {
|
||||
if (!are_connected(A, B)) {
|
||||
add_memory_dependency(*itr_A, *itr_B);
|
||||
add_memory_dependency(*itr_B, *itr_A);
|
||||
|
@ -27,6 +27,8 @@ void skipped_branch_memory_dependencies::run(program& p) {
|
||||
while (itrB != processing_order.end()) {
|
||||
auto& nodeB = *itrB;
|
||||
auto itrA = ++itrB;
|
||||
if (nodeB->is_constant())
|
||||
continue;
|
||||
if (nodeB->get_users().size() == 0)
|
||||
continue;
|
||||
|
||||
|
@ -144,7 +144,7 @@ public:
|
||||
|
||||
void allocate_internal_buffers();
|
||||
static memory::ptr allocate_output(engine& engine, memory_pool& pool,
|
||||
const program_node& _node, bool is_internal);
|
||||
const program_node& _node, uint32_t net_id, bool is_internal);
|
||||
|
||||
std::vector<memory::cptr> get_intermediates_memories() const { return _intermediates_memory; }
|
||||
|
||||
|
@ -263,12 +263,12 @@ void primitive_inst::allocate_internal_buffers(void) {
|
||||
_intermediates_memory.push_back(engine.allocate_memory(layout, allocation_type::usm_host));
|
||||
}
|
||||
}
|
||||
memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node,
|
||||
memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node, uint32_t net_id,
|
||||
bool is_internal) {
|
||||
auto get_memory_from_pool = [&](engine& _engine, const layout& layout, const primitive_id id, std::set<primitive_id> dependencies,
|
||||
allocation_type type, bool reusable) {
|
||||
if (_engine.configuration().use_memory_pool)
|
||||
return pool.get_memory(layout, id, 0, dependencies, type, reusable);
|
||||
return pool.get_memory(layout, id, net_id, dependencies, type, reusable);
|
||||
return pool.get_memory(layout, type);
|
||||
};
|
||||
|
||||
@ -332,7 +332,7 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
|
||||
}
|
||||
}
|
||||
memory::ptr primitive_inst::allocate_output() {
|
||||
return allocate_output(get_network().get_engine(), _network.get_memory_pool(), _node, _network.is_internal());
|
||||
return allocate_output(get_network().get_engine(), _network.get_memory_pool(), _node, get_network_id(), _network.is_internal());
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<primitive_inst>> primitive_inst::build_exec_deps(
|
||||
|
@ -579,7 +579,7 @@ void program::post_optimize_graph(bool is_internal) {
|
||||
}
|
||||
|
||||
if (options.get<build_option_type::optimize_data>()->enabled())
|
||||
apply_opt_pass<remove_redundant_reorders>(lo, false, true, true); // pass to remove output reorders while all others graph optimizations were done
|
||||
apply_opt_pass<remove_redundant_reorders>(lo, false, true, true); // pass to remove output reorders while all others graph optimizations were done
|
||||
|
||||
// update loop input/output primitive mappings
|
||||
apply_opt_pass<update_loop_primitive_map>();
|
||||
@ -1544,7 +1544,7 @@ std::pair<int64_t, int64_t> program::get_estimated_device_mem_usage() {
|
||||
} else if (node->is_type<mutable_data>() && node->get_dependencies().empty()) {
|
||||
continue;
|
||||
} else {
|
||||
allocated_mem_ptrs.insert(primitive_inst::allocate_output(engine, pool, *node, false));
|
||||
allocated_mem_ptrs.insert(primitive_inst::allocate_output(engine, pool, *node, 0, false));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user