[IE CLDNN] Add a byxf format for RegionYolo op. (#4451)
This commit is contained in:
parent
3757c079c1
commit
e40a44202e
@ -46,21 +46,46 @@ JitConstants RegionYoloKernelRef::GetJitConstants(const region_yolo_params& ry)
|
|||||||
return jit;
|
return jit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool RegionYoloKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||||
|
if (p.GetType() != KernelType:: REGION_YOLO || o.GetType() != KernelType::REGION_YOLO) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const region_yolo_params& params = static_cast<const region_yolo_params&>(p);
|
||||||
|
const size_t expected_feature_size =
|
||||||
|
params.do_softmax ? params.inputs[0].X().v * params.inputs[0].Y().v * params.inputs[0].Feature().v : params.inputs[0].Feature().v;
|
||||||
|
|
||||||
|
if (expected_feature_size != params.output.Feature().v) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
RegionYoloKernelRef::DispatchData SetDefault(const region_yolo_params& params) {
|
RegionYoloKernelRef::DispatchData SetDefault(const region_yolo_params& params) {
|
||||||
RegionYoloKernelRef::DispatchData dispatchData;
|
RegionYoloKernelRef::DispatchData dispatchData;
|
||||||
|
|
||||||
const auto& input = params.inputs[0];
|
const auto& input = params.inputs[0];
|
||||||
if (input.GetLayout() == DataLayout::bfyx) {
|
|
||||||
dispatchData.gws = {input.X().v * input.Y().v, 1, 1};
|
switch (input.GetLayout()) {
|
||||||
} else {
|
case DataLayout::bfyx:
|
||||||
dispatchData.gws = {input.Feature().v * input.Batch().v, input.X().v, input.Y().v};
|
case DataLayout::byxf: {
|
||||||
|
uint32_t region_num = params.do_softmax ? params.num : params.mask_size;
|
||||||
|
dispatchData.gws = {input.X().v * input.Y().v, region_num, input.Batch().v};
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported DataLayout");
|
||||||
}
|
}
|
||||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||||
|
|
||||||
return dispatchData;
|
return dispatchData;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelsData RegionYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
KernelsData RegionYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||||
assert(params.GetType() == KernelType::REGION_YOLO);
|
if (!Validate(params, options)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
const region_yolo_params& orgParams = static_cast<const region_yolo_params&>(params);
|
const region_yolo_params& orgParams = static_cast<const region_yolo_params&>(params);
|
||||||
|
|
||||||
DispatchData dispatchData = SetDefault(orgParams);
|
DispatchData dispatchData = SetDefault(orgParams);
|
||||||
|
@ -61,5 +61,6 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual JitConstants GetJitConstants(const region_yolo_params& params) const;
|
virtual JitConstants GetJitConstants(const region_yolo_params& params) const;
|
||||||
|
bool Validate(const Params& p, const optional_params& o) const override;
|
||||||
};
|
};
|
||||||
} // namespace kernel_selector
|
} // namespace kernel_selector
|
||||||
|
@ -12,94 +12,79 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "include/common.cl"
|
#include "include/fetch.cl"
|
||||||
#include "include/data_types.cl"
|
|
||||||
|
|
||||||
#define IW INPUT0_SIZES[0]
|
inline INPUT0_TYPE FUNC(logistic_activate)(INPUT0_TYPE x) {
|
||||||
#define IH INPUT0_SIZES[1]
|
|
||||||
#define IC INPUT0_SIZES[2]
|
|
||||||
#define IB INPUT0_SIZES[3]
|
|
||||||
|
|
||||||
inline UNIT_TYPE FUNC(logistic_activate)(UNIT_TYPE x) {
|
|
||||||
return 1. / (1. + exp(-x));
|
return 1. / (1. + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int FUNC(entry_index)(int width, int height, int coords, int classes,
|
inline int FUNC(output_index)(int batch, int region_num, int x, int y, int xy, int feature_offset) {
|
||||||
int outputs, int batch, int location,
|
|
||||||
int entry) {
|
|
||||||
int n = location / (width * height);
|
|
||||||
int loc = location % (width * height);
|
|
||||||
return batch * outputs + n * width * height * (coords + classes + 1) +
|
|
||||||
entry * width * height + loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if DO_SOFTMAX
|
#if DO_SOFTMAX
|
||||||
inline void FUNC(softmax_generic)(const __global UNIT_TYPE* src_data, __global UNIT_TYPE* dst_data,
|
return OUTPUT_GET_INDEX(batch, feature_offset * INPUT0_SIZE_X * INPUT0_SIZE_Y + xy, 1, 1);
|
||||||
int B, int C, int W, int H, int i)
|
|
||||||
{
|
|
||||||
for (int b = 0; b < B; b++) {
|
|
||||||
UNIT_TYPE max = src_data[b*C*H*W + i];
|
|
||||||
for (int c = 0; c < C; c++) {
|
|
||||||
UNIT_TYPE val = src_data[b*C*H*W + c*H*W + i];
|
|
||||||
if (val > max) max = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
UNIT_TYPE expSum = 0;
|
|
||||||
for (int c = 0; c < C; c++) {
|
|
||||||
dst_data[b*C*H*W + c*H*W + i] = exp(src_data[b*C*H*W + c*H*W + i] - max);
|
|
||||||
expSum += dst_data[b*C*H*W + c*H*W + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int c = 0; c < C; c++) {
|
|
||||||
dst_data[b*C*H*W + c*H*W + i] = dst_data[b*C*H*W + c*H*W + i] / expSum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
KERNEL (region_yolo_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output)
|
|
||||||
{
|
|
||||||
int x = get_global_id(0);
|
|
||||||
|
|
||||||
#if DO_SOFTMAX
|
|
||||||
#define ACTUAL_NUM (NUM)
|
|
||||||
#define CONF_CLASSES (1)
|
|
||||||
#else
|
#else
|
||||||
#define ACTUAL_NUM (MASK_SIZE)
|
return OUTPUT_GET_INDEX(batch, feature_offset, y, x);
|
||||||
#define CONF_CLASSES (CLASSES+1)
|
#endif
|
||||||
#endif
|
}
|
||||||
#define INPUTS_COUNT (IH * IW * ACTUAL_NUM * (CLASSES + COORDS + 1))
|
|
||||||
|
KERNEL (region_yolo_ref)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output)
|
||||||
for (int b = 0; b < IB; b++) {
|
{
|
||||||
for (int n = 0; n < ACTUAL_NUM; n++) {
|
int xy = get_global_id(0);
|
||||||
// coords: x/y
|
int region_num = get_global_id(1);
|
||||||
int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 0);
|
int batch = get_global_id(2);
|
||||||
int i = index + 2 * x;
|
int x_index = xy % INPUT0_SIZE_X;
|
||||||
output[i] = FUNC_CALL(logistic_activate)(input[i]);
|
int y_index = (xy / INPUT0_SIZE_X) % (INPUT0_SIZE_Y);
|
||||||
output[i+1] = FUNC_CALL(logistic_activate)(input[i+1]);
|
|
||||||
|
/// [x, y, width, height, objectness score, class score]
|
||||||
// coords: w/h: directly copy?
|
/// x,y
|
||||||
index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 2);
|
int region_offset = region_num * (COORDS + CLASSES + 1);
|
||||||
i = index + 2 * x;
|
int in_i = INPUT0_GET_INDEX(batch, 0 + region_offset, y_index, x_index);
|
||||||
output[i] = input[i];
|
int out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 0 + region_offset);
|
||||||
output[i+1] = input[i+1];
|
output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
|
||||||
|
|
||||||
// confidence
|
in_i = INPUT0_GET_INDEX(batch, 1 + region_offset, y_index, x_index);
|
||||||
index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, COORDS);
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 1 + region_offset);
|
||||||
for (int j = 0; j < CONF_CLASSES; j++)
|
output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
|
||||||
{
|
|
||||||
i = index + x + j*IH*IW;
|
/// width,height
|
||||||
output[i] = FUNC_CALL(logistic_activate)(input[i]);
|
in_i = INPUT0_GET_INDEX(batch, 2 + region_offset, y_index, x_index);
|
||||||
}
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 2 + region_offset);
|
||||||
}
|
output[out_i] = input[in_i];
|
||||||
}
|
|
||||||
|
in_i = INPUT0_GET_INDEX(batch, 3 + region_offset, y_index, x_index);
|
||||||
#if DO_SOFTMAX
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 3 + region_offset);
|
||||||
// the probability of classes
|
output[out_i] = input[in_i];
|
||||||
int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, 0, 0, COORDS + 1);
|
|
||||||
int batch_offset = INPUTS_COUNT / NUM;
|
/// objectness score
|
||||||
for (int b = 0; b < IB * NUM; b++)
|
in_i = INPUT0_GET_INDEX(batch, COORDS + region_offset, y_index, x_index);
|
||||||
FUNC_CALL(softmax_generic)(input + index + b * batch_offset, output + index + b * batch_offset,
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + region_offset);
|
||||||
1, CLASSES, IH, IW, x);
|
output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
|
||||||
|
|
||||||
|
/// class score(confidence)
|
||||||
|
#if DO_SOFTMAX
|
||||||
|
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + region_offset, y_index, x_index);
|
||||||
|
INPUT0_TYPE max_value = input[in_i];
|
||||||
|
for (int j = 1; j < CLASSES; j++) {
|
||||||
|
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
|
||||||
|
max_value = max(max_value, input[in_i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUT_TYPE expSum = 0;
|
||||||
|
for (int j = 0; j < CLASSES; j++) {
|
||||||
|
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
|
||||||
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset);
|
||||||
|
output[out_i] = exp(input[in_i] - max_value);
|
||||||
|
expSum += output[out_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < CLASSES; j++) {
|
||||||
|
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset);
|
||||||
|
output[out_i] /= expSum;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (int j = 0; j < CLASSES; j++)
|
||||||
|
{
|
||||||
|
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
|
||||||
|
output[in_i] = FUNC_CALL(logistic_activate)(input[in_i]);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -866,10 +866,6 @@ format layout_optimizer::get_preferred_format(program_node& node) {
|
|||||||
if (input_layout.format.dimension() == 5 &&
|
if (input_layout.format.dimension() == 5 &&
|
||||||
(input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16))
|
(input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16))
|
||||||
expected = format::bfzyx;
|
expected = format::bfzyx;
|
||||||
} else if (node.is_type<region_yolo>()) {
|
|
||||||
if (_optimization_attributes.b_fs_yx_fsv16_network) {
|
|
||||||
expected = format::bfyx;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return expected;
|
return expected;
|
||||||
|
@ -3536,7 +3536,7 @@ public:
|
|||||||
EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
|
EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
|
||||||
if (sizeof(input0_type) == 1) {
|
if (sizeof(input0_type) == 1) {
|
||||||
for (size_t i = 0; i < out_data.size(); ++i) {
|
for (size_t i = 0; i < out_data.size(); ++i) {
|
||||||
EXPECT_FLOAT_EQ(float(output_ptr[i]), float(out_data[i])) << "index = " << i;
|
EXPECT_NEAR(float(output_ptr[i]), float(out_data[i]), 1e-1) << "index = " << i;
|
||||||
}
|
}
|
||||||
} else if (sizeof(input0_type) == 2) {
|
} else if (sizeof(input0_type) == 2) {
|
||||||
for (size_t i = 0; i < out_data.size(); ++i) {
|
for (size_t i = 0; i < out_data.size(); ++i) {
|
||||||
|
270
inference-engine/thirdparty/clDNN/tests/test_cases/region_yolo_gpu_test.cpp
vendored
Normal file
270
inference-engine/thirdparty/clDNN/tests/test_cases/region_yolo_gpu_test.cpp
vendored
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <api/input_layout.hpp>
|
||||||
|
#include <api/memory.hpp>
|
||||||
|
#include <api/region_yolo.hpp>
|
||||||
|
#include <api/reorder.hpp>
|
||||||
|
#include <api/topology.hpp>
|
||||||
|
#include <api/network.hpp>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <tests/test_utils/test_utils.h>
|
||||||
|
|
||||||
|
using namespace cldnn;
|
||||||
|
using namespace ::tests;
|
||||||
|
|
||||||
|
namespace internal
|
||||||
|
{
|
||||||
|
static inline int entry_index(int width,
|
||||||
|
int height,
|
||||||
|
int coords,
|
||||||
|
int classes,
|
||||||
|
int outputs,
|
||||||
|
int batch,
|
||||||
|
int location,
|
||||||
|
int entry)
|
||||||
|
{
|
||||||
|
int n = location / (width * height);
|
||||||
|
int loc = location % (width * height);
|
||||||
|
return batch * outputs + n * width * height * (coords + classes + 1) +
|
||||||
|
entry * width * height + loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline T sigmoid(float x)
|
||||||
|
{
|
||||||
|
return static_cast<T>(1.f / (1.f + std::exp(-x)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline void softmax_generic(const T* src_data, T* dst_data,
|
||||||
|
uint32_t batches, uint32_t channels, uint32_t height, uint32_t width)
|
||||||
|
{
|
||||||
|
const uint32_t area = height * width;
|
||||||
|
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
|
||||||
|
{
|
||||||
|
const int offset = batch_idx * channels * area;
|
||||||
|
for (unsigned int i = 0; i < height * width; i++)
|
||||||
|
{
|
||||||
|
T max = src_data[batch_idx * channels * area + i];
|
||||||
|
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||||
|
{
|
||||||
|
T val = src_data[offset + channel_idx * area + i];
|
||||||
|
max = std::max(max, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
T sum = 0;
|
||||||
|
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||||
|
{
|
||||||
|
dst_data[offset + channel_idx * area + i] =
|
||||||
|
std::exp((float)(src_data[offset + channel_idx * area + i] - max));
|
||||||
|
sum += dst_data[offset + channel_idx * area + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||||
|
{
|
||||||
|
dst_data[offset + channel_idx * area + i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t shape_size(const std::vector<uint32_t>& input_shape)
|
||||||
|
{
|
||||||
|
uint32_t ret = 1;
|
||||||
|
std::for_each(input_shape.begin(), input_shape.end(), [&ret](uint32_t n){
|
||||||
|
ret *= n;
|
||||||
|
});
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void region_yolo(const T* input,
|
||||||
|
T* output,
|
||||||
|
const std::vector<uint32_t>& input_shape,
|
||||||
|
const uint32_t coords,
|
||||||
|
const uint32_t classes,
|
||||||
|
const uint32_t regions,
|
||||||
|
const bool do_softmax,
|
||||||
|
const std::vector<int64_t>& mask)
|
||||||
|
{
|
||||||
|
EXPECT_EQ(input_shape.size(), 4);
|
||||||
|
|
||||||
|
const uint32_t batches = input_shape[0];
|
||||||
|
//const uint32_t channels = input_shape[1];
|
||||||
|
const uint32_t height = input_shape[2];
|
||||||
|
const uint32_t width = input_shape[3];
|
||||||
|
|
||||||
|
const auto mask_size = mask.size();
|
||||||
|
|
||||||
|
std::copy(input, input + shape_size(input_shape), output);
|
||||||
|
|
||||||
|
uint32_t num_regions = 0;
|
||||||
|
uint32_t end_index = 0;
|
||||||
|
|
||||||
|
if (do_softmax)
|
||||||
|
{
|
||||||
|
// Region layer (Yolo v2)
|
||||||
|
num_regions = regions;
|
||||||
|
end_index = width * height;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Yolo layer (Yolo v3)
|
||||||
|
num_regions = static_cast<uint32_t>(mask_size);
|
||||||
|
end_index = width * height * (classes + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t inputs_size = width * height * num_regions * (classes + coords + 1);
|
||||||
|
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
|
||||||
|
{
|
||||||
|
for (unsigned int n = 0; n < num_regions; n++)
|
||||||
|
{
|
||||||
|
int index = entry_index(width,
|
||||||
|
height,
|
||||||
|
coords,
|
||||||
|
classes,
|
||||||
|
inputs_size,
|
||||||
|
batch_idx,
|
||||||
|
n * width * height,
|
||||||
|
0);
|
||||||
|
std::transform(input + index,
|
||||||
|
input + index + 2 * width * height,
|
||||||
|
output + index,
|
||||||
|
[](T elem) { return sigmoid<T>(elem); });
|
||||||
|
|
||||||
|
index = entry_index(width,
|
||||||
|
height,
|
||||||
|
coords,
|
||||||
|
classes,
|
||||||
|
inputs_size,
|
||||||
|
batch_idx,
|
||||||
|
n * width * height,
|
||||||
|
coords);
|
||||||
|
std::transform(input + index,
|
||||||
|
input + index + end_index,
|
||||||
|
output + index,
|
||||||
|
[](T elem) { return sigmoid<T>(elem); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_softmax)
|
||||||
|
{
|
||||||
|
int index =
|
||||||
|
entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1);
|
||||||
|
int batch_offset = inputs_size / regions;
|
||||||
|
for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++)
|
||||||
|
{
|
||||||
|
softmax_generic<T>(input + index + batch_idx * batch_offset,
|
||||||
|
output + index + batch_idx * batch_offset,
|
||||||
|
1,
|
||||||
|
classes,
|
||||||
|
height,
|
||||||
|
width);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct region_yolo_test_params {
|
||||||
|
std::vector<uint32_t> tensor;
|
||||||
|
std::vector<int64_t> mask;
|
||||||
|
uint32_t coords;
|
||||||
|
uint32_t classes;
|
||||||
|
uint32_t regionNum;
|
||||||
|
data_types dataType;
|
||||||
|
format fmt;
|
||||||
|
bool softMax;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void runRegionTest(internal::region_yolo_test_params& params)
|
||||||
|
{
|
||||||
|
engine eng;
|
||||||
|
const tensor kInputTensor(params.tensor[0], params.tensor[1], params.tensor[2], params.tensor[3]);
|
||||||
|
auto inputData = generate_random_1d<T>(params.tensor[0] * params.tensor[1] * params.tensor[2] * params.tensor[3], -1, 1);
|
||||||
|
|
||||||
|
auto inputPrim = memory::allocate(eng, { params.dataType, format::bfyx, kInputTensor });
|
||||||
|
set_values(inputPrim, inputData);
|
||||||
|
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("InputData", inputPrim.get_layout()));
|
||||||
|
topology.add(reorder("reorder_pre", "InputData", params.fmt, params.dataType));
|
||||||
|
topology.add(region_yolo("region_yolo", "reorder_pre", params.coords, params.classes,
|
||||||
|
params.regionNum, static_cast<uint32_t>(params.mask.size()), params.softMax));
|
||||||
|
topology.add(reorder("reorder_post", "region_yolo", format::bfyx, params.dataType));
|
||||||
|
|
||||||
|
network network(eng, topology);
|
||||||
|
network.set_input_data("InputData", inputPrim);
|
||||||
|
|
||||||
|
auto outputs = network.execute();
|
||||||
|
auto output = outputs.at("reorder_post").get_memory();
|
||||||
|
auto outputData = output.pointer<T>();
|
||||||
|
|
||||||
|
/// reference value
|
||||||
|
std::vector<T> refOutputData(inputData.size());
|
||||||
|
internal::region_yolo<T>(inputData.data(), refOutputData.data(),
|
||||||
|
params.tensor, params.coords, params.classes,
|
||||||
|
params.regionNum, params.softMax, params.mask);
|
||||||
|
|
||||||
|
/// compare values
|
||||||
|
for (size_t i = 0; i < inputData.size(); ++i) {
|
||||||
|
EXPECT_NEAR(refOutputData[i], outputData[i], 0.01);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp32, bfyx) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false};
|
||||||
|
runRegionTest<float>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp32, bfyx_softmax) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true};
|
||||||
|
runRegionTest<float>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp32, byxf) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false};
|
||||||
|
runRegionTest<float>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp32, byxf_softmax) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true};
|
||||||
|
runRegionTest<float>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp16, bfyx) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false};
|
||||||
|
runRegionTest<FLOAT16>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp16, bfyx_softmax) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true};
|
||||||
|
runRegionTest<FLOAT16>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp16, byxf) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false};
|
||||||
|
runRegionTest<FLOAT16>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(region_yolo_gpu_fp16, byxf_softmax) {
|
||||||
|
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true};
|
||||||
|
runRegionTest<FLOAT16>(params);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user