[GPU] ROIAlign-3 (#8991)

This commit is contained in:
Yaroslav Torzuk 2021-12-07 09:10:09 +02:00 committed by GitHub
parent ad668d6ac6
commit 8ce22396b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 653 additions and 1 deletions

View File

@ -0,0 +1,62 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/roi_align.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
const std::vector<InferenceEngine::Precision> netPRCs = {
InferenceEngine::Precision::FP32,
// There is no possibility to test ROIAlign in fp16 precision,
// because on edge cases where in fp32 version ROI value is
// a little bit smaller than the nearest integer value,
// it would be bigger than the nearest integer in fp16 precision.
// Such behavior leads to completely different results of ROIAlign
// in fp32 and fp16 precisions.
// In real AI applications this problem is solved by precision-aware training.
// InferenceEngine::Precision::FP16,
};
const auto ROIAlignCases_average =
::testing::Combine(
::testing::ValuesIn(
std::vector<std::vector<size_t>> {
{ 3, 8, 16, 16 },
{ 2, 1, 16, 16 },
{ 2, 1, 8, 16 }}),
::testing::Values(std::vector<size_t>{ 2, 4 }),
::testing::Values(2),
::testing::Values(2),
::testing::ValuesIn(std::vector<float> { 1, 0.625 }),
::testing::Values(2),
::testing::Values("avg"),
::testing::ValuesIn(netPRCs),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_SUITE_P(smoke_TestsROIAlign_average, ROIAlignLayerTest, ROIAlignCases_average, ROIAlignLayerTest::getTestCaseName);
const auto ROIAlignCases_max =
::testing::Combine(
::testing::ValuesIn(
std::vector<std::vector<size_t>> {
{ 2, 8, 20, 20 },
{ 2, 1, 20, 20 },
{ 2, 1, 10, 20 }
}),
::testing::Values(std::vector<size_t>{ 2, 4 }),
::testing::Values(2),
::testing::Values(2),
::testing::ValuesIn(std::vector<float> { 1, 0.625 }),
::testing::Values(2),
::testing::Values("max"),
::testing::ValuesIn(netPRCs),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_SUITE_P(smoke_TestsROIAlign_max, ROIAlignLayerTest, ROIAlignCases_max, ROIAlignLayerTest::getTestCaseName);

View File

@ -0,0 +1,69 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "primitive.hpp"
#include <vector>
namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{
/// @brief ROIAlign is a pooling layer used over feature maps of
/// non-uniform input sizes and outputs a feature map of a fixed size.
struct roi_align : public primitive_base<roi_align> {
CLDNN_DECLARE_PRIMITIVE(roi_align)
/// @brief Pooling mode for the @ref roi_align
enum PoolingMode {
Max,
Avg
};
/// @brief Constructs roi_align primitive.
/// @param id This primitive id.
/// @param inputs Inputs data primitive ids.
/// @param pooled_h Height of the ROI output feature map.
/// @param pooled_w Width of the ROI output feature map.
/// @param sampling_ratio Number of bins over height and width to use to calculate each output feature map element.
/// @param spatial_scale multiplicative spatial scale factor to translate ROI coordinates
/// from their input spatial scale to the scale used when pooling.
/// @param mode Method to perform pooling to produce output feature map elements.
/// @param shrink_axis_mask Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
roi_align(const primitive_id& id,
const std::vector<primitive_id>& inputs,
int pooled_h,
int pooled_w,
int sampling_ratio,
float spatial_scale,
PoolingMode mode,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, inputs, ext_prim_id, output_padding),
pooled_h {pooled_h},
pooled_w {pooled_w},
sampling_ratio {sampling_ratio},
spatial_scale {spatial_scale},
mode {mode}
{}
/// @brief Height of the ROI output feature map.
int pooled_h;
/// @brief Width of the ROI output feature map.
int pooled_w;
/// @brief Number of bins over height and width to use to calculate each output feature map element.
int sampling_ratio;
/// @brief multiplicative spatial scale factor to translate ROI coordinates
/// from their input spatial scale to the scale used when pooling.
float spatial_scale;
/// @brief Method to perform pooling to produce output feature map elements.
PoolingMode mode;
};
/// @}
/// @}
/// @}
} // namespace cldnn

View File

@ -21,6 +21,7 @@ enum class KernelType {
NORMALIZE,
POOLING,
ROI_POOLING,
ROI_ALIGN,
FULLY_CONNECTED,
ACTIVATION,
SOFT_MAX,

View File

@ -0,0 +1,90 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "roi_align_kernel_ref.h"
#include <kernel_selector_utils.h>
namespace kernel_selector {
ParamsKey ROIAlignKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
k.EnablePoolType(PoolType::MAX);
k.EnablePoolType(PoolType::AVG);
return k;
}
namespace {
ROIAlignKernelRef::DispatchData SetDefault(const roi_align_params& params) {
ROIAlignKernelRef::DispatchData dispatchData;
// Determine global work sizes.
dispatchData.gws[0] = params.output.LogicalSize();
dispatchData.gws[1] = 1;
dispatchData.gws[2] = 1;
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
return dispatchData;
}
} // anonymous namespace
KernelsData ROIAlignKernelRef::GetKernelsData(const Params &params, const optional_params &options) const {
if (!Validate(params, options)) {
return {};
}
KernelData kernel_data = KernelData::Default<roi_align_params>(params);
roi_align_params &new_params = dynamic_cast<roi_align_params&>(*kernel_data.params.get());
auto dispatch_data = SetDefault(new_params);
auto entry_point = GetEntryPoint(kernelName, new_params.layerID, params, options);
auto roi_align_specific_jit = GetJitConstants(new_params);
auto jit = CreateJit(kernelName, roi_align_specific_jit, entry_point);
FillCLKernelData(kernel_data.kernels[0], dispatch_data, params.engineInfo,
kernelName, jit, entry_point);
kernel_data.kernels[0].params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
kernel_data.kernels[0].params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2});
return {kernel_data};
}
float ROIAlignKernelRef::GetKernelsPriority(const Params &params, const optional_params &options) const {
return FORCE_PRIORITY_1;
}
bool ROIAlignKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType::ROI_ALIGN || o.GetType() != KernelType::ROI_ALIGN) {
return false;
}
const roi_align_params &params = static_cast<const roi_align_params&>(p);
if (params.inputs.size() != 3)
return false;
if (params.output.Dimentions() > 4 || params.inputs[0].Dimentions() > 4 || params.inputs[1].Dimentions() > 2)
return false;
return true;
}
JitConstants ROIAlignKernelRef::GetJitConstants(const roi_align_params &params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);
jit.AddConstant(MakeJitConstant("SPATIAL_SCALE", params.spatial_scale));
jit.AddConstant(MakeJitConstant("SAMPLING_RATIO", params.sampling_ratio));
if (params.mode == PoolType::MAX)
jit.AddConstant(MakeJitConstant("MAX_POOL", true));
else if (params.mode == PoolType::AVG)
jit.AddConstant(MakeJitConstant("AVG_POOL", true));
return jit;
}
} // namespace kernel_selector

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <kernel_base_opencl.h>
namespace kernel_selector {
struct roi_align_params : public base_params {
roi_align_params() : base_params{KernelType::ROI_ALIGN} {}
int sampling_ratio = 0;
float spatial_scale = 1.f;
PoolType mode = PoolType::MAX;
ParamsKey GetParamsKey() const override {
auto k = base_params::GetParamsKey();
k.EnablePoolType(mode);
return k;
}
};
struct roi_align_optional_params : optional_params {
roi_align_optional_params() : optional_params{KernelType::ROI_ALIGN} {}
};
class ROIAlignKernelRef : public KernelBaseOpenCL {
public:
using KernelBaseOpenCL::KernelBaseOpenCL;
using DispatchData = CommonDispatchData;
ROIAlignKernelRef() : KernelBaseOpenCL{"roi_align_ref"} {}
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
bool Validate(const Params&, const optional_params&) const override;
protected:
JitConstants GetJitConstants(const roi_align_params& params) const;
};
} // namespace kernel_selector

