Add reference implementation for PSROIPooling operator (#3245)

* Add reference implementation for PSROIPooling operator

* fix test_roi_pooling

* use std::roundf

* remove unnecessary copies in single layer tets

* Fixes after review

* fixes after review

* use element::Type_t instead of element::

* apply code format

* add PSROIPooling to evaluates_map

* apply code format
This commit is contained in:
Mateusz Tabaka 2020-12-08 04:35:52 +01:00 committed by GitHub
parent ec48fcb29b
commit 305f005605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1114 additions and 42 deletions

View File

@ -24,7 +24,7 @@ ROIs coordinates are specified in absolute values for the average mode and in no
* *group_size*
* **Description**: *group_size* is the number of groups to encode position-sensitive score maps. Use for *average* mode only.
* **Description**: *group_size* is the number of groups to encode position-sensitive score maps.
* **Range of values**: a positive integer
* **Type**: `int`
* **Default value**: 1
@ -63,14 +63,19 @@ ROIs coordinates are specified in absolute values for the average mode and in no
**Inputs**:
* **1**: 4D input blob with feature maps. Required.
* **1**: 4D input tensor with shape `[N, C, H, W]` and type *T* with feature maps. Required.
* **2**: 2D input blob describing box consisting of five element tuples: `[batch_id, x_1, y_1, x_2, y_2]`. Required.
* **2**: 2D input tensor with shape `[num_boxes, 5]`. It contains a list of five element tuples that describe a region of interest: `[batch_id, x_1, y_1, x_2, y_2]`. Required.
Batch indices must be in the range of `[0, N-1]`.
**Outputs**:
* **1**: 4D output tensor with areas copied and interpolated from the 1st input tensor by coordinates of boxes from the 2nd input.
**Types**
* *T*: any supported floating point type.
**Example**
```xml
@ -97,4 +102,4 @@ ROIs coordinates are specified in absolute values for the average mode and in no
</port>
</output>
</layer>
```
```

View File

@ -102,10 +102,10 @@ public:
roi_width = roi_end_w - roi_start_w;
roi_height = roi_end_h - roi_start_h;
} else if (mode_ == "average") {
roi_start_w = static_cast<float>(round(bottom_rois[1])) * spatial_scale_;
roi_start_h = static_cast<float>(round(bottom_rois[2])) * spatial_scale_;
roi_end_w = static_cast<float>(round(bottom_rois[3]) + 1.0f) * spatial_scale_;
roi_end_h = static_cast<float>(round(bottom_rois[4]) + 1.0f) * spatial_scale_;
roi_start_w = round(bottom_rois[1] * spatial_scale_);
roi_start_h = round(bottom_rois[2] * spatial_scale_);
roi_end_w = round(bottom_rois[3] * spatial_scale_) + 1.0f;
roi_end_h = round(bottom_rois[4] * spatial_scale_) + 1.0f;
// Force too small ROIs to be 1x1
roi_width = std::max<float>(roi_end_w - roi_start_w, 0.1f); // avoid 0
roi_height = std::max<float>(roi_end_h - roi_start_h, 0.1f);

View File

@ -0,0 +1,43 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/psroi_pooling.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
std::vector<float> spatialScales = {1, 0.625};
const auto PSROICases_average = ::testing::Combine(
::testing::Values(std::vector<size_t>{3, 8, 16, 16}),
::testing::Values(std::vector<size_t>{10, 5}),
::testing::Values(2),
::testing::Values(2),
::testing::ValuesIn(spatialScales),
::testing::Values(1),
::testing::Values(1),
::testing::Values("average"),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(smoke_TestsPSROIPooling_average, PSROIPoolingLayerTest, PSROICases_average, PSROIPoolingLayerTest::getTestCaseName);
const auto PSROICases_bilinear = ::testing::Combine(
::testing::Values(std::vector<size_t>{3, 32, 20, 20}),
::testing::Values(std::vector<size_t>{10, 5}),
::testing::Values(4),
::testing::Values(3),
::testing::ValuesIn(spatialScales),
::testing::Values(4),
::testing::Values(2),
::testing::Values("bilinear"),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(smoke_TestsPSROIPooling_bilinear, PSROIPoolingLayerTest, PSROICases_bilinear, PSROIPoolingLayerTest::getTestCaseName);

View File

@ -0,0 +1,48 @@
// Copyright (C) 2020 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "functional_test_utils/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
using psroiParams = std::tuple<std::vector<size_t>, // input shape
std::vector<size_t>, // coords shape
size_t, // output_dim
size_t, // group_size
float, // Spatial scale
size_t, // spatial_bins_x
size_t, // spatial_bins_y
std::string, // mode
InferenceEngine::Precision, // Net precision
LayerTestsUtils::TargetDevice>; // Device name
class PSROIPoolingLayerTest : public testing::WithParamInterface<psroiParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<psroiParams> obj);
void Infer() override;
protected:
void SetUp() override;
private:
size_t groupSize_;
float spatialScale_;
size_t spatialBinsX_;
size_t spatialBinsY_;
std::string mode_;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,146 @@
// Copyright (C) 2020 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include "functional_test_utils/layer_test_utils.hpp"
#include "single_layer_tests/psroi_pooling.hpp"
using namespace InferenceEngine;
using namespace FuncTestUtils::PrecisionUtils;
namespace LayerTestsDefinitions {
std::string PSROIPoolingLayerTest::getTestCaseName(testing::TestParamInfo<psroiParams> obj) {
std::vector<size_t> inputShape;
std::vector<size_t> coordsShape;
size_t outputDim;
size_t groupSize;
float spatialScale;
size_t spatialBinsX;
size_t spatialBinsY;
std::string mode;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(inputShape, coordsShape, outputDim, groupSize, spatialScale, spatialBinsX, spatialBinsY, mode, netPrecision, targetDevice) = obj.param;
std::ostringstream result;
result << "in_shape=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "coord_shape=" << CommonTestUtils::vec2str(coordsShape) << "_";
result << "out_dim=" << outputDim << "_";
result << "group_size=" << groupSize << "_";
result << "scale=" << spatialScale << "_";
result << "bins_x=" << spatialBinsX << "_";
result << "bins_y=" << spatialBinsY << "_";
result << "mode=" << mode << "_";
result << "prec=" << netPrecision.name() << "_";
result << "dev=" << targetDevice;
return result.str();
}
static int randInt(int low, int high) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(low, high);
return dis(gen);
}
static void fillROITensor(float* buffer, int numROIs, int batchSize,
int height, int width, int groupSize,
float spatialScale, int spatialBinsX, int spatialBinsY, const std::string& mode) {
int minRoiWidth = groupSize;
int maxRoiWidth = width / groupSize * groupSize;
int minRoiHeight = groupSize;
int maxRoiHeight = height / groupSize * groupSize;
float scaleX = spatialScale;
float scaleY = spatialScale;
if (mode == "bilinear") {
minRoiWidth = spatialBinsX;
maxRoiWidth = width / spatialBinsX * spatialBinsX;
minRoiHeight = spatialBinsY;
maxRoiHeight = height / spatialBinsY * spatialBinsY;
scaleX *= width;
scaleY *= height;
}
int batchId = 0;
for (int i = 0; i < numROIs; i++) {
int sizeX = std::min(width, randInt(minRoiWidth, maxRoiWidth));
int sizeY = std::min(height, randInt(minRoiHeight, maxRoiHeight));
int startX = randInt(0, std::max(1, width - sizeX - 1));
int startY = randInt(0, std::max(1, height - sizeY - 1));
float* roi = buffer + i * 5;
roi[0] = batchId;
roi[1] = startX / scaleX;
roi[2] = startY / scaleY;
roi[3] = (startX + sizeX - 1) / scaleX;
roi[4] = (startY + sizeY - 1) / scaleY;
batchId = (batchId + 1) % batchSize;
}
}
void PSROIPoolingLayerTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear();
auto inputShape = cnnNetwork.getInputShapes().begin()->second;
size_t it = 0;
for (const auto &input : cnnNetwork.getInputsInfo()) {
const auto &info = input.second;
Blob::Ptr blob;
if (it == 1) {
blob = make_blob_with_precision(info->getTensorDesc());
blob->allocate();
fillROITensor(blob->buffer(), blob->size() / 5,
inputShape[0], inputShape[2], inputShape[3], groupSize_,
spatialScale_, spatialBinsX_, spatialBinsY_, mode_);
} else {
blob = GenerateInput(*info);
}
inferRequest.SetBlob(info->name(), blob);
inputs.push_back(blob);
it++;
}
inferRequest.Infer();
}
void PSROIPoolingLayerTest::SetUp() {
std::vector<size_t> inputShape;
std::vector<size_t> coordsShape;
size_t outputDim;
InferenceEngine::Precision netPrecision;
std::tie(inputShape, coordsShape, outputDim, groupSize_, spatialScale_,
spatialBinsX_, spatialBinsY_, mode_, netPrecision, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShape, coordsShape});
auto paramOuts = ngraph::helpers::convert2OutputVector(
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
std::shared_ptr<ngraph::Node> psroiPooling = std::make_shared<ngraph::op::v0::PSROIPooling>(paramOuts[0],
paramOuts[1],
outputDim,
groupSize_,
spatialScale_,
spatialBinsX_,
spatialBinsY_,
mode_);
ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(psroiPooling)};
function = std::make_shared<ngraph::Function>(results, params, "psroi_pooling");
}
TEST_P(PSROIPoolingLayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,206 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <string>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
enum PSROIPoolingMode
{
AVG,
BILINEAR
};
template <typename T>
void psroi_pooling(const T* input,
const Shape& input_shape,
const T* rois,
const Shape& rois_shape,
T* output,
const Shape& output_shape,
const std::string& mode_str,
float spatial_scale,
int spatial_bins_x,
int spatial_bins_y)
{
PSROIPoolingMode mode;
if (mode_str == "average")
{
mode = AVG;
}
else if (mode_str == "bilinear")
{
mode = BILINEAR;
}
else
{
NGRAPH_CHECK(false, "Invalid PS ROI pooling mode: " + mode_str);
}
size_t channels_in = input_shape[1];
size_t height = input_shape[2];
size_t width = input_shape[3];
size_t num_rois = output_shape[0];
size_t channels_out = output_shape[1];
size_t pooling_height = output_shape[2];
size_t pooling_width = output_shape[3];
int num_spatial_bins = spatial_bins_x * spatial_bins_y;
for (size_t roi = 0; roi < num_rois; roi++)
{
const T* box = rois + roi * 5;
int batch_id = box[0];
float start_w = box[1] * spatial_scale;
float start_h = box[2] * spatial_scale;
float end_w = box[3] * spatial_scale;
float end_h = box[4] * spatial_scale;
if (mode == AVG)
{
start_w = std::roundf(start_w);
start_h = std::roundf(start_h);
end_w = std::roundf(end_w) + 1;
end_h = std::roundf(end_h) + 1;
}
float box_width = end_w - start_w;
float box_height = end_h - start_h;
float bin_width = box_width / pooling_width;
float bin_height = box_height / pooling_height;
float width_scale = 0;
float height_scale = 0;
if (mode == BILINEAR)
{
bin_width = box_width / spatial_bins_x;
bin_height = box_height / spatial_bins_y;
if (pooling_width > 1)
width_scale = bin_width * (width - 1) / (pooling_width - 1);
if (pooling_height > 1)
height_scale = bin_height * (height - 1) / (pooling_height - 1);
}
size_t c_in = 0;
for (size_t c_out = 0; c_out < channels_out; c_out++)
{
for (size_t ph = 0; ph < pooling_height; ph++)
{
for (size_t pw = 0; pw < pooling_width; pw++)
{
size_t index =
((roi * channels_out + c_out) * pooling_height + ph) *
pooling_width +
pw;
output[index] = 0;
if (mode == AVG)
{
size_t bin_start_w = std::min(
static_cast<size_t>(start_w + floorf(pw * bin_width)),
width - 1);
size_t bin_start_h = std::min(
static_cast<size_t>(start_h + floorf(ph * bin_height)),
height - 1);
size_t current_bin_width =
std::min(static_cast<size_t>(start_w +
ceilf((pw + 1) * bin_width)),
width) -
bin_start_w;
size_t current_bin_height =
std::min(static_cast<size_t>(start_h +
ceilf((ph + 1) * bin_height)),
height) -
bin_start_h;
T sum = 0;
const T* input_offset =
input +
((batch_id * channels_in + c_in) * height + bin_start_h) *
width +
bin_start_w;
for (size_t h = 0; h < current_bin_height; h++)
{
for (size_t w = 0; w < current_bin_width; w++)
{
sum += input_offset[h * width + w];
}
}
output[index] = sum / (current_bin_width * current_bin_height);
c_in++;
}
else if (mode == BILINEAR)
{
c_in = 0;
for (size_t sby = 0; sby < spatial_bins_y; sby++)
{
for (size_t sbx = 0; sbx < spatial_bins_x; sbx++)
{
float bin_start_w = start_w + sbx * bin_width;
float bin_start_h = start_h + sby * bin_height;
const T* input_offset = input +
(batch_id * channels_in +
c_in * channels_out + c_out) *
height * width;
float point_x =
pooling_width > 1
? (pw * width_scale + bin_start_w * (width - 1))
: (bin_start_w + bin_start_w + bin_width) *
(width - 1) / 2;
float point_y =
pooling_height > 1
? (ph * height_scale +
bin_start_h * (height - 1))
: (bin_start_h + bin_start_h + bin_height) *
(height - 1) / 2;
if (point_x < width && point_y < height)
{
size_t left = floorf(point_x);
size_t right = std::min(
static_cast<size_t>(ceilf(point_x)), width - 1);
size_t top = floorf(point_y);
size_t bottom =
std::min(static_cast<size_t>(ceilf(point_y)),
height - 1);
T top_left = input_offset[top * width + left];
T top_right = input_offset[top * width + right];
T bottom_left = input_offset[bottom * width + left];
T bottom_right =
input_offset[bottom * width + right];
T top_interp =
top_left +
(top_right - top_left) * (point_x - left);
T bottom_interp =
bottom_left +
(bottom_right - bottom_left) * (point_x - left);
output[index] +=
top_interp +
(bottom_interp - top_interp) * (point_y - top);
}
c_in++;
}
}
output[index] /= num_spatial_bins;
}
}
}
}
}
}
}
}
}

View File

@ -54,29 +54,81 @@ bool ngraph::op::v0::PSROIPooling::visit_attributes(AttributeVisitor& visitor)
void op::PSROIPooling::validate_and_infer_types()
{
auto input_et = get_input_element_type(0);
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static())
auto feat_maps_et = get_input_element_type(0);
auto coords_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
feat_maps_et.is_real(),
"Feature maps' data type must be floating point. Got " +
feat_maps_et.get_type_name());
NODE_VALIDATION_CHECK(this,
coords_et.is_real(),
"Coords' data type must be floating point. Got " +
coords_et.get_type_name());
NODE_VALIDATION_CHECK(this,
m_mode == "average" || m_mode == "bilinear",
"Expected 'average' or 'bilinear' mode. Got " + m_mode);
NODE_VALIDATION_CHECK(this, m_group_size > 0, "group_size has to be greater than 0");
if (m_mode == "bilinear")
{
Shape input_shape = get_input_partial_shape(0).to_shape();
Shape coords_shape = get_input_partial_shape(1).to_shape();
NODE_VALIDATION_CHECK(this,
input_shape.size() >= 3,
"PSROIPooling expects 3 or higher dimensions for input. Got ",
input_shape.size());
NODE_VALIDATION_CHECK(this,
coords_shape.size() == 2,
"PSROIPooling expects 2 dimensions for box coordinates. Got ",
coords_shape.size());
Shape output_shape{coords_shape[0], m_output_dim};
for (size_t i = 2; i < input_shape.size(); i++)
{
output_shape.push_back(m_group_size);
}
set_output_type(0, input_et, output_shape);
NODE_VALIDATION_CHECK(
this, m_spatial_bins_x > 0, "spatial_bins_x has to be greater than 0");
NODE_VALIDATION_CHECK(
this, m_spatial_bins_y > 0, "spatial_bins_y has to be greater than 0");
}
const PartialShape& feat_map_pshape = get_input_partial_shape(0);
const PartialShape& coords_pshape = get_input_partial_shape(1);
if (feat_map_pshape.rank().is_dynamic() || coords_pshape.rank().is_dynamic())
{
set_output_type(0, feat_maps_et, PartialShape::dynamic());
}
else
{
set_output_type(0, input_et, PartialShape::dynamic());
NODE_VALIDATION_CHECK(this,
feat_map_pshape.rank().get_length() == 4,
"PSROIPooling expects 4 dimensions for input. Got ",
feat_map_pshape.rank().get_length());
NODE_VALIDATION_CHECK(this,
coords_pshape.rank().get_length() == 2,
"PSROIPooling expects 2 dimensions for box coordinates. Got ",
coords_pshape.rank().get_length());
if (feat_map_pshape[1].is_static())
{
auto num_input_channels = feat_map_pshape[1].get_interval().get_min_val();
if (m_mode == "average")
{
NODE_VALIDATION_CHECK(
this,
num_input_channels % (m_group_size * m_group_size) == 0,
"Number of input's channels must be a multiply of group_size * group_size");
NODE_VALIDATION_CHECK(this,
m_output_dim ==
num_input_channels / (m_group_size * m_group_size),
"output_dim must be equal to input channels divided by "
"group_size * group_size");
}
else if (m_mode == "bilinear")
{
NODE_VALIDATION_CHECK(this,
num_input_channels % (m_spatial_bins_x * m_spatial_bins_y) ==
0,
"Number of input's channels must be a multiply of "
"spatial_bins_x * spatial_bins_y");
NODE_VALIDATION_CHECK(
this,
m_output_dim == num_input_channels / (m_spatial_bins_x * m_spatial_bins_y),
"output_dim must be equal to input channels divided by "
"spatial_bins_x * spatial_bins_y");
}
}
std::vector<Dimension> output_shape{coords_pshape[0],
static_cast<Dimension::value_type>(m_output_dim)};
for (size_t i = 2; i < feat_map_pshape.rank().get_length(); i++)
{
output_shape.push_back(m_group_size);
}
set_output_type(0, feat_maps_et, output_shape);
}
}

View File

@ -700,9 +700,9 @@ def test_roi_pooling():
def test_psroi_pooling():
inputs = ng.parameter([1, 3, 4, 5], dtype=np.float32)
inputs = ng.parameter([1, 72, 4, 5], dtype=np.float32)
coords = ng.parameter([150, 5], dtype=np.float32)
node = ng.psroi_pooling(inputs, coords, 2, 6, 0.0625, 0, 0, "Avg")
node = ng.psroi_pooling(inputs, coords, 2, 6, 0.0625, 0, 0, "average")
assert node.get_type_name() == "PSROIPooling"
assert node.get_output_size() == 1

View File

@ -153,6 +153,7 @@ set(SRC
type_prop/parameter.cpp
type_prop/prelu.cpp
type_prop/proposal.cpp
type_prop/psroi_pooling.cpp
type_prop/quantize.cpp
type_prop/range.cpp
type_prop/read_value.cpp
@ -317,6 +318,7 @@ set(MULTI_TEST_SRC
backend/pad.in.cpp
backend/parameter_as_output.in.cpp
backend/power.in.cpp
backend/psroi_pooling.in.cpp
backend/range.in.cpp
backend/reduce_max.in.cpp
backend/reduce_mean.in.cpp

View File

@ -652,12 +652,12 @@ TEST(attributes, psroi_pooling_op)
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1024, 63, 38});
auto coords = make_shared<op::Parameter>(element::f32, Shape{300, 5});
const int64_t output_dim = 882;
const int64_t group_size = 3;
const int64_t output_dim = 64;
const int64_t group_size = 4;
const float spatial_scale = 0.0625;
int spatial_bins_x = 1;
int spatial_bins_y = 1;
string mode = "Avg";
string mode = "average";
auto psroi_pool = make_shared<opset1::PSROIPooling>(
input, coords, output_dim, group_size, spatial_scale, spatial_bins_x, spatial_bins_y, mode);

View File

@ -0,0 +1,234 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/op/psroi_pooling.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
using namespace ngraph;
static std::string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_average)
{
size_t num_channels = 8;
size_t group_size = 2;
size_t output_dim = num_channels / (group_size * group_size);
size_t num_boxes = 3;
Shape image_shape{2, num_channels, 20, 20};
Shape coords_shape{num_boxes, 5};
auto image = std::make_shared<op::Parameter>(element::Type_t::f32, image_shape);
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, coords_shape);
auto f =
std::make_shared<Function>(std::make_shared<op::v0::PSROIPooling>(
image, coords, output_dim, group_size, 1, 1, 1, "average"),
ParameterVector{image, coords});
Shape output_shape{num_boxes, output_dim, group_size, group_size};
std::vector<float> image_input(shape_size(image_shape));
float val = 0;
std::generate(
image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; });
std::vector<float> coords_input{
// batch_id, x1, y1, x2, y2
0,
1,
2,
4,
6,
1,
0,
3,
10,
4,
0,
10,
7,
11,
13,
};
std::vector<float> output{
6.2499962, 46.44986, 90.249184, 130.44876, 166.25095, 206.45341, 250.25606, 290.45853,
326.36069, 366.86316, 408.36572, 448.86816, 486.37045, 526.86841, 568.35828, 608.84839,
18.100033, 58.199684, 104.09898, 144.1996, 178.10167, 218.20412, 264.1069, 304.20935,
};
auto tc = test::TestCase<TestEngine>(f);
tc.add_input<float>(image_input);
tc.add_input<float>(coords_input);
tc.add_expected_output<float>(output_shape, output);
tc.run();
}
NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_average_spatial_scale)
{
size_t num_channels = 8;
size_t group_size = 2;
size_t output_dim = num_channels / (group_size * group_size);
size_t num_boxes = 4;
float spatial_scale = 0.2;
Shape image_shape{2, num_channels, 20, 20};
Shape coords_shape{num_boxes, 5};
auto image = std::make_shared<op::Parameter>(element::Type_t::f32, image_shape);
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, coords_shape);
auto f = std::make_shared<Function>(
std::make_shared<op::v0::PSROIPooling>(
image, coords, output_dim, group_size, spatial_scale, 1, 1, "average"),
ParameterVector{image, coords});
Shape output_shape{num_boxes, output_dim, group_size, group_size};
std::vector<float> image_input(shape_size(image_shape));
float val = 0;
std::generate(
image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; });
std::vector<float> coords_input{
// batch_id, x1, y1, x2, y2
0, 5, 10, 20, 30, 0, 0, 15, 50, 20, 1, 50, 35, 55, 65, 1, 0, 60, 5, 70,
};
std::vector<float> output{
6.2499962, 46.44986, 90.249184, 130.44876, 166.25095, 206.45341, 250.25606, 290.45853,
6.3499966, 46.849857, 88.349236, 128.84866, 166.35095, 206.85341, 248.35596, 288.8584,
338.11142, 378.21387, 424.11667, 464.21912, 498.12119, 538.21564, 584.10443, 624.19464,
345.11185, 385.21429, 427.11685, 467.2193, 505.12161, 545.21393, 587.1037, 627.19391,
};
auto tc = test::TestCase<TestEngine>(f);
tc.add_input<float>(image_input);
tc.add_input<float>(coords_input);
tc.add_expected_output<float>(output_shape, output);
tc.run();
}
NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_bilinear)
{
size_t num_channels = 12;
size_t group_size = 3;
size_t spatial_bins_x = 2;
size_t spatial_bins_y = 3;
size_t output_dim = num_channels / (spatial_bins_x * spatial_bins_y);
size_t num_boxes = 5;
Shape image_shape{2, num_channels, 20, 20};
Shape coords_shape{num_boxes, 5};
auto image = std::make_shared<op::Parameter>(element::Type_t::f32, image_shape);
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, coords_shape);
auto f = std::make_shared<Function>(
std::make_shared<op::v0::PSROIPooling>(
image, coords, output_dim, group_size, 1, spatial_bins_x, spatial_bins_y, "bilinear"),
ParameterVector{image, coords});
Shape output_shape{num_boxes, output_dim, group_size, group_size};
std::vector<float> image_input(shape_size(image_shape));
float val = 0;
std::generate(
image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; });
std::vector<float> coords_input{
0, 0.1, 0.2, 0.7, 0.4, 1, 0.4, 0.1, 0.9, 0.3, 0, 0.5, 0.7,
0.7, 0.9, 1, 0.15, 0.3, 0.65, 0.35, 0, 0.0, 0.2, 0.7, 0.8,
};
std::vector<float> output{
210.71394, 210.99896, 211.28398, 211.98065, 212.26567, 212.55066, 213.24738, 213.53239,
213.8174, 250.71545, 251.00047, 251.28548, 251.98218, 252.2672, 252.5522, 253.2489,
253.53392, 253.81892, 687.40869, 687.64606, 687.88354, 688.67511, 688.91254, 689.14996,
689.94147, 690.17896, 690.41644, 727.40021, 727.6377, 727.87518, 728.66669, 728.90405,
729.14154, 729.93292, 730.17041, 730.4079, 230.28471, 230.3797, 230.47472, 231.55144,
231.64642, 231.74141, 232.81813, 232.91313, 233.00813, 270.28638, 270.38141, 270.47641,
271.5531, 271.64813, 271.74313, 272.81985, 272.91486, 273.00986, 692.63281, 692.87018,
693.1076, 692.94928, 693.18683, 693.42426, 693.26593, 693.50342, 693.74078, 732.62402,
732.86139, 733.09888, 732.94049, 733.17804, 733.41547, 733.25714, 733.49463, 733.73199,
215.63843, 215.97093, 216.30345, 219.43855, 219.77106, 220.10358, 223.23871, 223.57123,
223.90375, 255.63994, 255.97246, 256.30496, 259.44009, 259.77261, 260.10513, 263.2403,
263.57281, 263.9053,
};
auto tc = test::TestCase<TestEngine>(f);
tc.add_input<float>(image_input);
tc.add_input<float>(coords_input);
tc.add_expected_output<float>(output_shape, output);
tc.run();
}
NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_bilinear_spatial_scale)
{
size_t num_channels = 12;
size_t group_size = 4;
size_t spatial_bins_x = 2;
size_t spatial_bins_y = 3;
size_t output_dim = num_channels / (spatial_bins_x * spatial_bins_y);
size_t num_boxes = 6;
float spatial_scale = 0.5;
Shape image_shape{2, num_channels, 20, 20};
Shape coords_shape{num_boxes, 5};
auto image = std::make_shared<op::Parameter>(element::Type_t::f32, image_shape);
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, coords_shape);
auto f = std::make_shared<Function>(std::make_shared<op::v0::PSROIPooling>(image,
coords,
output_dim,
group_size,
spatial_scale,
spatial_bins_x,
spatial_bins_y,
"bilinear"),
ParameterVector{image, coords});
Shape output_shape{num_boxes, output_dim, group_size, group_size};
std::vector<float> image_input(shape_size(image_shape));
float val = 0;
std::generate(
image_input.begin(), image_input.end(), [val]() mutable -> float { return val += 0.1; });
std::vector<float> coords_input{
0, 0.1, 0.2, 0.7, 0.4, 0, 0.5, 0.7, 1.2, 1.3, 0, 1.0, 1.3, 1.2, 1.8,
1, 0.5, 1.1, 0.7, 1.44, 1, 0.2, 1.1, 0.5, 1.2, 1, 0.34, 1.3, 1.15, 1.35,
};
std::vector<float> output{
205.40955, 205.50456, 205.59955, 205.69453, 205.83179, 205.9268, 206.0218, 206.11681,
206.25403, 206.34901, 206.44403, 206.53905, 206.67627, 206.77126, 206.86627, 206.96129,
245.41107, 245.50606, 245.60106, 245.69604, 245.8333, 245.9283, 246.02327, 246.1183,
246.25554, 246.35052, 246.44556, 246.54054, 246.67778, 246.77277, 246.86775, 246.96278,
217.84717, 217.95801, 218.06885, 218.17969, 219.11389, 219.22473, 219.33557, 219.44641,
220.3806, 220.49144, 220.60228, 220.71312, 221.64732, 221.75816, 221.86897, 221.97981,
257.84872, 257.95956, 258.0704, 258.18124, 259.11545, 259.22629, 259.33713, 259.44797,
260.38217, 260.49301, 260.60385, 260.71469, 261.6489, 261.75974, 261.87057, 261.98141,
228.9705, 229.00215, 229.03383, 229.06549, 230.02608, 230.05774, 230.08943, 230.12109,
231.08168, 231.11334, 231.14502, 231.1767, 232.13728, 232.16895, 232.20062, 232.23228,
268.97217, 269.00385, 269.03549, 269.06717, 270.02777, 270.05945, 270.09109, 270.12277,
271.08337, 271.11502, 271.1467, 271.17838, 272.13901, 272.17065, 272.2023, 272.23398,
703.65057, 703.68219, 703.71387, 703.74554, 704.36816, 704.39984, 704.43146, 704.4632,
705.08575, 705.11749, 705.14911, 705.18085, 705.80347, 705.83514, 705.86676, 705.89844,
743.64136, 743.67291, 743.70459, 743.73633, 744.35889, 744.39056, 744.42218, 744.45392,
745.07648, 745.10815, 745.13983, 745.17157, 745.79413, 745.82574, 745.85742, 745.8891,
701.86963, 701.91724, 701.9646, 702.01221, 702.08081, 702.12823, 702.17578, 702.22321,
702.29181, 702.33936, 702.38678, 702.43433, 702.50293, 702.55035, 702.5979, 702.64545,
741.86041, 741.90796, 741.95538, 742.00293, 742.07153, 742.11896, 742.1665, 742.21405,
742.28253, 742.33008, 742.3775, 742.42505, 742.49365, 742.54108, 742.58862, 742.63617,
705.60645, 705.73468, 705.86298, 705.99115, 705.71198, 705.84027, 705.96844, 706.09668,
705.81757, 705.94574, 706.07397, 706.20215, 705.9231, 706.05127, 706.1795, 706.3078,
745.59698, 745.72534, 745.85352, 745.98169, 745.70264, 745.83081, 745.95898, 746.08722,
745.80811, 745.93628, 746.06451, 746.19269, 745.91364, 746.04181, 746.1701, 746.29834,
};
auto tc = test::TestCase<TestEngine>(f);
tc.add_input<float>(image_input);
tc.add_input<float>(coords_input);
tc.add_expected_output<float>(output_shape, output);
tc.run();
}

View File

@ -56,6 +56,7 @@
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/mvn.hpp"
#include "ngraph/runtime/reference/normalize_l2.hpp"
#include "ngraph/runtime/reference/psroi_pooling.hpp"
#include "ngraph/runtime/reference/region_yolo.hpp"
#include "ngraph/runtime/reference/roi_pooling.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
@ -1629,6 +1630,26 @@ namespace
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::PSROIPooling>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::psroi_pooling<T>(inputs[0]->get_data_ptr<T>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<T>(),
inputs[1]->get_shape(),
outputs[0]->get_data_ptr<T>(),
outputs[0]->get_shape(),
op->get_mode(),
op->get_spatial_scale(),
op->get_spatial_bins_x(),
op->get_spatial_bins_y());
return true;
}
template <typename T>
bool evaluate_node(std::shared_ptr<Node> node,
const HostTensorVector& outputs,
@ -1701,4 +1722,4 @@ runtime::interpreter::EvaluatorsMap& runtime::interpreter::get_evaluators_map()
#undef NGRAPH_OP
};
return evaluatorsMap;
}
}

View File

@ -35,6 +35,7 @@ NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(MVN, ngraph::op::v0)
NGRAPH_OP(NormalizeL2, op::v0)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(PSROIPooling, op::v0)
NGRAPH_OP(RegionYolo, op::v0)
NGRAPH_OP(Relu, op::v0)
NGRAPH_OP(ReorgYolo, op::v0)

View File

@ -0,0 +1,323 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/psroi_pooling.hpp"
#include "util/type_prop.hpp"
using namespace ngraph;
TEST(type_prop, psroi_pooling_average)
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_shape(), (Shape{150, 2, 6, 6}));
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
TEST(type_prop, psroi_pooling_bilinear)
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 18, 6, 1, 2, 2, "bilinear");
ASSERT_EQ(op->get_shape(), (Shape{150, 18, 6, 6}));
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
TEST(type_prop, psroi_pooling_invalid_type)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::i32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Feature maps' data type must be floating point"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::i32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Coords' data type must be floating point"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_invalid_mode)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op =
std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "invalid_mode");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected 'average' or 'bilinear' mode"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_invalid_shapes)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("PSROIPooling expects 4 dimensions for input"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 1, 72, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("PSROIPooling expects 2 dimensions for box coordinates"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_invalid_group_size)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 0, 1, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("group_size has to be greater than 0"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 5, 1, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Number of input's channels must be a multiply of group_size * group_size"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_invalid_output_dim)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 17, 2, 1, 0, 0, "average");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"output_dim must be equal to input channels divided by group_size * group_size"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_invalid_spatial_bins)
{
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 17, 2, 1, 0, 0, "bilinear");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("spatial_bins_x has to be greater than 0"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 17, 2, 1, 1, 0, "bilinear");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("spatial_bins_y has to be greater than 0"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 17, 2, 1, 2, 5, "bilinear");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Number of input's channels must be a multiply of "
"spatial_bins_x * spatial_bins_y"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
try
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 5, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 10, 2, 1, 2, 4, "bilinear");
FAIL() << "Exception expected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("output_dim must be equal to input channels divided by "
"spatial_bins_x * spatial_bins_y"));
}
catch (...)
{
FAIL() << "Unknown exception was thrown";
}
}
TEST(type_prop, psroi_pooling_dynamic_ranks)
{
{
auto inputs =
std::make_shared<op::Parameter>(element::Type_t::f32, PartialShape::dynamic());
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{150, 5});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords =
std::make_shared<op::Parameter>(element::Type_t::f32, PartialShape::dynamic());
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
}
TEST(type_prop, psroi_pooling_dynamic_num_boxes)
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 72, 4, 5});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32,
PartialShape{{Dimension::dynamic(), 5}});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_output_partial_shape(0), (PartialShape{{Dimension::dynamic(), 2, 6, 6}}));
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
TEST(type_prop, psroi_pooling_static_rank_dynamic_shape)
{
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32,
PartialShape{{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()}});
auto coords = std::make_shared<op::Parameter>(
element::Type_t::f32, PartialShape{{Dimension::dynamic(), Dimension::dynamic()}});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_output_partial_shape(0), (PartialShape{{Dimension::dynamic(), 2, 6, 6}}));
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
{
auto inputs = std::make_shared<op::Parameter>(element::Type_t::f32,
PartialShape{{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()}});
auto coords = std::make_shared<op::Parameter>(element::Type_t::f32,
PartialShape{{200, Dimension::dynamic()}});
auto op = std::make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "average");
ASSERT_EQ(op->get_shape(), (Shape{200, 2, 6, 6}));
ASSERT_EQ(op->get_element_type(), element::Type_t::f32);
}
}

View File

@ -22,7 +22,6 @@
#include "ngraph/op/interpolate.hpp"
#include "ngraph/op/prior_box.hpp"
#include "ngraph/op/prior_box_clustered.hpp"
#include "ngraph/op/psroi_pooling.hpp"
#include "ngraph/op/region_yolo.hpp"
#include "ngraph/op/reorg_yolo.hpp"
#include "ngraph/op/roi_pooling.hpp"
@ -157,14 +156,6 @@ TEST(type_prop_layers, reorg_yolo)
ASSERT_EQ(op->get_shape(), (Shape{2, 96, 17, 31}));
}
TEST(type_prop_layers, psroi_pooling)
{
auto inputs = make_shared<op::Parameter>(element::f32, Shape{1, 3, 4, 5});
auto coords = make_shared<op::Parameter>(element::f32, Shape{150, 5});
auto op = make_shared<op::PSROIPooling>(inputs, coords, 2, 6, 0.0625, 0, 0, "Avg");
ASSERT_EQ(op->get_shape(), (Shape{150, 2, 6, 6}));
}
TEST(type_prop_layers, roi_pooling)
{
auto inputs = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});