[GPU] Serialize GPU memory reuse flag (#14269)

* serialize gpu memory reuse flag

* added c-tor of reduce_impl

* renamed memory_reuse_by_user as can_reuse_memory, and modified memory allocation logic to be simpler
This commit is contained in:
Eddy Kim 2022-11-30 04:13:00 +09:00 committed by GitHub
parent 115a9071e4
commit cb15d281f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 16 deletions

View File

@ -78,6 +78,14 @@ struct reduce_impl : typed_primitive_impl_ocl<reduce> {
return make_unique<reduce_impl>(*this);
}
reduce_impl() : parent() {}
explicit reduce_impl(const reduce_impl& other) : parent(other) {}
reduce_impl(const reduce_node& arg, const kernel_selector::kernel_data& kd) : parent(arg, kd) {
this->can_reuse_memory = kd.can_reuse_memory;
}
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
const auto& primitive = impl_param.typed_desc<reduce>();
auto params = get_default_params<kernel_selector::reduce_params>(impl_param);

View File

@ -795,6 +795,21 @@ event::ptr primitive_inst::update_weights() {
return nullptr;
}
static bool user_requesting_mem_reuse_false(const program_node& node) {
for (auto& user : node.get_users()) {
if (user->is_dynamic())
return true;
if ((user->get_selected_impl() != nullptr) && (user->get_selected_impl()->can_reuse_memory == false)) {
return true;
} else if (user->get_selected_impl() == nullptr) {
if (user_requesting_mem_reuse_false(*user)) {
return true;
}
}
}
return false;
}
memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node, const kernel_impl_params& impl_params,
uint32_t net_id, bool is_internal, size_t idx) {
auto get_memory_from_pool = [&](engine& _engine, const layout& layout, const primitive_id id, std::set<primitive_id> dependencies,
@ -822,21 +837,6 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
bool memory_reuse_by_user = true;
std::function<bool(const program_node&)> user_requesting_mem_reuse_false = [&user_requesting_mem_reuse_false](const program_node& node) {
for (auto& user : node.get_users()) {
if (user->is_dynamic())
return true;
if ((user->get_selected_impl() != nullptr) && (user->get_selected_impl()->can_reuse_memory == false)) {
return true;
} else if (user->get_selected_impl() == nullptr) {
if (user_requesting_mem_reuse_false(*user)) {
return true;
}
}
}
return false;
};
if (user_requesting_mem_reuse_false(_node)) {
memory_reuse_by_user = false;
}
@ -1147,6 +1147,12 @@ void primitive_inst::save(cldnn::BinaryOutputBuffer& ob) const {
const auto _allocation_type = _outputs[0]->get_allocation_type();
ob << make_data(&_allocation_type, sizeof(_allocation_type));
bool can_reuse_memory = true;
if (user_requesting_mem_reuse_false(*_node)) {
can_reuse_memory = false;
}
ob << can_reuse_memory;
ob << _node->get_memory_dependencies();
ob << _deps.size();
@ -1251,6 +1257,9 @@ void primitive_inst::load(cldnn::BinaryInputBuffer& ib) {
allocation_type _allocation_type;
ib >> make_data(&_allocation_type, sizeof(_allocation_type));
bool can_reuse_memory;
ib >> can_reuse_memory;
std::set<primitive_id> _node_mem_deps;
ib >> _node_mem_deps;
@ -1280,7 +1289,7 @@ void primitive_inst::load(cldnn::BinaryInputBuffer& ib) {
if ((!can_share_buffer()) || can_be_optimized() || is_output()) {
_outputs[0] = get_network().get_engine().allocate_memory(output_layout, _allocation_type);
} else {
_outputs[0] = get_network().get_memory_pool().get_memory(output_layout, id(), get_network_id(), _node_mem_deps, _allocation_type, true);
_outputs[0] = get_network().get_memory_pool().get_memory(output_layout, id(), get_network_id(), _node_mem_deps, _allocation_type, can_reuse_memory);
}
}
_output_changed = false;