View File

@ -0,0 +1,18 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "roi_align_kernel_selector.h"
#include "roi_align_kernel_ref.h"
namespace kernel_selector {
roi_align_kernel_selector::roi_align_kernel_selector() {
Attach<ROIAlignKernelRef>();
}
KernelsData roi_align_kernel_selector::GetBestKernels(const Params &params,
const optional_params &options) const {
return GetNaiveBestKernel(params, options, KernelType::ROI_ALIGN);
}
} // namespace kernel_selector

View File

@ -0,0 +1,20 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_selector.h"
namespace kernel_selector {
class roi_align_kernel_selector : public kernel_selector_base {
public:
static roi_align_kernel_selector& Instance() {
static roi_align_kernel_selector instance_;
return instance_;
}
roi_align_kernel_selector();
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,114 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "include/batch_headers/common.cl"
#include "include/batch_headers/data_types.cl"
#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define NUM_ROIS OUTPUT_BATCH_NUM
#define NUM_CHANNELS INPUT0_FEATURE_NUM
#define POOLED_WIDTH OUTPUT_SIZE_X
#define POOLED_HEIGHT OUTPUT_SIZE_Y
KERNEL(roi_align_ref)
(
const __global INPUT0_TYPE * src_data,
__global OUTPUT_TYPE * dst_data,
const __global INPUT1_TYPE * src_rois,
const __global INPUT2_TYPE * src_batches
)
{
const size_t i = get_global_id(0);
const uint x = i % POOLED_WIDTH;
const uint y = i / POOLED_WIDTH % POOLED_HEIGHT;
const uint c = i / POOLED_WIDTH / POOLED_HEIGHT % NUM_CHANNELS;
const uint r = i / POOLED_WIDTH / POOLED_HEIGHT / NUM_CHANNELS % NUM_ROIS;
const __global INPUT1_TYPE* roi_ptr = &src_rois[INPUT1_BATCH_PITCH * r];
// Get ROI`s corners
const INPUT1_TYPE x1 = *roi_ptr * (INPUT1_TYPE) SPATIAL_SCALE;
const INPUT1_TYPE y1 = roi_ptr[1] * (INPUT1_TYPE) SPATIAL_SCALE;
const INPUT1_TYPE x2 = roi_ptr[2] * (INPUT1_TYPE) SPATIAL_SCALE;
const INPUT1_TYPE y2 = roi_ptr[3] * (INPUT1_TYPE) SPATIAL_SCALE;
const INPUT1_TYPE roi_width = MAX(x2 - x1, (INPUT1_TYPE) 1.0);
const INPUT1_TYPE roi_height = MAX(y2 - y1, (INPUT1_TYPE) 1.0);
const INPUT1_TYPE bin_width = roi_width / POOLED_WIDTH;
const INPUT1_TYPE bin_height = roi_height / POOLED_HEIGHT;
const int sampling_ratio_x = SAMPLING_RATIO == 0 ? (int) ceil(bin_width) : SAMPLING_RATIO;
const int sampling_ratio_y = SAMPLING_RATIO == 0 ? (int) ceil(bin_height) : SAMPLING_RATIO;
const INPUT1_TYPE sample_distance_x = bin_width / (INPUT1_TYPE) sampling_ratio_x;
const INPUT1_TYPE sample_distance_y = bin_height / (INPUT1_TYPE) sampling_ratio_y;
const __global INPUT0_TYPE* data = src_data + INPUT0_OFFSET + r*INPUT0_BATCH_PITCH + INPUT0_FEATURE_PITCH*c;
OUTPUT_TYPE pooled_value = 0;
for (unsigned int y_sample_ind = 0; y_sample_ind < sampling_ratio_y; y_sample_ind++) {
INPUT1_TYPE sample_y = y1 + (INPUT1_TYPE) y * bin_height +
sample_distance_y * ((INPUT1_TYPE) y_sample_ind + (INPUT1_TYPE) 0.5f);
for (unsigned int x_sample_ind = 0; x_sample_ind < sampling_ratio_x; x_sample_ind++) {
INPUT1_TYPE sample_x = x1 + (INPUT1_TYPE) x * bin_width +
sample_distance_x * ((INPUT1_TYPE) x_sample_ind + (INPUT1_TYPE) 0.5f);
unsigned int sample_y_low = 0;
unsigned int sample_x_low = 0;
unsigned int sample_y_high = 0;
unsigned int sample_x_high = 0;
INPUT1_TYPE weight_left = (INPUT1_TYPE) 0.f;
INPUT1_TYPE weight_right = (INPUT1_TYPE) 0.f;
INPUT1_TYPE weight_top = (INPUT1_TYPE) 0.f;
INPUT1_TYPE weight_bottom = (INPUT1_TYPE) 0.f;
if (sample_x >= -1.0 || sample_x <= INPUT0_SIZE_X || sample_y >= -1.0 || sample_y <= INPUT0_SIZE_Y) {
sample_x = MAX(sample_x, (INPUT1_TYPE) 0.f);
sample_y = MAX(sample_y, (INPUT1_TYPE) 0.f);
sample_y_low = (unsigned int) sample_y;
sample_x_low = (unsigned int) sample_x;
if (sample_y_low >= INPUT0_SIZE_Y - 1) {
sample_y_high = sample_y_low = INPUT0_SIZE_Y - 1;
sample_y = (INPUT1_TYPE) sample_y_low;
} else {
sample_y_high = sample_y_low + 1;
}
if (sample_x_low >= INPUT0_SIZE_X - 1) {
sample_x_high = sample_x_low = INPUT0_SIZE_X - 1;
sample_x = (INPUT1_TYPE) sample_x_low;
} else {
sample_x_high = sample_x_low + 1;
}
// weight calculation for bilinear interpolation
weight_top = sample_y - (INPUT1_TYPE) sample_y_low;
weight_left = sample_x - (INPUT1_TYPE) sample_x_low;
weight_bottom = (INPUT1_TYPE) 1.f - weight_top;
weight_right = (INPUT1_TYPE) 1.f - weight_left;
}
const INPUT0_TYPE top_left = data[sample_y_low * INPUT0_Y_PITCH + sample_x_low * INPUT0_X_PITCH];
const INPUT0_TYPE top_right = data[sample_y_low * INPUT0_Y_PITCH + sample_x_high * INPUT0_X_PITCH];
const INPUT0_TYPE bottom_left = data[sample_y_high * INPUT0_Y_PITCH + sample_x_low * INPUT0_X_PITCH];
const INPUT0_TYPE bottom_right = data[sample_y_high * INPUT0_Y_PITCH + sample_x_high * INPUT0_X_PITCH];
const INPUT0_TYPE interpolated_value = weight_bottom * weight_right * top_left +
weight_bottom * weight_left * top_right +
weight_top * weight_right * bottom_left +
weight_top * weight_left * bottom_right;
#if MAX_POOL
pooled_value = MAX(pooled_value, interpolated_value);
#elif AVG_POOL
pooled_value += interpolated_value;
#endif
}
}
#if AVG_POOL
pooled_value /= sampling_ratio_x * sampling_ratio_x;
#endif
const uint output_offset = OUTPUT_OFFSET + x*OUTPUT_X_PITCH + y*OUTPUT_Y_PITCH + c*OUTPUT_FEATURE_PITCH + r*OUTPUT_BATCH_PITCH;
dst_data[output_offset] = ACTIVATION((OUTPUT_TYPE)pooled_value, ACTIVATION_PARAMS);
}

View File

@ -54,6 +54,7 @@ void register_implementations() {
REGISTER_OCL(reorg_yolo);
REGISTER_OCL(reshape);
REGISTER_OCL(reverse_sequence);
REGISTER_OCL(roi_align);
REGISTER_OCL(roi_pooling);
REGISTER_OCL(scale);
REGISTER_OCL(scatter_update);

View File

@ -45,6 +45,7 @@
#include "cldnn/primitives/reorg_yolo.hpp"
#include "cldnn/primitives/reshape.hpp"
#include "cldnn/primitives/reverse_sequence.hpp"
#include "cldnn/primitives/roi_align.hpp"
#include "cldnn/primitives/roi_pooling.hpp"
#include "cldnn/primitives/scale.hpp"
#include "cldnn/primitives/scatter_update.hpp"
@ -120,6 +121,7 @@ REGISTER_OCL(reorder);
REGISTER_OCL(reorg_yolo);
REGISTER_OCL(reshape);
REGISTER_OCL(reverse_sequence);
REGISTER_OCL(roi_align);
REGISTER_OCL(roi_pooling);
REGISTER_OCL(scale);
REGISTER_OCL(scatter_update);

View File

@ -0,0 +1,103 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "roi_align_inst.h"
#include "primitive_base.hpp"
#include "impls/implementation_map.hpp"
#include "cldnn/runtime/error_handler.hpp"
#include "kernel_selector_helper.h"
#include "roi_align/roi_align_kernel_selector.h"
#include "roi_align/roi_align_kernel_ref.h"
namespace cldnn {
namespace ocl {
namespace {
kernel_selector::pool_type from(roi_align::PoolingMode mode) {
switch (mode) {
case roi_align::PoolingMode::Max:
return kernel_selector::pool_type::MAX;
default:
case roi_align::PoolingMode::Avg:
return kernel_selector::pool_type::AVG;
}
}
} // namespace
struct roi_align_impl : typed_primitive_impl_ocl<roi_align> {
using parent = typed_primitive_impl_ocl<roi_align>;
using parent::parent;
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<roi_align_impl>(*this);
}
protected:
kernel_arguments_data get_arguments(typed_primitive_inst<roi_align>& instance, int32_t) const override {
kernel_arguments_data args;
args.inputs = { instance.input_memory_ptr(), instance.rois_memory(), instance.batches_memory() };
args.output = instance.output_memory_ptr();
return args;
}
public:
static primitive_impl* create(const roi_align_node& arg) {
const auto& input_layout = arg.input().get_output_layout();
const auto& output_layout = arg.get_output_layout();
const auto& rois_layout = arg.input(1).get_output_layout();
const auto& batches_layout = arg.input(2).get_output_layout();
const auto& primitive = arg.get_primitive();
const auto padding_filling_value = output_layout.data_padding.filling_value();
CLDNN_ERROR_NOT_EQUAL(arg.id(),
"roi_align padding filling value",
padding_filling_value,
"padding mode",
0.0f,
"Unknown padding mode in roi_align.");
CLDNN_ERROR_NOT_PROPER_FORMAT(arg.id(),
"Input_layout.format",
input_layout.format.value,
"output_layout.format",
output_layout.format);
auto roi_align_params = get_default_params<kernel_selector::roi_align_params>(arg);
auto roi_align_optional_params =
get_default_optional_params<kernel_selector::roi_align_optional_params>(arg.get_program());
const auto roi_bfyx = convert_data_tensor(rois_layout);
roi_align_params.inputs.push_back(roi_bfyx.FlattenFeatureAndSpatials());
roi_align_params.inputs.push_back(convert_data_tensor(batches_layout));
roi_align_params.mode = from(primitive->mode);
roi_align_params.sampling_ratio = primitive->sampling_ratio;
roi_align_params.spatial_scale = primitive->spatial_scale;
auto& kernel_selector = kernel_selector::roi_align_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(roi_align_params, roi_align_optional_params);
CLDNN_ERROR_BOOL(arg.id(),
"Best_kernel.empty()",
best_kernels.empty(),
"Cannot find a proper kernel with this arguments");
auto roi_align = new roi_align_impl(arg, best_kernels[0]);
return roi_align;
}
};
namespace detail {
attach_roi_align_impl::attach_roi_align_impl() {
implementation_map<roi_align>::add(impl_types::ocl, roi_align_impl::create,
{
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::f32, format::bfyx),
});
}
} // namespace detail
} // namespace ocl
} // namespace cldnn

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cldnn/primitives/roi_align.hpp>
#include "primitive_inst.h"
#include <cldnn/runtime/error_handler.hpp>
namespace cldnn {
template <>
struct typed_program_node<roi_align> : public typed_program_node_base<roi_align> {
using parent = typed_program_node_base<roi_align>;
public:
using parent::parent;
program_node& input(std::size_t index = 0) const { return get_dependency(index); }
};
using roi_align_node = typed_program_node<roi_align>;
template <>
class typed_primitive_inst<roi_align> : public typed_primitive_inst_base<roi_align> {
using parent = typed_primitive_inst_base<roi_align>;
public:
static layout calc_output_layout(roi_align_node const& node);
static std::string to_string(roi_align_node const& node);
public:
typed_primitive_inst(network& network, roi_align_node const& desc);
memory::ptr rois_memory() const { return dep_memory_ptr(1); }
memory::ptr batches_memory() const { return dep_memory_ptr(2); }
};
using roi_align_inst = typed_primitive_inst<roi_align>;
} // namespace cldnn

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <roi_align_inst.h>
#include "primitive_type_base.h"
#include <sstream>
#include <json_object.h>
namespace cldnn {
primitive_type_id roi_align::type_id() {
static primitive_type_base<roi_align> instance;
return &instance;
}
roi_align_inst::typed_primitive_inst(network& network, roi_align_node const& node)
: parent(network, node) {}
layout roi_align_inst::calc_output_layout(roi_align_node const& node) {
auto primitive = node.get_primitive();
auto input_layout = node.input(0).get_output_layout();
auto rois_layout = node.input(0).get_output_layout();
auto num_rois = rois_layout.size.batch[0];
auto num_channels = input_layout.size.feature[0];
return layout(input_layout.data_type, format::bfyx, {num_rois, num_channels, primitive->pooled_h, primitive->pooled_w});
}
std::string roi_align_inst::to_string(roi_align_node const& node) {
auto node_info = node.desc_to_json();
json_composite roi_align_info;
roi_align_info.add("input id", node.input().id());
roi_align_info.add("rois id", node.get_dependency(1).id());
roi_align_info.add("batches id", node.get_dependency(2).id());
roi_align_info.add("pooled_h", node.get_primitive()->pooled_h);
roi_align_info.add("pooled_w", node.get_primitive()->pooled_w);
roi_align_info.add("sampling_ratio", node.get_primitive()->sampling_ratio);
roi_align_info.add("spatial_scale", node.get_primitive()->spatial_scale);
roi_align_info.add("mode", node.get_primitive()->mode == roi_align::PoolingMode::Max ? "Max" : "Avg");
node_info->add("roi_align info", roi_align_info);
std::stringstream primitive_description;
node_info->dump(primitive_description);
return primitive_description.str();
}
} // namespace cldnn

View File

@ -166,7 +166,7 @@ REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, Bucketize);
// REGISTER_FACTORY(v3, GRUCell);
// REGISTER_FACTORY(v3, NonZero);
// REGISTER_FACTORY(v3, ROIAlign);
REGISTER_FACTORY(v3, ROIAlign);
// REGISTER_FACTORY(v3, ReadValue);
// REGISTER_FACTORY(v3, ShapeOf);
// REGISTER_FACTORY(v3, TopK);

