[GPU] Fix incomplete condition for NMS shape inference (#16960)
Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
9b9c31d46b
commit
7282728cec
@ -46,7 +46,17 @@ std::vector<layout> non_max_suppression_inst::calc_output_layouts(non_max_suppre
|
||||
auto max_output_boxes_per_class_tensor = make_host_tensor(max_output_boxes_per_class_mem->get_layout(),
|
||||
max_output_boxes_per_class_lock.data());
|
||||
const_data.emplace(2, max_output_boxes_per_class_tensor);
|
||||
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, true, const_data);
|
||||
|
||||
const auto& boxes = input_shapes[0];
|
||||
const auto& scores = input_shapes[1];
|
||||
// To produce a static output, we need to check dynamism of input tensor's dimensions
|
||||
// Output tensor has the following shape: [min(num_boxes, max_output_boxes_per_class) * num_batches * num_classes, 3]
|
||||
// The first dimension is an upper bound for the number of possible selected boxes
|
||||
bool static_output = boxes[1].is_static() && scores[0].is_static() && scores[1].is_static();
|
||||
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, static_output, const_data);
|
||||
} else {
|
||||
output_shapes[0] = output_shapes[1] = ShapeType{ov::Dimension::dynamic(), 3};
|
||||
output_shapes[2] = ShapeType{1};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < desc->num_outputs; ++i) {
|
||||
|
@ -21,14 +21,11 @@ using namespace ::tests;
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct non_max_suppression_test_params {
|
||||
layout in0_layout;
|
||||
layout in1_layout;
|
||||
layout data_layout;
|
||||
std::vector<layout> in_layouts;
|
||||
float max_output_boxes_per_class;
|
||||
int32_t selected_indices_num;
|
||||
bool center_point_box;
|
||||
bool sort_result_descending;
|
||||
std::vector<input_info> inputs;
|
||||
size_t num_outputs;
|
||||
std::vector<layout> expected_layouts;
|
||||
};
|
||||
@ -40,18 +37,31 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input0_layout_prim = std::make_shared<input_layout>("input0", p.in0_layout);
|
||||
auto input1_layout_prim = std::make_shared<input_layout>("input1", p.in1_layout);
|
||||
auto data_mem = engine.allocate_memory(p.data_layout);
|
||||
set_values(data_mem, {p.max_output_boxes_per_class});
|
||||
auto data_prim = std::make_shared<data>("const", data_mem);
|
||||
std::vector<std::shared_ptr<primitive>> input_prims;
|
||||
std::vector<input_info> input_prim_ids;
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
auto prim_id = "input" + std::to_string(i);
|
||||
auto input_layout_prim = std::make_shared<input_layout>(prim_id, p.in_layouts[i]);
|
||||
input_prims.push_back(input_layout_prim);
|
||||
input_prim_ids.push_back(input_info(prim_id));
|
||||
}
|
||||
|
||||
for (size_t i = 2; i < p.in_layouts.size(); i++) {
|
||||
auto prim_id = "const" + std::to_string(i);
|
||||
auto data_mem = engine.allocate_memory(p.in_layouts[i]);
|
||||
set_values(data_mem, {p.max_output_boxes_per_class});
|
||||
auto data_prim = std::make_shared<data>(prim_id, data_mem);
|
||||
input_prims.push_back(data_prim);
|
||||
input_prim_ids.push_back(input_info(prim_id));
|
||||
}
|
||||
|
||||
auto non_max_suppression_prim = std::make_shared<non_max_suppression>("output",
|
||||
p.inputs[0],
|
||||
p.inputs[1],
|
||||
input_prim_ids[0],
|
||||
input_prim_ids[1],
|
||||
p.selected_indices_num,
|
||||
p.center_point_box,
|
||||
p.sort_result_descending,
|
||||
"const",
|
||||
primitive_id(),
|
||||
primitive_id(),
|
||||
primitive_id(),
|
||||
primitive_id(),
|
||||
@ -59,17 +69,18 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
||||
primitive_id(),
|
||||
p.num_outputs);
|
||||
non_max_suppression_prim->output_paddings = {padding(), padding(), padding()};
|
||||
non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in1_layout.data_type}, optional_data_type{}};
|
||||
non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in_layouts[1].data_type}, optional_data_type{}};
|
||||
if (p.in_layouts.size() > 2) {
|
||||
non_max_suppression_prim->num_select_per_class = input_prim_ids[2].pid;
|
||||
}
|
||||
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto& input0_layout_node = prog.get_or_create(input0_layout_prim);
|
||||
auto& input1_layout_node = prog.get_or_create(input1_layout_prim);
|
||||
auto& data_node = prog.get_or_create(data_prim);
|
||||
auto& non_max_suppression_node = prog.get_or_create(non_max_suppression_prim);
|
||||
program_wrapper::add_connection(prog, input0_layout_node, non_max_suppression_node);
|
||||
program_wrapper::add_connection(prog, input1_layout_node, non_max_suppression_node);
|
||||
program_wrapper::add_connection(prog, data_node, non_max_suppression_node);
|
||||
for (auto& prim : input_prims) {
|
||||
auto& input_layout_node = prog.get_or_create(prim);
|
||||
program_wrapper::add_connection(prog, input_layout_node, non_max_suppression_node);
|
||||
}
|
||||
|
||||
auto params = non_max_suppression_node.get_kernel_impl_params();
|
||||
auto res = non_max_suppression_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_node, *params);
|
||||
@ -82,19 +93,36 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, non_max_suppression_test,
|
||||
testing::ValuesIn(std::vector<non_max_suppression_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx},
|
||||
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3,
|
||||
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||
1.f, 4, false, true, 3,
|
||||
{layout{ov::PartialShape{4, 3}, data_types::i32, format::bfyx},
|
||||
layout{ov::PartialShape{4, 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx},
|
||||
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3,
|
||||
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx}},
|
||||
1.f, 4, false, true, 3,
|
||||
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||
},
|
||||
{
|
||||
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||
1.f, 4, false, true, 3,
|
||||
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||
},
|
||||
{
|
||||
{layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||
1.f, 4, false, true, 3,
|
||||
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||
|
Loading…
Reference in New Issue
Block a user