diff --git a/inference-engine/thirdparty/clDNN/src/network.cpp b/inference-engine/thirdparty/clDNN/src/network.cpp index abb6a7bfec3..d7cf8870155 100644 --- a/inference-engine/thirdparty/clDNN/src/network.cpp +++ b/inference-engine/thirdparty/clDNN/src/network.cpp @@ -483,7 +483,8 @@ void network::allocate_primitives() { for (auto& fused_op : node->get_fused_primitives()) { if (fused_op.node->is_type() && fused_op.deps.size() == 1) { - auto eltw_in_layout = node->get_dependency(fused_op.dep_start_idx).get_output_layout(); + auto& eltw_in = node->get_dependency(fused_op.dep_start_idx); + auto eltw_in_layout = eltw_in.get_output_layout(); auto out_layout = node->get_output_layout(); if (eltw_in_layout.size == out_layout.size && @@ -497,6 +498,17 @@ void network::allocate_primitives() { can_reuse_eltwise_mem = true; } + if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) { + auto& eltw_inst = _primitives.at(eltw_in.id()); + auto& prim_inst = _primitives.at(node->id()); + auto eltw_mem_type = eltw_inst->output_memory().get_allocation_type(); + auto prim_mem_type = prim_inst->output_memory().get_allocation_type(); + + // Keep lockable memory type for `prim_inst` output if needed + if (eltw_mem_type != prim_mem_type && eltw_mem_type != allocation_type::cl_mem && eltw_mem_type != allocation_type::usm_host) + can_reuse_eltwise_mem = false; + } + if (fused_op.node->as().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) { throw std::runtime_error("Buffer reuse is required for onednn sum post operation."); }