[Shape infer] Implement ExperimentalDetectronTopKROIs shape infer (#7926)
* Implement ExperimentalDetectronROIs shape infer * Fix review comments * Update for review comments Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
parent
f3ca0a99a8
commit
0a850ce73e
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <experimental_detectron_topkrois_shape_inference.hpp>
|
||||
#include <openvino/core/coordinate_diff.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
|
||||
#include "utils/shape_inference/static_shape.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
TEST(StaticShapeInferenceTest, ExperimentalDetectronTopKROIsTest) {
|
||||
auto input_rois = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto rois_probs = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||
size_t max_rois = 5;
|
||||
|
||||
auto rois = std::make_shared<op::v6::ExperimentalDetectronTopKROIs>(input_rois, rois_probs, max_rois);
|
||||
|
||||
std::vector<PartialShape> input_shapes = {PartialShape{10, 4}, PartialShape{10}}, output_shapes = {PartialShape{}};
|
||||
shape_infer(rois.get(), input_shapes, output_shapes);
|
||||
|
||||
ASSERT_EQ(output_shapes[0], PartialShape({5, 4}));
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 4}, StaticShape{10}}, static_output_shapes = {StaticShape{}};
|
||||
shape_infer(rois.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({5, 4}));
|
||||
}
|
@ -40,6 +40,11 @@ public:
|
||||
|
||||
private:
|
||||
size_t m_max_rois;
|
||||
|
||||
template <class T>
|
||||
friend void shape_infer(ExperimentalDetectronTopKROIs* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes);
|
||||
};
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
|
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/op/experimental_detectron_topkrois.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v6 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(ExperimentalDetectronTopKROIs* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
|
||||
|
||||
const auto input_rois_shape = input_shapes[0];
|
||||
const auto rois_probs_shape = input_shapes[1];
|
||||
|
||||
if (input_rois_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rois_shape.rank().get_length() == 2,
|
||||
"The 'input_rois' input is expected to be a 2D. Got: ",
|
||||
input_rois_shape);
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rois_shape[1].compatible(4),
|
||||
"The second dimension of 'input_rois' should be 4. Got: ",
|
||||
input_rois_shape[1]);
|
||||
}
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
rois_probs_shape.rank().compatible(1),
|
||||
"The 'rois_probs' input is expected to be a 1D. Got: ",
|
||||
rois_probs_shape);
|
||||
|
||||
if (input_rois_shape.rank().is_static() && rois_probs_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rois_shape[0].compatible(rois_probs_shape[0]),
|
||||
"Number of rois and number of probabilities should be equal. Got: ",
|
||||
input_rois_shape[0],
|
||||
rois_probs_shape[0]);
|
||||
}
|
||||
|
||||
auto& output_shape = output_shapes[0];
|
||||
auto max_rois = op->m_max_rois;
|
||||
|
||||
output_shape.resize(2);
|
||||
output_shape[0] = max_rois;
|
||||
output_shape[1] = 4;
|
||||
}
|
||||
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "ngraph/op/experimental_detectron_topkrois.hpp"
|
||||
|
||||
#include <experimental_detectron_topkrois_shape_inference.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
@ -39,31 +41,8 @@ void op::v6::ExperimentalDetectronTopKROIs::validate_and_infer_types() {
|
||||
const auto input_rois_shape = get_input_partial_shape(0);
|
||||
const auto rois_probs_shape = get_input_partial_shape(1);
|
||||
|
||||
set_output_type(0, get_input_element_type(0), ov::Shape{m_max_rois, 4});
|
||||
|
||||
if (input_rois_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rois_shape.rank().get_length() == 2,
|
||||
"The 'input_rois' input is expected to be a 2D. Got: ",
|
||||
input_rois_shape);
|
||||
if (input_rois_shape.is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rois_shape[1] == 4,
|
||||
"The second dimension of 'input_rois' should be 4. Got: ",
|
||||
input_rois_shape[1]);
|
||||
}
|
||||
}
|
||||
if (rois_probs_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
rois_probs_shape.rank().get_length() == 1,
|
||||
"The 'rois_probs' input is expected to be a 1D. Got: ",
|
||||
rois_probs_shape);
|
||||
}
|
||||
if (input_rois_shape.rank().is_static() && rois_probs_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rois_shape[0] == rois_probs_shape[0],
|
||||
"Number of rois and number of probabilities should be equal. Got: ",
|
||||
input_rois_shape[0],
|
||||
rois_probs_shape[0]);
|
||||
}
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||
std::vector<ov::PartialShape> input_shapes = {input_rois_shape, rois_probs_shape};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user