From 305f0056059b091e0290b983dabd93f655e86e8d Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Tue, 8 Dec 2020 04:35:52 +0100 Subject: [PATCH] Add reference implementation for PSROIPooling operator (#3245) * Add reference implementation for PSROIPooling operator * fix test_roi_pooling * use std::roundf * remove unnecessary copies in single layer tets * Fixes after review * fixes after review * use element::Type_t instead of element:: * apply code format * add PSROIPooling to evaluates_map * apply code format --- docs/ops/detection/PSROIPooling_1.md | 13 +- .../src/mkldnn_plugin/nodes/psroi.cpp | 8 +- .../single_layer_tests/psroi_pooling.cpp | 43 +++ .../single_layer_tests/psroi_pooling.hpp | 48 +++ .../src/single_layer_tests/psroi_pooling.cpp | 146 ++++++++ .../runtime/reference/psroi_pooling.hpp | 206 +++++++++++ ngraph/core/src/op/psroi_pooling.cpp | 90 +++-- .../tests/test_ngraph/test_create_op.py | 4 +- ngraph/test/CMakeLists.txt | 2 + ngraph/test/attributes.cpp | 6 +- ngraph/test/backend/psroi_pooling.in.cpp | 234 +++++++++++++ .../runtime/interpreter/evaluates_map.cpp | 23 +- .../runtime/interpreter/opset_int_tbl.hpp | 1 + ngraph/test/type_prop/psroi_pooling.cpp | 323 ++++++++++++++++++ ngraph/test/type_prop_layers.cpp | 9 - 15 files changed, 1114 insertions(+), 42 deletions(-) create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/psroi_pooling.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/psroi_pooling.hpp create mode 100644 inference-engine/tests/functional/plugin/shared/src/single_layer_tests/psroi_pooling.cpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/psroi_pooling.hpp create mode 100644 ngraph/test/backend/psroi_pooling.in.cpp create mode 100644 ngraph/test/type_prop/psroi_pooling.cpp diff --git a/docs/ops/detection/PSROIPooling_1.md b/docs/ops/detection/PSROIPooling_1.md index ae82d0f93dc..98841ccf4dc 100644 --- a/docs/ops/detection/PSROIPooling_1.md +++ b/docs/ops/detection/PSROIPooling_1.md @@ -24,7 +24,7 @@ ROIs coordinates are specified in absolute values for the average mode and in no * *group_size* - * **Description**: *group_size* is the number of groups to encode position-sensitive score maps. Use for *average* mode only. + * **Description**: *group_size* is the number of groups to encode position-sensitive score maps. * **Range of values**: a positive integer * **Type**: `int` * **Default value**: 1 @@ -63,14 +63,19 @@ ROIs coordinates are specified in absolute values for the average mode and in no **Inputs**: -* **1**: 4D input blob with feature maps. Required. +* **1**: 4D input tensor with shape `[N, C, H, W]` and type *T* with feature maps. Required. -* **2**: 2D input blob describing box consisting of five element tuples: `[batch_id, x_1, y_1, x_2, y_2]`. Required. +* **2**: 2D input tensor with shape `[num_boxes, 5]`. It contains a list of five element tuples that describe a region of interest: `[batch_id, x_1, y_1, x_2, y_2]`. Required. +Batch indices must be in the range of `[0, N-1]`. **Outputs**: * **1**: 4D output tensor with areas copied and interpolated from the 1st input tensor by coordinates of boxes from the 2nd input. +**Types** + +* *T*: any supported floating point type. + **Example** ```xml @@ -97,4 +102,4 @@ ROIs coordinates are specified in absolute values for the average mode and in no -``` \ No newline at end of file +``` diff --git a/inference-engine/src/mkldnn_plugin/nodes/psroi.cpp b/inference-engine/src/mkldnn_plugin/nodes/psroi.cpp index 7b03df16e83..7e0d7709b25 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/psroi.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/psroi.cpp @@ -102,10 +102,10 @@ public: roi_width = roi_end_w - roi_start_w; roi_height = roi_end_h - roi_start_h; } else if (mode_ == "average") { - roi_start_w = static_cast(round(bottom_rois[1])) * spatial_scale_; - roi_start_h = static_cast(round(bottom_rois[2])) * spatial_scale_; - roi_end_w = static_cast(round(bottom_rois[3]) + 1.0f) * spatial_scale_; - roi_end_h = static_cast(round(bottom_rois[4]) + 1.0f) * spatial_scale_; + roi_start_w = round(bottom_rois[1] * spatial_scale_); + roi_start_h = round(bottom_rois[2] * spatial_scale_); + roi_end_w = round(bottom_rois[3] * spatial_scale_) + 1.0f; + roi_end_h = round(bottom_rois[4] * spatial_scale_) + 1.0f; // Force too small ROIs to be 1x1 roi_width = std::max(roi_end_w - roi_start_w, 0.1f); // avoid 0 roi_height = std::max(roi_end_h - roi_start_h, 0.1f); diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/psroi_pooling.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/psroi_pooling.cpp new file mode 100644 index 00000000000..c648a990667 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/psroi_pooling.cpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/psroi_pooling.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +std::vector spatialScales = {1, 0.625}; + +const auto PSROICases_average = ::testing::Combine( + ::testing::Values(std::vector{3, 8, 16, 16}), + ::testing::Values(std::vector{10, 5}), + ::testing::Values(2), + ::testing::Values(2), + ::testing::ValuesIn(spatialScales), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values("average"), + ::testing::Values(InferenceEngine::Precision::FP32), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_TestsPSROIPooling_average, PSROIPoolingLayerTest, PSROICases_average, PSROIPoolingLayerTest::getTestCaseName); + + +const auto PSROICases_bilinear = ::testing::Combine( + ::testing::Values(std::vector{3, 32, 20, 20}), + ::testing::Values(std::vector{10, 5}), + ::testing::Values(4), + ::testing::Values(3), + ::testing::ValuesIn(spatialScales), + ::testing::Values(4), + ::testing::Values(2), + ::testing::Values("bilinear"), + ::testing::Values(InferenceEngine::Precision::FP32), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_TestsPSROIPooling_bilinear, PSROIPoolingLayerTest, PSROICases_bilinear, PSROIPoolingLayerTest::getTestCaseName); diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/psroi_pooling.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/psroi_pooling.hpp new file mode 100644 index 00000000000..8234502d795 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/psroi_pooling.hpp @@ -0,0 +1,48 @@ +// Copyright (C) 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +#include "functional_test_utils/layer_test_utils.hpp" + +namespace LayerTestsDefinitions { + +using psroiParams = std::tuple, // input shape + std::vector, // coords shape + size_t, // output_dim + size_t, // group_size + float, // Spatial scale + size_t, // spatial_bins_x + size_t, // spatial_bins_y + std::string, // mode + InferenceEngine::Precision, // Net precision + LayerTestsUtils::TargetDevice>; // Device name + +class PSROIPoolingLayerTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { + public: + static std::string getTestCaseName(testing::TestParamInfo obj); + void Infer() override; + + protected: + void SetUp() override; + + private: + size_t groupSize_; + float spatialScale_; + size_t spatialBinsX_; + size_t spatialBinsY_; + std::string mode_; + }; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/psroi_pooling.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/psroi_pooling.cpp new file mode 100644 index 00000000000..d184a8ec456 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/psroi_pooling.cpp @@ -0,0 +1,146 @@ +// Copyright (C) 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "functional_test_utils/layer_test_utils.hpp" + +#include "single_layer_tests/psroi_pooling.hpp" + +using namespace InferenceEngine; +using namespace FuncTestUtils::PrecisionUtils; + +namespace LayerTestsDefinitions { + + std::string PSROIPoolingLayerTest::getTestCaseName(testing::TestParamInfo obj) { + std::vector inputShape; + std::vector coordsShape; + size_t outputDim; + size_t groupSize; + float spatialScale; + size_t spatialBinsX; + size_t spatialBinsY; + std::string mode; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::tie(inputShape, coordsShape, outputDim, groupSize, spatialScale, spatialBinsX, spatialBinsY, mode, netPrecision, targetDevice) = obj.param; + + std::ostringstream result; + + result << "in_shape=" << CommonTestUtils::vec2str(inputShape) << "_"; + result << "coord_shape=" << CommonTestUtils::vec2str(coordsShape) << "_"; + result << "out_dim=" << outputDim << "_"; + result << "group_size=" << groupSize << "_"; + result << "scale=" << spatialScale << "_"; + result << "bins_x=" << spatialBinsX << "_"; + result << "bins_y=" << spatialBinsY << "_"; + result << "mode=" << mode << "_"; + result << "prec=" << netPrecision.name() << "_"; + result << "dev=" << targetDevice; + return result.str(); + } + + static int randInt(int low, int high) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(low, high); + return dis(gen); + } + + static void fillROITensor(float* buffer, int numROIs, int batchSize, + int height, int width, int groupSize, + float spatialScale, int spatialBinsX, int spatialBinsY, const std::string& mode) { + int minRoiWidth = groupSize; + int maxRoiWidth = width / groupSize * groupSize; + int minRoiHeight = groupSize; + int maxRoiHeight = height / groupSize * groupSize; + float scaleX = spatialScale; + float scaleY = spatialScale; + if (mode == "bilinear") { + minRoiWidth = spatialBinsX; + maxRoiWidth = width / spatialBinsX * spatialBinsX; + minRoiHeight = spatialBinsY; + maxRoiHeight = height / spatialBinsY * spatialBinsY; + scaleX *= width; + scaleY *= height; + } + int batchId = 0; + for (int i = 0; i < numROIs; i++) { + int sizeX = std::min(width, randInt(minRoiWidth, maxRoiWidth)); + int sizeY = std::min(height, randInt(minRoiHeight, maxRoiHeight)); + int startX = randInt(0, std::max(1, width - sizeX - 1)); + int startY = randInt(0, std::max(1, height - sizeY - 1)); + + float* roi = buffer + i * 5; + roi[0] = batchId; + roi[1] = startX / scaleX; + roi[2] = startY / scaleY; + roi[3] = (startX + sizeX - 1) / scaleX; + roi[4] = (startY + sizeY - 1) / scaleY; + + batchId = (batchId + 1) % batchSize; + } + } + + void PSROIPoolingLayerTest::Infer() { + inferRequest = executableNetwork.CreateInferRequest(); + inputs.clear(); + + auto inputShape = cnnNetwork.getInputShapes().begin()->second; + + size_t it = 0; + for (const auto &input : cnnNetwork.getInputsInfo()) { + const auto &info = input.second; + Blob::Ptr blob; + + if (it == 1) { + blob = make_blob_with_precision(info->getTensorDesc()); + blob->allocate(); + fillROITensor(blob->buffer(), blob->size() / 5, + inputShape[0], inputShape[2], inputShape[3], groupSize_, + spatialScale_, spatialBinsX_, spatialBinsY_, mode_); + } else { + blob = GenerateInput(*info); + } + inferRequest.SetBlob(info->name(), blob); + inputs.push_back(blob); + it++; + } + inferRequest.Infer(); + } + + void PSROIPoolingLayerTest::SetUp() { + std::vector inputShape; + std::vector coordsShape; + size_t outputDim; + InferenceEngine::Precision netPrecision; + std::tie(inputShape, coordsShape, outputDim, groupSize_, spatialScale_, + spatialBinsX_, spatialBinsY_, mode_, netPrecision, targetDevice) = this->GetParam(); + + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShape, coordsShape}); + auto paramOuts = ngraph::helpers::convert2OutputVector( + ngraph::helpers::castOps2Nodes(params)); + std::shared_ptr psroiPooling = std::make_shared(paramOuts[0], + paramOuts[1], + outputDim, + groupSize_, + spatialScale_, + spatialBinsX_, + spatialBinsY_, + mode_); + ngraph::ResultVector results{std::make_shared(psroiPooling)}; + function = std::make_shared(results, params, "psroi_pooling"); + } + + TEST_P(PSROIPoolingLayerTest, CompareWithRefs) { + Run(); + } +} // namespace LayerTestsDefinitions diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/psroi_pooling.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/psroi_pooling.hpp new file mode 100644 index 00000000000..0034e486fcd --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/psroi_pooling.hpp @@ -0,0 +1,206 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include + +#include "ngraph/shape.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + enum PSROIPoolingMode + { + AVG, + BILINEAR + }; + template + void psroi_pooling(const T* input, + const Shape& input_shape, + const T* rois, + const Shape& rois_shape, + T* output, + const Shape& output_shape, + const std::string& mode_str, + float spatial_scale, + int spatial_bins_x, + int spatial_bins_y) + { + PSROIPoolingMode mode; + if (mode_str == "average") + { + mode = AVG; + } + else if (mode_str == "bilinear") + { + mode = BILINEAR; + } + else + { + NGRAPH_CHECK(false, "Invalid PS ROI pooling mode: " + mode_str); + } + size_t channels_in = input_shape[1]; + size_t height = input_shape[2]; + size_t width = input_shape[3]; + size_t num_rois = output_shape[0]; + size_t channels_out = output_shape[1]; + size_t pooling_height = output_shape[2]; + size_t pooling_width = output_shape[3]; + int num_spatial_bins = spatial_bins_x * spatial_bins_y; + for (size_t roi = 0; roi < num_rois; roi++) + { + const T* box = rois + roi * 5; + int batch_id = box[0]; + float start_w = box[1] * spatial_scale; + float start_h = box[2] * spatial_scale; + float end_w = box[3] * spatial_scale; + float end_h = box[4] * spatial_scale; + if (mode == AVG) + { + start_w = std::roundf(start_w); + start_h = std::roundf(start_h); + end_w = std::roundf(end_w) + 1; + end_h = std::roundf(end_h) + 1; + } + float box_width = end_w - start_w; + float box_height = end_h - start_h; + float bin_width = box_width / pooling_width; + float bin_height = box_height / pooling_height; + float width_scale = 0; + float height_scale = 0; + if (mode == BILINEAR) + { + bin_width = box_width / spatial_bins_x; + bin_height = box_height / spatial_bins_y; + if (pooling_width > 1) + width_scale = bin_width * (width - 1) / (pooling_width - 1); + if (pooling_height > 1) + height_scale = bin_height * (height - 1) / (pooling_height - 1); + } + size_t c_in = 0; + for (size_t c_out = 0; c_out < channels_out; c_out++) + { + for (size_t ph = 0; ph < pooling_height; ph++) + { + for (size_t pw = 0; pw < pooling_width; pw++) + { + size_t index = + ((roi * channels_out + c_out) * pooling_height + ph) * + pooling_width + + pw; + output[index] = 0; + if (mode == AVG) + { + size_t bin_start_w = std::min( + static_cast(start_w + floorf(pw * bin_width)), + width - 1); + size_t bin_start_h = std::min( + static_cast(start_h + floorf(ph * bin_height)), + height - 1); + size_t current_bin_width = + std::min(static_cast(start_w + + ceilf((pw + 1) * bin_width)), + width) - + bin_start_w; + size_t current_bin_height = + std::min(static_cast(start_h + + ceilf((ph + 1) * bin_height)), + height) - + bin_start_h; + T sum = 0; + const T* input_offset = + input + + ((batch_id * channels_in + c_in) * height + bin_start_h) * + width + + bin_start_w; + for (size_t h = 0; h < current_bin_height; h++) + { + for (size_t w = 0; w < current_bin_width; w++) + { + sum += input_offset[h * width + w]; + } + } + output[index] = sum / (current_bin_width * current_bin_height); + c_in++; + } + else if (mode == BILINEAR) + { + c_in = 0; + for (size_t sby = 0; sby < spatial_bins_y; sby++) + { + for (size_t sbx = 0; sbx < spatial_bins_x; sbx++) + { + float bin_start_w = start_w + sbx * bin_width; + float bin_start_h = start_h + sby * bin_height; + + const T* input_offset = input + + (batch_id * channels_in + + c_in * channels_out + c_out) * + height * width; + float point_x = + pooling_width > 1 + ? (pw * width_scale + bin_start_w * (width - 1)) + : (bin_start_w + bin_start_w + bin_width) * + (width - 1) / 2; + float point_y = + pooling_height > 1 + ? (ph * height_scale + + bin_start_h * (height - 1)) + : (bin_start_h + bin_start_h + bin_height) * + (height - 1) / 2; + if (point_x < width && point_y < height) + { + size_t left = floorf(point_x); + size_t right = std::min( + static_cast(ceilf(point_x)), width - 1); + size_t top = floorf(point_y); + size_t bottom = + std::min(static_cast(ceilf(point_y)), + height - 1); + T top_left = input_offset[top * width + left]; + T top_right = input_offset[top * width + right]; + T bottom_left = input_offset[bottom * width + left]; + T bottom_right = + input_offset[bottom * width + right]; + + T top_interp = + top_left + + (top_right - top_left) * (point_x - left); + T bottom_interp = + bottom_left + + (bottom_right - bottom_left) * (point_x - left); + output[index] += + top_interp + + (bottom_interp - top_interp) * (point_y - top); + } + c_in++; + } + } + output[index] /= num_spatial_bins; + } + } + } + } + } + } + } + } +} diff --git a/ngraph/core/src/op/psroi_pooling.cpp b/ngraph/core/src/op/psroi_pooling.cpp index 2ba3035fcfc..b6217438e81 100644 --- a/ngraph/core/src/op/psroi_pooling.cpp +++ b/ngraph/core/src/op/psroi_pooling.cpp @@ -54,29 +54,81 @@ bool ngraph::op::v0::PSROIPooling::visit_attributes(AttributeVisitor& visitor) void op::PSROIPooling::validate_and_infer_types() { - auto input_et = get_input_element_type(0); - if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static()) + auto feat_maps_et = get_input_element_type(0); + auto coords_et = get_input_element_type(1); + NODE_VALIDATION_CHECK(this, + feat_maps_et.is_real(), + "Feature maps' data type must be floating point. Got " + + feat_maps_et.get_type_name()); + NODE_VALIDATION_CHECK(this, + coords_et.is_real(), + "Coords' data type must be floating point. Got " + + coords_et.get_type_name()); + NODE_VALIDATION_CHECK(this, + m_mode == "average" || m_mode == "bilinear", + "Expected 'average' or 'bilinear' mode. Got " + m_mode); + NODE_VALIDATION_CHECK(this, m_group_size > 0, "group_size has to be greater than 0"); + if (m_mode == "bilinear") { - Shape input_shape = get_input_partial_shape(0).to_shape(); - Shape coords_shape = get_input_partial_shape(1).to_shape(); - NODE_VALIDATION_CHECK(this, - input_shape.size() >= 3, - "PSROIPooling expects 3 or higher dimensions for input. Got ", - input_shape.size()); - NODE_VALIDATION_CHECK(this, - coords_shape.size() == 2, - "PSROIPooling expects 2 dimensions for box coordinates. Got ", - coords_shape.size()); - Shape output_shape{coords_shape[0], m_output_dim}; - for (size_t i = 2; i < input_shape.size(); i++) - { - output_shape.push_back(m_group_size); - } - set_output_type(0, input_et, output_shape); + NODE_VALIDATION_CHECK( + this, m_spatial_bins_x > 0, "spatial_bins_x has to be greater than 0"); + NODE_VALIDATION_CHECK( + this, m_spatial_bins_y > 0, "spatial_bins_y has to be greater than 0"); + } + + const PartialShape& feat_map_pshape = get_input_partial_shape(0); + const PartialShape& coords_pshape = get_input_partial_shape(1); + if (feat_map_pshape.rank().is_dynamic() || coords_pshape.rank().is_dynamic()) + { + set_output_type(0, feat_maps_et, PartialShape::dynamic()); } else { - set_output_type(0, input_et, PartialShape::dynamic()); + NODE_VALIDATION_CHECK(this, + feat_map_pshape.rank().get_length() == 4, + "PSROIPooling expects 4 dimensions for input. Got ", + feat_map_pshape.rank().get_length()); + NODE_VALIDATION_CHECK(this, + coords_pshape.rank().get_length() == 2, + "PSROIPooling expects 2 dimensions for box coordinates. Got ", + coords_pshape.rank().get_length()); + + if (feat_map_pshape[1].is_static()) + { + auto num_input_channels = feat_map_pshape[1].get_interval().get_min_val(); + if (m_mode == "average") + { + NODE_VALIDATION_CHECK( + this, + num_input_channels % (m_group_size * m_group_size) == 0, + "Number of input's channels must be a multiply of group_size * group_size"); + NODE_VALIDATION_CHECK(this, + m_output_dim == + num_input_channels / (m_group_size * m_group_size), + "output_dim must be equal to input channels divided by " + "group_size * group_size"); + } + else if (m_mode == "bilinear") + { + NODE_VALIDATION_CHECK(this, + num_input_channels % (m_spatial_bins_x * m_spatial_bins_y) == + 0, + "Number of input's channels must be a multiply of " + "spatial_bins_x * spatial_bins_y"); + NODE_VALIDATION_CHECK( + this, + m_output_dim == num_input_channels / (m_spatial_bins_x * m_spatial_bins_y), + "output_dim must be equal to input channels divided by " + "spatial_bins_x * spatial_bins_y"); + } + } + std::vector output_shape{coords_pshape[0], + static_cast(m_output_dim)}; + for (size_t i = 2; i < feat_map_pshape.rank().get_length(); i++) + { + output_shape.push_back(m_group_size); + } + set_output_type(0, feat_maps_et, output_shape); } } diff --git a/ngraph/python/tests/test_ngraph/test_create_op.py b/ngraph/python/tests/test_ngraph/test_create_op.py index 7c8d13b1c87..4a3b6d0eeef 100644 --- a/ngraph/python/tests/test_ngraph/test_create_op.py +++ b/ngraph/python/tests/test_ngraph/test_create_op.py @@ -700,9 +700,9 @@ def test_roi_pooling(): def test_psroi_pooling(): - inputs = ng.parameter([1, 3, 4, 5], dtype=np.float32) + inputs = ng.parameter([1, 72, 4, 5], dtype=np.float32) coords = ng.parameter([150, 5], dtype=np.float32) - node = ng.psroi_pooling(inputs, coords, 2, 6, 0.0625, 0, 0, "Avg") + node = ng.psroi_pooling(inputs, coords, 2, 6, 0.0625, 0, 0, "average") assert node.get_type_name() == "PSROIPooling" assert node.get_output_size() == 1 diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 70b90a36596..ddbbcd5f2bd 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -153,6 +153,7 @@ set(SRC type_prop/parameter.cpp type_prop/prelu.cpp type_prop/proposal.cpp + type_prop/psroi_pooling.cpp type_prop/quantize.cpp type_prop/range.cpp type_prop/read_value.cpp @@ -317,6 +318,7 @@ set(MULTI_TEST_SRC backend/pad.in.cpp backend/parameter_as_output.in.cpp backend/power.in.cpp + backend/psroi_pooling.in.cpp backend/range.in.cpp backend/reduce_max.in.cpp backend/reduce_mean.in.cpp diff --git a/ngraph/test/attributes.cpp b/ngraph/test/attributes.cpp index 88efac34f63..f015be27f03 100644 --- a/ngraph/test/attributes.cpp +++ b/ngraph/test/attributes.cpp @@ -652,12 +652,12 @@ TEST(attributes, psroi_pooling_op) auto input = make_shared(element::f32, Shape{1, 1024, 63, 38}); auto coords = make_shared(element::f32, Shape{300, 5}); - const int64_t output_dim = 882; - const int64_t group_size = 3; + const int64_t output_dim = 64; + const int64_t group_size = 4; const float spatial_scale = 0.0625; int spatial_bins_x = 1; int spatial_bins_y = 1; - string mode = "Avg"; + string mode = "average"; auto psroi_pool = make_shared( input, coords, output_dim, group_size, spatial_scale, spatial_bins_x, spatial_bins_y, mode); diff --git a/ngraph/test/backend/psroi_pooling.in.cpp b/ngraph/test/backend/psroi_pooling.in.cpp new file mode 100644 index 00000000000..725ab8fce7c --- /dev/null +++ b/ngraph/test/backend/psroi_pooling.in.cpp @@ -0,0 +1,234 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" + +#include "ngraph/op/psroi_pooling.hpp" +#include "util/engine/test_engines.hpp" +#include "util/test_case.hpp" +#include "util/test_control.hpp" + +using namespace ngraph; + +static std::string s_manifest = "${MANIFEST}"; +using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME}); + +NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_average) +{ + size_t num_channels = 8; + size_t group_size = 2; + size_t output_dim = num_channels / (group_size * group_size); + size_t num_boxes = 3; + Shape image_shape{2, num_channels, 20, 20}; + Shape coords_shape{num_boxes, 5}; + auto image = std::make_shared(element::Type_t::f32, image_shape); + auto coords = std::make_shared(element::Type_t::f32, coords_shape); + auto f = + std::make_shared(std::make_shared( + image, coords, output_dim, group_size, 1, 1, 1, "average"), + ParameterVector{image, coords}); + Shape output_shape{num_boxes, output_dim, group_size, group_size}; + + std::vector image_input(shape_size(image_shape)); + float val = 0; + std::generate( + image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; }); + std::vector coords_input{ + // batch_id, x1, y1, x2, y2 + 0, + 1, + 2, + 4, + 6, + 1, + 0, + 3, + 10, + 4, + 0, + 10, + 7, + 11, + 13, + }; + std::vector output{ + 6.2499962, 46.44986, 90.249184, 130.44876, 166.25095, 206.45341, 250.25606, 290.45853, + 326.36069, 366.86316, 408.36572, 448.86816, 486.37045, 526.86841, 568.35828, 608.84839, + 18.100033, 58.199684, 104.09898, 144.1996, 178.10167, 218.20412, 264.1069, 304.20935, + + }; + + auto tc = test::TestCase(f); + tc.add_input(image_input); + tc.add_input(coords_input); + tc.add_expected_output(output_shape, output); + tc.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_average_spatial_scale) +{ + size_t num_channels = 8; + size_t group_size = 2; + size_t output_dim = num_channels / (group_size * group_size); + size_t num_boxes = 4; + float spatial_scale = 0.2; + Shape image_shape{2, num_channels, 20, 20}; + Shape coords_shape{num_boxes, 5}; + auto image = std::make_shared(element::Type_t::f32, image_shape); + auto coords = std::make_shared(element::Type_t::f32, coords_shape); + auto f = std::make_shared( + std::make_shared( + image, coords, output_dim, group_size, spatial_scale, 1, 1, "average"), + ParameterVector{image, coords}); + Shape output_shape{num_boxes, output_dim, group_size, group_size}; + + std::vector image_input(shape_size(image_shape)); + float val = 0; + std::generate( + image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; }); + std::vector coords_input{ + // batch_id, x1, y1, x2, y2 + 0, 5, 10, 20, 30, 0, 0, 15, 50, 20, 1, 50, 35, 55, 65, 1, 0, 60, 5, 70, + }; + std::vector output{ + 6.2499962, 46.44986, 90.249184, 130.44876, 166.25095, 206.45341, 250.25606, 290.45853, + 6.3499966, 46.849857, 88.349236, 128.84866, 166.35095, 206.85341, 248.35596, 288.8584, + 338.11142, 378.21387, 424.11667, 464.21912, 498.12119, 538.21564, 584.10443, 624.19464, + 345.11185, 385.21429, 427.11685, 467.2193, 505.12161, 545.21393, 587.1037, 627.19391, + + }; + + auto tc = test::TestCase(f); + tc.add_input(image_input); + tc.add_input(coords_input); + tc.add_expected_output(output_shape, output); + tc.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_bilinear) +{ + size_t num_channels = 12; + size_t group_size = 3; + size_t spatial_bins_x = 2; + size_t spatial_bins_y = 3; + size_t output_dim = num_channels / (spatial_bins_x * spatial_bins_y); + size_t num_boxes = 5; + Shape image_shape{2, num_channels, 20, 20}; + Shape coords_shape{num_boxes, 5}; + auto image = std::make_shared(element::Type_t::f32, image_shape); + auto coords = std::make_shared(element::Type_t::f32, coords_shape); + auto f = std::make_shared( + std::make_shared( + image, coords, output_dim, group_size, 1, spatial_bins_x, spatial_bins_y, "bilinear"), + ParameterVector{image, coords}); + Shape output_shape{num_boxes, output_dim, group_size, group_size}; + + std::vector image_input(shape_size(image_shape)); + float val = 0; + std::generate( + image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; }); + std::vector coords_input{ + 0, 0.1, 0.2, 0.7, 0.4, 1, 0.4, 0.1, 0.9, 0.3, 0, 0.5, 0.7, + 0.7, 0.9, 1, 0.15, 0.3, 0.65, 0.35, 0, 0.0, 0.2, 0.7, 0.8, + }; + std::vector output{ + 210.71394, 210.99896, 211.28398, 211.98065, 212.26567, 212.55066, 213.24738, 213.53239, + 213.8174, 250.71545, 251.00047, 251.28548, 251.98218, 252.2672, 252.5522, 253.2489, + 253.53392, 253.81892, 687.40869, 687.64606, 687.88354, 688.67511, 688.91254, 689.14996, + 689.94147, 690.17896, 690.41644, 727.40021, 727.6377, 727.87518, 728.66669, 728.90405, + 729.14154, 729.93292, 730.17041, 730.4079, 230.28471, 230.3797, 230.47472, 231.55144, + 231.64642, 231.74141, 232.81813, 232.91313, 233.00813, 270.28638, 270.38141, 270.47641, + 271.5531, 271.64813, 271.74313, 272.81985, 272.91486, 273.00986, 692.63281, 692.87018, + 693.1076, 692.94928, 693.18683, 693.42426, 693.26593, 693.50342, 693.74078, 732.62402, + 732.86139, 733.09888, 732.94049, 733.17804, 733.41547, 733.25714, 733.49463, 733.73199, + 215.63843, 215.97093, 216.30345, 219.43855, 219.77106, 220.10358, 223.23871, 223.57123, + 223.90375, 255.63994, 255.97246, 256.30496, 259.44009, 259.77261, 260.10513, 263.2403, + 263.57281, 263.9053, + + }; + + auto tc = test::TestCase(f); + tc.add_input(image_input); + tc.add_input(coords_input); + tc.add_expected_output(output_shape, output); + tc.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_bilinear_spatial_scale) +{ + size_t num_channels = 12; + size_t group_size = 4; + size_t spatial_bins_x = 2; + size_t spatial_bins_y = 3; + size_t output_dim = num_channels / (spatial_bins_x * spatial_bins_y); + size_t num_boxes = 6; + float spatial_scale = 0.5; + Shape image_shape{2, num_channels, 20, 20}; + Shape coords_shape{num_boxes, 5}; + auto image = std::make_shared(element::Type_t::f32, image_shape); + auto coords = std::make_shared(element::Type_t::f32, coords_shape); + auto f = std::make_shared(std::make_shared(image, + coords, + output_dim, + group_size, + spatial_scale, + spatial_bins_x, + spatial_bins_y, + "bilinear"), + ParameterVector{image, coords}); + Shape output_shape{num_boxes, output_dim, group_size, group_size}; + + std::vector image_input(shape_size(image_shape)); + float val = 0; + std::generate( + image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; }); + std::vector coords_input{ + 0, 0.1, 0.2, 0.7, 0.4, 0, 0.5, 0.7, 1.2, 1.3, 0, 1.0, 1.3, 1.2, 1.8, + 1, 0.5, 1.1, 0.7, 1.44, 1, 0.2, 1.1, 0.5, 1.2, 1, 0.34, 1.3, 1.15, 1.35, + }; + std::vector output{ + 205.40955, 205.50456, 205.59955, 205.69453, 205.83179, 205.9268, 206.0218, 206.11681, + 206.25403, 206.34901, 206.44403, 206.53905, 206.67627, 206.77126, 206.86627, 206.96129, + 245.41107, 245.50606, 245.60106, 245.69604, 245.8333, 245.9283, 246.02327, 246.1183, + 246.25554, 246.35052, 246.44556, 246.54054, 246.67778, 246.77277, 246.86775, 246.96278, + 217.84717, 217.95801, 218.06885, 218.17969, 219.11389, 219.22473, 219.33557, 219.44641, + 220.3806, 220.49144, 220.60228, 220.71312, 221.64732, 221.75816, 221.86897, 221.97981, + 257.84872, 257.95956, 258.0704, 258.18124, 259.11545, 259.22629, 259.33713, 259.44797, + 260.38217, 260.49301, 260.60385, 260.71469, 261.6489, 261.75974, 261.87057, 261.98141, + 228.9705, 229.00215, 229.03383, 229.06549, 230.02608, 230.05774, 230.08943, 230.12109, + 231.08168, 231.11334, 231.14502, 231.1767, 232.13728, 232.16895, 232.20062, 232.23228, + 268.97217, 269.00385, 269.03549, 269.06717, 270.02777, 270.05945, 270.09109, 270.12277, + 271.08337, 271.11502, 271.1467, 271.17838, 272.13901, 272.17065, 272.2023, 272.23398, + 703.65057, 703.68219, 703.71387, 703.74554, 704.36816, 704.39984, 704.43146, 704.4632, + 705.08575, 705.11749, 705.14911, 705.18085, 705.80347, 705.83514, 705.86676, 705.89844, + 743.64136, 743.67291, 743.70459, 743.73633, 744.35889, 744.39056, 744.42218, 744.45392, + 745.07648, 745.10815, 745.13983, 745.17157, 745.79413, 745.82574, 745.85742, 745.8891, + 701.86963, 701.91724, 701.9646, 702.01221, 702.08081, 702.12823, 702.17578, 702.22321, + 702.29181, 702.33936, 702.38678, 702.43433, 702.50293, 702.55035, 702.5979, 702.64545, + 741.86041, 741.90796, 741.95538, 742.00293, 742.07153, 742.11896, 742.1665, 742.21405, + 742.28253, 742.33008, 742.3775, 742.42505, 742.49365, 742.54108, 742.58862, 742.63617, + 705.60645, 705.73468, 705.86298, 705.99115, 705.71198, 705.84027, 705.96844, 706.09668, + 705.81757, 705.94574, 706.07397, 706.20215, 705.9231, 706.05127, 706.1795, 706.3078, + 745.59698, 745.72534, 745.85352, 745.98169, 745.70264, 745.83081, 745.95898, 746.08722, + 745.80811, 745.93628, 746.06451, 746.19269, 745.91364, 746.04181, 746.1701, 746.29834, + }; + + auto tc = test::TestCase(f); + tc.add_input(image_input); + tc.add_input(coords_input); + tc.add_expected_output(output_shape, output); + tc.run(); +} diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 32505a58e6d..b40942d3f5d 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -56,6 +56,7 @@ #include "ngraph/runtime/reference/lrn.hpp" #include "ngraph/runtime/reference/mvn.hpp" #include "ngraph/runtime/reference/normalize_l2.hpp" +#include "ngraph/runtime/reference/psroi_pooling.hpp" #include "ngraph/runtime/reference/region_yolo.hpp" #include "ngraph/runtime/reference/roi_pooling.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" @@ -1629,6 +1630,26 @@ namespace return true; } + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + using T = typename element_type_traits::value_type; + runtime::reference::psroi_pooling(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + outputs[0]->get_data_ptr(), + outputs[0]->get_shape(), + op->get_mode(), + op->get_spatial_scale(), + op->get_spatial_bins_x(), + op->get_spatial_bins_y()); + + return true; + } + template bool evaluate_node(std::shared_ptr node, const HostTensorVector& outputs, @@ -1701,4 +1722,4 @@ runtime::interpreter::EvaluatorsMap& runtime::interpreter::get_evaluators_map() #undef NGRAPH_OP }; return evaluatorsMap; -} \ No newline at end of file +} diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 85d25805282..d9d00a1747b 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -35,6 +35,7 @@ NGRAPH_OP(LRN, ngraph::op::v0) NGRAPH_OP(MVN, ngraph::op::v0) NGRAPH_OP(NormalizeL2, op::v0) NGRAPH_OP(PriorBox, ngraph::op::v0) +NGRAPH_OP(PSROIPooling, op::v0) NGRAPH_OP(RegionYolo, op::v0) NGRAPH_OP(Relu, op::v0) NGRAPH_OP(ReorgYolo, op::v0) diff --git a/ngraph/test/type_prop/psroi_pooling.cpp b/ngraph/test/type_prop/psroi_pooling.cpp new file mode 100644 index 00000000000..1c6057af0a6 --- /dev/null +++ b/ngraph/test/type_prop/psroi_pooling.cpp @@ -0,0 +1,323 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" + +#include "ngraph/ngraph.hpp" +#include "ngraph/op/psroi_pooling.hpp" +#include "util/type_prop.hpp" + +using namespace ngraph; + +TEST(type_prop, psroi_pooling_average) +{ + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_shape(), (Shape{150, 2, 6, 6})); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); +} + +TEST(type_prop, psroi_pooling_bilinear) +{ + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 18, 6, 1, 2, 2, "bilinear"); + ASSERT_EQ(op->get_shape(), (Shape{150, 18, 6, 6})); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); +} + +TEST(type_prop, psroi_pooling_invalid_type) +{ + try + { + auto inputs = std::make_shared(element::Type_t::i32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Feature maps' data type must be floating point")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::i32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Coords' data type must be floating point")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_invalid_mode) +{ + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = + std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "invalid_mode"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected 'average' or 'bilinear' mode")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_invalid_shapes) +{ + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("PSROIPooling expects 4 dimensions for input")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 1, 72, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("PSROIPooling expects 2 dimensions for box coordinates")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_invalid_group_size) +{ + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 0, 1, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("group_size has to be greater than 0")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 5, 1, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string( + "Number of input's channels must be a multiply of group_size * group_size")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_invalid_output_dim) +{ + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 17, 2, 1, 0, 0, "average"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string( + "output_dim must be equal to input channels divided by group_size * group_size")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_invalid_spatial_bins) +{ + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 17, 2, 1, 0, 0, "bilinear"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("spatial_bins_x has to be greater than 0")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 17, 2, 1, 1, 0, "bilinear"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("spatial_bins_y has to be greater than 0")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 17, 2, 1, 2, 5, "bilinear"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Number of input's channels must be a multiply of " + "spatial_bins_x * spatial_bins_y")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } + + try + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 5, 5}); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 10, 2, 1, 2, 4, "bilinear"); + FAIL() << "Exception expected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("output_dim must be equal to input channels divided by " + "spatial_bins_x * spatial_bins_y")); + } + catch (...) + { + FAIL() << "Unknown exception was thrown"; + } +} + +TEST(type_prop, psroi_pooling_dynamic_ranks) +{ + { + auto inputs = + std::make_shared(element::Type_t::f32, PartialShape::dynamic()); + auto coords = std::make_shared(element::Type_t::f32, Shape{150, 5}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic()); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); + } + { + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = + std::make_shared(element::Type_t::f32, PartialShape::dynamic()); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic()); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); + } +} + +TEST(type_prop, psroi_pooling_dynamic_num_boxes) +{ + auto inputs = std::make_shared(element::Type_t::f32, Shape{1, 72, 4, 5}); + auto coords = std::make_shared(element::Type_t::f32, + PartialShape{{Dimension::dynamic(), 5}}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_output_partial_shape(0), (PartialShape{{Dimension::dynamic(), 2, 6, 6}})); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); +} + +TEST(type_prop, psroi_pooling_static_rank_dynamic_shape) +{ + { + auto inputs = std::make_shared(element::Type_t::f32, + PartialShape{{Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic()}}); + auto coords = std::make_shared( + element::Type_t::f32, PartialShape{{Dimension::dynamic(), Dimension::dynamic()}}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_output_partial_shape(0), (PartialShape{{Dimension::dynamic(), 2, 6, 6}})); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); + } + { + auto inputs = std::make_shared(element::Type_t::f32, + PartialShape{{Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic()}}); + auto coords = std::make_shared(element::Type_t::f32, + PartialShape{{200, Dimension::dynamic()}}); + auto op = std::make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "average"); + ASSERT_EQ(op->get_shape(), (Shape{200, 2, 6, 6})); + ASSERT_EQ(op->get_element_type(), element::Type_t::f32); + } +} diff --git a/ngraph/test/type_prop_layers.cpp b/ngraph/test/type_prop_layers.cpp index 1d4c012089d..10050741c43 100644 --- a/ngraph/test/type_prop_layers.cpp +++ b/ngraph/test/type_prop_layers.cpp @@ -22,7 +22,6 @@ #include "ngraph/op/interpolate.hpp" #include "ngraph/op/prior_box.hpp" #include "ngraph/op/prior_box_clustered.hpp" -#include "ngraph/op/psroi_pooling.hpp" #include "ngraph/op/region_yolo.hpp" #include "ngraph/op/reorg_yolo.hpp" #include "ngraph/op/roi_pooling.hpp" @@ -157,14 +156,6 @@ TEST(type_prop_layers, reorg_yolo) ASSERT_EQ(op->get_shape(), (Shape{2, 96, 17, 31})); } -TEST(type_prop_layers, psroi_pooling) -{ - auto inputs = make_shared(element::f32, Shape{1, 3, 4, 5}); - auto coords = make_shared(element::f32, Shape{150, 5}); - auto op = make_shared(inputs, coords, 2, 6, 0.0625, 0, 0, "Avg"); - ASSERT_EQ(op->get_shape(), (Shape{150, 2, 6, 6})); -} - TEST(type_prop_layers, roi_pooling) { auto inputs = make_shared(element::f32, Shape{2, 3, 4, 5});