[ShapeInference] DeformablePSROIPooling shape infer (#15766)
* Add shape infer function * Update shape_infer and usage * Add setters * Register shape_infer for CPU * Tests * Style * Add cast for dim type * Add precision * Update input size check * Move setters to cpp
This commit is contained in:
parent
63d282fd73
commit
a7bb54da2d
@ -70,9 +70,11 @@ public:
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
void set_output_dim(int64_t output_dim);
|
||||
int64_t get_output_dim() const {
|
||||
return m_output_dim;
|
||||
}
|
||||
void set_group_size(int64_t group_size);
|
||||
int64_t get_group_size() const {
|
||||
return m_group_size;
|
||||
}
|
||||
|
@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
#include "openvino/op/deformable_psroi_pooling.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
|
||||
template <class TShape>
|
||||
std::vector<TShape> shape_infer(const DeformablePSROIPooling* op, const std::vector<TShape>& input_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 || input_shapes.size() == 3);
|
||||
|
||||
const auto& input_pshape = input_shapes[0];
|
||||
const auto& box_coords_pshape = input_shapes[1];
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_pshape.rank().compatible(4),
|
||||
"First input rank must be compatible with 4 (input rank: ",
|
||||
input_pshape.rank(),
|
||||
")");
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
box_coords_pshape.rank().compatible(2),
|
||||
"Second input rank must be compatible with 2 (input rank: ",
|
||||
box_coords_pshape.rank(),
|
||||
")");
|
||||
|
||||
if (input_shapes.size() == 3) // offsets input is provided
|
||||
{
|
||||
const auto& offsets_shape = input_shapes[2];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
offsets_shape.rank().compatible(4),
|
||||
"Third input rank must be compatible with 4 (input rank: ",
|
||||
offsets_shape.rank(),
|
||||
")");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(op, op->get_output_dim() > 0, "Value of `output_dim` attribute has to be greater than 0 ");
|
||||
NODE_VALIDATION_CHECK(op, op->get_group_size() > 0, "Value of `group_size` attribute has to be greater than 0 ");
|
||||
|
||||
using DimType = typename TShape::value_type;
|
||||
using DimTypeVal = typename DimType::value_type;
|
||||
// The output shape: [num_rois, output_dim, group_size, group_size]
|
||||
return {TShape{box_coords_pshape.rank().is_static() ? box_coords_pshape[0] : DimType{},
|
||||
static_cast<DimTypeVal>(op->get_output_dim()),
|
||||
static_cast<DimTypeVal>(op->get_group_size()),
|
||||
static_cast<DimTypeVal>(op->get_group_size())}};
|
||||
}
|
||||
|
||||
template <class TShape>
|
||||
void shape_infer(const DeformablePSROIPooling* op,
|
||||
const std::vector<TShape>& input_shapes,
|
||||
std::vector<TShape>& output_shapes) {
|
||||
output_shapes = shape_infer(op, input_shapes);
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,7 +4,9 @@
|
||||
|
||||
#include "ngraph/op/deformable_psroi_pooling.hpp"
|
||||
|
||||
#include "deformable_psroi_pooling_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -70,46 +72,8 @@ bool op::v1::DeformablePSROIPooling::visit_attributes(AttributeVisitor& visitor)
|
||||
void op::v1::DeformablePSROIPooling::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v1_DeformablePSROIPooling_validate_and_infer_types);
|
||||
const auto& input_et = get_input_element_type(0);
|
||||
|
||||
const auto& input_pshape = get_input_partial_shape(0);
|
||||
const auto& box_coords_pshape = get_input_partial_shape(1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_pshape.rank().compatible(4),
|
||||
"First input rank must be compatible with 4 (input rank: ",
|
||||
input_pshape.rank(),
|
||||
")");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
box_coords_pshape.rank().compatible(2),
|
||||
"Second input rank must be compatible with 2 (input rank: ",
|
||||
box_coords_pshape.rank(),
|
||||
")");
|
||||
|
||||
if (get_input_size() == 3) // offsets input is provided
|
||||
{
|
||||
const auto& offsets_pshape = get_input_partial_shape(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
offsets_pshape.rank().compatible(4),
|
||||
"Third input rank must be compatible with 4 (input rank: ",
|
||||
offsets_pshape.rank(),
|
||||
")");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this, m_group_size > 0, "Value of `group_size` attribute has to be greater than 0 ");
|
||||
|
||||
NODE_VALIDATION_CHECK(this, m_output_dim > 0, "Value of `output_dim` attribute has to be greater than 0 ");
|
||||
|
||||
int64_t output_rank = 4;
|
||||
std::vector<Dimension> output_dim_vec(output_rank, Dimension::dynamic());
|
||||
if (box_coords_pshape.rank().is_static()) {
|
||||
output_dim_vec[0] = box_coords_pshape[0]; // Number of ROIs
|
||||
}
|
||||
output_dim_vec[1] = m_output_dim;
|
||||
for (int i = 2; i < output_rank; ++i) {
|
||||
output_dim_vec[i] = m_group_size;
|
||||
}
|
||||
|
||||
set_output_type(0, input_et, ov::PartialShape(output_dim_vec));
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
set_output_type(0, input_et, shape_infer(this, input_shapes)[0]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::DeformablePSROIPooling::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
@ -142,3 +106,11 @@ shared_ptr<Node> op::v1::DeformablePSROIPooling::clone_with_new_inputs(const Out
|
||||
throw ngraph_error("Not supported number of DeformablePSROIPooling args");
|
||||
}
|
||||
}
|
||||
|
||||
void op::v1::DeformablePSROIPooling::set_output_dim(int64_t output_dim) {
|
||||
m_output_dim = output_dim;
|
||||
}
|
||||
|
||||
void op::v1::DeformablePSROIPooling::set_group_size(int64_t group_size) {
|
||||
m_group_size = group_size;
|
||||
}
|
||||
|
@ -8,6 +8,50 @@
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, deformable_psroi_pooling_default_ctor) {
|
||||
const int64_t output_dim = 48;
|
||||
const int64_t group_size = 2;
|
||||
|
||||
const auto rois_dim = 30;
|
||||
|
||||
auto input_data = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 64, 56});
|
||||
auto input_coords = make_shared<op::Parameter>(element::f32, PartialShape{rois_dim, 5});
|
||||
|
||||
auto op = make_shared<op::v1::DeformablePSROIPooling>();
|
||||
|
||||
op->set_arguments(OutputVector{input_data, input_coords});
|
||||
op->set_output_dim(output_dim);
|
||||
op->set_group_size(group_size);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
const PartialShape expected_output{rois_dim, output_dim, group_size, group_size};
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_output);
|
||||
}
|
||||
|
||||
TEST(type_prop, deformable_psroi_pooling_interval_labels) {
|
||||
const float spatial_scale = 0.05f;
|
||||
const int64_t output_dim = 48;
|
||||
const int64_t group_size = 2;
|
||||
|
||||
const auto rois_dim = Dimension(15, 30);
|
||||
|
||||
auto input_data = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 64, 56});
|
||||
|
||||
auto coords_shape = PartialShape{rois_dim, 5};
|
||||
set_shape_labels(coords_shape, 20);
|
||||
auto input_coords = make_shared<op::Parameter>(element::f32, coords_shape);
|
||||
|
||||
auto op =
|
||||
make_shared<op::v1::DeformablePSROIPooling>(input_data, input_coords, output_dim, spatial_scale, group_size);
|
||||
|
||||
const PartialShape expected_output{rois_dim, output_dim, group_size, group_size};
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_output);
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)),
|
||||
ElementsAre(20, ov::no_label, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, deformable_psroi_pooling_no_offsets_group_size_3) {
|
||||
const float spatial_scale = 0.0625;
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "ctc_greedy_decoder_seq_len_shape_inference.hpp"
|
||||
#include "ctc_greedy_decoder_shape_inference.hpp"
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "deformable_psroi_pooling_shape_inference.hpp"
|
||||
#include "depth_to_space_shape_inference.hpp"
|
||||
#include "detection_output_shape_inference.hpp"
|
||||
#include "einsum_shape_inference.hpp"
|
||||
@ -530,6 +531,7 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{
|
||||
_OV_OP_SHAPE_INFER_REG(CTCGreedyDecoderSeqLen, entryIO),
|
||||
_OV_OP_SHAPE_INFER_REG(CTCLoss, entryIO),
|
||||
_OV_OP_SHAPE_INFER_REG(DeformableConvolution, entryFallbackWithPadding),
|
||||
_OV_OP_SHAPE_INFER_REG(DeformablePSROIPooling, entryIO),
|
||||
_OV_OP_SHAPE_INFER_REG(DepthToSpace, entryIO),
|
||||
_OV_OP_SHAPE_INFER_REG(DetectionOutput, entryIO),
|
||||
_OV_OP_SHAPE_INFER_REG(DFT, entryIOC),
|
||||
|
@ -0,0 +1,87 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <array>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "gmock/gmock.h"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace ov::opset10;
|
||||
using namespace testing;
|
||||
|
||||
class DeformablePSROIPoolingV1StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v1::DeformablePSROIPooling> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
output_shapes.resize(1);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DeformablePSROIPoolingV1StaticShapeInferenceTest, default_ctor) {
|
||||
const auto op = make_op();
|
||||
|
||||
const int64_t output_dim = 88;
|
||||
const int64_t group_size = 2;
|
||||
|
||||
const auto rois_dim = 30;
|
||||
|
||||
op->set_output_dim(output_dim);
|
||||
op->set_group_size(group_size);
|
||||
|
||||
auto expected_output = StaticShape{rois_dim, output_dim, group_size, group_size};
|
||||
|
||||
// 2 inputs
|
||||
{
|
||||
input_shapes = {StaticShape{2, 4, 8, 6}, StaticShape{rois_dim, 5}};
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], expected_output);
|
||||
}
|
||||
// 3 inputs
|
||||
{
|
||||
input_shapes = {StaticShape{2, 4, 8, 6}, StaticShape{rois_dim, 5}, StaticShape{rois_dim, 20, group_size, group_size}};
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], expected_output);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DeformablePSROIPoolingV1StaticShapeInferenceTest, no_offsets_input) {
|
||||
const float spatial_scale = 0.05f;
|
||||
const int64_t output_dim = 88;
|
||||
const int64_t group_size = 2;
|
||||
|
||||
const auto rois_dim = 30;
|
||||
|
||||
auto input_data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto input_coords = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
auto op = make_op(input_data, input_coords, output_dim, spatial_scale, group_size);
|
||||
|
||||
StaticShape expected_output{rois_dim, output_dim, group_size, group_size};
|
||||
input_shapes = {StaticShape{2, 4, 8, 6}, StaticShape{rois_dim, 5}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], expected_output);
|
||||
}
|
||||
|
||||
TEST_F(DeformablePSROIPoolingV1StaticShapeInferenceTest, offsets_input) {
|
||||
const float spatial_scale = 0.05f;
|
||||
const int64_t output_dim = 88;
|
||||
const int64_t group_size = 2;
|
||||
|
||||
const auto rois_dim = 30;
|
||||
|
||||
auto input_data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto input_coords = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto input_offsets = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
auto op = make_op(input_data, input_coords, input_offsets, output_dim, spatial_scale, group_size);
|
||||
|
||||
StaticShape expected_output{rois_dim, output_dim, group_size, group_size};
|
||||
input_shapes = {StaticShape{2, 4, 8, 6}, StaticShape{rois_dim, 5}, StaticShape{rois_dim, 20, group_size, group_size}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], expected_output);
|
||||
}
|
Loading…
Reference in New Issue
Block a user