[GPU] Fix incomplete condition for NMS shape inference (#16960)

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2023-04-17 14:41:57 +09:00 committed by GitHub
parent 9b9c31d46b
commit 7282728cec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 28 deletions

View File

@ -46,7 +46,17 @@ std::vector<layout> non_max_suppression_inst::calc_output_layouts(non_max_suppre
auto max_output_boxes_per_class_tensor = make_host_tensor(max_output_boxes_per_class_mem->get_layout(), auto max_output_boxes_per_class_tensor = make_host_tensor(max_output_boxes_per_class_mem->get_layout(),
max_output_boxes_per_class_lock.data()); max_output_boxes_per_class_lock.data());
const_data.emplace(2, max_output_boxes_per_class_tensor); const_data.emplace(2, max_output_boxes_per_class_tensor);
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, true, const_data);
const auto& boxes = input_shapes[0];
const auto& scores = input_shapes[1];
// To produce a static output, we need to check dynamism of input tensor's dimensions
// Output tensor has the following shape: [min(num_boxes, max_output_boxes_per_class) * num_batches * num_classes, 3]
// The first dimension is an upper bound for the number of possible selected boxes
bool static_output = boxes[1].is_static() && scores[0].is_static() && scores[1].is_static();
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, static_output, const_data);
} else {
output_shapes[0] = output_shapes[1] = ShapeType{ov::Dimension::dynamic(), 3};
output_shapes[2] = ShapeType{1};
} }
for (size_t i = 0; i < desc->num_outputs; ++i) { for (size_t i = 0; i < desc->num_outputs; ++i) {

View File

@ -21,14 +21,11 @@ using namespace ::tests;
namespace shape_infer_tests { namespace shape_infer_tests {
struct non_max_suppression_test_params { struct non_max_suppression_test_params {
layout in0_layout; std::vector<layout> in_layouts;
layout in1_layout;
layout data_layout;
float max_output_boxes_per_class; float max_output_boxes_per_class;
int32_t selected_indices_num; int32_t selected_indices_num;
bool center_point_box; bool center_point_box;
bool sort_result_descending; bool sort_result_descending;
std::vector<input_info> inputs;
size_t num_outputs; size_t num_outputs;
std::vector<layout> expected_layouts; std::vector<layout> expected_layouts;
}; };
@ -40,18 +37,31 @@ TEST_P(non_max_suppression_test, shape_infer) {
auto& engine = get_test_engine(); auto& engine = get_test_engine();
auto input0_layout_prim = std::make_shared<input_layout>("input0", p.in0_layout); std::vector<std::shared_ptr<primitive>> input_prims;
auto input1_layout_prim = std::make_shared<input_layout>("input1", p.in1_layout); std::vector<input_info> input_prim_ids;
auto data_mem = engine.allocate_memory(p.data_layout); for (size_t i = 0; i < 2; i++) {
auto prim_id = "input" + std::to_string(i);
auto input_layout_prim = std::make_shared<input_layout>(prim_id, p.in_layouts[i]);
input_prims.push_back(input_layout_prim);
input_prim_ids.push_back(input_info(prim_id));
}
for (size_t i = 2; i < p.in_layouts.size(); i++) {
auto prim_id = "const" + std::to_string(i);
auto data_mem = engine.allocate_memory(p.in_layouts[i]);
set_values(data_mem, {p.max_output_boxes_per_class}); set_values(data_mem, {p.max_output_boxes_per_class});
auto data_prim = std::make_shared<data>("const", data_mem); auto data_prim = std::make_shared<data>(prim_id, data_mem);
input_prims.push_back(data_prim);
input_prim_ids.push_back(input_info(prim_id));
}
auto non_max_suppression_prim = std::make_shared<non_max_suppression>("output", auto non_max_suppression_prim = std::make_shared<non_max_suppression>("output",
p.inputs[0], input_prim_ids[0],
p.inputs[1], input_prim_ids[1],
p.selected_indices_num, p.selected_indices_num,
p.center_point_box, p.center_point_box,
p.sort_result_descending, p.sort_result_descending,
"const", primitive_id(),
primitive_id(), primitive_id(),
primitive_id(), primitive_id(),
primitive_id(), primitive_id(),
@ -59,17 +69,18 @@ TEST_P(non_max_suppression_test, shape_infer) {
primitive_id(), primitive_id(),
p.num_outputs); p.num_outputs);
non_max_suppression_prim->output_paddings = {padding(), padding(), padding()}; non_max_suppression_prim->output_paddings = {padding(), padding(), padding()};
non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in1_layout.data_type}, optional_data_type{}}; non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in_layouts[1].data_type}, optional_data_type{}};
if (p.in_layouts.size() > 2) {
non_max_suppression_prim->num_select_per_class = input_prim_ids[2].pid;
}
cldnn::program prog(engine); cldnn::program prog(engine);
auto& input0_layout_node = prog.get_or_create(input0_layout_prim);
auto& input1_layout_node = prog.get_or_create(input1_layout_prim);
auto& data_node = prog.get_or_create(data_prim);
auto& non_max_suppression_node = prog.get_or_create(non_max_suppression_prim); auto& non_max_suppression_node = prog.get_or_create(non_max_suppression_prim);
program_wrapper::add_connection(prog, input0_layout_node, non_max_suppression_node); for (auto& prim : input_prims) {
program_wrapper::add_connection(prog, input1_layout_node, non_max_suppression_node); auto& input_layout_node = prog.get_or_create(prim);
program_wrapper::add_connection(prog, data_node, non_max_suppression_node); program_wrapper::add_connection(prog, input_layout_node, non_max_suppression_node);
}
auto params = non_max_suppression_node.get_kernel_impl_params(); auto params = non_max_suppression_node.get_kernel_impl_params();
auto res = non_max_suppression_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_node, *params); auto res = non_max_suppression_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_node, *params);
@ -82,19 +93,36 @@ TEST_P(non_max_suppression_test, shape_infer) {
INSTANTIATE_TEST_SUITE_P(smoke, non_max_suppression_test, INSTANTIATE_TEST_SUITE_P(smoke, non_max_suppression_test,
testing::ValuesIn(std::vector<non_max_suppression_test_params>{ testing::ValuesIn(std::vector<non_max_suppression_test_params>{
{ {
layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx}, {layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx}, layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}, layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3, 1.f, 4, false, true, 3,
{layout{ov::PartialShape{4, 3}, data_types::i32, format::bfyx}, {layout{ov::PartialShape{4, 3}, data_types::i32, format::bfyx},
layout{ov::PartialShape{4, 3}, data_types::f32, format::bfyx}, layout{ov::PartialShape{4, 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}} layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
}, },
{ {
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx}},
1.f, 4, false, true, 3,
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
},
{
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}, layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
1.f, 4, false, true, 3,
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
},
{
{layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}, layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}, layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3, 1.f, 4, false, true, 3,
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx}, {layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx}, layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}} layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}