[GPU] Modify perferred format of Reduce (#13977)

+ Do not select a planar format if keep_dims is false and blocked axis is reduced.

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2022-11-18 17:17:57 +09:00 committed by GitHub
parent 711feac6d8
commit bc90ed740f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -97,8 +97,7 @@ static bool is_reduce_blocked_axes(reduce_node const& node) {
auto num_spatial = format::spatial_num(node.get_output_layout().format);
auto dims = node.get_output_layout().format.dimension();
if (prim->keep_dims == false &&
(count(reduce_axes.begin(), reduce_axes.end(), 1) > 0 ||
if ((count(reduce_axes.begin(), reduce_axes.end(), 1) > 0 ||
(count(reduce_axes.begin(), reduce_axes.end(), 0) > 0 && input_layout.batch() > 1))) {
for (size_t idx_spatial = dims - num_spatial ; idx_spatial < dims ; idx_spatial++) {
if (count(reduce_axes.begin(), reduce_axes.end(), idx_spatial) == 0)
@ -823,7 +822,7 @@ static bool is_node_for_onednn(reduce_node const& node, format preferred_format)
// Onednn reduction does NOT support reordering of unreduced-axes.
// Currently, an Onednn reduce layer which contains reduction of blocked axes(b-f) is expected to select planar format.
if (is_reduce_blocked_axes(node))
if (reduce_prim->keep_dims == false && is_reduce_blocked_axes(node))
return false;
return true;
@ -1751,19 +1750,6 @@ format layout_optimizer::get_preferred_format(program_node& node) {
expected = format::b_fs_yx_fsv32;
}
}
} else if (node.is_type<reduce>()) {
auto& reduce_node = node.as<reduce>();
// if blocked axes are reduced, it will have huge memory overhead. A clDNN reduce reorders un-reduced axes to b-f and w-x axis for this.
// But oneDNN does not allow this. So planar format is used for this case.
if (is_reduce_blocked_axes(reduce_node) && use_onednn_impls) {
auto input_layout = reduce_node.input().get_output_layout();
if (input_layout.format.dimension() == 6)
expected = format::bfwzyx;
else if (input_layout.format.dimension() == 5)
expected = format::bfzyx;
else if (input_layout.format.dimension() == 4)
expected = format::bfyx;
}
}
return expected;