Review RegionYolo class for shape inference aspects (#18741)

* Add static shape adapter
- Adapters holds CPU dimension which can be reference to it or vector
- Add ov::optional for holding optional result from shape inference
- Add new `infer` function in `IStaticShapeInfer`

* Temporary support of StaticShape

* Minor corrections in ShapeInferenceTA

* Migrate shape_infer to new interface version

* Replace StaticShape by adapter implementation

* Replace IShapeInferCommon by IStaticShapeInfer

* Correct code formatting

* Fix build issues

* NodeValidationFailure::create for StaticShapeRef

* Review RegionYolo for shape inference:
- Check dynamic shape and label propagation
- Check static shape inference
- Review shape_infer template implementation
- Update unit test

* Remove commented test code

* Correct flatten dim calculation
This commit is contained in:
Pawel Raasz 2023-07-26 18:42:41 +02:00 committed by GitHub
parent ab42ff1164
commit bb3c9aa9a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 368 additions and 139 deletions

View File

@ -4,17 +4,19 @@
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/region_yolo.hpp>
#include <iterator>
#include "openvino/core/validation_util.hpp"
#include "openvino/op/region_yolo.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace v0 {
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const RegionYolo* op, const std::vector<T>& input_shapes) {
using DimType = typename T::value_type;
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const RegionYolo* op, const std::vector<TShape>& input_shapes) {
using TDim = typename TShape::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 1));
const auto& input_shape = input_shapes[0];
@ -22,39 +24,37 @@ std::vector<TRShape> shape_infer(const RegionYolo* op, const std::vector<T>& inp
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];
NODE_VALIDATION_CHECK(op, input_rank.compatible(4), "Input must be a tensor of rank 4, but got ", input_rank);
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
input_rank.compatible(4),
"Input must be a tensor of rank 4, but got ",
input_rank);
if (input_rank.is_static()) {
int64_t end_axis = op->get_end_axis();
if (end_axis < 0) {
end_axis += static_cast<int>(input_shape.size());
}
const auto out_rank = input_shape.size();
output_shape.reserve(out_rank);
if (op->get_do_softmax()) {
output_shape.resize(0);
OPENVINO_SUPPRESS_DEPRECATED_START
auto axis = ov::normalize_axis(op, op->get_axis(), input_rank);
const auto axis = ov::normalize_axis(op, op->get_axis(), input_rank);
const auto end_axis = ov::normalize_axis(op, op->get_end_axis(), input_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
DimType flat_dim = 1;
for (int64_t i = 0; i < axis; i++) {
output_shape.push_back(input_shape[i]);
}
for (int64_t i = axis; i < end_axis + 1; i++) {
flat_dim *= input_shape[i];
}
output_shape.push_back(flat_dim);
for (size_t i = end_axis + 1; i < input_shape.size(); i++) {
output_shape.push_back(input_shape[i]);
auto input_it = input_shape.cbegin();
auto out_it = std::copy_n(input_it, axis + 1, std::back_inserter(output_shape));
input_it += (axis + 1);
for (; input_it <= input_shape.cbegin() + end_axis; ++input_it) {
output_shape[axis] *= *input_it;
}
std::copy(input_it, input_shape.end(), out_it);
} else {
output_shape = TRShape({input_shape[0],
static_cast<typename DimType::value_type>(
(op->get_num_classes() + op->get_num_coords() + 1) * op->get_mask().size()),
input_shape[2],
input_shape[3]});
output_shape = input_shape;
output_shape[1] = TDim((op->get_num_classes() + op->get_num_coords() + 1) * op->get_mask().size());
}
} else {
output_shape = ov::PartialShape::dynamic(ov::Rank(1, 4));
output_shape = PartialShape::dynamic(Rank(1, 4));
}
return output_shapes;
}

View File

@ -2,24 +2,24 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/region_yolo.hpp"
#include "openvino/op/region_yolo.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "region_yolo_shape_inference.hpp"
using namespace std;
using namespace ngraph;
op::RegionYolo::RegionYolo(const Output<Node>& input,
const size_t coords,
const size_t classes,
const size_t regions,
const bool do_softmax,
const vector<int64_t>& mask,
const int axis,
const int end_axis,
const vector<float>& anchors)
namespace ov {
namespace op {
namespace v0 {
RegionYolo::RegionYolo(const Output<Node>& input,
const size_t coords,
const size_t classes,
const size_t regions,
const bool do_softmax,
const std::vector<int64_t>& mask,
const int axis,
const int end_axis,
const std::vector<float>& anchors)
: Op({input}),
m_num_coords(coords),
m_num_classes(classes),
@ -32,7 +32,7 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor) {
bool RegionYolo::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v0_RegionYolo_visit_attributes);
visitor.on_attribute("anchors", m_anchors);
visitor.on_attribute("axis", m_axis);
@ -45,30 +45,34 @@ bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor) {
return true;
}
void op::RegionYolo::validate_and_infer_types() {
void RegionYolo::validate_and_infer_types() {
OV_OP_SCOPE(v0_RegionYolo_validate_and_infer_types);
auto input_et = get_input_element_type(0);
const auto& input_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_et.is_real(),
"Type of input is expected to be a floating point type. Got: ",
input_et);
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};
std::vector<ov::PartialShape> output_shapes = shape_infer(this, input_shapes);
const auto input_shapes = std::vector<PartialShape>{get_input_partial_shape(0)};
const auto output_shapes = ov::op::v0::shape_infer(this, input_shapes);
set_output_type(0, input_et, output_shapes[0]);
}
shared_ptr<Node> op::RegionYolo::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> RegionYolo::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_RegionYolo_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<RegionYolo>(new_args.at(0),
m_num_coords,
m_num_classes,
m_num_regions,
m_do_softmax,
m_mask,
m_axis,
m_end_axis,
m_anchors);
return std::make_shared<RegionYolo>(new_args.at(0),
m_num_coords,
m_num_classes,
m_num_regions,
m_do_softmax,
m_mask,
m_axis,
m_end_axis,
m_anchors);
}
} // namespace v0
} // namespace op
} // namespace ov

