[GPU] Implement ExperimentalDetectronDetectionOutput operation (#11772)

* ExperimentalDetectronDetectionOutput: refine sorting criteria for NMS stage

This is to ensure the operation produces stable predictable results across
the possible sorting algorithm implementaions.
This property is useful for the operation testing.

* [GPU] Implement ExperimentalDetectronDetectionOutput operation

* [GPU] ExperimentalDetectronDetectionOutput: use vector types and operations in kernel

* Reformat changed files to make clang format checker happy

* [GPU] ExperimentalDetectronDetectionOutput: add another test case to the unit test

* [GPU] ExperimentalDetectronDetectionOutput: Add f16 test

* ExperimentalDetectronDetectionOutput: single-layer test: use all three outputs

* [GPU] ExperimentalDetectronDetectionOutput: increase single layer test coverage

More attribute permutations were added.
This commit is contained in:
opoluektov-lohika 2022-06-27 17:11:03 +03:00 committed by GitHub
parent 95a297ed68
commit 8a21e4e062
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1491 additions and 55 deletions

View File

@ -203,7 +203,7 @@ void nms_cf(const float* conf_data,
template <typename T>
bool SortScorePairDescend(const std::pair<float, T>& pair1, const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second));
}
} // namespace

View File

@ -118,9 +118,9 @@ static void refine_boxes(const float* boxes,
}
}
template <typename T>
static bool SortScorePairDescend(const std::pair<float, T>& pair1, const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
static bool SortScorePairDescend(const std::pair<float, std::pair<int, int>>& pair1,
const std::pair<float, std::pair<int, int>>& pair2) {
return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second));
}
struct ConfidenceComparator {
@ -362,7 +362,7 @@ void ExperimentalDetectronDetectionOutput::execute(dnnl::stream strm) {
std::partial_sort(conf_index_class_map.begin(),
conf_index_class_map.begin() + max_detections_per_image_,
conf_index_class_map.end(),
SortScorePairDescend<std::pair<int, int>>);
SortScorePairDescend);
conf_index_class_map.resize(max_detections_per_image_);
total_detections_num = max_detections_per_image_;
}

View File

@ -3,7 +3,7 @@
//
#ifndef REGISTER_FACTORY
#error "REGISTER_FACTORY is not defined"
# error "REGISTER_FACTORY is not defined"
#endif
// ------------------------------ Supported v0 ops ------------------------------ //
@ -190,7 +190,7 @@ REGISTER_FACTORY(v4, Swish);
REGISTER_FACTORY(v5, HSigmoid);
REGISTER_FACTORY(v5, LogSoftmax);
REGISTER_FACTORY(v5, LSTMSequence);
//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
// REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
REGISTER_FACTORY(v5, Round);
REGISTER_FACTORY(v5, GatherND);
REGISTER_FACTORY(v5, Loop);
@ -208,6 +208,7 @@ REGISTER_FACTORY(v6, ExperimentalDetectronPriorGridGenerator);
REGISTER_FACTORY(v6, ExperimentalDetectronROIFeatureExtractor);
REGISTER_FACTORY(v6, ExperimentalDetectronTopKROIs)
REGISTER_FACTORY(v6, ExperimentalDetectronGenerateProposalsSingleImage);
REGISTER_FACTORY(v6, ExperimentalDetectronDetectionOutput);
// ------------------------------ Supported v7 ops ------------------------------ //
REGISTER_FACTORY(v7, DFT);

View File

@ -0,0 +1,98 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <utility>
#include <vector>
#include "primitive.hpp"
namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{
/// @brief experimental detectron detection output
struct experimental_detectron_detection_output : public primitive_base<experimental_detectron_detection_output> {
CLDNN_DECLARE_PRIMITIVE(experimental_detectron_detection_output)
/// @brief Constructs experimental_detectron_detection_output primitive
/// @param id This primitive id
/// @param input_rois input rois
/// @param input_deltas input deltas
/// @param input_scores input scores
/// @param input_im_info image info
/// @param output_classes ROI scores
/// @param output_scores minimum box width and height
/// @param score_threshold a threshold to consider only detections whose score are larger than the threshold
/// @param nms_threshold a threshold to be used in the NMS stage
/// @param num_classes the number of detected classes
/// @param post_nms_count the maximum number of detections per class
/// @param max_detections_per_image the maximum number of detections per image
/// @param class_agnostic_box_regression specifies whether to delete background classes or not
/// @param max_delta_log_wh the maximum delta of logarithms for width and height
/// @param deltas_weights the weights for bounding boxes sizes deltas
experimental_detectron_detection_output(const primitive_id& id,
const primitive_id& input_rois,
const primitive_id& input_deltas,
const primitive_id& input_scores,
const primitive_id& input_im_info,
const primitive_id& output_classes,
const primitive_id& output_scores,
float score_threshold,
float nms_threshold,
int num_classes,
int post_nms_count,
int max_detections_per_image,
bool class_agnostic_box_regression,
float max_delta_log_wh,
std::vector<float> deltas_weights,
const primitive_id& ext_prim_id = "",
const padding& output_padding = {})
: primitive_base{id,
{input_rois, input_deltas, input_scores, input_im_info, output_classes, output_scores},
ext_prim_id,
output_padding},
output_classes{output_classes},
output_scores{output_scores},
score_threshold{score_threshold},
nms_threshold{nms_threshold},
num_classes{num_classes},
post_nms_count{post_nms_count},
class_agnostic_box_regression{class_agnostic_box_regression},
max_detections_per_image{max_detections_per_image},
max_delta_log_wh{max_delta_log_wh},
deltas_weights{std::move(deltas_weights)} {}
primitive_id output_classes;
primitive_id output_scores;
float score_threshold;
float nms_threshold;
int num_classes;
int post_nms_count;
int max_detections_per_image;
bool class_agnostic_box_regression;
float max_delta_log_wh;
std::vector<float> deltas_weights;
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
if (!output_classes.empty())
ret.emplace_back(output_classes);
if (!output_scores.empty())
ret.emplace_back(output_scores);
return ret;
}
};
/// @}
/// @}
/// @}
} // namespace cldnn

View File

@ -0,0 +1,49 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include "experimental_detectron_detection_output_inst.hpp"
#include "intel_gpu/runtime/error_handler.hpp"
#include "json_object.h"
#include "primitive_type_base.h"
namespace cldnn {
primitive_type_id experimental_detectron_detection_output::type_id() {
static primitive_type_base<experimental_detectron_detection_output> instance;
return &instance;
}
layout experimental_detectron_detection_output_inst::calc_output_layout(
const experimental_detectron_detection_output_node& node) {
const layout data_layout = node.input().get_output_layout();
auto desc = node.get_primitive();
return layout(data_layout.data_type, format::bfyx, {static_cast<int>(desc->max_detections_per_image), 4, 1, 1});
}
std::string experimental_detectron_detection_output_inst::to_string(
const experimental_detectron_detection_output_node& node) {
auto desc = node.get_primitive();
std::stringstream primitive_description;
json_composite ed_info;
ed_info.add("score_threshold", desc->score_threshold);
ed_info.add("nms_threshold", desc->nms_threshold);
ed_info.add("score_threshold", desc->score_threshold);
ed_info.add("max_delta_log_wh", desc->max_delta_log_wh);
ed_info.add("num_classes", desc->num_classes);
ed_info.add("post_nms_count", desc->post_nms_count);
ed_info.add("max_detections_per_image", desc->max_detections_per_image);
ed_info.add("class_agnostic_box_regression", desc->class_agnostic_box_regression);
ed_info.add("deltas_weights", desc->deltas_weights);
auto node_info = node.desc_to_json();
node_info->add("experimental_detectron_detection_output_info", ed_info);
node_info->dump(primitive_description);
return primitive_description.str();
}
} // namespace cldnn

