[shape_infer]Implement ExperimentalDetectronDetectionOutput shape infer. (#7903)

* Implement ExperimentalDetectronDetectionOutput shape infer.

Signed-off-by: Luwei Zhou <luwei.zhou@intel.com>

* Update on the review comments.

* Update based on review
This commit is contained in:
Luwei Zhou 2021-10-21 14:15:16 +08:00 committed by GitHub
parent 83fe59bd3e
commit c2512f8dc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 174 additions and 67 deletions

View File

@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "utils/shape_inference/static_shape.hpp"
#include <convolution_shape_inference.hpp>
#include <experimental_detectron_detection_output_shape_inference.hpp>
#include <gtest/gtest.h>
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/convolution.hpp>
#include <openvino/op/parameter.hpp>
#include <convolution_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include "utils/shape_inference/static_shape.hpp"
#include <openvino/op/parameter.hpp>
using namespace ov;
@ -40,6 +41,67 @@ TEST(StaticShapeInferenceTest, ConvolutionTest) {
ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1}));
}
TEST(StaticShapeInferenceTest, ExperimentalDetectronDetectionOutputTest) {
using Attrs = op::v6::ExperimentalDetectronDetectionOutput::Attributes;
Attrs attrs;
attrs.class_agnostic_box_regression = true;
attrs.deltas_weights = {10.0f, 10.0f, 5.0f, 5.0f};
attrs.max_delta_log_wh = 2.0f;
attrs.max_detections_per_image = 5;
attrs.nms_threshold = 0.2f;
attrs.num_classes = 2;
attrs.post_nms_count = 500;
attrs.score_threshold = 0.01000000074505806f;
int64_t rois_num = static_cast<int64_t>(attrs.max_detections_per_image);
auto rois = std::make_shared<ov::op::v0::Parameter>(element::f32,
PartialShape{-1, -1});
auto deltas = std::make_shared<ov::op::v0::Parameter>(element::f32,
PartialShape{-1, -1});
auto scores = std::make_shared<ov::op::v0::Parameter>(element::f32,
PartialShape{-1, -1});
auto im_info = std::make_shared<ov::op::v0::Parameter>(element::f32,
PartialShape{-1, -1});
auto detection =
std::make_shared<ov::op::v6::ExperimentalDetectronDetectionOutput>(
rois, deltas, scores, im_info, attrs);
std::vector<PartialShape> input_shapes = {
PartialShape::dynamic(), PartialShape::dynamic(), PartialShape::dynamic(),
PartialShape::dynamic()};
std::vector<PartialShape> output_shapes = {PartialShape::dynamic(),
PartialShape::dynamic(),
PartialShape::dynamic()};
shape_infer(detection.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
input_shapes = {PartialShape{-1, -1}, PartialShape{-1, -1},
PartialShape{-1, -1}, PartialShape{-1, -1}};
output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}};
shape_infer(detection.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
input_shapes = {PartialShape{16, 4}, PartialShape{16, 8}, PartialShape{16, 2}, PartialShape{1, 3}};
output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}};
shape_infer(detection.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4}));
ASSERT_EQ(output_shapes[1], (PartialShape{rois_num}));
ASSERT_EQ(output_shapes[2], (PartialShape{rois_num}));
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 4}, StaticShape{16, 8}, StaticShape{16, 2}, StaticShape{1, 3}};
std::vector<StaticShape> static_output_shapes = {StaticShape{}, StaticShape{},
StaticShape{}};
shape_infer(detection.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0],
StaticShape({attrs.max_detections_per_image, 4}));
ASSERT_EQ(static_output_shapes[1], StaticShape({attrs.max_detections_per_image}));
ASSERT_EQ(static_output_shapes[2], StaticShape({attrs.max_detections_per_image}));
}
#if 0
TEST(StaticShapeInferenceTest, ConvolutionTimeTest) {
Strides strides{1, 1};

View File

@ -69,6 +69,10 @@ public:
private:
Attributes m_attrs;
template <class T>
friend void shape_infer(const ExperimentalDetectronDetectionOutput* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v6
} // namespace op

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/op/experimental_detectron_detection_output.hpp>
namespace ov {
namespace op {
namespace v6 {
template <class T>
void shape_infer(const ExperimentalDetectronDetectionOutput* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 && output_shapes.size() == 3);
const auto& rois_shape = input_shapes[0];
const auto& deltas_shape = input_shapes[1];
const auto& scores_shape = input_shapes[2];
const auto& im_info_shape = input_shapes[3];
auto& output_box_shape = output_shapes[0];
auto& output_det_shape = output_shapes[1];
auto& output_score_shape = output_shapes[2];
output_box_shape.resize(2);
output_det_shape.resize(1);
output_score_shape.resize(1);
const auto rois_shape_rank_is_static = rois_shape.rank().is_static();
if (rois_shape_rank_is_static) {
NODE_VALIDATION_CHECK(op, rois_shape.size() == 2, "Input rois rank must be equal to 2.");
NODE_VALIDATION_CHECK(op,
rois_shape[1].compatible(4),
"The last dimension of the 'input_rois' input must be compatible with 4. "
"Got: ",
rois_shape[1]);
}
const auto deltas_shape_rank_is_static = deltas_shape.rank().is_static();
if (deltas_shape_rank_is_static) {
NODE_VALIDATION_CHECK(op, deltas_shape.size() == 2, "Input deltas rank must be equal to 2.");
NODE_VALIDATION_CHECK(op,
deltas_shape[1].compatible(op->m_attrs.num_classes * 4),
"The last dimension of the 'input_deltas' input be compatible with "
"the value of the attribute 'num_classes' * 4. Got: ",
deltas_shape[1]);
}
const auto scores_shape_is_static = scores_shape.rank().is_static();
if (scores_shape_is_static) {
NODE_VALIDATION_CHECK(op, scores_shape.size() == 2, "Input scores rank must be equal to 2.");
NODE_VALIDATION_CHECK(op,
scores_shape[1].compatible(op->m_attrs.num_classes),
"The last dimension of the 'input_scores' input must be compatible with"
"the value of the attribute 'num_classes'. Got: ",
scores_shape[1]);
}
NODE_VALIDATION_CHECK(op, im_info_shape.rank().compatible(2), "Input image info rank must be compatible with 2.");
if (rois_shape_rank_is_static && deltas_shape_rank_is_static && scores_shape_is_static) {
const auto& num_batches_rois = rois_shape[0];
const auto& num_batches_deltas = deltas_shape[0];
const auto& num_batches_scores = scores_shape[0];
auto merge_res = rois_shape[0];
NODE_VALIDATION_CHECK(op,
DimType::merge(merge_res, num_batches_rois, num_batches_deltas) &&
DimType::merge(merge_res, merge_res, num_batches_scores),
"The first dimension of inputs 'input_rois', 'input_deltas', "
"'input_scores' must be the compatible. input_rois batch: ",
num_batches_rois,
"; input_deltas batch: ",
num_batches_deltas,
"; input_scores batch: ",
num_batches_scores);
}
const auto& rois_num = op->m_attrs.max_detections_per_image;
output_box_shape[0] = rois_num;
output_box_shape[1] = 4;
output_det_shape[0] = rois_num;
output_score_shape[0] = rois_num;
}
} // namespace v6
} // namespace op
} // namespace ov