View File

@ -2,72 +2,265 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/region_yolo.hpp"
#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
using namespace ov;
TEST(type_prop, region_yolo_v2) {
const size_t num = 5;
const size_t coords = 4;
const size_t classes = 20;
const size_t batch = 1;
const size_t channels = 125;
const size_t width = 13;
const size_t height = 13;
class TypePropRegionYoloV0Test : public TypePropOpTest<op::v0::RegionYolo> {};
TEST_F(TypePropRegionYoloV0Test, default_ctor_do_softmax) {
const std::vector<int64_t> mask{0, 1, 2};
const int axis = 1;
const int end_axis = 3;
const auto in_shape = Shape{batch, channels, width, height};
auto data_param = make_shared<op::Parameter>(element::f32, in_shape);
auto region_yolo = make_shared<op::v0::RegionYolo>(data_param, coords, classes, num, true, mask, axis, end_axis);
// in_shape [N,C,H,W] -> out_shape [N, C*stride*stride, H/stride, W/stride]
Shape expected_shape = Shape{batch, channels * height * width};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{10, 5, 16, 16});
const auto op = make_op();
op->set_argument(0, data);
op->set_do_softmax(true);
op->set_axis(-1);
op->set_end_axis(2);
op->validate_and_infer_types();
EXPECT_EQ(region_yolo->get_output_shape(0), expected_shape);
EXPECT_EQ(op->get_input_size(), 1);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({10, 5, 16, 16}));
}
TEST(type_prop, region_yolo_v3_1) {
const size_t num = 9;
const size_t coords = 4;
const size_t classes = 20;
const size_t batch = 1;
const size_t channels = 75;
const size_t width = 32;
const size_t height = 32;
TEST_F(TypePropRegionYoloV0Test, default_ctor_no_softmax) {
const std::vector<int64_t> mask{0, 1, 2};
const int axis = 1;
const int end_axis = 3;
const auto in_shape = Shape{batch, channels, width, height};
auto data_param = make_shared<op::Parameter>(element::f32, in_shape);
auto region_yolo = make_shared<op::v0::RegionYolo>(data_param, coords, classes, num, false, mask, axis, end_axis);
// in_shape [N,C,H,W] -> out_shape [N, C*stride*stride, H/stride, W/stride]
Shape expected_shape = Shape{batch, channels, height, width};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{10, 5, 11, 12});
const auto op = make_op();
op->set_argument(0, data);
op->set_do_softmax(false);
op->set_num_classes(5);
op->set_num_coords(2);
op->set_mask({1, 2});
op->validate_and_infer_types();
EXPECT_EQ(region_yolo->get_output_shape(0), expected_shape);
EXPECT_EQ(op->get_input_size(), 1);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({10, 16, 11, 12}));
}
TEST(type_prop, region_yolo_v3_2) {
const size_t num = 1;
const size_t coords = 4;
const size_t classes = 1;
const size_t batch = 1;
const size_t channels = 8;
const size_t width = 2;
const size_t height = 2;
const std::vector<int64_t> mask{0};
const int axis = 1;
const int end_axis = 3;
const auto in_shape = Shape{batch, channels, width, height};
auto data_param = make_shared<op::Parameter>(element::f32, in_shape);
auto region_yolo = make_shared<op::v0::RegionYolo>(data_param, coords, classes, num, false, mask, axis, end_axis);
TEST_F(TypePropRegionYoloV0Test, data_input_dynamic_rank_do_not_softmax) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 1, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
// in_shape [N,C,H,W] -> out_shape [N, C*stride*stride, H/stride, W/stride]
Shape expected_shape = Shape{batch, (classes + coords + 1) * mask.size(), height, width};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto op = make_op(data, coords, classes, num, false, mask, axis, end_axis);
EXPECT_EQ(region_yolo->get_output_shape(0), expected_shape);
}
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
}
TEST_F(TypePropRegionYoloV0Test, data_input_dynamic_rank_do_softmax) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 1, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
}
TEST_F(TypePropRegionYoloV0Test, data_input_static_rank_do_softmax) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 1, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape::dynamic(4);
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f64, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic(2));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 0));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_end_axis_is_negative) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 1, end_axis = -1;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape::dynamic(4);
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic(2));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 0));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_axis_eq_end_axis) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 2, end_axis = 2;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{5, 4, 10, 11};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({5, 4, 10, 11}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12, 13));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_axis_gt_end_axis) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 3, end_axis = 1;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{5, 4, 10, 11};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({5, 4, 10, 11}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12, 13));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_axis_end_axis_on_last_dim) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = -1, end_axis = -1;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{5, 4, 10, 11};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({5, 4, 10, 11}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12, 13));
}
TEST_F(TypePropRegionYoloV0Test, data_input_interval_shape_with_labels_do_softmax) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 2, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{{2, 4}, {5, 8}, -1, {0, 10}};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f16, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f16);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{2, 4}, {5, 8}, -1}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 0));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_start_axis_negative) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = -2, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{{2, 4}, {5, 8}, -1, {0, 10}};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f16, data_shape);
const auto op = make_op(data, coords, classes, num, true, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f16);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{2, 4}, {5, 8}, -1}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 0));
}
TEST_F(TypePropRegionYoloV0Test, data_input_interval_shape_with_labels_no_softmax) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 2, end_axis = 3;
const std::vector<int64_t> mask{0, 1, 2};
auto data_shape = PartialShape{{2, 4}, {5, 8}, -1, {0, 10}};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<op::v0::Parameter>(element::f16, data_shape);
const auto op = make_op(data, coords, classes, num, false, mask, axis, end_axis);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f16);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{2, 4}, 75, -1, {0, 10}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 0, 12, 13));
}
TEST_F(TypePropRegionYoloV0Test, data_input_not_4d) {
constexpr size_t num = 5, coords = 4, classes = 20;
constexpr int axis = 1, end_axis = 5;
const std::vector<int64_t> mask{0, 1, 2};
OV_EXPECT_THROW(std::ignore = make_op(std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3)),
coords,
classes,
num,
true,
mask,
axis,
end_axis),
NodeValidationFailure,
HasSubstr("Input must be a tensor of rank 4"));
OV_EXPECT_THROW(std::ignore = make_op(std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(5)),
coords,
classes,
num,
true,
mask,
axis,
end_axis),
NodeValidationFailure,
HasSubstr("Input must be a tensor of rank 4"));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_axis_not_valid_value) {
constexpr size_t num = 5, coords = 4, classes = 20;
const std::vector<int64_t> mask{0, 1, 2};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(4));
OV_EXPECT_THROW(std::ignore = make_op(data, coords, classes, num, true, mask, 4, 2),
AssertFailure,
HasSubstr("out of the tensor rank range"));
OV_EXPECT_THROW(std::ignore = make_op(data, coords, classes, num, true, mask, -5, 2),
AssertFailure,
HasSubstr("out of the tensor rank range"));
}
TEST_F(TypePropRegionYoloV0Test, do_softmax_end_axis_not_valid_value) {
constexpr size_t num = 5, coords = 4, classes = 20;
const std::vector<int64_t> mask{0, 1, 2};
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(4));
OV_EXPECT_THROW(std::ignore = make_op(data, coords, classes, num, true, mask, 1, 4),
AssertFailure,
HasSubstr("out of the tensor rank range"));
OV_EXPECT_THROW(std::ignore = make_op(data, coords, classes, num, true, mask, 1, -5),
AssertFailure,
HasSubstr("out of the tensor rank range"));
}

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "region_yolo_shape_inference.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class StaticShapeRegionYoloTest : public OpStaticShapeInferenceTest<op::v0::RegionYolo> {};
TEST_F(StaticShapeRegionYoloTest, default_ctor_do_soft_max_no_args) {
op = make_op();
op->set_do_softmax(true);
op->set_axis(-2);
op->set_end_axis(3);
input_shapes = ShapeVector{{10, 8, 12, 6}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({10, 8, 72}));
}
TEST_F(StaticShapeRegionYoloTest, data_input_is_dynamic_rank) {
const auto data = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
op = make_op(data, 0, 0, 0, true, std::vector<int64_t>(), 1, 3);
input_shapes = ShapeVector{{2, 2, 3, 4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({2, 24}));
}
TEST_F(StaticShapeRegionYoloTest, data_input_is_static_rank) {
const auto data = std::make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic(4));
op = make_op(data, 5, 4, 20, false, std::vector<int64_t>{0, 1}, 1, 3);
input_shapes = ShapeVector{{2, 5, 6, 7}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({2, 20, 6, 7}));
}
TEST_F(StaticShapeRegionYoloTest, data_shape_not_compatible_rank_4) {
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
op = make_op(data, 5, 4, 20, false, std::vector<int64_t>{0, 1}, 1, 3);
OV_EXPECT_THROW(shape_inference(op.get(), ShapeVector{{2, 20, 12, 24, 1}}, output_shapes),
NodeValidationFailure,
HasSubstr("Input must be a tensor of rank 4, but got"));
}

View File

@ -1,28 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <region_yolo_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace std;
TEST(StaticShapeInferenceTest, RegionYoloV0) {
auto inputs = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1, -1, -1});
auto op = make_shared<op::v0::RegionYolo>(inputs, 0, 0, 0, true, std::vector<int64_t>{}, 0, 1);
check_static_shape(op.get(), {StaticShape{1, 125, 13, 13}}, {StaticShape{1 * 125, 13, 13}});
}
TEST(StaticShapeInferenceTest, RegionYoloV0Dynamic) {
auto inputs = make_shared<op::v0::Parameter>(element::f32,
ov::PartialShape{{1, 11}, {2, 12}, ov::Dimension::dynamic(), {4, 14}});
auto op = make_shared<op::v0::RegionYolo>(inputs, 4, 80, 5, true, std::vector<int64_t>{}, 1, 3);
EXPECT_EQ(op->get_output_partial_shape(0), ov::PartialShape({{1, 11}, ov::Dimension::dynamic()}));
check_static_shape(op.get(), {StaticShape{10, 125, 13, 13}}, {StaticShape{10, 125 * 13 * 13}});
}