fix PR14422 regression for using arg_max_min opt kernel (#14487)

This commit is contained in:
Wilson Seok 2022-12-09 17:37:29 +09:00 committed by GitHub
parent fac45d79e4
commit c16cee4624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 0 deletions

View File

@ -11,6 +11,7 @@
#include "reorder_inst.h"
#include "resample_inst.h"
#include "reshape_inst.h"
#include "arg_max_min_inst.h"
#include "generic_layer.hpp"
#include <sstream>
@ -1790,6 +1791,10 @@ format layout_optimizer::get_preferred_format(program_node& node) {
else if (input_layout.format.dimension() == 4)
expected = format::bfyx;
}
} else if (node.is_type<arg_max_min>()) {
// Set default format for issue 92967/98750
// TODO: will remove when arg_max_min_ref supports blocked format
expected = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
}
return expected;

View File

@ -1487,6 +1487,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|| (prim.as<mvn>().input().get_output_layout().data_type != data_types::u8 &&
prim.as<mvn>().input().get_output_layout().data_type != data_types::i8)
|| prim.as<mvn>().get_primitive()->across_channels) &&
prim.type() != cldnn::arg_max_min::type_id() &&
prim.type() != cldnn::dft::type_id() &&
prim.type() != cldnn::grid_sample::type_id() &&
prim.type() != cldnn::mutable_data::type_id() &&