[GPU] Fix a bug of reorder optimized-out (#13180)
This commit is contained in:
parent
b73d3370d8
commit
f7e05ad402
@ -607,32 +607,17 @@ void network::allocate_primitives() {
|
|||||||
return (lhs->get_output_layout().bytes_count() > rhs->get_output_layout().bytes_count());
|
return (lhs->get_output_layout().bytes_count() > rhs->get_output_layout().bytes_count());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Move layers that can be optimized to the back of nodes_to_allocate
|
||||||
|
std::stable_partition(nodes_to_allocate.begin(),
|
||||||
|
nodes_to_allocate.end(),
|
||||||
|
[&po](std::shared_ptr<program_node> const& e) {
|
||||||
|
return !e->can_be_optimized();
|
||||||
|
});
|
||||||
|
|
||||||
for (auto const& node : nodes_to_allocate) {
|
for (auto const& node : nodes_to_allocate) {
|
||||||
allocate_primitive_instance(*node);
|
allocate_primitive_instance(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto const& node : po) {
|
|
||||||
if (node->get_preferred_impl_type() == impl_types::onednn) {
|
|
||||||
size_t eltw_dep = 0;
|
|
||||||
for (auto& fused_op : node->get_fused_primitives()) {
|
|
||||||
if (fused_op.is_type<eltwise>() && fused_op.deps.size() == 1) {
|
|
||||||
// If it is first sum, reuse the buffer
|
|
||||||
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(*node, fused_op);
|
|
||||||
if (fusing_type != add_fusing_type::sum || eltw_dep != 0)
|
|
||||||
continue;
|
|
||||||
eltw_dep = fused_op.dep_start_idx;
|
|
||||||
auto& eltw_in = node->get_dependency(eltw_dep);
|
|
||||||
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
|
|
||||||
auto& eltw_inst = _primitives.at(eltw_in.id());
|
|
||||||
auto& prim_inst = _primitives.at(node->id());
|
|
||||||
auto& eltw_mem = eltw_inst->output_memory();
|
|
||||||
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node->get_output_layout());
|
|
||||||
prim_inst->set_output_memory(new_mem);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// allocate intermediate buffers
|
// allocate intermediate buffers
|
||||||
for (auto const& node : po) {
|
for (auto const& node : po) {
|
||||||
auto prim = _primitives[node->id()];
|
auto prim = _primitives[node->id()];
|
||||||
@ -965,6 +950,27 @@ void network::allocate_primitive_instance(program_node const& node) {
|
|||||||
_variable_state_primitives.push_back(inst);
|
_variable_state_primitives.push_back(inst);
|
||||||
if (node.is_constant())
|
if (node.is_constant())
|
||||||
transfer_memory_to_device(inst, node);
|
transfer_memory_to_device(inst, node);
|
||||||
|
|
||||||
|
if (node.get_preferred_impl_type() == impl_types::onednn) {
|
||||||
|
size_t eltw_dep = 0;
|
||||||
|
for (auto& fused_op : node.get_fused_primitives()) {
|
||||||
|
if (fused_op.is_type<eltwise>() && fused_op.deps.size() == 1) {
|
||||||
|
// If it is first sum, reuse the buffer
|
||||||
|
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(node, fused_op);
|
||||||
|
if (fusing_type != add_fusing_type::sum || eltw_dep != 0)
|
||||||
|
continue;
|
||||||
|
eltw_dep = fused_op.dep_start_idx;
|
||||||
|
auto& eltw_in = node.get_dependency(eltw_dep);
|
||||||
|
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node.id()) != _primitives.end()) {
|
||||||
|
auto& eltw_inst = _primitives.at(eltw_in.id());
|
||||||
|
auto& prim_inst = _primitives.at(node.id());
|
||||||
|
auto& eltw_mem = eltw_inst->output_memory();
|
||||||
|
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node.get_output_layout());
|
||||||
|
prim_inst->set_output_memory(new_mem);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance, program_node const& node) {
|
void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance, program_node const& node) {
|
||||||
|
Loading…
Reference in New Issue
Block a user