[GPU] Add additional rule for oneDNN eltwise memory reuse (#7924)

This commit is contained in:
Sergey Shlyapnikov 2021-10-15 10:42:00 +03:00 committed by GitHub
parent a4cc31c0b9
commit fa38103e5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -483,7 +483,8 @@ void network::allocate_primitives() {
for (auto& fused_op : node->get_fused_primitives()) { for (auto& fused_op : node->get_fused_primitives()) {
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) { if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
auto eltw_in_layout = node->get_dependency(fused_op.dep_start_idx).get_output_layout(); auto& eltw_in = node->get_dependency(fused_op.dep_start_idx);
auto eltw_in_layout = eltw_in.get_output_layout();
auto out_layout = node->get_output_layout(); auto out_layout = node->get_output_layout();
if (eltw_in_layout.size == out_layout.size && if (eltw_in_layout.size == out_layout.size &&
@ -497,6 +498,17 @@ void network::allocate_primitives() {
can_reuse_eltwise_mem = true; can_reuse_eltwise_mem = true;
} }
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
auto& eltw_inst = _primitives.at(eltw_in.id());
auto& prim_inst = _primitives.at(node->id());
auto eltw_mem_type = eltw_inst->output_memory().get_allocation_type();
auto prim_mem_type = prim_inst->output_memory().get_allocation_type();
// Keep lockable memory type for `prim_inst` output if needed
if (eltw_mem_type != prim_mem_type && eltw_mem_type != allocation_type::cl_mem && eltw_mem_type != allocation_type::usm_host)
can_reuse_eltwise_mem = false;
}
if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) { if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) {
throw std::runtime_error("Buffer reuse is required for onednn sum post operation."); throw std::runtime_error("Buffer reuse is required for onednn sum post operation.");
} }