[Shape infer] Implement ExperimentalDetectronTopKROIs shape infer (#7926)

* Implement ExperimentalDetectronROIs shape infer

* Fix review comments

* Update for review comments

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
Mang Guo 2021-11-02 10:32:46 +08:00 committed by GitHub
parent f3ca0a99a8
commit 0a850ce73e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 27 deletions

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <experimental_detectron_topkrois_shape_inference.hpp>
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include "utils/shape_inference/static_shape.hpp"
using namespace ov;
TEST(StaticShapeInferenceTest, ExperimentalDetectronTopKROIsTest) {
auto input_rois = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto rois_probs = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
size_t max_rois = 5;
auto rois = std::make_shared<op::v6::ExperimentalDetectronTopKROIs>(input_rois, rois_probs, max_rois);
std::vector<PartialShape> input_shapes = {PartialShape{10, 4}, PartialShape{10}}, output_shapes = {PartialShape{}};
shape_infer(rois.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], PartialShape({5, 4}));
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 4}, StaticShape{10}}, static_output_shapes = {StaticShape{}};
shape_infer(rois.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({5, 4}));
}

View File

@ -40,6 +40,11 @@ public:
private:
size_t m_max_rois;
template <class T>
friend void shape_infer(ExperimentalDetectronTopKROIs* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v6
} // namespace op

View File

@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/op/experimental_detectron_topkrois.hpp>
namespace ov {
namespace op {
namespace v6 {
template <class T>
void shape_infer(ExperimentalDetectronTopKROIs* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
const auto input_rois_shape = input_shapes[0];
const auto rois_probs_shape = input_shapes[1];
if (input_rois_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
input_rois_shape.rank().get_length() == 2,
"The 'input_rois' input is expected to be a 2D. Got: ",
input_rois_shape);
NODE_VALIDATION_CHECK(op,
input_rois_shape[1].compatible(4),
"The second dimension of 'input_rois' should be 4. Got: ",
input_rois_shape[1]);
}
NODE_VALIDATION_CHECK(op,
rois_probs_shape.rank().compatible(1),
"The 'rois_probs' input is expected to be a 1D. Got: ",
rois_probs_shape);
if (input_rois_shape.rank().is_static() && rois_probs_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
input_rois_shape[0].compatible(rois_probs_shape[0]),
"Number of rois and number of probabilities should be equal. Got: ",
input_rois_shape[0],
rois_probs_shape[0]);
}
auto& output_shape = output_shapes[0];
auto max_rois = op->m_max_rois;
output_shape.resize(2);
output_shape[0] = max_rois;
output_shape[1] = 4;
}
} // namespace v6
} // namespace op
} // namespace ov

View File

@ -4,6 +4,8 @@
#include "ngraph/op/experimental_detectron_topkrois.hpp"
#include <experimental_detectron_topkrois_shape_inference.hpp>
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/op_types.hpp"
@ -39,31 +41,8 @@ void op::v6::ExperimentalDetectronTopKROIs::validate_and_infer_types() {
const auto input_rois_shape = get_input_partial_shape(0);
const auto rois_probs_shape = get_input_partial_shape(1);
set_output_type(0, get_input_element_type(0), ov::Shape{m_max_rois, 4});
if (input_rois_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
input_rois_shape.rank().get_length() == 2,
"The 'input_rois' input is expected to be a 2D. Got: ",
input_rois_shape);
if (input_rois_shape.is_static()) {
NODE_VALIDATION_CHECK(this,
input_rois_shape[1] == 4,
"The second dimension of 'input_rois' should be 4. Got: ",
input_rois_shape[1]);
}
}
if (rois_probs_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
rois_probs_shape.rank().get_length() == 1,
"The 'rois_probs' input is expected to be a 1D. Got: ",
rois_probs_shape);
}
if (input_rois_shape.rank().is_static() && rois_probs_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
input_rois_shape[0] == rois_probs_shape[0],
"Number of rois and number of probabilities should be equal. Got: ",
input_rois_shape[0],
rois_probs_shape[0]);
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes = {input_rois_shape, rois_probs_shape};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}