[GPU] Fix a bug of reorder optimized-out (#13180)

This commit is contained in:
Jade Cho 2022-10-05 18:04:27 +09:00 committed by GitHub
parent b73d3370d8
commit f7e05ad402
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -607,32 +607,17 @@ void network::allocate_primitives() {
return (lhs->get_output_layout().bytes_count() > rhs->get_output_layout().bytes_count()); return (lhs->get_output_layout().bytes_count() > rhs->get_output_layout().bytes_count());
}); });
// Move layers that can be optimized to the back of nodes_to_allocate
std::stable_partition(nodes_to_allocate.begin(),
nodes_to_allocate.end(),
[&po](std::shared_ptr<program_node> const& e) {
return !e->can_be_optimized();
});
for (auto const& node : nodes_to_allocate) { for (auto const& node : nodes_to_allocate) {
allocate_primitive_instance(*node); allocate_primitive_instance(*node);
} }
for (auto const& node : po) {
if (node->get_preferred_impl_type() == impl_types::onednn) {
size_t eltw_dep = 0;
for (auto& fused_op : node->get_fused_primitives()) {
if (fused_op.is_type<eltwise>() && fused_op.deps.size() == 1) {
// If it is first sum, reuse the buffer
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(*node, fused_op);
if (fusing_type != add_fusing_type::sum || eltw_dep != 0)
continue;
eltw_dep = fused_op.dep_start_idx;
auto& eltw_in = node->get_dependency(eltw_dep);
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
auto& eltw_inst = _primitives.at(eltw_in.id());
auto& prim_inst = _primitives.at(node->id());
auto& eltw_mem = eltw_inst->output_memory();
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node->get_output_layout());
prim_inst->set_output_memory(new_mem);
}
}
}
}
}
// allocate intermediate buffers // allocate intermediate buffers
for (auto const& node : po) { for (auto const& node : po) {
auto prim = _primitives[node->id()]; auto prim = _primitives[node->id()];
@ -965,6 +950,27 @@ void network::allocate_primitive_instance(program_node const& node) {
_variable_state_primitives.push_back(inst); _variable_state_primitives.push_back(inst);
if (node.is_constant()) if (node.is_constant())
transfer_memory_to_device(inst, node); transfer_memory_to_device(inst, node);
if (node.get_preferred_impl_type() == impl_types::onednn) {
size_t eltw_dep = 0;
for (auto& fused_op : node.get_fused_primitives()) {
if (fused_op.is_type<eltwise>() && fused_op.deps.size() == 1) {
// If it is first sum, reuse the buffer
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(node, fused_op);
if (fusing_type != add_fusing_type::sum || eltw_dep != 0)
continue;
eltw_dep = fused_op.dep_start_idx;
auto& eltw_in = node.get_dependency(eltw_dep);
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node.id()) != _primitives.end()) {
auto& eltw_inst = _primitives.at(eltw_in.id());
auto& prim_inst = _primitives.at(node.id());
auto& eltw_mem = eltw_inst->output_memory();
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node.get_output_layout());
prim_inst->set_output_memory(new_mem);
}
}
}
}
} }
void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance, program_node const& node) { void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance, program_node const& node) {