View File

@ -0,0 +1,80 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "eddo/experimental_detectron_detection_output_kernel_ref.h"
#include "eddo/experimental_detectron_detection_output_kernel_selector.h"
#include "experimental_detectron_detection_output_inst.hpp"
#include "impls/implementation_map.hpp"
#include "kernel_selector_helper.h"
#include "primitive_base.hpp"
namespace cldnn {
namespace ocl {
struct experimental_detectron_detection_output_impl
: public typed_primitive_impl_ocl<experimental_detectron_detection_output> {
using parent = typed_primitive_impl_ocl<experimental_detectron_detection_output>;
using parent::parent;
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<experimental_detectron_detection_output_impl>(*this);
}
protected:
kernel_arguments_data get_arguments(typed_primitive_inst<experimental_detectron_detection_output>& instance,
int32_t unused) const override {
kernel_arguments_data args = parent::get_arguments(instance, unused);
args.inputs.push_back(instance.output_classes_memory());
args.inputs.push_back(instance.output_scores_memory());
return args;
}
public:
static primitive_impl* create(const experimental_detectron_detection_output_node& arg) {
auto params = get_default_params<kernel_selector::experimental_detectron_detection_output_params>(arg);
auto optional_params =
get_default_optional_params<kernel_selector::experimental_detectron_detection_output_optional_params>(
arg.get_program());
const auto& primitive = arg.get_primitive();
params.score_threshold = primitive->score_threshold;
params.nms_threshold = primitive->nms_threshold;
params.max_delta_log_wh = primitive->max_delta_log_wh;
params.num_classes = primitive->num_classes;
params.post_nms_count = primitive->post_nms_count;
params.max_detections_per_image = primitive->max_detections_per_image;
params.class_agnostic_box_regression = primitive->class_agnostic_box_regression;
params.deltas_weights = primitive->deltas_weights;
params.inputs.push_back(convert_data_tensor(arg.deltas().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.scores().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.image_size_info().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.output_classes_node().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.output_scores_node().get_output_layout()));
const auto& kernel_selector =
kernel_selector::experimental_detectron_detection_output_kernel_selector::Instance();
const auto best_kernels = kernel_selector.GetBestKernels(params, optional_params);
CLDNN_ERROR_BOOL(arg.id(),
"best_kernels.empty()",
best_kernels.empty(),
"Cannot find a proper kernel with this arguments");
return new experimental_detectron_detection_output_impl(arg, best_kernels[0]);
}
};
namespace detail {
attach_experimental_detectron_detection_output_impl::attach_experimental_detectron_detection_output_impl() {
implementation_map<experimental_detectron_detection_output>::add(
impl_types::ocl,
experimental_detectron_detection_output_impl::create,
{std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::f32, format::bfyx)});
}
} // namespace detail
} // namespace ocl
} // namespace cldnn

View File

