fix PR14422 regression for using arg_max_min opt kernel (#14487)
This commit is contained in:
parent
fac45d79e4
commit
c16cee4624
@ -11,6 +11,7 @@
|
||||
#include "reorder_inst.h"
|
||||
#include "resample_inst.h"
|
||||
#include "reshape_inst.h"
|
||||
#include "arg_max_min_inst.h"
|
||||
#include "generic_layer.hpp"
|
||||
#include <sstream>
|
||||
|
||||
@ -1790,6 +1791,10 @@ format layout_optimizer::get_preferred_format(program_node& node) {
|
||||
else if (input_layout.format.dimension() == 4)
|
||||
expected = format::bfyx;
|
||||
}
|
||||
} else if (node.is_type<arg_max_min>()) {
|
||||
// Set default format for issue 92967/98750
|
||||
// TODO: will remove when arg_max_min_ref supports blocked format
|
||||
expected = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
|
||||
}
|
||||
|
||||
return expected;
|
||||
|
@ -1487,6 +1487,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
|| (prim.as<mvn>().input().get_output_layout().data_type != data_types::u8 &&
|
||||
prim.as<mvn>().input().get_output_layout().data_type != data_types::i8)
|
||||
|| prim.as<mvn>().get_primitive()->across_channels) &&
|
||||
prim.type() != cldnn::arg_max_min::type_id() &&
|
||||
prim.type() != cldnn::dft::type_id() &&
|
||||
prim.type() != cldnn::grid_sample::type_id() &&
|
||||
prim.type() != cldnn::mutable_data::type_id() &&
|
||||
|
Loading…
Reference in New Issue
Block a user