View File

@ -0,0 +1,42 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "cldnn_program.h"
#include "cldnn_common_utils.h"
#include "ngraph/op/roi_align.hpp"
#include "cldnn/primitives/roi_align.hpp"
#include <memory>
namespace CLDNNPlugin {
namespace {
cldnn::roi_align::PoolingMode from(ngraph::op::v3::ROIAlign::PoolingMode mode) {
switch (mode) {
case ngraph::op::v3::ROIAlign::PoolingMode::MAX:
return cldnn::roi_align::PoolingMode::Max;
case ngraph::op::v3::ROIAlign::PoolingMode::AVG:
default:
return cldnn::roi_align::PoolingMode::Avg;
}
}
void CreateROIAlignOp(Program& p, const std::shared_ptr<ngraph::op::v3::ROIAlign>& op) {
p.ValidateInputs(op, { 3 });
auto roi_align_prim = cldnn::roi_align(layer_type_name_ID(op),
p.GetInputPrimitiveIDs(op),
op->get_pooled_h(),
op->get_pooled_w(),
op->get_sampling_ratio(),
op->get_spatial_scale(),
from(op->get_mode()),
op->get_friendly_name());
p.AddPrimitive(roi_align_prim);
p.AddPrimitiveToProfiler(op);
}
} // anonymous namespace
REGISTER_FACTORY_IMPL(v3, ROIAlign);
} // namespace CLDNNPlugin