@ -8,8 +8,7 @@
namespace cldnn {
namespace ocl {
#define REGISTER_OCL(prim) \
static detail::attach_##prim##_impl attach_##prim
#define REGISTER_OCL(prim) static detail::attach_##prim##_impl attach_##prim
void register_implementations() {
REGISTER_OCL(activation);
@ -31,6 +30,7 @@ void register_implementations() {
REGISTER_OCL(detection_output);
REGISTER_OCL(dft);
REGISTER_OCL(batch_to_space);
REGISTER_OCL(experimental_detectron_detection_output);
REGISTER_OCL(experimental_detectron_generate_proposals_single_image);
REGISTER_OCL(experimental_detectron_prior_grid_generator);
REGISTER_OCL(experimental_detectron_roi_feature_extractor);

View File

@ -5,6 +5,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "generic_layer.hpp"
#include "intel_gpu/primitives/activation.hpp"
#include "intel_gpu/primitives/arg_max_min.hpp"
#include "intel_gpu/primitives/average_unpooling.hpp"
@ -14,8 +15,10 @@
#include "intel_gpu/primitives/broadcast.hpp"
#include "intel_gpu/primitives/bucketize.hpp"
#include "intel_gpu/primitives/concatenation.hpp"
#include "intel_gpu/primitives/convert_color.hpp"
#include "intel_gpu/primitives/convolution.hpp"
#include "intel_gpu/primitives/crop.hpp"
#include "intel_gpu/primitives/ctc_greedy_decoder.hpp"
#include "intel_gpu/primitives/custom_gpu_primitive.hpp"
#include "intel_gpu/primitives/deconvolution.hpp"
#include "intel_gpu/primitives/depth_to_space.hpp"
@ -26,12 +29,16 @@
#include "intel_gpu/primitives/experimental_detectron_topk_rois.hpp"
#include "intel_gpu/primitives/fully_connected.hpp"
#include "intel_gpu/primitives/gather.hpp"
#include "intel_gpu/primitives/gather_nd.hpp"
#include "intel_gpu/primitives/gather_elements.hpp"
#include "intel_gpu/primitives/gather_nd.hpp"
#include "intel_gpu/primitives/gather_tree.hpp"
#include "intel_gpu/primitives/gemm.hpp"
#include "intel_gpu/primitives/grn.hpp"
#include "intel_gpu/primitives/lrn.hpp"
#include "intel_gpu/primitives/lstm.hpp"
#include "intel_gpu/primitives/lstm_dynamic.hpp"
#include "intel_gpu/primitives/lstm_dynamic_input.hpp"
#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp"
#include "intel_gpu/primitives/max_unpooling.hpp"
#include "intel_gpu/primitives/mutable_data.hpp"
#include "intel_gpu/primitives/mvn.hpp"
@ -48,15 +55,16 @@
#include "intel_gpu/primitives/region_yolo.hpp"
#include "intel_gpu/primitives/reorder.hpp"
#include "intel_gpu/primitives/reorg_yolo.hpp"
#include "intel_gpu/primitives/resample.hpp"
#include "intel_gpu/primitives/reshape.hpp"
#include "intel_gpu/primitives/reverse_sequence.hpp"
#include "intel_gpu/primitives/roi_align.hpp"
#include "intel_gpu/primitives/roi_pooling.hpp"
#include "intel_gpu/primitives/roll.hpp"
#include "intel_gpu/primitives/scale.hpp"
#include "intel_gpu/primitives/scatter_update.hpp"
#include "intel_gpu/primitives/scatter_elements_update.hpp"
#include "intel_gpu/primitives/scatter_nd_update.hpp"
#include "intel_gpu/primitives/scatter_update.hpp"
#include "intel_gpu/primitives/select.hpp"
#include "intel_gpu/primitives/shape_of.hpp"
#include "intel_gpu/primitives/shuffle_channels.hpp"
@ -65,15 +73,6 @@
#include "intel_gpu/primitives/space_to_batch.hpp"
#include "intel_gpu/primitives/strided_slice.hpp"
#include "intel_gpu/primitives/tile.hpp"
#include "intel_gpu/primitives/resample.hpp"
#include "intel_gpu/primitives/gather_tree.hpp"
#include "intel_gpu/primitives/lstm_dynamic_input.hpp"
#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp"
#include "intel_gpu/primitives/grn.hpp"
#include "intel_gpu/primitives/ctc_greedy_decoder.hpp"
#include "intel_gpu/primitives/convert_color.hpp"
#include "generic_layer.hpp"
namespace cldnn {
namespace ocl {
@ -81,9 +80,9 @@ void register_implementations();
namespace detail {
#define REGISTER_OCL(prim) \
struct attach_##prim##_impl { \
attach_##prim##_impl(); \
#define REGISTER_OCL(prim) \
struct attach_##prim##_impl { \
attach_##prim##_impl(); \
}
REGISTER_OCL(activation);
@ -106,6 +105,7 @@ REGISTER_OCL(deformable_interp);
REGISTER_OCL(depth_to_space);
REGISTER_OCL(detection_output);
REGISTER_OCL(dft);
REGISTER_OCL(experimental_detectron_detection_output);
REGISTER_OCL(experimental_detectron_generate_proposals_single_image);
REGISTER_OCL(experimental_detectron_prior_grid_generator);
REGISTER_OCL(experimental_detectron_roi_feature_extractor);

View File

@ -0,0 +1,66 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp"
#include "primitive_inst.h"
namespace cldnn {
template <>
struct typed_program_node<experimental_detectron_detection_output>
: public typed_program_node_base<experimental_detectron_detection_output> {
using parent = typed_program_node_base<experimental_detectron_detection_output>;
public:
using parent::parent;
program_node& input() const {
return get_dependency(0);
}
program_node& deltas() const {
return get_dependency(1);
}
program_node& scores() const {
return get_dependency(2);
}
program_node& image_size_info() const {
return get_dependency(3);
}
program_node& output_classes_node() const {
return get_dependency(4);
}
program_node& output_scores_node() const {
return get_dependency(5);
}
};
using experimental_detectron_detection_output_node = typed_program_node<experimental_detectron_detection_output>;
template <>
class typed_primitive_inst<experimental_detectron_detection_output>
: public typed_primitive_inst_base<experimental_detectron_detection_output> {
using parent = typed_primitive_inst_base<experimental_detectron_detection_output>;
public:
static layout calc_output_layout(const experimental_detectron_detection_output_node& node);
static std::string to_string(const experimental_detectron_detection_output_node& node);
typed_primitive_inst(network& network, const experimental_detectron_detection_output_node& node)
: parent(network, node) {}
memory::ptr output_classes_memory() const {
return dep_memory_ptr(4);
}
memory::ptr output_scores_memory() const {
return dep_memory_ptr(5);
}
};
using experimental_detectron_detection_output_inst = typed_primitive_inst<experimental_detectron_detection_output>;
} // namespace cldnn

View File

@ -79,6 +79,7 @@ enum class KernelType {
LOOP,
NON_MAX_SUPPRESSION,
DETECTION_OUTPUT,
EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT,
EXPERIMENTAL_DETECTRON_GENERATE_PROPOSALS_SINGLE_IMAGE,
EXPERIMENTAL_DETECTRON_PRIOR_GRID_GENERATOR,
EXPERIMENTAL_DETECTRON_ROI_FEATURE_EXTRACTOR,

View File

@ -0,0 +1,199 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "experimental_detectron_detection_output_kernel_ref.h"
#include <algorithm>
#include <string>
#include "kernel_selector_utils.h"
namespace kernel_selector {
ParamsKey ExperimentalDetectronDetectionOutputKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT64);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT64);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}
KernelsPriority ExperimentalDetectronDetectionOutputKernelRef::GetKernelsPriority(const Params&,
const optional_params&) const {
return DONT_USE_IF_HAVE_SOMETHING_ELSE;
}
bool ExperimentalDetectronDetectionOutputKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT ||
o.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {
return false;
}
return true;
}
constexpr int kBoxesInputIdx = 0;
constexpr int kDeltasInputIdx = 1;
constexpr int kScoresInputIdx = 2;
constexpr int kImInfoInputIdx = 3;
constexpr int kOutputClassesInputIdx = 4;
constexpr int kOutputScoresInputIdx = 5;
constexpr int kRefinedBoxesBufferIdx = 0;
constexpr int kRefinedBoxAreasBufferIdx = 1;
constexpr int kRefinedScoresBufferIdx = 2;
constexpr int kScoreClassIndexBufferIdx = 3;
constexpr int kDetectionCountBufferIdx = 4;
constexpr int kBufferCount = 5;
constexpr int kOutputIdx = 0;
JitConstants ExperimentalDetectronDetectionOutputKernelRef::GetJitConstants(
const experimental_detectron_detection_output_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);
jit.AddConstants({
MakeJitConstant("SCORE_THRESHOLD", params.score_threshold),
MakeJitConstant("NMS_THRESHOLD", params.nms_threshold),
MakeJitConstant("NUM_CLASSES", params.num_classes),
MakeJitConstant("POST_NMS_COUNT", params.post_nms_count),
MakeJitConstant("MAX_DETECTIONS_PER_IMAGE", params.max_detections_per_image),
MakeJitConstant("MAX_DELTA_LOG_WH", params.max_delta_log_wh),
MakeJitConstant("DELTA_WEIGHT_X", params.deltas_weights[0]),
MakeJitConstant("DELTA_WEIGHT_Y", params.deltas_weights[1]),
MakeJitConstant("DELTA_WEIGHT_LOG_W", params.deltas_weights[2]),
MakeJitConstant("DELTA_WEIGHT_LOG_H", params.deltas_weights[3]),
MakeJitConstant("ROI_COUNT", params.inputs[kScoresInputIdx].Batch().v),
MakeJitConstant("OUTPUT_INDICES_TYPE", "INPUT4_TYPE"),
});
return jit;
}
using DispatchData = CommonDispatchData;
void ExperimentalDetectronDetectionOutputKernelRef::PrepareKernelCommon(
const experimental_detectron_detection_output_params& params,
const optional_params& options,
std::vector<size_t> gws,
const std::string& stage_name,
size_t stage_index,
clKernelData& kernel) const {
DispatchData dispatch_data;
dispatch_data.gws = std::move(gws);
dispatch_data.lws = GetOptimalLocalWorkGroupSizes(dispatch_data.gws, params.engineInfo);
const auto entry_point = GetEntryPoint(kernelName, params.layerID, params, options, stage_index);
auto cldnn_jit = GetJitConstants(params);
cldnn_jit.AddConstant(MakeJitConstant(stage_name, "true"));
const auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
KernelBase::CheckDispatchData(kernelName, dispatch_data, params.engineInfo.maxWorkGroupSize);
kernel.params.workGroups.global = dispatch_data.gws;
kernel.params.workGroups.local = dispatch_data.lws;
kernel.code.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo);
}
void ExperimentalDetectronDetectionOutputKernelRef::PrepareRefineBoxesKernel(
const experimental_detectron_detection_output_params& params,
const optional_params& options,
clKernelData& kernel) const {
const size_t roi_count = params.inputs[kScoresInputIdx].Batch().v;
const size_t class_count = params.num_classes;
PrepareKernelCommon(params, options, {roi_count, class_count, 1}, "EDDO_STAGE_0_REFINE_BOXES", 0, kernel);
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kBoxesInputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kDeltasInputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kScoresInputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kImInfoInputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx});
}
void ExperimentalDetectronDetectionOutputKernelRef::PrepareNmsClassWiseKernel(
const experimental_detectron_detection_output_params& params,
const optional_params& options,
clKernelData& kernel) const {
PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_1_NMS", 1, kernel);
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
}
void ExperimentalDetectronDetectionOutputKernelRef::PrepareTopKDetectionsKernel(
const experimental_detectron_detection_output_params& params,
const optional_params& options,
clKernelData& kernel) const {
PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_2_TOPK", 2, kernel);
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
}
void ExperimentalDetectronDetectionOutputKernelRef::PrepareCopyOutputKernel(
const experimental_detectron_detection_output_params& params,
const optional_params& options,
clKernelData& kernel) const {
PrepareKernelCommon(params,
options,
{static_cast<size_t>(params.max_detections_per_image), 1, 1},
"EDDO_STAGE_3_COPY_OUTPUT",
3,
kernel);
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, kOutputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputClassesInputIdx});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputScoresInputIdx});
}
KernelsData ExperimentalDetectronDetectionOutputKernelRef::GetKernelsData(const Params& params,
const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
constexpr size_t kKernelCount = 4;
KernelData kd = KernelData::Default<experimental_detectron_detection_output_params>(params, kKernelCount);
const auto& eddo_params = static_cast<const experimental_detectron_detection_output_params&>(params);
const auto roi_count = eddo_params.inputs[kScoresInputIdx].Batch().v;
const auto class_count = static_cast<size_t>(eddo_params.num_classes);
kd.internalBufferDataType = Datatype::F32;
kd.internalBufferSizes.resize(kBufferCount);
kd.internalBufferSizes[kRefinedBoxesBufferIdx] = class_count * roi_count * 4 * sizeof(float);
kd.internalBufferSizes[kRefinedBoxAreasBufferIdx] = class_count * roi_count * sizeof(float);
kd.internalBufferSizes[kRefinedScoresBufferIdx] = class_count * roi_count * sizeof(float);
kd.internalBufferSizes[kScoreClassIndexBufferIdx] = class_count * roi_count * 12; // sizeof ScoreClassIndex
kd.internalBufferSizes[kDetectionCountBufferIdx] = sizeof(uint32_t);
PrepareRefineBoxesKernel(eddo_params, options, kd.kernels[0]);
PrepareNmsClassWiseKernel(eddo_params, options, kd.kernels[1]);
PrepareTopKDetectionsKernel(eddo_params, options, kd.kernels[2]);
PrepareCopyOutputKernel(eddo_params, options, kd.kernels[3]);
return {kd};
}
} // namespace kernel_selector

