[GPU] Modify perferred format of Reduce (#13977)
+ Do not select a planar format if keep_dims is false and blocked axis is reduced. Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
parent
711feac6d8
commit
bc90ed740f
@ -97,8 +97,7 @@ static bool is_reduce_blocked_axes(reduce_node const& node) {
|
||||
auto num_spatial = format::spatial_num(node.get_output_layout().format);
|
||||
auto dims = node.get_output_layout().format.dimension();
|
||||
|
||||
if (prim->keep_dims == false &&
|
||||
(count(reduce_axes.begin(), reduce_axes.end(), 1) > 0 ||
|
||||
if ((count(reduce_axes.begin(), reduce_axes.end(), 1) > 0 ||
|
||||
(count(reduce_axes.begin(), reduce_axes.end(), 0) > 0 && input_layout.batch() > 1))) {
|
||||
for (size_t idx_spatial = dims - num_spatial ; idx_spatial < dims ; idx_spatial++) {
|
||||
if (count(reduce_axes.begin(), reduce_axes.end(), idx_spatial) == 0)
|
||||
@ -823,7 +822,7 @@ static bool is_node_for_onednn(reduce_node const& node, format preferred_format)
|
||||
|
||||
// Onednn reduction does NOT support reordering of unreduced-axes.
|
||||
// Currently, an Onednn reduce layer which contains reduction of blocked axes(b-f) is expected to select planar format.
|
||||
if (is_reduce_blocked_axes(node))
|
||||
if (reduce_prim->keep_dims == false && is_reduce_blocked_axes(node))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@ -1751,19 +1750,6 @@ format layout_optimizer::get_preferred_format(program_node& node) {
|
||||
expected = format::b_fs_yx_fsv32;
|
||||
}
|
||||
}
|
||||
} else if (node.is_type<reduce>()) {
|
||||
auto& reduce_node = node.as<reduce>();
|
||||
// if blocked axes are reduced, it will have huge memory overhead. A clDNN reduce reorders un-reduced axes to b-f and w-x axis for this.
|
||||
// But oneDNN does not allow this. So planar format is used for this case.
|
||||
if (is_reduce_blocked_axes(reduce_node) && use_onednn_impls) {
|
||||
auto input_layout = reduce_node.input().get_output_layout();
|
||||
if (input_layout.format.dimension() == 6)
|
||||
expected = format::bfwzyx;
|
||||
else if (input_layout.format.dimension() == 5)
|
||||
expected = format::bfzyx;
|
||||
else if (input_layout.format.dimension() == 4)
|
||||
expected = format::bfyx;
|
||||
}
|
||||
}
|
||||
|
||||
return expected;
|
||||
|
Loading…
Reference in New Issue
Block a user