[GPU] Add additional rule for oneDNN eltwise memory reuse (#7924)
This commit is contained in:
parent
a4cc31c0b9
commit
fa38103e5b
@ -483,7 +483,8 @@ void network::allocate_primitives() {
|
||||
|
||||
for (auto& fused_op : node->get_fused_primitives()) {
|
||||
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
|
||||
auto eltw_in_layout = node->get_dependency(fused_op.dep_start_idx).get_output_layout();
|
||||
auto& eltw_in = node->get_dependency(fused_op.dep_start_idx);
|
||||
auto eltw_in_layout = eltw_in.get_output_layout();
|
||||
auto out_layout = node->get_output_layout();
|
||||
|
||||
if (eltw_in_layout.size == out_layout.size &&
|
||||
@ -497,6 +498,17 @@ void network::allocate_primitives() {
|
||||
can_reuse_eltwise_mem = true;
|
||||
}
|
||||
|
||||
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
|
||||
auto& eltw_inst = _primitives.at(eltw_in.id());
|
||||
auto& prim_inst = _primitives.at(node->id());
|
||||
auto eltw_mem_type = eltw_inst->output_memory().get_allocation_type();
|
||||
auto prim_mem_type = prim_inst->output_memory().get_allocation_type();
|
||||
|
||||
// Keep lockable memory type for `prim_inst` output if needed
|
||||
if (eltw_mem_type != prim_mem_type && eltw_mem_type != allocation_type::cl_mem && eltw_mem_type != allocation_type::usm_host)
|
||||
can_reuse_eltwise_mem = false;
|
||||
}
|
||||
|
||||
if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) {
|
||||
throw std::runtime_error("Buffer reuse is required for onednn sum post operation.");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user