View File

@ -4,6 +4,7 @@
#include "ngraph/op/experimental_detectron_detection_output.hpp"
#include <experimental_detectron_detection_output_shape_inference.hpp>
#include <memory>
#include "itt.hpp"
@ -40,73 +41,20 @@ bool op::v6::ExperimentalDetectronDetectionOutput::visit_attributes(AttributeVis
void op::v6::ExperimentalDetectronDetectionOutput::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v6_ExperimentalDetectronDetectionOutput_validate_and_infer_types);
size_t rois_num = m_attrs.max_detections_per_image;
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}, ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
get_input_partial_shape(1),
get_input_partial_shape(2),
get_input_partial_shape(3)};
shape_infer(this, input_shapes, output_shapes);
auto input_et = get_input_element_type(0);
auto rois_shape = get_input_partial_shape(0);
auto deltas_shape = get_input_partial_shape(1);
auto scores_shape = get_input_partial_shape(2);
auto im_info_shape = get_input_partial_shape(3);
set_output_size(3);
set_output_type(0, input_et, ov::Shape{rois_num, 4});
set_output_type(1, element::Type_t::i32, ov::Shape{rois_num});
set_output_type(2, input_et, ov::Shape{rois_num});
if (rois_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, rois_shape.rank().get_length() == 2, "Input rois rank must be equal to 2.");
NODE_VALIDATION_CHECK(this,
rois_shape[1].is_dynamic() || rois_shape[1].get_length() == 4u,
"The last dimension of the 'input_rois' input must be equal to 4. "
"Got: ",
rois_shape[1]);
}
if (deltas_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, deltas_shape.rank().get_length() == 2, "Input deltas rank must be equal to 2.");
NODE_VALIDATION_CHECK(this,
deltas_shape[1].is_dynamic() || deltas_shape[1].get_length() == m_attrs.num_classes * 4,
"The last dimension of the 'input_deltas' input must be equal to "
"the value of the attribute 'num_classes' * 4. Got: ",
deltas_shape[1]);
}
if (scores_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, scores_shape.rank().get_length() == 2, "Input scores rank must be equal to 2.");
NODE_VALIDATION_CHECK(this,
scores_shape[1].is_dynamic() || scores_shape[1].get_length() == m_attrs.num_classes,
"The last dimension of the 'input_scores' input must be equal to "
"the value of the attribute 'num_classes'. Got: ",
scores_shape[1]);
}
if (im_info_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
im_info_shape.rank().get_length() == 2,
"Input image info rank must be equal to 2.");
}
if (rois_shape.rank().is_static() && deltas_shape.rank().is_static() && scores_shape.rank().is_static()) {
const auto num_batches_rois = rois_shape[0];
const auto num_batches_deltas = deltas_shape[0];
const auto num_batches_scores = scores_shape[0];
if (num_batches_rois.is_static() && num_batches_deltas.is_static() && num_batches_scores.is_static()) {
NODE_VALIDATION_CHECK(
this,
num_batches_rois.same_scheme(num_batches_deltas) && num_batches_deltas.same_scheme(num_batches_scores),
"The first dimension of inputs 'input_rois', 'input_deltas', "
"'input_scores' must be the same. input_rois batch: ",
num_batches_rois,
"; input_deltas batch: ",
num_batches_deltas,
"; input_scores batch: ",
num_batches_scores);
}
}
set_output_type(0, input_et, output_shapes[0]);
set_output_type(1, element::Type_t::i32, output_shapes[1]);
set_output_type(2, input_et, output_shapes[2]);
}
shared_ptr<Node> op::v6::ExperimentalDetectronDetectionOutput::clone_with_new_inputs(