[GPU] Fix incomplete condition for NMS shape inference (#16960)
Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
9b9c31d46b
commit
7282728cec
@ -46,7 +46,17 @@ std::vector<layout> non_max_suppression_inst::calc_output_layouts(non_max_suppre
|
|||||||
auto max_output_boxes_per_class_tensor = make_host_tensor(max_output_boxes_per_class_mem->get_layout(),
|
auto max_output_boxes_per_class_tensor = make_host_tensor(max_output_boxes_per_class_mem->get_layout(),
|
||||||
max_output_boxes_per_class_lock.data());
|
max_output_boxes_per_class_lock.data());
|
||||||
const_data.emplace(2, max_output_boxes_per_class_tensor);
|
const_data.emplace(2, max_output_boxes_per_class_tensor);
|
||||||
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, true, const_data);
|
|
||||||
|
const auto& boxes = input_shapes[0];
|
||||||
|
const auto& scores = input_shapes[1];
|
||||||
|
// To produce a static output, we need to check dynamism of input tensor's dimensions
|
||||||
|
// Output tensor has the following shape: [min(num_boxes, max_output_boxes_per_class) * num_batches * num_classes, 3]
|
||||||
|
// The first dimension is an upper bound for the number of possible selected boxes
|
||||||
|
bool static_output = boxes[1].is_static() && scores[0].is_static() && scores[1].is_static();
|
||||||
|
ov::op::v9::shape_infer(&op, input_shapes, output_shapes, static_output, const_data);
|
||||||
|
} else {
|
||||||
|
output_shapes[0] = output_shapes[1] = ShapeType{ov::Dimension::dynamic(), 3};
|
||||||
|
output_shapes[2] = ShapeType{1};
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < desc->num_outputs; ++i) {
|
for (size_t i = 0; i < desc->num_outputs; ++i) {
|
||||||
|
@ -21,14 +21,11 @@ using namespace ::tests;
|
|||||||
namespace shape_infer_tests {
|
namespace shape_infer_tests {
|
||||||
|
|
||||||
struct non_max_suppression_test_params {
|
struct non_max_suppression_test_params {
|
||||||
layout in0_layout;
|
std::vector<layout> in_layouts;
|
||||||
layout in1_layout;
|
|
||||||
layout data_layout;
|
|
||||||
float max_output_boxes_per_class;
|
float max_output_boxes_per_class;
|
||||||
int32_t selected_indices_num;
|
int32_t selected_indices_num;
|
||||||
bool center_point_box;
|
bool center_point_box;
|
||||||
bool sort_result_descending;
|
bool sort_result_descending;
|
||||||
std::vector<input_info> inputs;
|
|
||||||
size_t num_outputs;
|
size_t num_outputs;
|
||||||
std::vector<layout> expected_layouts;
|
std::vector<layout> expected_layouts;
|
||||||
};
|
};
|
||||||
@ -40,18 +37,31 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
|||||||
|
|
||||||
auto& engine = get_test_engine();
|
auto& engine = get_test_engine();
|
||||||
|
|
||||||
auto input0_layout_prim = std::make_shared<input_layout>("input0", p.in0_layout);
|
std::vector<std::shared_ptr<primitive>> input_prims;
|
||||||
auto input1_layout_prim = std::make_shared<input_layout>("input1", p.in1_layout);
|
std::vector<input_info> input_prim_ids;
|
||||||
auto data_mem = engine.allocate_memory(p.data_layout);
|
for (size_t i = 0; i < 2; i++) {
|
||||||
|
auto prim_id = "input" + std::to_string(i);
|
||||||
|
auto input_layout_prim = std::make_shared<input_layout>(prim_id, p.in_layouts[i]);
|
||||||
|
input_prims.push_back(input_layout_prim);
|
||||||
|
input_prim_ids.push_back(input_info(prim_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 2; i < p.in_layouts.size(); i++) {
|
||||||
|
auto prim_id = "const" + std::to_string(i);
|
||||||
|
auto data_mem = engine.allocate_memory(p.in_layouts[i]);
|
||||||
set_values(data_mem, {p.max_output_boxes_per_class});
|
set_values(data_mem, {p.max_output_boxes_per_class});
|
||||||
auto data_prim = std::make_shared<data>("const", data_mem);
|
auto data_prim = std::make_shared<data>(prim_id, data_mem);
|
||||||
|
input_prims.push_back(data_prim);
|
||||||
|
input_prim_ids.push_back(input_info(prim_id));
|
||||||
|
}
|
||||||
|
|
||||||
auto non_max_suppression_prim = std::make_shared<non_max_suppression>("output",
|
auto non_max_suppression_prim = std::make_shared<non_max_suppression>("output",
|
||||||
p.inputs[0],
|
input_prim_ids[0],
|
||||||
p.inputs[1],
|
input_prim_ids[1],
|
||||||
p.selected_indices_num,
|
p.selected_indices_num,
|
||||||
p.center_point_box,
|
p.center_point_box,
|
||||||
p.sort_result_descending,
|
p.sort_result_descending,
|
||||||
"const",
|
primitive_id(),
|
||||||
primitive_id(),
|
primitive_id(),
|
||||||
primitive_id(),
|
primitive_id(),
|
||||||
primitive_id(),
|
primitive_id(),
|
||||||
@ -59,17 +69,18 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
|||||||
primitive_id(),
|
primitive_id(),
|
||||||
p.num_outputs);
|
p.num_outputs);
|
||||||
non_max_suppression_prim->output_paddings = {padding(), padding(), padding()};
|
non_max_suppression_prim->output_paddings = {padding(), padding(), padding()};
|
||||||
non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in1_layout.data_type}, optional_data_type{}};
|
non_max_suppression_prim->output_data_types = {optional_data_type{}, optional_data_type{p.in_layouts[1].data_type}, optional_data_type{}};
|
||||||
|
if (p.in_layouts.size() > 2) {
|
||||||
|
non_max_suppression_prim->num_select_per_class = input_prim_ids[2].pid;
|
||||||
|
}
|
||||||
|
|
||||||
cldnn::program prog(engine);
|
cldnn::program prog(engine);
|
||||||
|
|
||||||
auto& input0_layout_node = prog.get_or_create(input0_layout_prim);
|
|
||||||
auto& input1_layout_node = prog.get_or_create(input1_layout_prim);
|
|
||||||
auto& data_node = prog.get_or_create(data_prim);
|
|
||||||
auto& non_max_suppression_node = prog.get_or_create(non_max_suppression_prim);
|
auto& non_max_suppression_node = prog.get_or_create(non_max_suppression_prim);
|
||||||
program_wrapper::add_connection(prog, input0_layout_node, non_max_suppression_node);
|
for (auto& prim : input_prims) {
|
||||||
program_wrapper::add_connection(prog, input1_layout_node, non_max_suppression_node);
|
auto& input_layout_node = prog.get_or_create(prim);
|
||||||
program_wrapper::add_connection(prog, data_node, non_max_suppression_node);
|
program_wrapper::add_connection(prog, input_layout_node, non_max_suppression_node);
|
||||||
|
}
|
||||||
|
|
||||||
auto params = non_max_suppression_node.get_kernel_impl_params();
|
auto params = non_max_suppression_node.get_kernel_impl_params();
|
||||||
auto res = non_max_suppression_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_node, *params);
|
auto res = non_max_suppression_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_node, *params);
|
||||||
@ -82,19 +93,36 @@ TEST_P(non_max_suppression_test, shape_infer) {
|
|||||||
INSTANTIATE_TEST_SUITE_P(smoke, non_max_suppression_test,
|
INSTANTIATE_TEST_SUITE_P(smoke, non_max_suppression_test,
|
||||||
testing::ValuesIn(std::vector<non_max_suppression_test_params>{
|
testing::ValuesIn(std::vector<non_max_suppression_test_params>{
|
||||||
{
|
{
|
||||||
layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx},
|
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx},
|
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||||
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3,
|
1.f, 4, false, true, 3,
|
||||||
{layout{ov::PartialShape{4, 3}, data_types::i32, format::bfyx},
|
{layout{ov::PartialShape{4, 3}, data_types::i32, format::bfyx},
|
||||||
layout{ov::PartialShape{4, 3}, data_types::f32, format::bfyx},
|
layout{ov::PartialShape{4, 3}, data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{2, 2, 3}, data_types::f32, format::bfyx}},
|
||||||
|
1.f, 4, false, true, 3,
|
||||||
|
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||||
|
1.f, 4, false, true, 3,
|
||||||
|
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||||
|
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx},
|
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}},
|
||||||
1.f, 4, false, true, {input_info("input0", 0), input_info("input1", 0)}, 3,
|
1.f, 4, false, true, 3,
|
||||||
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
{layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::i32, format::bfyx},
|
||||||
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
layout{ov::PartialShape{ov::Dimension::dynamic(), 3}, data_types::f32, format::bfyx},
|
||||||
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
layout{ov::PartialShape{1}, data_types::i32, format::bfyx}}
|
||||||
|
Loading…
Reference in New Issue
Block a user