[GPU] ROIAlign-3 (#8991)
This commit is contained in:
parent
ad668d6ac6
commit
8ce22396b5
@ -0,0 +1,62 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/roi_align.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPRCs = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
// There is no possibility to test ROIAlign in fp16 precision,
|
||||
// because on edge cases where in fp32 version ROI value is
|
||||
// a little bit smaller than the nearest integer value,
|
||||
// it would be bigger than the nearest integer in fp16 precision.
|
||||
// Such behavior leads to completely different results of ROIAlign
|
||||
// in fp32 and fp16 precisions.
|
||||
// In real AI applications this problem is solved by precision-aware training.
|
||||
|
||||
// InferenceEngine::Precision::FP16,
|
||||
};
|
||||
|
||||
const auto ROIAlignCases_average =
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(
|
||||
std::vector<std::vector<size_t>> {
|
||||
{ 3, 8, 16, 16 },
|
||||
{ 2, 1, 16, 16 },
|
||||
{ 2, 1, 8, 16 }}),
|
||||
::testing::Values(std::vector<size_t>{ 2, 4 }),
|
||||
::testing::Values(2),
|
||||
::testing::Values(2),
|
||||
::testing::ValuesIn(std::vector<float> { 1, 0.625 }),
|
||||
::testing::Values(2),
|
||||
::testing::Values("avg"),
|
||||
::testing::ValuesIn(netPRCs),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_TestsROIAlign_average, ROIAlignLayerTest, ROIAlignCases_average, ROIAlignLayerTest::getTestCaseName);
|
||||
|
||||
const auto ROIAlignCases_max =
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(
|
||||
std::vector<std::vector<size_t>> {
|
||||
{ 2, 8, 20, 20 },
|
||||
{ 2, 1, 20, 20 },
|
||||
{ 2, 1, 10, 20 }
|
||||
}),
|
||||
::testing::Values(std::vector<size_t>{ 2, 4 }),
|
||||
::testing::Values(2),
|
||||
::testing::Values(2),
|
||||
::testing::ValuesIn(std::vector<float> { 1, 0.625 }),
|
||||
::testing::Values(2),
|
||||
::testing::Values("max"),
|
||||
::testing::ValuesIn(netPRCs),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_TestsROIAlign_max, ROIAlignLayerTest, ROIAlignCases_max, ROIAlignLayerTest::getTestCaseName);
|
69
inference-engine/thirdparty/clDNN/api/cldnn/primitives/roi_align.hpp
vendored
Normal file
69
inference-engine/thirdparty/clDNN/api/cldnn/primitives/roi_align.hpp
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
#include "primitive.hpp"
|
||||
#include <vector>
|
||||
|
||||
namespace cldnn {
|
||||
/// @addtogroup cpp_api C++ API
|
||||
/// @{
|
||||
/// @addtogroup cpp_topology Network Topology
|
||||
/// @{
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief ROIAlign is a pooling layer used over feature maps of
|
||||
/// non-uniform input sizes and outputs a feature map of a fixed size.
|
||||
struct roi_align : public primitive_base<roi_align> {
|
||||
CLDNN_DECLARE_PRIMITIVE(roi_align)
|
||||
|
||||
/// @brief Pooling mode for the @ref roi_align
|
||||
enum PoolingMode {
|
||||
Max,
|
||||
Avg
|
||||
};
|
||||
|
||||
/// @brief Constructs roi_align primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param inputs Inputs data primitive ids.
|
||||
/// @param pooled_h Height of the ROI output feature map.
|
||||
/// @param pooled_w Width of the ROI output feature map.
|
||||
/// @param sampling_ratio Number of bins over height and width to use to calculate each output feature map element.
|
||||
/// @param spatial_scale multiplicative spatial scale factor to translate ROI coordinates
|
||||
/// from their input spatial scale to the scale used when pooling.
|
||||
/// @param mode Method to perform pooling to produce output feature map elements.
|
||||
/// @param shrink_axis_mask Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
|
||||
roi_align(const primitive_id& id,
|
||||
const std::vector<primitive_id>& inputs,
|
||||
int pooled_h,
|
||||
int pooled_w,
|
||||
int sampling_ratio,
|
||||
float spatial_scale,
|
||||
PoolingMode mode,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, inputs, ext_prim_id, output_padding),
|
||||
pooled_h {pooled_h},
|
||||
pooled_w {pooled_w},
|
||||
sampling_ratio {sampling_ratio},
|
||||
spatial_scale {spatial_scale},
|
||||
mode {mode}
|
||||
{}
|
||||
|
||||
/// @brief Height of the ROI output feature map.
|
||||
int pooled_h;
|
||||
/// @brief Width of the ROI output feature map.
|
||||
int pooled_w;
|
||||
/// @brief Number of bins over height and width to use to calculate each output feature map element.
|
||||
int sampling_ratio;
|
||||
/// @brief multiplicative spatial scale factor to translate ROI coordinates
|
||||
/// from their input spatial scale to the scale used when pooling.
|
||||
float spatial_scale;
|
||||
/// @brief Method to perform pooling to produce output feature map elements.
|
||||
PoolingMode mode;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
/// @}
|
||||
} // namespace cldnn
|
@ -21,6 +21,7 @@ enum class KernelType {
|
||||
NORMALIZE,
|
||||
POOLING,
|
||||
ROI_POOLING,
|
||||
ROI_ALIGN,
|
||||
FULLY_CONNECTED,
|
||||
ACTIVATION,
|
||||
SOFT_MAX,
|
||||
|
@ -0,0 +1,90 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "roi_align_kernel_ref.h"
|
||||
#include <kernel_selector_utils.h>
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
ParamsKey ROIAlignKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
k.EnablePoolType(PoolType::MAX);
|
||||
k.EnablePoolType(PoolType::AVG);
|
||||
return k;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
ROIAlignKernelRef::DispatchData SetDefault(const roi_align_params& params) {
|
||||
ROIAlignKernelRef::DispatchData dispatchData;
|
||||
// Determine global work sizes.
|
||||
dispatchData.gws[0] = params.output.LogicalSize();
|
||||
dispatchData.gws[1] = 1;
|
||||
dispatchData.gws[2] = 1;
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
|
||||
return dispatchData;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
KernelsData ROIAlignKernelRef::GetKernelsData(const Params ¶ms, const optional_params &options) const {
|
||||
if (!Validate(params, options)) {
|
||||
return {};
|
||||
}
|
||||
KernelData kernel_data = KernelData::Default<roi_align_params>(params);
|
||||
roi_align_params &new_params = dynamic_cast<roi_align_params&>(*kernel_data.params.get());
|
||||
auto dispatch_data = SetDefault(new_params);
|
||||
auto entry_point = GetEntryPoint(kernelName, new_params.layerID, params, options);
|
||||
auto roi_align_specific_jit = GetJitConstants(new_params);
|
||||
auto jit = CreateJit(kernelName, roi_align_specific_jit, entry_point);
|
||||
FillCLKernelData(kernel_data.kernels[0], dispatch_data, params.engineInfo,
|
||||
kernelName, jit, entry_point);
|
||||
kernel_data.kernels[0].params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
|
||||
kernel_data.kernels[0].params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2});
|
||||
|
||||
return {kernel_data};
|
||||
}
|
||||
|
||||
float ROIAlignKernelRef::GetKernelsPriority(const Params ¶ms, const optional_params &options) const {
|
||||
return FORCE_PRIORITY_1;
|
||||
}
|
||||
|
||||
bool ROIAlignKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||
if (p.GetType() != KernelType::ROI_ALIGN || o.GetType() != KernelType::ROI_ALIGN) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const roi_align_params ¶ms = static_cast<const roi_align_params&>(p);
|
||||
if (params.inputs.size() != 3)
|
||||
return false;
|
||||
|
||||
if (params.output.Dimentions() > 4 || params.inputs[0].Dimentions() > 4 || params.inputs[1].Dimentions() > 2)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
JitConstants ROIAlignKernelRef::GetJitConstants(const roi_align_params ¶ms) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
jit.AddConstant(MakeJitConstant("SPATIAL_SCALE", params.spatial_scale));
|
||||
jit.AddConstant(MakeJitConstant("SAMPLING_RATIO", params.sampling_ratio));
|
||||
if (params.mode == PoolType::MAX)
|
||||
jit.AddConstant(MakeJitConstant("MAX_POOL", true));
|
||||
else if (params.mode == PoolType::AVG)
|
||||
jit.AddConstant(MakeJitConstant("AVG_POOL", true));
|
||||
return jit;
|
||||
}
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <kernel_base_opencl.h>
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
struct roi_align_params : public base_params {
|
||||
roi_align_params() : base_params{KernelType::ROI_ALIGN} {}
|
||||
|
||||
int sampling_ratio = 0;
|
||||
float spatial_scale = 1.f;
|
||||
PoolType mode = PoolType::MAX;
|
||||
|
||||
ParamsKey GetParamsKey() const override {
|
||||
auto k = base_params::GetParamsKey();
|
||||
k.EnablePoolType(mode);
|
||||
return k;
|
||||
}
|
||||
};
|
||||
|
||||
struct roi_align_optional_params : optional_params {
|
||||
roi_align_optional_params() : optional_params{KernelType::ROI_ALIGN} {}
|
||||
};
|
||||
|
||||
class ROIAlignKernelRef : public KernelBaseOpenCL {
|
||||
public:
|
||||
using KernelBaseOpenCL::KernelBaseOpenCL;
|
||||
|
||||
using DispatchData = CommonDispatchData;
|
||||
|
||||
ROIAlignKernelRef() : KernelBaseOpenCL{"roi_align_ref"} {}
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
bool Validate(const Params&, const optional_params&) const override;
|
||||
|
||||
protected:
|
||||
JitConstants GetJitConstants(const roi_align_params& params) const;
|
||||
};
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,18 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "roi_align_kernel_selector.h"
|
||||
#include "roi_align_kernel_ref.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
roi_align_kernel_selector::roi_align_kernel_selector() {
|
||||
Attach<ROIAlignKernelRef>();
|
||||
}
|
||||
|
||||
KernelsData roi_align_kernel_selector::GetBestKernels(const Params ¶ms,
|
||||
const optional_params &options) const {
|
||||
return GetNaiveBestKernel(params, options, KernelType::ROI_ALIGN);
|
||||
}
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
class roi_align_kernel_selector : public kernel_selector_base {
|
||||
public:
|
||||
static roi_align_kernel_selector& Instance() {
|
||||
static roi_align_kernel_selector instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
roi_align_kernel_selector();
|
||||
|
||||
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
114
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/roi_align_ref.cl
vendored
Normal file
114
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/roi_align_ref.cl
vendored
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "include/batch_headers/common.cl"
|
||||
#include "include/batch_headers/data_types.cl"
|
||||
|
||||
#define MAX(a,b) ((a) > (b) ? (a) : (b))
|
||||
#define NUM_ROIS OUTPUT_BATCH_NUM
|
||||
#define NUM_CHANNELS INPUT0_FEATURE_NUM
|
||||
#define POOLED_WIDTH OUTPUT_SIZE_X
|
||||
#define POOLED_HEIGHT OUTPUT_SIZE_Y
|
||||
|
||||
KERNEL(roi_align_ref)
|
||||
(
|
||||
const __global INPUT0_TYPE * src_data,
|
||||
__global OUTPUT_TYPE * dst_data,
|
||||
const __global INPUT1_TYPE * src_rois,
|
||||
const __global INPUT2_TYPE * src_batches
|
||||
)
|
||||
{
|
||||
const size_t i = get_global_id(0);
|
||||
|
||||
const uint x = i % POOLED_WIDTH;
|
||||
const uint y = i / POOLED_WIDTH % POOLED_HEIGHT;
|
||||
const uint c = i / POOLED_WIDTH / POOLED_HEIGHT % NUM_CHANNELS;
|
||||
const uint r = i / POOLED_WIDTH / POOLED_HEIGHT / NUM_CHANNELS % NUM_ROIS;
|
||||
|
||||
const __global INPUT1_TYPE* roi_ptr = &src_rois[INPUT1_BATCH_PITCH * r];
|
||||
|
||||
// Get ROI`s corners
|
||||
const INPUT1_TYPE x1 = *roi_ptr * (INPUT1_TYPE) SPATIAL_SCALE;
|
||||
const INPUT1_TYPE y1 = roi_ptr[1] * (INPUT1_TYPE) SPATIAL_SCALE;
|
||||
const INPUT1_TYPE x2 = roi_ptr[2] * (INPUT1_TYPE) SPATIAL_SCALE;
|
||||
const INPUT1_TYPE y2 = roi_ptr[3] * (INPUT1_TYPE) SPATIAL_SCALE;
|
||||
|
||||
const INPUT1_TYPE roi_width = MAX(x2 - x1, (INPUT1_TYPE) 1.0);
|
||||
const INPUT1_TYPE roi_height = MAX(y2 - y1, (INPUT1_TYPE) 1.0);
|
||||
|
||||
const INPUT1_TYPE bin_width = roi_width / POOLED_WIDTH;
|
||||
const INPUT1_TYPE bin_height = roi_height / POOLED_HEIGHT;
|
||||
|
||||
const int sampling_ratio_x = SAMPLING_RATIO == 0 ? (int) ceil(bin_width) : SAMPLING_RATIO;
|
||||
const int sampling_ratio_y = SAMPLING_RATIO == 0 ? (int) ceil(bin_height) : SAMPLING_RATIO;
|
||||
|
||||
const INPUT1_TYPE sample_distance_x = bin_width / (INPUT1_TYPE) sampling_ratio_x;
|
||||
const INPUT1_TYPE sample_distance_y = bin_height / (INPUT1_TYPE) sampling_ratio_y;
|
||||
|
||||
const __global INPUT0_TYPE* data = src_data + INPUT0_OFFSET + r*INPUT0_BATCH_PITCH + INPUT0_FEATURE_PITCH*c;
|
||||
OUTPUT_TYPE pooled_value = 0;
|
||||
for (unsigned int y_sample_ind = 0; y_sample_ind < sampling_ratio_y; y_sample_ind++) {
|
||||
INPUT1_TYPE sample_y = y1 + (INPUT1_TYPE) y * bin_height +
|
||||
sample_distance_y * ((INPUT1_TYPE) y_sample_ind + (INPUT1_TYPE) 0.5f);
|
||||
for (unsigned int x_sample_ind = 0; x_sample_ind < sampling_ratio_x; x_sample_ind++) {
|
||||
INPUT1_TYPE sample_x = x1 + (INPUT1_TYPE) x * bin_width +
|
||||
sample_distance_x * ((INPUT1_TYPE) x_sample_ind + (INPUT1_TYPE) 0.5f);
|
||||
unsigned int sample_y_low = 0;
|
||||
unsigned int sample_x_low = 0;
|
||||
unsigned int sample_y_high = 0;
|
||||
unsigned int sample_x_high = 0;
|
||||
INPUT1_TYPE weight_left = (INPUT1_TYPE) 0.f;
|
||||
INPUT1_TYPE weight_right = (INPUT1_TYPE) 0.f;
|
||||
INPUT1_TYPE weight_top = (INPUT1_TYPE) 0.f;
|
||||
INPUT1_TYPE weight_bottom = (INPUT1_TYPE) 0.f;
|
||||
if (sample_x >= -1.0 || sample_x <= INPUT0_SIZE_X || sample_y >= -1.0 || sample_y <= INPUT0_SIZE_Y) {
|
||||
sample_x = MAX(sample_x, (INPUT1_TYPE) 0.f);
|
||||
sample_y = MAX(sample_y, (INPUT1_TYPE) 0.f);
|
||||
|
||||
sample_y_low = (unsigned int) sample_y;
|
||||
sample_x_low = (unsigned int) sample_x;
|
||||
|
||||
if (sample_y_low >= INPUT0_SIZE_Y - 1) {
|
||||
sample_y_high = sample_y_low = INPUT0_SIZE_Y - 1;
|
||||
sample_y = (INPUT1_TYPE) sample_y_low;
|
||||
} else {
|
||||
sample_y_high = sample_y_low + 1;
|
||||
}
|
||||
|
||||
if (sample_x_low >= INPUT0_SIZE_X - 1) {
|
||||
sample_x_high = sample_x_low = INPUT0_SIZE_X - 1;
|
||||
sample_x = (INPUT1_TYPE) sample_x_low;
|
||||
} else {
|
||||
sample_x_high = sample_x_low + 1;
|
||||
}
|
||||
|
||||
// weight calculation for bilinear interpolation
|
||||
weight_top = sample_y - (INPUT1_TYPE) sample_y_low;
|
||||
weight_left = sample_x - (INPUT1_TYPE) sample_x_low;
|
||||
weight_bottom = (INPUT1_TYPE) 1.f - weight_top;
|
||||
weight_right = (INPUT1_TYPE) 1.f - weight_left;
|
||||
}
|
||||
const INPUT0_TYPE top_left = data[sample_y_low * INPUT0_Y_PITCH + sample_x_low * INPUT0_X_PITCH];
|
||||
const INPUT0_TYPE top_right = data[sample_y_low * INPUT0_Y_PITCH + sample_x_high * INPUT0_X_PITCH];
|
||||
const INPUT0_TYPE bottom_left = data[sample_y_high * INPUT0_Y_PITCH + sample_x_low * INPUT0_X_PITCH];
|
||||
const INPUT0_TYPE bottom_right = data[sample_y_high * INPUT0_Y_PITCH + sample_x_high * INPUT0_X_PITCH];
|
||||
|
||||
const INPUT0_TYPE interpolated_value = weight_bottom * weight_right * top_left +
|
||||
weight_bottom * weight_left * top_right +
|
||||
weight_top * weight_right * bottom_left +
|
||||
weight_top * weight_left * bottom_right;
|
||||
#if MAX_POOL
|
||||
pooled_value = MAX(pooled_value, interpolated_value);
|
||||
#elif AVG_POOL
|
||||
pooled_value += interpolated_value;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#if AVG_POOL
|
||||
pooled_value /= sampling_ratio_x * sampling_ratio_x;
|
||||
#endif
|
||||
const uint output_offset = OUTPUT_OFFSET + x*OUTPUT_X_PITCH + y*OUTPUT_Y_PITCH + c*OUTPUT_FEATURE_PITCH + r*OUTPUT_BATCH_PITCH;
|
||||
dst_data[output_offset] = ACTIVATION((OUTPUT_TYPE)pooled_value, ACTIVATION_PARAMS);
|
||||
}
|
||||
|
@ -54,6 +54,7 @@ void register_implementations() {
|
||||
REGISTER_OCL(reorg_yolo);
|
||||
REGISTER_OCL(reshape);
|
||||
REGISTER_OCL(reverse_sequence);
|
||||
REGISTER_OCL(roi_align);
|
||||
REGISTER_OCL(roi_pooling);
|
||||
REGISTER_OCL(scale);
|
||||
REGISTER_OCL(scatter_update);
|
||||
|
@ -45,6 +45,7 @@
|
||||
#include "cldnn/primitives/reorg_yolo.hpp"
|
||||
#include "cldnn/primitives/reshape.hpp"
|
||||
#include "cldnn/primitives/reverse_sequence.hpp"
|
||||
#include "cldnn/primitives/roi_align.hpp"
|
||||
#include "cldnn/primitives/roi_pooling.hpp"
|
||||
#include "cldnn/primitives/scale.hpp"
|
||||
#include "cldnn/primitives/scatter_update.hpp"
|
||||
@ -120,6 +121,7 @@ REGISTER_OCL(reorder);
|
||||
REGISTER_OCL(reorg_yolo);
|
||||
REGISTER_OCL(reshape);
|
||||
REGISTER_OCL(reverse_sequence);
|
||||
REGISTER_OCL(roi_align);
|
||||
REGISTER_OCL(roi_pooling);
|
||||
REGISTER_OCL(scale);
|
||||
REGISTER_OCL(scatter_update);
|
||||
|
103
inference-engine/thirdparty/clDNN/src/impls/ocl/roi_align.cpp
vendored
Normal file
103
inference-engine/thirdparty/clDNN/src/impls/ocl/roi_align.cpp
vendored
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "roi_align_inst.h"
|
||||
#include "primitive_base.hpp"
|
||||
#include "impls/implementation_map.hpp"
|
||||
#include "cldnn/runtime/error_handler.hpp"
|
||||
#include "kernel_selector_helper.h"
|
||||
#include "roi_align/roi_align_kernel_selector.h"
|
||||
#include "roi_align/roi_align_kernel_ref.h"
|
||||
|
||||
namespace cldnn {
|
||||
namespace ocl {
|
||||
|
||||
namespace {
|
||||
kernel_selector::pool_type from(roi_align::PoolingMode mode) {
|
||||
switch (mode) {
|
||||
case roi_align::PoolingMode::Max:
|
||||
return kernel_selector::pool_type::MAX;
|
||||
default:
|
||||
case roi_align::PoolingMode::Avg:
|
||||
return kernel_selector::pool_type::AVG;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct roi_align_impl : typed_primitive_impl_ocl<roi_align> {
|
||||
using parent = typed_primitive_impl_ocl<roi_align>;
|
||||
using parent::parent;
|
||||
|
||||
std::unique_ptr<primitive_impl> clone() const override {
|
||||
return make_unique<roi_align_impl>(*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
kernel_arguments_data get_arguments(typed_primitive_inst<roi_align>& instance, int32_t) const override {
|
||||
kernel_arguments_data args;
|
||||
args.inputs = { instance.input_memory_ptr(), instance.rois_memory(), instance.batches_memory() };
|
||||
args.output = instance.output_memory_ptr();
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const roi_align_node& arg) {
|
||||
const auto& input_layout = arg.input().get_output_layout();
|
||||
const auto& output_layout = arg.get_output_layout();
|
||||
const auto& rois_layout = arg.input(1).get_output_layout();
|
||||
const auto& batches_layout = arg.input(2).get_output_layout();
|
||||
const auto& primitive = arg.get_primitive();
|
||||
|
||||
const auto padding_filling_value = output_layout.data_padding.filling_value();
|
||||
|
||||
CLDNN_ERROR_NOT_EQUAL(arg.id(),
|
||||
"roi_align padding filling value",
|
||||
padding_filling_value,
|
||||
"padding mode",
|
||||
0.0f,
|
||||
"Unknown padding mode in roi_align.");
|
||||
CLDNN_ERROR_NOT_PROPER_FORMAT(arg.id(),
|
||||
"Input_layout.format",
|
||||
input_layout.format.value,
|
||||
"output_layout.format",
|
||||
output_layout.format);
|
||||
|
||||
auto roi_align_params = get_default_params<kernel_selector::roi_align_params>(arg);
|
||||
auto roi_align_optional_params =
|
||||
get_default_optional_params<kernel_selector::roi_align_optional_params>(arg.get_program());
|
||||
|
||||
const auto roi_bfyx = convert_data_tensor(rois_layout);
|
||||
roi_align_params.inputs.push_back(roi_bfyx.FlattenFeatureAndSpatials());
|
||||
roi_align_params.inputs.push_back(convert_data_tensor(batches_layout));
|
||||
roi_align_params.mode = from(primitive->mode);
|
||||
roi_align_params.sampling_ratio = primitive->sampling_ratio;
|
||||
roi_align_params.spatial_scale = primitive->spatial_scale;
|
||||
|
||||
auto& kernel_selector = kernel_selector::roi_align_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(roi_align_params, roi_align_optional_params);
|
||||
|
||||
CLDNN_ERROR_BOOL(arg.id(),
|
||||
"Best_kernel.empty()",
|
||||
best_kernels.empty(),
|
||||
"Cannot find a proper kernel with this arguments");
|
||||
|
||||
auto roi_align = new roi_align_impl(arg, best_kernels[0]);
|
||||
|
||||
return roi_align;
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_roi_align_impl::attach_roi_align_impl() {
|
||||
implementation_map<roi_align>::add(impl_types::ocl, roi_align_impl::create,
|
||||
{
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ocl
|
||||
} // namespace cldnn
|
40
inference-engine/thirdparty/clDNN/src/include/roi_align_inst.h
vendored
Normal file
40
inference-engine/thirdparty/clDNN/src/include/roi_align_inst.h
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cldnn/primitives/roi_align.hpp>
|
||||
#include "primitive_inst.h"
|
||||
#include <cldnn/runtime/error_handler.hpp>
|
||||
|
||||
namespace cldnn {
|
||||
|
||||
template <>
|
||||
struct typed_program_node<roi_align> : public typed_program_node_base<roi_align> {
|
||||
using parent = typed_program_node_base<roi_align>;
|
||||
|
||||
public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input(std::size_t index = 0) const { return get_dependency(index); }
|
||||
};
|
||||
|
||||
using roi_align_node = typed_program_node<roi_align>;
|
||||
|
||||
template <>
|
||||
class typed_primitive_inst<roi_align> : public typed_primitive_inst_base<roi_align> {
|
||||
using parent = typed_primitive_inst_base<roi_align>;
|
||||
|
||||
public:
|
||||
static layout calc_output_layout(roi_align_node const& node);
|
||||
static std::string to_string(roi_align_node const& node);
|
||||
|
||||
public:
|
||||
typed_primitive_inst(network& network, roi_align_node const& desc);
|
||||
memory::ptr rois_memory() const { return dep_memory_ptr(1); }
|
||||
memory::ptr batches_memory() const { return dep_memory_ptr(2); }
|
||||
};
|
||||
|
||||
using roi_align_inst = typed_primitive_inst<roi_align>;
|
||||
|
||||
} // namespace cldnn
|
46
inference-engine/thirdparty/clDNN/src/roi_align.cpp
vendored
Normal file
46
inference-engine/thirdparty/clDNN/src/roi_align.cpp
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <roi_align_inst.h>
|
||||
#include "primitive_type_base.h"
|
||||
#include <sstream>
|
||||
#include <json_object.h>
|
||||
|
||||
namespace cldnn {
|
||||
|
||||
primitive_type_id roi_align::type_id() {
|
||||
static primitive_type_base<roi_align> instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
roi_align_inst::typed_primitive_inst(network& network, roi_align_node const& node)
|
||||
: parent(network, node) {}
|
||||
|
||||
layout roi_align_inst::calc_output_layout(roi_align_node const& node) {
|
||||
auto primitive = node.get_primitive();
|
||||
auto input_layout = node.input(0).get_output_layout();
|
||||
auto rois_layout = node.input(0).get_output_layout();
|
||||
auto num_rois = rois_layout.size.batch[0];
|
||||
auto num_channels = input_layout.size.feature[0];
|
||||
return layout(input_layout.data_type, format::bfyx, {num_rois, num_channels, primitive->pooled_h, primitive->pooled_w});
|
||||
}
|
||||
|
||||
std::string roi_align_inst::to_string(roi_align_node const& node) {
|
||||
auto node_info = node.desc_to_json();
|
||||
json_composite roi_align_info;
|
||||
roi_align_info.add("input id", node.input().id());
|
||||
roi_align_info.add("rois id", node.get_dependency(1).id());
|
||||
roi_align_info.add("batches id", node.get_dependency(2).id());
|
||||
roi_align_info.add("pooled_h", node.get_primitive()->pooled_h);
|
||||
roi_align_info.add("pooled_w", node.get_primitive()->pooled_w);
|
||||
roi_align_info.add("sampling_ratio", node.get_primitive()->sampling_ratio);
|
||||
roi_align_info.add("spatial_scale", node.get_primitive()->spatial_scale);
|
||||
roi_align_info.add("mode", node.get_primitive()->mode == roi_align::PoolingMode::Max ? "Max" : "Avg");
|
||||
node_info->add("roi_align info", roi_align_info);
|
||||
std::stringstream primitive_description;
|
||||
node_info->dump(primitive_description);
|
||||
return primitive_description.str();
|
||||
}
|
||||
|
||||
} // namespace cldnn
|
@ -166,7 +166,7 @@ REGISTER_FACTORY(v3, ScatterNDUpdate);
|
||||
// REGISTER_FACTORY(v3, Bucketize);
|
||||
// REGISTER_FACTORY(v3, GRUCell);
|
||||
// REGISTER_FACTORY(v3, NonZero);
|
||||
// REGISTER_FACTORY(v3, ROIAlign);
|
||||
REGISTER_FACTORY(v3, ROIAlign);
|
||||
// REGISTER_FACTORY(v3, ReadValue);
|
||||
// REGISTER_FACTORY(v3, ShapeOf);
|
||||
// REGISTER_FACTORY(v3, TopK);
|
||||
|
42
src/plugins/intel_gpu/src/plugin/ops/roi_align.cpp
Normal file
42
src/plugins/intel_gpu/src/plugin/ops/roi_align.cpp
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "cldnn_program.h"
|
||||
#include "cldnn_common_utils.h"
|
||||
#include "ngraph/op/roi_align.hpp"
|
||||
#include "cldnn/primitives/roi_align.hpp"
|
||||
#include <memory>
|
||||
|
||||
namespace CLDNNPlugin {
|
||||
|
||||
namespace {
|
||||
|
||||
cldnn::roi_align::PoolingMode from(ngraph::op::v3::ROIAlign::PoolingMode mode) {
|
||||
switch (mode) {
|
||||
case ngraph::op::v3::ROIAlign::PoolingMode::MAX:
|
||||
return cldnn::roi_align::PoolingMode::Max;
|
||||
case ngraph::op::v3::ROIAlign::PoolingMode::AVG:
|
||||
default:
|
||||
return cldnn::roi_align::PoolingMode::Avg;
|
||||
}
|
||||
}
|
||||
|
||||
void CreateROIAlignOp(Program& p, const std::shared_ptr<ngraph::op::v3::ROIAlign>& op) {
|
||||
p.ValidateInputs(op, { 3 });
|
||||
auto roi_align_prim = cldnn::roi_align(layer_type_name_ID(op),
|
||||
p.GetInputPrimitiveIDs(op),
|
||||
op->get_pooled_h(),
|
||||
op->get_pooled_w(),
|
||||
op->get_sampling_ratio(),
|
||||
op->get_spatial_scale(),
|
||||
from(op->get_mode()),
|
||||
op->get_friendly_name());
|
||||
p.AddPrimitive(roi_align_prim);
|
||||
p.AddPrimitiveToProfiler(op);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_FACTORY_IMPL(v3, ROIAlign);
|
||||
|
||||
} // namespace CLDNNPlugin
|
Loading…
Reference in New Issue
Block a user