[shape_infer]Implement ExperimentalDetectronDetectionOutput shape infer. (#7903)
* Implement ExperimentalDetectronDetectionOutput shape infer. Signed-off-by: Luwei Zhou <luwei.zhou@intel.com> * Update on the review comments. * Update based on review
This commit is contained in:
parent
83fe59bd3e
commit
c2512f8dc5
@ -2,13 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "utils/shape_inference/static_shape.hpp"
|
||||
#include <convolution_shape_inference.hpp>
|
||||
#include <experimental_detectron_detection_output_shape_inference.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include <openvino/core/coordinate_diff.hpp>
|
||||
#include <openvino/op/convolution.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <convolution_shape_inference.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include "utils/shape_inference/static_shape.hpp"
|
||||
#include <openvino/op/parameter.hpp>
|
||||
|
||||
using namespace ov;
|
||||
|
||||
@ -40,6 +41,67 @@ TEST(StaticShapeInferenceTest, ConvolutionTest) {
|
||||
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, ExperimentalDetectronDetectionOutputTest) {
|
||||
using Attrs = op::v6::ExperimentalDetectronDetectionOutput::Attributes;
|
||||
Attrs attrs;
|
||||
attrs.class_agnostic_box_regression = true;
|
||||
attrs.deltas_weights = {10.0f, 10.0f, 5.0f, 5.0f};
|
||||
attrs.max_delta_log_wh = 2.0f;
|
||||
attrs.max_detections_per_image = 5;
|
||||
attrs.nms_threshold = 0.2f;
|
||||
attrs.num_classes = 2;
|
||||
attrs.post_nms_count = 500;
|
||||
attrs.score_threshold = 0.01000000074505806f;
|
||||
int64_t rois_num = static_cast<int64_t>(attrs.max_detections_per_image);
|
||||
|
||||
auto rois = std::make_shared<ov::op::v0::Parameter>(element::f32,
|
||||
PartialShape{-1, -1});
|
||||
auto deltas = std::make_shared<ov::op::v0::Parameter>(element::f32,
|
||||
PartialShape{-1, -1});
|
||||
auto scores = std::make_shared<ov::op::v0::Parameter>(element::f32,
|
||||
PartialShape{-1, -1});
|
||||
auto im_info = std::make_shared<ov::op::v0::Parameter>(element::f32,
|
||||
PartialShape{-1, -1});
|
||||
|
||||
auto detection =
|
||||
std::make_shared<ov::op::v6::ExperimentalDetectronDetectionOutput>(
|
||||
rois, deltas, scores, im_info, attrs);
|
||||
std::vector<PartialShape> input_shapes = {
|
||||
PartialShape::dynamic(), PartialShape::dynamic(), PartialShape::dynamic(),
|
||||
PartialShape::dynamic()};
|
||||
std::vector<PartialShape> output_shapes = {PartialShape::dynamic(),
|
||||
PartialShape::dynamic(),
|
||||
PartialShape::dynamic()};
|
||||
shape_infer(detection.get(), input_shapes, output_shapes);
|
||||
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
|
||||
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
|
||||
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
|
||||
|
||||
input_shapes = {PartialShape{-1, -1}, PartialShape{-1, -1},
|
||||
PartialShape{-1, -1}, PartialShape{-1, -1}};
|
||||
output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}};
|
||||
shape_infer(detection.get(), input_shapes, output_shapes);
|
||||
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
|
||||
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
|
||||
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
|
||||
|
||||
input_shapes = {PartialShape{16, 4}, PartialShape{16, 8}, PartialShape{16, 2}, PartialShape{1, 3}};
|
||||
output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}};
|
||||
shape_infer(detection.get(), input_shapes, output_shapes);
|
||||
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
|
||||
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
|
||||
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 4}, StaticShape{16, 8}, StaticShape{16, 2}, StaticShape{1, 3}};
|
||||
std::vector<StaticShape> static_output_shapes = {StaticShape{}, StaticShape{},
|
||||
StaticShape{}};
|
||||
shape_infer(detection.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0],
|
||||
StaticShape({attrs.max_detections_per_image, 4}));
|
||||
ASSERT_EQ(static_output_shapes[1], StaticShape({attrs.max_detections_per_image}));
|
||||
ASSERT_EQ(static_output_shapes[2], StaticShape({attrs.max_detections_per_image}));
|
||||
}
|
||||
|
||||
#if 0
|
||||
TEST(StaticShapeInferenceTest, ConvolutionTimeTest) {
|
||||
Strides strides{1, 1};
|
||||
|
@ -69,6 +69,10 @@ public:
|
||||
|
||||
private:
|
||||
Attributes m_attrs;
|
||||
template <class T>
|
||||
friend void shape_infer(const ExperimentalDetectronDetectionOutput* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes);
|
||||
};
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
|
@ -0,0 +1,93 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/op/experimental_detectron_detection_output.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v6 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const ExperimentalDetectronDetectionOutput* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 && output_shapes.size() == 3);
|
||||
|
||||
const auto& rois_shape = input_shapes[0];
|
||||
const auto& deltas_shape = input_shapes[1];
|
||||
const auto& scores_shape = input_shapes[2];
|
||||
const auto& im_info_shape = input_shapes[3];
|
||||
|
||||
auto& output_box_shape = output_shapes[0];
|
||||
auto& output_det_shape = output_shapes[1];
|
||||
auto& output_score_shape = output_shapes[2];
|
||||
|
||||
output_box_shape.resize(2);
|
||||
output_det_shape.resize(1);
|
||||
output_score_shape.resize(1);
|
||||
|
||||
const auto rois_shape_rank_is_static = rois_shape.rank().is_static();
|
||||
|
||||
if (rois_shape_rank_is_static) {
|
||||
NODE_VALIDATION_CHECK(op, rois_shape.size() == 2, "Input rois rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
rois_shape[1].compatible(4),
|
||||
"The last dimension of the 'input_rois' input must be compatible with 4. "
|
||||
"Got: ",
|
||||
rois_shape[1]);
|
||||
}
|
||||
const auto deltas_shape_rank_is_static = deltas_shape.rank().is_static();
|
||||
if (deltas_shape_rank_is_static) {
|
||||
NODE_VALIDATION_CHECK(op, deltas_shape.size() == 2, "Input deltas rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
deltas_shape[1].compatible(op->m_attrs.num_classes * 4),
|
||||
"The last dimension of the 'input_deltas' input be compatible with "
|
||||
"the value of the attribute 'num_classes' * 4. Got: ",
|
||||
deltas_shape[1]);
|
||||
}
|
||||
const auto scores_shape_is_static = scores_shape.rank().is_static();
|
||||
if (scores_shape_is_static) {
|
||||
NODE_VALIDATION_CHECK(op, scores_shape.size() == 2, "Input scores rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
scores_shape[1].compatible(op->m_attrs.num_classes),
|
||||
"The last dimension of the 'input_scores' input must be compatible with"
|
||||
"the value of the attribute 'num_classes'. Got: ",
|
||||
scores_shape[1]);
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(op, im_info_shape.rank().compatible(2), "Input image info rank must be compatible with 2.");
|
||||
|
||||
if (rois_shape_rank_is_static && deltas_shape_rank_is_static && scores_shape_is_static) {
|
||||
const auto& num_batches_rois = rois_shape[0];
|
||||
const auto& num_batches_deltas = deltas_shape[0];
|
||||
const auto& num_batches_scores = scores_shape[0];
|
||||
auto merge_res = rois_shape[0];
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(merge_res, num_batches_rois, num_batches_deltas) &&
|
||||
DimType::merge(merge_res, merge_res, num_batches_scores),
|
||||
"The first dimension of inputs 'input_rois', 'input_deltas', "
|
||||
"'input_scores' must be the compatible. input_rois batch: ",
|
||||
num_batches_rois,
|
||||
"; input_deltas batch: ",
|
||||
num_batches_deltas,
|
||||
"; input_scores batch: ",
|
||||
num_batches_scores);
|
||||
}
|
||||
|
||||
const auto& rois_num = op->m_attrs.max_detections_per_image;
|
||||
|
||||
output_box_shape[0] = rois_num;
|
||||
output_box_shape[1] = 4;
|
||||
output_det_shape[0] = rois_num;
|
||||
output_score_shape[0] = rois_num;
|
||||
}
|
||||
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "ngraph/op/experimental_detectron_detection_output.hpp"
|
||||
|
||||
#include <experimental_detectron_detection_output_shape_inference.hpp>
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
@ -40,73 +41,20 @@ bool op::v6::ExperimentalDetectronDetectionOutput::visit_attributes(AttributeVis
|
||||
|
||||
void op::v6::ExperimentalDetectronDetectionOutput::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v6_ExperimentalDetectronDetectionOutput_validate_and_infer_types);
|
||||
size_t rois_num = m_attrs.max_detections_per_image;
|
||||
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}, ov::PartialShape{}};
|
||||
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
|
||||
get_input_partial_shape(1),
|
||||
get_input_partial_shape(2),
|
||||
get_input_partial_shape(3)};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
auto input_et = get_input_element_type(0);
|
||||
|
||||
auto rois_shape = get_input_partial_shape(0);
|
||||
auto deltas_shape = get_input_partial_shape(1);
|
||||
auto scores_shape = get_input_partial_shape(2);
|
||||
auto im_info_shape = get_input_partial_shape(3);
|
||||
|
||||
set_output_size(3);
|
||||
set_output_type(0, input_et, ov::Shape{rois_num, 4});
|
||||
set_output_type(1, element::Type_t::i32, ov::Shape{rois_num});
|
||||
set_output_type(2, input_et, ov::Shape{rois_num});
|
||||
|
||||
if (rois_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, rois_shape.rank().get_length() == 2, "Input rois rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
rois_shape[1].is_dynamic() || rois_shape[1].get_length() == 4u,
|
||||
"The last dimension of the 'input_rois' input must be equal to 4. "
|
||||
"Got: ",
|
||||
rois_shape[1]);
|
||||
}
|
||||
|
||||
if (deltas_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, deltas_shape.rank().get_length() == 2, "Input deltas rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
deltas_shape[1].is_dynamic() || deltas_shape[1].get_length() == m_attrs.num_classes * 4,
|
||||
"The last dimension of the 'input_deltas' input must be equal to "
|
||||
"the value of the attribute 'num_classes' * 4. Got: ",
|
||||
deltas_shape[1]);
|
||||
}
|
||||
|
||||
if (scores_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, scores_shape.rank().get_length() == 2, "Input scores rank must be equal to 2.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
scores_shape[1].is_dynamic() || scores_shape[1].get_length() == m_attrs.num_classes,
|
||||
"The last dimension of the 'input_scores' input must be equal to "
|
||||
"the value of the attribute 'num_classes'. Got: ",
|
||||
scores_shape[1]);
|
||||
}
|
||||
|
||||
if (im_info_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
im_info_shape.rank().get_length() == 2,
|
||||
"Input image info rank must be equal to 2.");
|
||||
}
|
||||
|
||||
if (rois_shape.rank().is_static() && deltas_shape.rank().is_static() && scores_shape.rank().is_static()) {
|
||||
const auto num_batches_rois = rois_shape[0];
|
||||
const auto num_batches_deltas = deltas_shape[0];
|
||||
const auto num_batches_scores = scores_shape[0];
|
||||
|
||||
if (num_batches_rois.is_static() && num_batches_deltas.is_static() && num_batches_scores.is_static()) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
num_batches_rois.same_scheme(num_batches_deltas) && num_batches_deltas.same_scheme(num_batches_scores),
|
||||
"The first dimension of inputs 'input_rois', 'input_deltas', "
|
||||
"'input_scores' must be the same. input_rois batch: ",
|
||||
num_batches_rois,
|
||||
"; input_deltas batch: ",
|
||||
num_batches_deltas,
|
||||
"; input_scores batch: ",
|
||||
num_batches_scores);
|
||||
}
|
||||
}
|
||||
set_output_type(0, input_et, output_shapes[0]);
|
||||
set_output_type(1, element::Type_t::i32, output_shapes[1]);
|
||||
set_output_type(2, input_et, output_shapes[2]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v6::ExperimentalDetectronDetectionOutput::clone_with_new_inputs(
|
||||
|
Loading…
Reference in New Issue
Block a user