ReorgYolo shape inference fix (#17728)

This commit is contained in:
Vladislav Golubev 2023-05-26 11:06:49 +02:00 committed by GitHub
parent 9e646bf446
commit 2df980aa9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 1 deletions

View File

@ -43,7 +43,7 @@ void shape_infer(const ReorgYolo* op, const std::vector<T>& input_shapes, std::v
const auto& interval = input_shape[i].get_interval();
if (interval.has_upper_bound()) {
output_shape.push_back(
ov::Dimension(interval.get_max_val() / strides[0], interval.get_min_val() / strides[0]));
ov::Dimension(interval.get_min_val() / strides[0], interval.get_max_val() / strides[0]));
} else {
output_shape.push_back(ov::Dimension::dynamic());
}

View File

@ -21,6 +21,29 @@ TEST(type_prop, reorg_yolo_stride_2) {
EXPECT_EQ(reorg_yolo->get_output_shape(0), expected_shape);
}
TEST(type_prop, reorg_yolo_stride_2_dynamic_shape) {
const auto in_shape = PartialShape{-1, -1, -1, -1};
size_t stride = 2;
auto data_param = make_shared<op::Parameter>(element::f32, in_shape);
auto reorg_yolo = make_shared<op::v0::ReorgYolo>(data_param, stride);
const auto expected_shape = PartialShape{-1, -1, -1, -1};
EXPECT_EQ(reorg_yolo->get_output_partial_shape(0), expected_shape);
}
TEST(type_prop, reorg_yolo_stride_2_dynamic_shape_ranges) {
const auto in_shape = PartialShape{{1, 4}, {3, 9}, {16, 32}, {16, 32}};
size_t stride = 2;
auto data_param = make_shared<op::Parameter>(element::f32, in_shape);
auto reorg_yolo = make_shared<op::v0::ReorgYolo>(data_param, stride);
// in_shape [N,C,H,W] -> out_shape [N, C*stride*stride, H/stride, W/stride]
const auto expected_shape = PartialShape{{1, 4}, {12, 36}, {8, 16}, {8, 16}};
EXPECT_EQ(reorg_yolo->get_output_partial_shape(0), expected_shape);
}
TEST(type_prop, reorg_yolo_stride_2_batch_2) {
const auto in_shape = Shape{2, 64, 26, 26};
size_t stride = 2;