diff --git a/inference-engine/tests/unit/cpu/shape_inference_test.cpp b/inference-engine/tests/unit/cpu/shape_inference_test.cpp index b5bd9b1ea38..fea6360d231 100644 --- a/inference-engine/tests/unit/cpu/shape_inference_test.cpp +++ b/inference-engine/tests/unit/cpu/shape_inference_test.cpp @@ -2,13 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "utils/shape_inference/static_shape.hpp" +#include +#include #include #include #include -#include -#include #include -#include "utils/shape_inference/static_shape.hpp" +#include using namespace ov; @@ -40,6 +41,67 @@ TEST(StaticShapeInferenceTest, ConvolutionTest) { ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1})); } +TEST(StaticShapeInferenceTest, ExperimentalDetectronDetectionOutputTest) { + using Attrs = op::v6::ExperimentalDetectronDetectionOutput::Attributes; + Attrs attrs; + attrs.class_agnostic_box_regression = true; + attrs.deltas_weights = {10.0f, 10.0f, 5.0f, 5.0f}; + attrs.max_delta_log_wh = 2.0f; + attrs.max_detections_per_image = 5; + attrs.nms_threshold = 0.2f; + attrs.num_classes = 2; + attrs.post_nms_count = 500; + attrs.score_threshold = 0.01000000074505806f; + int64_t rois_num = static_cast(attrs.max_detections_per_image); + + auto rois = std::make_shared(element::f32, + PartialShape{-1, -1}); + auto deltas = std::make_shared(element::f32, + PartialShape{-1, -1}); + auto scores = std::make_shared(element::f32, + PartialShape{-1, -1}); + auto im_info = std::make_shared(element::f32, + PartialShape{-1, -1}); + + auto detection = + std::make_shared( + rois, deltas, scores, im_info, attrs); + std::vector input_shapes = { + PartialShape::dynamic(), PartialShape::dynamic(), PartialShape::dynamic(), + PartialShape::dynamic()}; + std::vector output_shapes = {PartialShape::dynamic(), + PartialShape::dynamic(), + PartialShape::dynamic()}; + shape_infer(detection.get(), input_shapes, output_shapes); + ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4})); + ASSERT_EQ(output_shapes[1], (PartialShape{rois_num})); + ASSERT_EQ(output_shapes[2], (PartialShape{rois_num})); + + input_shapes = {PartialShape{-1, -1}, PartialShape{-1, -1}, + PartialShape{-1, -1}, PartialShape{-1, -1}}; + output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}}; + shape_infer(detection.get(), input_shapes, output_shapes); + ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4})); + ASSERT_EQ(output_shapes[1], (PartialShape{rois_num})); + ASSERT_EQ(output_shapes[2], (PartialShape{rois_num})); + + input_shapes = {PartialShape{16, 4}, PartialShape{16, 8}, PartialShape{16, 2}, PartialShape{1, 3}}; + output_shapes = {PartialShape{}, PartialShape{}, PartialShape{}}; + shape_infer(detection.get(), input_shapes, output_shapes); + ASSERT_EQ(output_shapes[0], (PartialShape{rois_num, 4})); + ASSERT_EQ(output_shapes[1], (PartialShape{rois_num})); + ASSERT_EQ(output_shapes[2], (PartialShape{rois_num})); + + std::vector static_input_shapes = {StaticShape{16, 4}, StaticShape{16, 8}, StaticShape{16, 2}, StaticShape{1, 3}}; + std::vector static_output_shapes = {StaticShape{}, StaticShape{}, + StaticShape{}}; + shape_infer(detection.get(), static_input_shapes, static_output_shapes); + ASSERT_EQ(static_output_shapes[0], + StaticShape({attrs.max_detections_per_image, 4})); + ASSERT_EQ(static_output_shapes[1], StaticShape({attrs.max_detections_per_image})); + ASSERT_EQ(static_output_shapes[2], StaticShape({attrs.max_detections_per_image})); +} + #if 0 TEST(StaticShapeInferenceTest, ConvolutionTimeTest) { Strides strides{1, 1}; diff --git a/ngraph/core/include/openvino/op/experimental_detectron_detection_output.hpp b/ngraph/core/include/openvino/op/experimental_detectron_detection_output.hpp index 0bd812cefaf..1668df170fb 100644 --- a/ngraph/core/include/openvino/op/experimental_detectron_detection_output.hpp +++ b/ngraph/core/include/openvino/op/experimental_detectron_detection_output.hpp @@ -69,6 +69,10 @@ public: private: Attributes m_attrs; + template + friend void shape_infer(const ExperimentalDetectronDetectionOutput* op, + const std::vector& input_shapes, + std::vector& output_shapes); }; } // namespace v6 } // namespace op diff --git a/ngraph/core/shape_inference/include/experimental_detectron_detection_output_shape_inference.hpp b/ngraph/core/shape_inference/include/experimental_detectron_detection_output_shape_inference.hpp new file mode 100644 index 00000000000..aec8bb3ef99 --- /dev/null +++ b/ngraph/core/shape_inference/include/experimental_detectron_detection_output_shape_inference.hpp @@ -0,0 +1,93 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +namespace ov { +namespace op { +namespace v6 { + +template +void shape_infer(const ExperimentalDetectronDetectionOutput* op, + const std::vector& input_shapes, + std::vector& output_shapes) { + using DimType = typename std::iterator_traits::value_type; + NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 && output_shapes.size() == 3); + + const auto& rois_shape = input_shapes[0]; + const auto& deltas_shape = input_shapes[1]; + const auto& scores_shape = input_shapes[2]; + const auto& im_info_shape = input_shapes[3]; + + auto& output_box_shape = output_shapes[0]; + auto& output_det_shape = output_shapes[1]; + auto& output_score_shape = output_shapes[2]; + + output_box_shape.resize(2); + output_det_shape.resize(1); + output_score_shape.resize(1); + + const auto rois_shape_rank_is_static = rois_shape.rank().is_static(); + + if (rois_shape_rank_is_static) { + NODE_VALIDATION_CHECK(op, rois_shape.size() == 2, "Input rois rank must be equal to 2."); + + NODE_VALIDATION_CHECK(op, + rois_shape[1].compatible(4), + "The last dimension of the 'input_rois' input must be compatible with 4. " + "Got: ", + rois_shape[1]); + } + const auto deltas_shape_rank_is_static = deltas_shape.rank().is_static(); + if (deltas_shape_rank_is_static) { + NODE_VALIDATION_CHECK(op, deltas_shape.size() == 2, "Input deltas rank must be equal to 2."); + + NODE_VALIDATION_CHECK(op, + deltas_shape[1].compatible(op->m_attrs.num_classes * 4), + "The last dimension of the 'input_deltas' input be compatible with " + "the value of the attribute 'num_classes' * 4. Got: ", + deltas_shape[1]); + } + const auto scores_shape_is_static = scores_shape.rank().is_static(); + if (scores_shape_is_static) { + NODE_VALIDATION_CHECK(op, scores_shape.size() == 2, "Input scores rank must be equal to 2."); + + NODE_VALIDATION_CHECK(op, + scores_shape[1].compatible(op->m_attrs.num_classes), + "The last dimension of the 'input_scores' input must be compatible with" + "the value of the attribute 'num_classes'. Got: ", + scores_shape[1]); + } + + NODE_VALIDATION_CHECK(op, im_info_shape.rank().compatible(2), "Input image info rank must be compatible with 2."); + + if (rois_shape_rank_is_static && deltas_shape_rank_is_static && scores_shape_is_static) { + const auto& num_batches_rois = rois_shape[0]; + const auto& num_batches_deltas = deltas_shape[0]; + const auto& num_batches_scores = scores_shape[0]; + auto merge_res = rois_shape[0]; + + NODE_VALIDATION_CHECK(op, + DimType::merge(merge_res, num_batches_rois, num_batches_deltas) && + DimType::merge(merge_res, merge_res, num_batches_scores), + "The first dimension of inputs 'input_rois', 'input_deltas', " + "'input_scores' must be the compatible. input_rois batch: ", + num_batches_rois, + "; input_deltas batch: ", + num_batches_deltas, + "; input_scores batch: ", + num_batches_scores); + } + + const auto& rois_num = op->m_attrs.max_detections_per_image; + + output_box_shape[0] = rois_num; + output_box_shape[1] = 4; + output_det_shape[0] = rois_num; + output_score_shape[0] = rois_num; +} + +} // namespace v6 +} // namespace op +} // namespace ov diff --git a/ngraph/core/src/op/experimental_detectron_detection_output.cpp b/ngraph/core/src/op/experimental_detectron_detection_output.cpp index 19cbf2787cb..e08b9c0d783 100644 --- a/ngraph/core/src/op/experimental_detectron_detection_output.cpp +++ b/ngraph/core/src/op/experimental_detectron_detection_output.cpp @@ -4,6 +4,7 @@ #include "ngraph/op/experimental_detectron_detection_output.hpp" +#include #include #include "itt.hpp" @@ -40,73 +41,20 @@ bool op::v6::ExperimentalDetectronDetectionOutput::visit_attributes(AttributeVis void op::v6::ExperimentalDetectronDetectionOutput::validate_and_infer_types() { NGRAPH_OP_SCOPE(v6_ExperimentalDetectronDetectionOutput_validate_and_infer_types); - size_t rois_num = m_attrs.max_detections_per_image; + + std::vector output_shapes = {ov::PartialShape{}, ov::PartialShape{}, ov::PartialShape{}}; + std::vector input_shapes = {get_input_partial_shape(0), + get_input_partial_shape(1), + get_input_partial_shape(2), + get_input_partial_shape(3)}; + shape_infer(this, input_shapes, output_shapes); + auto input_et = get_input_element_type(0); - auto rois_shape = get_input_partial_shape(0); - auto deltas_shape = get_input_partial_shape(1); - auto scores_shape = get_input_partial_shape(2); - auto im_info_shape = get_input_partial_shape(3); - set_output_size(3); - set_output_type(0, input_et, ov::Shape{rois_num, 4}); - set_output_type(1, element::Type_t::i32, ov::Shape{rois_num}); - set_output_type(2, input_et, ov::Shape{rois_num}); - - if (rois_shape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, rois_shape.rank().get_length() == 2, "Input rois rank must be equal to 2."); - - NODE_VALIDATION_CHECK(this, - rois_shape[1].is_dynamic() || rois_shape[1].get_length() == 4u, - "The last dimension of the 'input_rois' input must be equal to 4. " - "Got: ", - rois_shape[1]); - } - - if (deltas_shape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, deltas_shape.rank().get_length() == 2, "Input deltas rank must be equal to 2."); - - NODE_VALIDATION_CHECK(this, - deltas_shape[1].is_dynamic() || deltas_shape[1].get_length() == m_attrs.num_classes * 4, - "The last dimension of the 'input_deltas' input must be equal to " - "the value of the attribute 'num_classes' * 4. Got: ", - deltas_shape[1]); - } - - if (scores_shape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, scores_shape.rank().get_length() == 2, "Input scores rank must be equal to 2."); - - NODE_VALIDATION_CHECK(this, - scores_shape[1].is_dynamic() || scores_shape[1].get_length() == m_attrs.num_classes, - "The last dimension of the 'input_scores' input must be equal to " - "the value of the attribute 'num_classes'. Got: ", - scores_shape[1]); - } - - if (im_info_shape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, - im_info_shape.rank().get_length() == 2, - "Input image info rank must be equal to 2."); - } - - if (rois_shape.rank().is_static() && deltas_shape.rank().is_static() && scores_shape.rank().is_static()) { - const auto num_batches_rois = rois_shape[0]; - const auto num_batches_deltas = deltas_shape[0]; - const auto num_batches_scores = scores_shape[0]; - - if (num_batches_rois.is_static() && num_batches_deltas.is_static() && num_batches_scores.is_static()) { - NODE_VALIDATION_CHECK( - this, - num_batches_rois.same_scheme(num_batches_deltas) && num_batches_deltas.same_scheme(num_batches_scores), - "The first dimension of inputs 'input_rois', 'input_deltas', " - "'input_scores' must be the same. input_rois batch: ", - num_batches_rois, - "; input_deltas batch: ", - num_batches_deltas, - "; input_scores batch: ", - num_batches_scores); - } - } + set_output_type(0, input_et, output_shapes[0]); + set_output_type(1, element::Type_t::i32, output_shapes[1]); + set_output_type(2, input_et, output_shapes[2]); } shared_ptr op::v6::ExperimentalDetectronDetectionOutput::clone_with_new_inputs(