[GPU] Onednn integration for handling reorders (#7764)

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2021-10-13 15:06:53 +09:00 committed by GitHub
parent e72200dbe1
commit d23ec24fd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 10 deletions

View File

@ -162,16 +162,24 @@ void add_required_reorders::run(program& p) {
}; };
} }
if (usr->get_preferred_impl_type() == impl_types::onednn) {
usr->set_preferred_impl_type(impl_types::ocl);
usr->set_output_layout(original_layout, false);
if (usr->type()->does_possible_implementation_exist(*usr)) {
correct_layout_selected = true;
}
}
if (!correct_layout_selected) {
for (auto new_layout_format : preffered_layout_formats) { for (auto new_layout_format : preffered_layout_formats) {
layout current_layout(original_layout.data_type, layout current_layout(original_layout.data_type, new_layout_format, original_layout.size);
new_layout_format,
original_layout.size);
usr->set_output_layout(current_layout, false); usr->set_output_layout(current_layout, false);
if (usr->type()->does_possible_implementation_exist(*usr)) { if (usr->type()->does_possible_implementation_exist(*usr)) {
correct_layout_selected = true; correct_layout_selected = true;
break; break;
} }
} }
}
if (!correct_layout_selected) { if (!correct_layout_selected) {
// goal of this section is to use int32 implementation // goal of this section is to use int32 implementation

View File

@ -54,10 +54,15 @@ void remove_redundant_reorders::run(program& p) {
if (!node.get_fused_activations_funcs().empty()) if (!node.get_fused_activations_funcs().empty())
continue; continue;
// Avoid different data types between input and output
auto same_data_type = input.get_output_layout().data_type == output_layout.data_type; auto same_data_type = input.get_output_layout().data_type == output_layout.data_type;
if (!same_data_type) if (!same_data_type)
continue; continue;
// Avoid optimization of nv12 reorder
if (node.get_dependencies().size() != 1)
continue;
bool all_users_fuse = true; bool all_users_fuse = true;
std::vector<program_node*> recalc_list; std::vector<program_node*> recalc_list;
@ -334,7 +339,7 @@ void remove_redundant_reorders::run(program& p) {
p.remove_if_dangling(node); p.remove_if_dangling(node);
} }
// Remove reorder for Convolution bfyx -> fs_b_yx_fsv32 // Remove reorder for Convolution bfyx -> fs_b_yx_fsv32 (+ onednn: bfyx -> b_fs_yx_fsv32)
auto try_fuse_reorder_bfyx_to_fsv32 = [&](reorder_node* node) { auto try_fuse_reorder_bfyx_to_fsv32 = [&](reorder_node* node) {
if (node->get_users().size() != 1) if (node->get_users().size() != 1)
return; return;
@ -342,9 +347,14 @@ void remove_redundant_reorders::run(program& p) {
auto& usr = node->get_users().front(); auto& usr = node->get_users().front();
auto& dep = node->get_dependency(0); auto& dep = node->get_dependency(0);
if (!(usr->is_type<convolution>()) || if (!(usr->is_type<convolution>()) ||
(usr->get_output_layout().data_type != dep.get_output_layout().data_type) || usr->get_output_layout().data_type != dep.get_output_layout().data_type ||
(dep.get_output_layout().format != format::bfyx) || dep.get_output_layout().format != format::bfyx)
(usr->get_output_layout().format != format::fs_b_yx_fsv32)) return;
if (usr->as<convolution>().get_preferred_impl_type() == impl_types::ocl &&
usr->get_output_layout().format != format::fs_b_yx_fsv32)
return;
if (usr->as<convolution>().get_preferred_impl_type() == impl_types::onednn &&
usr->get_output_layout().format != format::b_fs_yx_fsv32)
return; return;
if (dep.is_type<input_layout>()) if (dep.is_type<input_layout>())
@ -377,6 +387,10 @@ void remove_redundant_reorders::run(program& p) {
if (input.as<convolution>().get_primitive()->groups != 1) if (input.as<convolution>().get_primitive()->groups != 1)
return; return;
// Avoid onednn convolution selects ref kernel for fsv16 -> bfyx
if (input.as<convolution>().get_preferred_impl_type() == impl_types::onednn)
return;
if (input.get_users().size() != 1) if (input.get_users().size() != 1)
return; return;
@ -475,4 +489,11 @@ void remove_redundant_reorders::run(program& p) {
p.extract_and_remove(reshape_node); p.extract_and_remove(reshape_node);
} }
} }
for (auto n : p.get_processing_order()) {
if (n->is_in_data_flow() && n->is_type<reorder>()) {
auto preferred_impl = lo.get_preferred_impl_type(*n, n->get_dependency(0).get_output_layout().format);
n->set_preferred_impl_type(preferred_impl);
}
}
} }