View File

@ -0,0 +1,63 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_base_opencl.h"
namespace kernel_selector {
struct experimental_detectron_detection_output_params : public base_params {
experimental_detectron_detection_output_params()
: base_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {}
float score_threshold;
float nms_threshold;
float max_delta_log_wh;
int num_classes;
int post_nms_count;
int max_detections_per_image;
bool class_agnostic_box_regression;
std::vector<float> deltas_weights;
};
struct experimental_detectron_detection_output_optional_params : public optional_params {
experimental_detectron_detection_output_optional_params()
: optional_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {}
};
class ExperimentalDetectronDetectionOutputKernelRef : public KernelBaseOpenCL {
public:
ExperimentalDetectronDetectionOutputKernelRef() : KernelBaseOpenCL("experimental_detectron_detection_output_ref") {}
~ExperimentalDetectronDetectionOutputKernelRef() = default;
protected:
bool Validate(const Params& p, const optional_params& o) const override;
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
private:
JitConstants GetJitConstants(const experimental_detectron_detection_output_params& params) const;
void PrepareKernelCommon(const experimental_detectron_detection_output_params& params,
const optional_params& options,
std::vector<size_t> gws,
const std::string& stage_name,
size_t stage_index,
clKernelData& kernel) const;
void PrepareRefineBoxesKernel(const experimental_detectron_detection_output_params&,
const optional_params&,
clKernelData&) const;
void PrepareNmsClassWiseKernel(const experimental_detectron_detection_output_params&,
const optional_params&,
clKernelData&) const;
void PrepareTopKDetectionsKernel(const experimental_detectron_detection_output_params&,
const optional_params&,
clKernelData&) const;
void PrepareCopyOutputKernel(const experimental_detectron_detection_output_params&,
const optional_params&,
clKernelData&) const;
};
} // namespace kernel_selector

View File

@ -0,0 +1,25 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "experimental_detectron_detection_output_kernel_selector.h"
#include "experimental_detectron_detection_output_kernel_ref.h"
namespace kernel_selector {
experimental_detectron_detection_output_kernel_selector::experimental_detectron_detection_output_kernel_selector() {
Attach<ExperimentalDetectronDetectionOutputKernelRef>();
}
experimental_detectron_detection_output_kernel_selector&
experimental_detectron_detection_output_kernel_selector::Instance() {
static experimental_detectron_detection_output_kernel_selector instance_;
return instance_;
}
KernelsData experimental_detectron_detection_output_kernel_selector::GetBestKernels(
const Params& params,
const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT);
}
} // namespace kernel_selector

View File

