From cb15d281f1f59b45ae29a7f8375410b35dd540c0 Mon Sep 17 00:00:00 2001 From: Eddy Kim Date: Wed, 30 Nov 2022 04:13:00 +0900 Subject: [PATCH] [GPU] Serialize GPU memory reuse flag (#14269) * serialize gpu memory reuse flag * added c-tor of reduce_impl * renamed memory_reuse_by_user as can_reuse_memory, and modified memory allocation logic to be simpler --- .../intel_gpu/src/graph/impls/ocl/reduce.cpp | 8 ++++ .../intel_gpu/src/graph/primitive_inst.cpp | 41 +++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/reduce.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/reduce.cpp index 78d611a9b1e..894fd5013b3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/reduce.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/reduce.cpp @@ -78,6 +78,14 @@ struct reduce_impl : typed_primitive_impl_ocl { return make_unique(*this); } + reduce_impl() : parent() {} + + explicit reduce_impl(const reduce_impl& other) : parent(other) {} + + reduce_impl(const reduce_node& arg, const kernel_selector::kernel_data& kd) : parent(arg, kd) { + this->can_reuse_memory = kd.can_reuse_memory; + } + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) { const auto& primitive = impl_param.typed_desc(); auto params = get_default_params(impl_param); diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 434be418849..d6c3d566ee7 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -795,6 +795,21 @@ event::ptr primitive_inst::update_weights() { return nullptr; } +static bool user_requesting_mem_reuse_false(const program_node& node) { + for (auto& user : node.get_users()) { + if (user->is_dynamic()) + return true; + if ((user->get_selected_impl() != nullptr) && (user->get_selected_impl()->can_reuse_memory == false)) { + return true; + } else if (user->get_selected_impl() == nullptr) { + if (user_requesting_mem_reuse_false(*user)) { + return true; + } + } + } + return false; +} + memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node, const kernel_impl_params& impl_params, uint32_t net_id, bool is_internal, size_t idx) { auto get_memory_from_pool = [&](engine& _engine, const layout& layout, const primitive_id id, std::set dependencies, @@ -822,21 +837,6 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, bool memory_reuse_by_user = true; - std::function user_requesting_mem_reuse_false = [&user_requesting_mem_reuse_false](const program_node& node) { - for (auto& user : node.get_users()) { - if (user->is_dynamic()) - return true; - if ((user->get_selected_impl() != nullptr) && (user->get_selected_impl()->can_reuse_memory == false)) { - return true; - } else if (user->get_selected_impl() == nullptr) { - if (user_requesting_mem_reuse_false(*user)) { - return true; - } - } - } - return false; - }; - if (user_requesting_mem_reuse_false(_node)) { memory_reuse_by_user = false; } @@ -1147,6 +1147,12 @@ void primitive_inst::save(cldnn::BinaryOutputBuffer& ob) const { const auto _allocation_type = _outputs[0]->get_allocation_type(); ob << make_data(&_allocation_type, sizeof(_allocation_type)); + bool can_reuse_memory = true; + if (user_requesting_mem_reuse_false(*_node)) { + can_reuse_memory = false; + } + ob << can_reuse_memory; + ob << _node->get_memory_dependencies(); ob << _deps.size(); @@ -1251,6 +1257,9 @@ void primitive_inst::load(cldnn::BinaryInputBuffer& ib) { allocation_type _allocation_type; ib >> make_data(&_allocation_type, sizeof(_allocation_type)); + bool can_reuse_memory; + ib >> can_reuse_memory; + std::set _node_mem_deps; ib >> _node_mem_deps; @@ -1280,7 +1289,7 @@ void primitive_inst::load(cldnn::BinaryInputBuffer& ib) { if ((!can_share_buffer()) || can_be_optimized() || is_output()) { _outputs[0] = get_network().get_engine().allocate_memory(output_layout, _allocation_type); } else { - _outputs[0] = get_network().get_memory_pool().get_memory(output_layout, id(), get_network_id(), _node_mem_deps, _allocation_type, true); + _outputs[0] = get_network().get_memory_pool().get_memory(output_layout, id(), get_network_id(), _node_mem_deps, _allocation_type, can_reuse_memory); } } _output_changed = false;