@ -0,0 +1,19 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_selector.h"
namespace kernel_selector {
class experimental_detectron_detection_output_kernel_selector : public kernel_selector_base {
public:
static experimental_detectron_detection_output_kernel_selector& Instance();
experimental_detectron_detection_output_kernel_selector();
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,319 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "include/batch_headers/data_types.cl"
#define INPUT_TYPE INPUT0_TYPE
#define INPUT_TYPE2 MAKE_VECTOR_TYPE(INPUT0_TYPE, 2)
#define INPUT_TYPE4 MAKE_VECTOR_TYPE(INPUT0_TYPE, 4)
#if INPUT0_TYPE_SIZE == 2 // f16
# define HALF_ONE (INPUT_TYPE2)(0.5h)
# define ZERO 0.0h
# define ONE 1.0h
#else
# define HALF_ONE (INPUT_TYPE2)(0.5f)
# define ZERO 0.0f
# define ONE 1.0f
#endif
#define ZERO2 (INPUT_TYPE2)(ZERO)
#define ZERO4 (INPUT_TYPE4)(ZERO)
#define COORDINATE_OFFSET (INPUT_TYPE2)(ONE)
#define DELTA_WEIGHTS (INPUT_TYPE4)(DELTA_WEIGHT_X, DELTA_WEIGHT_Y, DELTA_WEIGHT_LOG_W, DELTA_WEIGHT_LOG_H)
#define MAX_DELTA_LOG_SIZE (INPUT_TYPE2)(TO_INPUT0_TYPE(MAX_DELTA_LOG_WH))
typedef struct __attribute__((packed)) {
INPUT_TYPE score __attribute__((aligned(4)));
uint class_idx;
uint box_idx;
} FUNC(SCI);
#define ScoreClassIndex FUNC(SCI)
inline void FUNC(swap_info)(__global ScoreClassIndex* a, __global ScoreClassIndex* b) {
const ScoreClassIndex temp = *a;
*a = *b;
*b = temp;
}
inline int FUNC(partition)(__global ScoreClassIndex* arr, int l, int h) {
const INPUT_TYPE pivot_score = arr[h].score;
const size_t pivot_box_idx = arr[h].box_idx;
int i = (l - 1);
for (int j = l; j <= h - 1; j++) {
if ((arr[j].score > pivot_score) || (arr[j].score == pivot_score && arr[j].box_idx < pivot_box_idx)) {
i++;
FUNC_CALL(swap_info)(&arr[i], &arr[j]);
}
}
FUNC_CALL(swap_info)(&arr[i + 1], &arr[h]);
return (i + 1);
}
inline void FUNC(bubbleSortIterative)(__global ScoreClassIndex* arr, int l, int h) {
for (int i = 0; i < h - l; i++) {
bool swapped = false;
for (int j = l; j < h - i; j++) {
if ((arr[j].score > arr[j + 1].score) ||
(arr[j].score == arr[j + 1].score && arr[j].box_idx < arr[j + 1].box_idx)) {
FUNC_CALL(swap_info)(&arr[j], &arr[j + 1]);
swapped = true;
}
}
if (!swapped)
break;
}
}
inline void FUNC(quickSortIterative)(__global ScoreClassIndex* arr, int l, int h) {
// Create an auxiliary stack
const int kStackSize = 100;
int stack[kStackSize];
// initialize top of stack
int top = -1;
// push initial values of l and h to stack
stack[++top] = l;
stack[++top] = h;
// Keep popping from stack while is not empty
while (top >= 0) {
// Pop h and l
h = stack[top--];
l = stack[top--];
// Set pivot element at its correct position
// in sorted array
int p = FUNC_CALL(partition)(arr, l, h);
// If there are elements on left side of pivot,
// then push left side to stack
if (p - 1 > l) {
if (top >= (kStackSize - 1)) {
FUNC_CALL(bubbleSortIterative)(arr, l, p - 1);
} else {
stack[++top] = l;
stack[++top] = p - 1;
}
}
// If there are elements on right side of pivot,
// then push right side to stack
if (p + 1 < h) {
if (top >= (kStackSize - 1)) {
FUNC_CALL(bubbleSortIterative)(arr, p + 1, h);
} else {
stack[++top] = p + 1;
stack[++top] = h;
}
}
}
}
// FIXME: rename stages accordingly
#ifdef EDDO_STAGE_0_REFINE_BOXES
// 0. Refine boxes
KERNEL(eddo_ref_stage_0)
(const __global INPUT_TYPE* boxes,
const __global INPUT_TYPE* deltas,
const __global INPUT_TYPE* scores,
const __global INPUT_TYPE* im_info,
__global INPUT_TYPE* refined_boxes,
__global INPUT_TYPE* refined_box_areas,
__global INPUT_TYPE* refined_scores) {
const size_t roi_count = get_global_size(0);
size_t roi_idx = get_global_id(0);
size_t class_idx = get_global_id(1);
INPUT_TYPE4 box = vload4(roi_idx, boxes);
if (any(islessequal(box.hi - box.lo, ZERO2))) {
const int refined_offset = roi_count * class_idx + roi_idx;
refined_scores[refined_offset] = ZERO;
} else {
const int offset = NUM_CLASSES * roi_idx + class_idx;
// width & height of box
INPUT_TYPE2 box_size = (box.hi - box.lo + COORDINATE_OFFSET);
// center location of box
const INPUT_TYPE2 center = box.lo + HALF_ONE * box_size;
const INPUT_TYPE4 delta = vload4(offset, deltas) / DELTA_WEIGHTS;
// new center location according to deltas (dx, dy)
const INPUT_TYPE2 new_center = delta.lo * box_size + center;
// new width & height according to deltas d(log w), d(log h)
const INPUT_TYPE2 new_size = exp(min(delta.hi, MAX_DELTA_LOG_SIZE)) * box_size;
// update upper-left corner and lower-right corners respectively
INPUT_TYPE4 new_box =
(INPUT_TYPE4)(new_center - HALF_ONE * new_size, new_center + HALF_ONE * new_size - COORDINATE_OFFSET);
// adjust new corner locations to be within the image region
const INPUT_TYPE2 img_size = vload2(0, im_info).s10;
new_box = clamp(new_box, ZERO4, img_size.xyxy);
// recompute new width & height
const INPUT_TYPE2 new_box_size = new_box.hi - new_box.lo + COORDINATE_OFFSET;
const int refined_offset = roi_count * class_idx + roi_idx;
vstore4(new_box, refined_offset, refined_boxes);
refined_box_areas[refined_offset] = new_box_size.x * new_box_size.y;
refined_scores[refined_offset] = scores[offset];
}
}
#endif /* EDDO_STAGE_0_REFINE_BOXES */
#ifdef EDDO_STAGE_1_NMS
inline INPUT_TYPE FUNC(jaccard_overlap)(const __global INPUT_TYPE* refined_boxes,
const __global INPUT_TYPE* refined_box_areas,
size_t idx1,
size_t idx2) {
INPUT_TYPE4 box1 = vload4(idx1, refined_boxes);
INPUT_TYPE4 box2 = vload4(idx2, refined_boxes);
const bool bbox_not_covered = any(isgreater((INPUT_TYPE4)(box1.lo, box2.lo), (INPUT_TYPE4)(box2.hi, box1.hi)));
if (bbox_not_covered) {
return ZERO;
}
INPUT_TYPE2 intersect_min = max(box1.lo, box2.lo);
INPUT_TYPE2 intersect_max = min(box1.hi, box2.hi);
INPUT_TYPE2 intersect_size = intersect_max - intersect_min + COORDINATE_OFFSET;
if (any(islessequal(intersect_size, ZERO2))) {
return ZERO;
}
INPUT_TYPE intersect_area = intersect_size.x * intersect_size.y;
INPUT_TYPE bbox1_area = refined_box_areas[idx1];
INPUT_TYPE bbox2_area = refined_box_areas[idx2];
return intersect_area / (bbox1_area + bbox2_area - intersect_area);
}
inline void FUNC(nms_cf)(const __global INPUT_TYPE* refined_scores,
const __global INPUT_TYPE* refined_boxes,
const __global INPUT_TYPE* refined_box_areas,
size_t class_idx,
size_t roi_count,
__global ScoreClassIndex* score_class_index_map,
__global uint* detection_count) {
size_t count = 0;
for (size_t i = 0; i < roi_count; ++i) {
if (refined_scores[i] > SCORE_THRESHOLD) {
score_class_index_map[count] = (ScoreClassIndex){refined_scores[i], class_idx, i};
count++;
}
}
FUNC_CALL(quickSortIterative)(score_class_index_map, 0, count - 1);
int detections = 0;
for (size_t i = 0; i < count; ++i) {
const size_t idx = score_class_index_map[i].box_idx;
bool keep = true;
for (size_t k = 0; k < detections; ++k) {
const size_t kept_idx = score_class_index_map[k].box_idx;
INPUT_TYPE overlap = FUNC_CALL(jaccard_overlap)(refined_boxes, refined_box_areas, idx, kept_idx);
if (overlap > NMS_THRESHOLD) {
keep = false;
break;
}
}
if (keep) {
score_class_index_map[detections] = score_class_index_map[i];
detections++;
}
}
*detection_count = min(POST_NMS_COUNT, detections);
}
KERNEL(eddo_ref_stage_1)
(const __global INPUT_TYPE* refined_scores,
const __global INPUT_TYPE* refined_boxes,
const __global INPUT_TYPE* refined_box_areas,
__global ScoreClassIndex* score_class_index_map,
__global uint* detection_count) {
size_t total_detections_num = 0;
// FIXME: figure out how to parallelize this!!!
for (int class_idx = 0; class_idx < NUM_CLASSES; ++class_idx) {
FUNC_CALL(nms_cf)
(&refined_scores[ROI_COUNT * class_idx],
&refined_boxes[ROI_COUNT * 4 * class_idx],
&refined_box_areas[ROI_COUNT * class_idx],
class_idx,
ROI_COUNT,
&score_class_index_map[total_detections_num],
detection_count);
total_detections_num += *detection_count;
}
*detection_count = total_detections_num;
}
#endif /* EDDO_STAGE_1_NMS */
#ifdef EDDO_STAGE_2_TOPK
KERNEL(eddo_ref_stage_2)
(__global ScoreClassIndex* score_class_index_map, const __global uint* detection_count) {
if (*detection_count > MAX_DETECTIONS_PER_IMAGE) {
FUNC_CALL(quickSortIterative)(score_class_index_map, 0, *detection_count - 1);
}
}
#endif /* EDDO_STAGE_2_TOPK */
#ifdef EDDO_STAGE_3_COPY_OUTPUT
KERNEL(eddo_ref_stage_3)
(const __global ScoreClassIndex* score_class_index_map,
const __global uint* detection_count,
const __global INPUT_TYPE* refined_boxes,
__global OUTPUT_TYPE* output_boxes,
__global OUTPUT_INDICES_TYPE* output_classes,
__global OUTPUT_TYPE* output_scores) {
size_t i = get_global_id(0);
if (i < *detection_count) {
OUTPUT_TYPE score = score_class_index_map[i].score;
OUTPUT_INDICES_TYPE cls = score_class_index_map[i].class_idx;
OUTPUT_INDICES_TYPE idx = score_class_index_map[i].box_idx;
vstore4(vload4(ROI_COUNT * cls + idx, refined_boxes), i, output_boxes);
output_scores[i] = score;
output_classes[i] = cls;
} else {
vstore4(ZERO4, i, output_boxes);
output_scores[i] = ZERO;
output_classes[i] = 0;
}
}
#endif /* EDDO_STAGE_3_COPY_OUTPUT */
#undef INPUT_TYPE
#undef INPUT_TYPE2
#undef INPUT_TYPE4
#undef HALF_ONE
#undef ZERO
#undef ONE
#undef ZERO2
#undef ZERO4
#undef COORDINATE_OFFSET
#undef DELTA_WEIGHTS
#undef MAX_DELTA_LOG_SIZE

View File

@ -0,0 +1,101 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/experimental_detectron_detection_output.hpp"
#include "intel_gpu/plugin/common_utils.hpp"
#include "intel_gpu/plugin/program.hpp"
#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp"
#include "intel_gpu/primitives/mutable_data.hpp"
namespace ov {
namespace runtime {
namespace intel_gpu {
static void CreateExperimentalDetectronDetectionOutputOp(
Program& p,
const std::shared_ptr<ngraph::op::v6::ExperimentalDetectronDetectionOutput>& op) {
p.ValidateInputs(op, {4});
if (op->get_output_size() != 3) {
IE_THROW() << "ExperimentalDetectronDetectionOutput requires 3 outputs";
}
auto inputs = p.GetInputPrimitiveIDs(op);
const auto& attrs = op->get_attrs();
const auto op_friendly_name = op->get_friendly_name();
const auto layer_type_name = layer_type_name_ID(op);
const auto layer_name = layer_type_name + ".0";
const auto mutable_precision1 = op->get_output_element_type(1);
const auto output_shape1 = op->get_output_shape(1);
const cldnn::layout mutable_layout1{DataTypeFromPrecision(mutable_precision1),
DefaultFormatForDims(output_shape1.size()),
tensor_from_dims(output_shape1)};
cldnn::memory::ptr shared_memory1{p.GetEngine().allocate_memory(mutable_layout1)};
const auto mutable_id_w1 = layer_type_name + "_md_write.1";
const cldnn::mutable_data mutable_prim_w{mutable_id_w1, shared_memory1, op_friendly_name};
p.primitiveIDs[mutable_id_w1] = mutable_id_w1;
p.AddPrimitive(mutable_prim_w);
inputs.push_back(mutable_id_w1);
const auto mutable_precision2 = op->get_output_element_type(2);
const auto output_shape2 = op->get_output_shape(2);
const cldnn::layout mutable_layout2{DataTypeFromPrecision(mutable_precision2),
DefaultFormatForDims(output_shape2.size()),
tensor_from_dims(output_shape2)};
cldnn::memory::ptr shared_memory2{p.GetEngine().allocate_memory(mutable_layout2)};
const auto mutable_id_w2 = layer_type_name + "_md_write.2";
const cldnn::mutable_data mutable_prim_w2{mutable_id_w2, shared_memory2, op_friendly_name};
p.primitiveIDs[mutable_id_w2] = mutable_id_w2;
p.AddPrimitive(mutable_prim_w2);
inputs.push_back(mutable_id_w2);
const auto expectedPrimInputCount = 4 + 2; // 4 operation inputs plus 2 input-outputs
if (inputs.size() != expectedPrimInputCount) {
IE_THROW() << "experimental_detectron_detection_output primitive requires 6 inputs";
}
const cldnn::experimental_detectron_detection_output prim{layer_name,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4], // output classes
inputs[5], // output scores
attrs.score_threshold,
attrs.nms_threshold,
static_cast<int>(attrs.num_classes),
static_cast<int>(attrs.post_nms_count),
static_cast<int>(attrs.max_detections_per_image),
attrs.class_agnostic_box_regression,
attrs.max_delta_log_wh,
attrs.deltas_weights,
op_friendly_name};
p.AddPrimitive(prim);
const auto mutable_id_r1 = layer_type_name + ".1";
const cldnn::mutable_data mutable_prim_r1{mutable_id_r1, {layer_name}, shared_memory1, op_friendly_name};
p.primitiveIDs[mutable_id_r1] = mutable_id_r1;
p.AddPrimitive(mutable_prim_r1);
const auto mutable_id_r2 = layer_type_name + ".2";
const cldnn::mutable_data mutable_prim_r2{mutable_id_r2, {layer_name}, shared_memory2, op_friendly_name};
p.primitiveIDs[mutable_id_r2] = mutable_id_r2;
p.AddPrimitive(mutable_prim_r2);
p.AddPrimitiveToProfiler(prim, op);
}
REGISTER_FACTORY_IMPL(v6, ExperimentalDetectronDetectionOutput);
} // namespace intel_gpu
} // namespace runtime
} // namespace ov

View File

@ -0,0 +1,322 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <intel_gpu/primitives/experimental_detectron_detection_output.hpp>
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/mutable_data.hpp>
#include "test_utils.h"
using namespace cldnn;
using namespace ::tests;
namespace {
template <typename T>
std::vector<T> getValues(const std::vector<float>& values) {
std::vector<T> result(values.begin(), values.end());
return result;
}
template <typename T>
float getError();
template <>
float getError<float>() {
return 0.001;
}
template <>
float getError<half_t>() {
return 0.2;
}
}; // namespace
template <typename T>
struct ExperimentalDetectronDetectionOutputParams {
float score_threshold;
float nms_threshold;
float max_delta_log_wh;
int num_classes;
int post_nms_count;
int max_detections_per_image;
bool class_agnostic_box_regression;
std::vector<float> deltas_weights;
size_t roi_count;
std::vector<T> boxes;
std::vector<T> deltas;
std::vector<T> scores;
std::vector<T> im_info;
std::vector<T> expected_boxes;
std::vector<int32_t> expected_classes;
std::vector<T> expected_scores;
};
template <typename T>
struct experimental_detectron_detection_output_test
: public ::testing::TestWithParam<ExperimentalDetectronDetectionOutputParams<T>> {
public:
void test() {
const ExperimentalDetectronDetectionOutputParams<T> param =
testing::TestWithParam<ExperimentalDetectronDetectionOutputParams<T>>::GetParam();
auto data_type = type_to_data_type<T>::value;
auto& engine = get_test_engine();
const primitive_id input_boxes_id = "InputBoxes";
const auto input_boxes =
engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.roi_count), feature(4)}});
set_values(input_boxes, param.boxes);
const primitive_id input_deltas_id = "InputDeltas";
auto input_deltas = engine.allocate_memory(
{data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes * 4)}});
set_values(input_deltas, param.deltas);
const primitive_id input_scores_id = "InputScores";
auto input_scores = engine.allocate_memory(
{data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes)}});
set_values(input_scores, param.scores);
const primitive_id input_im_info_id = "InputImInfo";
const auto input_im_info = engine.allocate_memory({data_type, format::bfyx, tensor{batch(1), feature(3)}});
set_values(input_im_info, param.im_info);
const primitive_id output_scores_id = "OutputScores";
auto output_scores =
engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.max_detections_per_image)}});
const primitive_id output_classes_id = "OutputClasses";
auto output_classes =
engine.allocate_memory({data_types::i32, format::bfyx, tensor{batch(param.max_detections_per_image)}});
topology topology;
topology.add(input_layout(input_boxes_id, input_boxes->get_layout()));
topology.add(input_layout(input_deltas_id, input_deltas->get_layout()));
topology.add(input_layout(input_scores_id, input_scores->get_layout()));
topology.add(input_layout(input_im_info_id, input_im_info->get_layout()));
topology.add(mutable_data(output_classes_id, output_classes));
topology.add(mutable_data(output_scores_id, output_scores));
const primitive_id eddo_id = "experimental_detectron_detection_output";
const auto eddo_primitive = experimental_detectron_detection_output{
eddo_id,
input_boxes_id,
input_deltas_id,
input_scores_id,
input_im_info_id,
output_classes_id,
output_scores_id,
param.score_threshold,
param.nms_threshold,
param.num_classes,
param.post_nms_count,
param.max_detections_per_image,
param.class_agnostic_box_regression,
param.max_delta_log_wh,
param.deltas_weights,
};
topology.add(eddo_primitive);
network network(engine, topology);
network.set_input_data(input_boxes_id, input_boxes);
network.set_input_data(input_deltas_id, input_deltas);
network.set_input_data(input_scores_id, input_scores);
network.set_input_data(input_im_info_id, input_im_info);
const auto outputs = network.execute();
const auto output_boxes = outputs.at(eddo_id).get_memory();
const cldnn::mem_lock<T> output_boxes_ptr(output_boxes, get_test_stream());
ASSERT_EQ(output_boxes_ptr.size(), param.max_detections_per_image * 4);
const cldnn::mem_lock<int32_t> output_classes_ptr(output_classes, get_test_stream());
ASSERT_EQ(output_classes_ptr.size(), param.max_detections_per_image);
const cldnn::mem_lock<T> output_scores_ptr(output_scores, get_test_stream());
ASSERT_EQ(output_scores_ptr.size(), param.max_detections_per_image);
const auto& expected_boxes = param.expected_boxes;
const auto& expected_classes = param.expected_classes;
const auto& expected_scores = param.expected_scores;
for (size_t i = 0; i < param.max_detections_per_image; ++i) {
EXPECT_NEAR(expected_scores[i], output_scores_ptr[i], 0.001) << "i=" << i;
for (size_t coord = 0; coord < 4; ++coord) {
const auto roi_idx = i * 4 + coord;
EXPECT_NEAR(expected_boxes[roi_idx], output_boxes_ptr[roi_idx], getError<T>())
<< "i=" << i << ", coord=" << coord;
}
EXPECT_EQ(expected_classes[i], output_classes_ptr[i]) << "i=" << i;
}
}
};
using experimental_detectron_detection_output_test_f32 = experimental_detectron_detection_output_test<float>;
using experimental_detectron_detection_output_test_f16 = experimental_detectron_detection_output_test<half_t>;
TEST_P(experimental_detectron_detection_output_test_f32, basic) {
ASSERT_NO_FATAL_FAILURE(test());
}
TEST_P(experimental_detectron_detection_output_test_f16, basic) {
ASSERT_NO_FATAL_FAILURE(test());
}
template <typename T>
std::vector<ExperimentalDetectronDetectionOutputParams<T>> getExperimentalDetectronDetectionOutputParams() {
std::vector<ExperimentalDetectronDetectionOutputParams<T>> params = {
{
0.01000000074505806f, // score_threshold
0.2f, // nms_threshold
2.0f, // max_delta_log_wh
2, // num_classes
500, // post_nms_count
5, // max_detections_per_image
true, // class_agnostic_box_regression
{10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights
16, // roi count
// boxes
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f,
1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
getValues<T>({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}),
getValues<T>({16.0f, 12.0f, 1.0f}),
getValues<T>({4.8929863f, 0.892986298f, 12.0f, 12.1070137f, 0.0f, 0.892986298f, 10.1070137f,
12.1070137f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}),
std::vector<int32_t>{0, 1, 0, 0, 0},
getValues<T>({0.8f, 0.9f, 0.0f, 0.0f, 0.0f}),
},
{
0.0500000007, // score_threshold
0.5, // nms_threshold
4.13516665, // max_delta_log_wh
5, // num_classes
10, // post_nms_count
10, // max_detections_per_image
false, // class_agnostic_box_regression
{10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights
10, // roi count
// boxes
getValues<T>({
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
}),
// deltas
getValues<T>({
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062,
6.55078, 0.257812, 4.32422, 9.5625, 8.53516, 0.554688, 8.68359, 2.73438, 6.26953, 5.60156,
2.79297, 8.65625, 5.75391, 5.39844, 2.65234, 7.32812, 8.98828, 7.94531, 6.26172, 4.75,
7.97266, 1.24219, 5.62109, 8.92188, 2.70703, 1.28906, 4.73047, 7.84375, 5.19141, 6.08594,
7.58984, 9.51562, 7.42578, 5.63281, 6.19922, 7.9375, 5.41016, 9.92969, 2.55859, 1.10938,
1.14453, 8.97656, 4.66797, 9.03125, 4.62891, 0.773438, 4.52734, 1.70312, 9.86328, 1.32031,
0.136719, 9.125, 2.84766, 4.61719, 9.49609, 5.29688, 5.58203, 0.664062, 2.60547, 6.21875,
8.06641, 5.46094, 1.46484, 7.89062, 0.300781, 5.00781, 0.0742188, 0.3125, 6.28516, 3.30469,
4.43359, 1.48438, 2.01953, 8.35156, 8.54297, 7.40625, 9.50391, 2.14844, 2.40234, 2.07812,
2.73828, 2.69531, 4.01172, 9.5, 7.72266, 9.99219, 1.37109, 3.67188, 2.45703, 2.03906,
0.480469, 4.59375, 2.94141, 4.83594, 1.33984, 0.265625, 1.17578, 4.38281, 5.94922, 8.6875,
5.16016, 0.679688, 4.30859, 5.85938, 4.89453, 7.72656, 4.41797, 5.78125, 4.37891, 1.52344,
8.27734, 4.45312, 3.61328, 4.07031, 7.88672, 9.875, 4.59766, 1.36719, 7.24609, 8.04688,
5.33203, 5.41406, 4.35547, 0.96875, 1.81641, 8.21094, 3.21484, 4.64062, 4.05078, 9.75781,
7.82422, 3.0625, 4.03516, 0.0546875, 8.18359, 8.23438, 1.76953, 1.10156, 2.29297, 8.15625,
9.25391, 0.898438, 6.15234, 8.82812, 6.48828, 7.44531, 1.76172, 2.25, 9.47266, 0.742188,
}),
// scores
getValues<T>({
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062,
}),
// im_info
getValues<T>({
4.90234,
6.57812,
5.23828,
}),
// out_boxes
getValues<T>({
0, 2.97829, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234, 4.37184, 4.90234,
6.03075, 4.90234, 5.95093, 3.66966, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234,
1.31075, 4.90234, 6.57812, 4.90234, 3.24829, 4.90234, 6.57812, 4.90234, 0, 0,
6.57812, 4.90234, 4.20346, 0, 6.57812, 4.90234, 0, 0, 6.57812, 4.90234,
}),
// out_classes
std::vector<int32_t>({
4,
3,
3,
4,
2,
0,
1,
0,
2,
3,
}),
// out_scores
getValues<T>({
9.86719,
9.71484,
9.19531,
8.51172,
8.375,
7.46875,
6.57812,
6.57031,
5.23828,
5.09766,
}),
},
};
return params;
}
INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test,
experimental_detectron_detection_output_test_f32,
::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams<float>()));
INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test,
experimental_detectron_detection_output_test_f16,
::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams<half_t>()));

View File

@ -0,0 +1,66 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "single_layer_tests/experimental_detectron_detection_output.hpp"
#include <vector>
#include "common_test_utils/ov_tensor_utils.hpp"
using namespace ov::test;
using namespace ov::test::subgraph;
namespace {
const std::vector<ov::test::ElementType> netPrecisions = {
ov::element::Type_t::f16,
ov::element::Type_t::f32,
};
const std::vector<float> score_threshold = {0.01f, 0.8f};
const std::vector<float> nms_threshold = {0.2f, 0.5f};
// specifies maximal delta of logarithms for width and height
const std::vector<float> max_delta_log_wh = {2.0f, 5.0f};
// specifies number of detected classes
const std::vector<int64_t> num_classes = {2};
// specifies maximal number of detections per class
const std::vector<int64_t> post_nms_count = {5, 25};
// specifies maximual number of detections per image
const std::vector<size_t> max_detections_per_image = {5, 25};
// a flag specifies whether to delete background classes or not
// `true` means background classes should be deleted,
// `false` means background classes shouldn't be deleted.
const std::vector<bool> class_agnostic_box_regression = {true, false};
// specifies deltas of weights
const std::vector<std::vector<float>> deltas_weights = {{10.0f, 10.0f, 5.0f, 5.0f}};
const std::vector<std::vector<InputShape>> inputShapes = {
// inputRois / inputDeltas / inputScores / inputImInfos
static_shapes_to_test_representation({{16, 4}, {16, 8}, {16, 2}, {1, 3}}),
};
INSTANTIATE_TEST_SUITE_P(smoke_ExperimentalDetectronDetectionOutput,
ExperimentalDetectronDetectionOutputLayerTest,
::testing::Combine(::testing::ValuesIn(inputShapes),
::testing::ValuesIn(score_threshold),
::testing::ValuesIn(nms_threshold),
::testing::ValuesIn(max_delta_log_wh),
::testing::ValuesIn(num_classes),
::testing::ValuesIn(post_nms_count),
::testing::ValuesIn(max_detections_per_image),
::testing::ValuesIn(class_agnostic_box_regression),
::testing::ValuesIn(deltas_weights),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
ExperimentalDetectronDetectionOutputLayerTest::getTestCaseName);
} // namespace

View File

@ -79,6 +79,9 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() {
netPrecision,
targetName) = this->GetParam();
if (netPrecision == element::f16)
abs_threshold = 0.01;
inType = outType = netPrecision;
targetDevice = targetName;
@ -93,42 +96,66 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() {
params[2], // input_scores
params[3], // input_im_info
attributes);
function = std::make_shared<ov::Model>(
ov::OutputVector{experimentalDetectron->output(0), experimentalDetectron->output(1)},
"ExperimentalDetectronDetectionOutput");
function = std::make_shared<ov::Model>(ov::OutputVector{experimentalDetectron->output(0),
experimentalDetectron->output(1),
experimentalDetectron->output(2)},
"ExperimentalDetectronDetectionOutput");
}
namespace {
template <typename T>
std::vector<T> getValues(const std::vector<float>& values) {
std::vector<T> result(values.begin(), values.end());
return result;
}
template <typename T>
std::vector<ov::Tensor> generateInputTensors() {
const auto netPrecision = ov::element::from<T>();
std::vector<ov::Tensor> inputTensors = {
// 16 x 4 = 64
ov::test::utils::create_tensor<T>(
netPrecision,
Shape{16, 4},
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f,
1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})),
// 16 x 8
ov::test::utils::create_tensor<T>(
netPrecision,
Shape{16, 8},
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})),
// 16 x 2 = 32
ov::test::utils::create_tensor<T>(
netPrecision,
Shape{16, 2},
getValues<T>({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f})),
// 1 x 3 = 3
ov::test::utils::create_tensor<T>(netPrecision, Shape{1, 3}, getValues<T>({16.0f, 12.0f, 1.0f}))};
return inputTensors;
}
} // namespace
void ExperimentalDetectronDetectionOutputLayerTest::generate_inputs(
const std::vector<ngraph::Shape>& targetInputStaticShapes) {
static const std::vector<ov::Tensor> inputTensors = {
// 16 x 4 = 64
ov::test::utils::create_tensor<float>(
ov::element::f32,
Shape{16, 4},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 8.0f, 5.0f,
1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
// 16 x 8
ov::test::utils::create_tensor<float>(
ov::element::f32,
Shape{16, 8},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
// 16 x 2 = 32
ov::test::utils::create_tensor<float>(
ov::element::f32,
Shape{16, 2},
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}),
// 1 x 3 = 3
ov::test::utils::create_tensor<float>(ov::element::f32, Shape{1, 3}, {16.0f, 12.0f, 1.0f})};
const auto netPrecision = std::get<9>(GetParam());
const std::vector<ov::Tensor> inputTensors =
(netPrecision == element::f16) ? generateInputTensors<ov::float16>() : generateInputTensors<float>();
inputs.clear();
const auto& funcInputs = function->inputs();