[GPU] reorg_yolo blocked layouts support (#12463)

* add unit test for reorg_yolo
* add validation to reorg_yolo kernel
* add blocked formats support
* remove non-working yxfb optimization
* add reorg_yolo to whitelist for blocked formats
This commit is contained in:
Oleksii Khovan 2022-10-20 18:35:43 +02:00 committed by GitHub
parent 99bb3bba6e
commit 2f982b9490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 475 additions and 42 deletions

View File

@ -46,14 +46,19 @@ struct reorg_yolo_impl : typed_primitive_impl_ocl<reorg_yolo> {
namespace detail {
attach_reorg_yolo_impl::attach_reorg_yolo_impl() {
implementation_map<reorg_yolo>::add(impl_types::ocl, reorg_yolo_impl::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::f32, format::yxfb),
std::make_tuple(data_types::f16, format::yxfb),
std::make_tuple(data_types::f32, format::byxf),
std::make_tuple(data_types::f16, format::byxf),
});
auto types = {data_types::f16, data_types::f32};
auto formats = {
format::bfyx,
format::yxfb,
format::byxf,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
};
implementation_map<reorg_yolo>::add(impl_types::ocl, reorg_yolo_impl::create, types, formats);
}
} // namespace detail

View File

@ -1447,7 +1447,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::resample::type_id() &&
prim.type() != cldnn::eye::type_id() &&
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id()) {
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id()) {
can_use_fsv16 = false;
}
@ -1486,7 +1487,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::prior_box::type_id() &&
prim.type() != cldnn::eye::type_id() &&
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id()) {
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id()) {
can_use_bs_fs_yx_bsv16_fsv16 = false;
}
}

View File

@ -4,31 +4,55 @@
#include "include/batch_headers/common.cl"
#include "include/batch_headers/data_types.cl"
#include "include/batch_headers/fetch_data.cl"
#if OUTPUT_LAYOUT_BFYX
#define IW INPUT0_SIZES[0]
#define IH INPUT0_SIZES[1]
#define IC INPUT0_SIZES[2]
#define B INPUT0_SIZES[3]
#elif OUTPUT_LAYOUT_YXFB
#if OUTPUT_LAYOUT_YXFB
#define IW INPUT0_SIZES[3]
#define IH INPUT0_SIZES[2]
#define IC INPUT0_SIZES[1]
#define B INPUT0_SIZES[0]
#elif OUTPUT_LAYOUT_BYXF
#define IW INPUT0_SIZES[1]
#define IH INPUT0_SIZES[2]
#define IC INPUT0_SIZES[0]
#define B INPUT0_SIZES[3]
#else
#define IW INPUT0_SIZES[0]
#define IH INPUT0_SIZES[1]
#define IC INPUT0_SIZES[2]
#define B INPUT0_SIZES[3]
#endif
#define ic_off (IC / (STRIDE * STRIDE))
#define ih_off (IH * STRIDE)
#define iw_off (IW * STRIDE)
#if !defined(OUTPUT_LAYOUT_BFYX)
inline void FUNC(planar_to_bfyx)(const uint planar_index,
const uint batch_num, const uint channel_num, const uint height, const uint width,
uint* dst_b, uint* dst_f, uint* dst_y, uint* dst_x)
{
const uint feature_size = height * width;
const uint batch_size = channel_num * feature_size;
*dst_b = planar_index / batch_size;
const uint dst_fxy = planar_index % batch_size;
*dst_f = dst_fxy / feature_size;
const uint dst_xy = dst_fxy % feature_size;
*dst_y = dst_xy / width;
*dst_x = dst_xy % width;
}
#endif
KERNEL (reorg_yolo_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output)
{
#if OUTPUT_LAYOUT_BFYX
int ic = get_global_id(2);
int ih = get_global_id(1);
int iw = get_global_id(0);
for (int b = 0; b < B; b++) {
#if OUTPUT_LAYOUT_BFYX
for (int b = 0; b < B; b++) {
int dstIndex = b*IC*IH*IW + ic*IH*IW + ih*IW + iw;
int oc = ic % ic_off;
@ -41,25 +65,38 @@ KERNEL (reorg_yolo_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* out
output[dstIndex] = input[srcIndex];
}
#elif OUTPUT_LAYOUT_YXFB
int ic = get_global_id(0) / B;
int ib = get_global_id(0) % B;
int ih = get_global_id(2);
int iw = get_global_id(1);
#else
const uint OC = IC * STRIDE * STRIDE;
const uint OH = IH / STRIDE;
const uint OW = IW / STRIDE;
for (int b = 0; b < B; b++) {
int dstIndex = ib + ic*B + ih*IC*B + iw*IH*IC*B;
const uint dstPlanarIndex = b*IC*IH*IW + ic*IH*IW + ih*IW + iw;
uint dstB, dstC, dstY, dstX;
FUNC_CALL(planar_to_bfyx)(dstPlanarIndex, B, OC, OH, OW, &dstB, &dstC, &dstY, &dstX);
const uint dstIndex = OUTPUT_GET_INDEX(dstB, dstC, dstY, dstX);
int oc = ic % ic_off;
int offset = ic / ic_off;
const int oc = ic % ic_off;
const int offset = ic / ic_off;
int ow = iw * STRIDE + offset % STRIDE;
int oh = ih * STRIDE + offset / STRIDE;
const int ow = iw * STRIDE + offset % STRIDE;
const int oh = ih * STRIDE + offset / STRIDE;
int srcIndex = b*ic_off*ih_off*iw_off + oc*ih_off*iw_off + oh*iw_off + ow;
const int srcPlanarIndex = b*ic_off*ih_off*iw_off + oc*ih_off*iw_off + oh*iw_off + ow;
uint srcB, srcC, srcY, srcX;
FUNC_CALL(planar_to_bfyx)(srcPlanarIndex, B, IC, IH, IW, &srcB, &srcC, &srcY, &srcX);
const uint srcIndex = INPUT0_GET_INDEX(srcB, srcC, srcY, srcX);
output[dstIndex] = input[srcIndex];
}
#endif
}
#undef iw_off
#undef ih_off
#undef ic_off
#undef B
#undef IC
#undef IH
#undef IW

View File

@ -38,23 +38,20 @@ ReorgYoloKernelRef::DispatchData SetDefault(const reorg_yolo_params& params) {
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws;
const auto& input = params.inputs[0];
if (input.GetLayout() == DataLayout::bfyx) {
dispatchData.gws = {input.X().v, input.Y().v, input.Feature().v};
dims_by_gws = {{Tensor::DataChannelName::X},
{Tensor::DataChannelName::Y},
{Tensor::DataChannelName::FEATURE}};
} else {
dispatchData.gws = {input.Feature().v * input.Batch().v, input.X().v, input.Y().v};
dims_by_gws = {{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH},
{Tensor::DataChannelName::X},
{Tensor::DataChannelName::Y}};
}
dispatchData.gws = {input.X().v, input.Y().v, input.Feature().v};
dims_by_gws = {{Tensor::DataChannelName::X},
{Tensor::DataChannelName::Y},
{Tensor::DataChannelName::FEATURE}};
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, in_layout, out_layout, dims_by_gws);
return dispatchData;
}
KernelsData ReorgYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
assert(params.GetType() == KernelType::REORG_YOLO);
if (!Validate(params, options)) {
return {};
}
const reorg_yolo_params& orgParams = static_cast<const reorg_yolo_params&>(params);
DispatchData dispatchData = SetDefault(orgParams);
@ -73,4 +70,22 @@ KernelsData ReorgYoloKernelRef::GetKernelsData(const Params& params, const optio
KernelsPriority ReorgYoloKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
return FORCE_PRIORITY_9;
}
bool ReorgYoloKernelRef::Validate(const Params& p, const optional_params& o) const {
const reorg_yolo_params& params = static_cast<const reorg_yolo_params&>(p);
const auto& input = params.inputs[0];
if (input.GetDims().size() != 4) {
return false;
}
if (!(input.Feature().v >= params.stride * params.stride
&& input.X().v % params.stride == 0
&& input.Y().v % params.stride == 0)) {
return false;
}
return true;
}
} // namespace kernel_selector

View File

@ -41,6 +41,7 @@ public:
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
bool Validate(const Params& params, const optional_params& options) const override;
protected:
virtual JitConstants GetJitConstants(const reorg_yolo_params& params) const;

View File

@ -0,0 +1,373 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/activation.hpp>
#include <intel_gpu/primitives/reorg_yolo.hpp>
#include <cstddef>
#include <string>
using namespace cldnn;
using namespace ::tests;
namespace {
template<typename T>
struct ReorgYoloParams {
ov::PartialShape inputTensor;
std::vector<T> input;
uint32_t stride;
std::vector<T> expected;
};
template<typename T>
using ReorgYoloParamsWithLayout = std::tuple<
ReorgYoloParams<T>,
format::type, // blocked layout
bool // should_fail
>;
const std::vector<format::type> dataFormats = {
format::bfyx,
format::yxfb,
format::byxf,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32
};
template<typename T>
std::vector<T> getValues(const std::vector<float> &values) {
std::vector<T> result(values.begin(), values.end());
return result;
}
template <typename T> float getError();
template<>
float getError<float>() {
return 0.001;
}
template<>
float getError<half_t>() {
return 0.2;
}
template<typename T>
std::vector<ReorgYoloParams<T>> generateParams() {
static const std::vector<ReorgYoloParams<T>> result = {
{
ov::PartialShape{1, 4, 2, 2},
getValues<T>({
0.0, 1.0,
2.0, 3.0,
4.0, 5.0,
6.0, 7.0,
8.0, 9.0,
10.0, 11.0,
12.0, 13.0,
14.0, 15.0,
}),
2,
getValues<T>({
0.0f, 2.0f, 8.0f, 10.0f,
1.0f, 3.0f, 9.0f, 11.0f,
4.0f, 6.0f, 12.0f, 14.0f,
5.0f, 7.0f, 13.0f, 15.0f,
})
},
{
ov::PartialShape{2, 9, 3, 3},
getValues<T>({
0.0f, 1.0f, 2.0f,
3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f,
9.0f, 10.0f, 11.0f,
12.0f, 13.0f, 14.0f,
15.0f, 16.0f, 17.0f,
18.0f, 19.0f, 20.0f,
21.0f, 22.0f, 23.0f,
24.0f, 25.0f, 26.0f,
27.0f, 28.0f, 29.0f,
30.0f, 31.0f, 32.0f,
33.0f, 34.0f, 35.0f,
36.0f, 37.0f, 38.0f, 39.0f,
40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f,
55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f,
64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, 71.0f, 72.0f,
73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 81.0f,
82.0f, 83.0f, 84.0f, 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f, 97.0f, 98.0f, 99.0f,
100.0f, 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f,
109.0f, 110.0f, 111.0f, 112.0f, 113.0f, 114.0f, 115.0f, 116.0f, 117.0f,
118.0f, 119.0f, 120.0f, 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, 126.0f,
127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f, 133.0f, 134.0f, 135.0f,
136.0f, 137.0f, 138.0f, 139.0f, 140.0f, 141.0f, 142.0f, 143.0f, 144.0f,
145.0f, 146.0f, 147.0f, 148.0f, 149.0f, 150.0f, 151.0f, 152.0f, 153.0f,
154.0f, 155.0f, 156.0f, 157.0f, 158.0f, 159.0f, 160.0f, 161.0f
}),
3,
getValues<T>({
0.0f, 3.0f, 6.0f, 27.0f, 30.0f, 33.0f, 54.0f, 57.0f, 60.0f,
1.0f, 4.0f, 7.0f, 28.0f, 31.0f, 34.0f, 55.0f, 58.0f, 61.0f,
2.0f, 5.0f, 8.0f, 29.0f, 32.0f, 35.0f, 56.0f, 59.0f, 62.0f,
9.0f, 12.0f, 15.0f, 36.0f, 39.0f, 42.0f, 63.0f, 66.0f, 69.0f,
10.0f, 13.0f, 16.0f, 37.0f, 40.0f, 43.0f, 64.0f, 67.0f, 70.0f,
11.0f, 14.0f, 17.0f, 38.0f, 41.0f, 44.0f, 65.0f, 68.0f, 71.0f,
18.0f, 21.0f, 24.0f, 45.0f, 48.0f, 51.0f, 72.0f, 75.0f, 78.0f,
19.0f, 22.0f, 25.0f, 46.0f, 49.0f, 52.0f, 73.0f, 76.0f, 79.0f,
20.0f, 23.0f, 26.0f, 47.0f, 50.0f, 53.0f, 74.0f, 77.0f, 80.0f,
81.0f, 84.0f, 87.0f, 108.0f, 111.0f, 114.0f, 135.0f, 138.0f, 141.0f,
82.0f, 85.0f, 88.0f, 109.0f, 112.0f, 115.0f, 136.0f, 139.0f, 142.0f,
83.0f, 86.0f, 89.0f, 110.0f, 113.0f, 116.0f, 137.0f, 140.0f, 143.0f,
90.0f, 93.0f, 96.0f, 117.0f, 120.0f, 123.0f, 144.0f, 147.0f, 150.0f,
91.0f, 94.0f, 97.0f, 118.0f, 121.0f, 124.0f, 145.0f, 148.0f, 151.0f,
92.0f, 95.0f, 98.0f, 119.0f, 122.0f, 125.0f, 146.0f, 149.0f, 152.0f,
99.0f, 102.0f, 105.0f, 126.0f, 129.0f, 132.0f, 153.0f, 156.0f, 159.0f,
100.0f, 103.0f, 106.0f, 127.0f, 130.0f, 133.0f, 154.0f, 157.0f, 160.0f,
101.0f, 104.0f, 107.0f, 128.0f, 131.0f, 134.0f, 155.0f, 158.0f, 161.0f,
}),
},
{
ov::PartialShape{2, 5, 4, 4},
getValues<T>({
0.0f, 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f,
12.0f, 13.0f, 14.0f, 15.0f,
16.0f, 17.0f, 18.0f, 19.0f,
20.0f, 21.0f, 22.0f, 23.0f,
24.0f, 25.0f, 26.0f, 27.0f,
28.0f, 29.0f, 30.0f, 31.0f,
32.0f, 33.0f, 34.0f, 35.0f,
36.0f, 37.0f, 38.0f, 39.0f,
40.0f, 41.0f, 42.0f, 43.0f,
44.0f, 45.0f, 46.0f, 47.0f,
48.0f, 49.0f, 50.0f, 51.0f,
52.0f, 53.0f, 54.0f, 55.0f,
56.0f, 57.0f, 58.0f, 59.0f,
60.0f, 61.0f, 62.0f, 63.0f,
64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
111.0f, 112.0f, 113.0f, 114.0f, 115.0f, 116.0f, 117.0f, 118.0f, 119.0f, 120.0f,
121.0f, 122.0f, 123.0f, 124.0f, 125.0f, 126.0f, 127.0f, 128.0f, 129.0f, 130.0f,
131.0f, 132.0f, 133.0f, 134.0f, 135.0f, 136.0f, 137.0f, 138.0f, 139.0f, 140.0f,
141.0f, 142.0f, 143.0f, 144.0f, 145.0f, 146.0f, 147.0f, 148.0f, 149.0f, 150.0f,
151.0f, 152.0f, 153.0f, 154.0f, 155.0f, 156.0f, 157.0f, 158.0f, 159.0f,
}),
2,
getValues<T>({
0.0f, 2.0f,
4.0f, 6.0f,
16.0f, 18.0f,
20.0f, 22.0f,
32.0f, 34.0f,
36.0f, 38.0f,
48.0f, 50.0f,
52.0f, 54.0f,
1.0f, 3.0f,
5.0f, 7.0f,
17.0f, 19.0f,
21.0f, 23.0f,
33.0f, 35.0f,
37.0f, 39.0f,
49.0f, 51.0f,
53.0f, 55.0f,
8.0f, 10.0f, 12.0f, 14.0f, 24.0f, 26.0f, 28.0f, 30.0f,
40.0f, 42.0f, 44.0f, 46.0f, 56.0f, 58.0f, 60.0f, 62.0f, 9.0f, 11.0f,
13.0f, 15.0f, 25.0f, 27.0f, 29.0f, 31.0f, 41.0f, 43.0f, 45.0f, 47.0f,
57.0f, 59.0f, 61.0f, 63.0f, 16.0f, 18.0f, 20.0f, 22.0f, 32.0f, 34.0f,
36.0f, 38.0f, 48.0f, 50.0f, 52.0f, 54.0f, 64.0f, 66.0f, 68.0f, 70.0f,
64.0f, 66.0f, 68.0f, 70.0f, 80.0f, 82.0f, 84.0f, 86.0f, 96.0f, 98.0f,
100.0f, 102.0f, 112.0f, 114.0f, 116.0f, 118.0f, 65.0f, 67.0f, 69.0f, 71.0f,
81.0f, 83.0f, 85.0f, 87.0f, 97.0f, 99.0f, 101.0f, 103.0f, 113.0f, 115.0f,
117.0f, 119.0f, 72.0f, 74.0f, 76.0f, 78.0f, 88.0f, 90.0f, 92.0f, 94.0f,
104.0f, 106.0f, 108.0f, 110.0f, 120.0f, 122.0f, 124.0f, 126.0f, 73.0f, 75.0f,
77.0f, 79.0f, 89.0f, 91.0f, 93.0f, 95.0f, 105.0f, 107.0f, 109.0f, 111.0f,
121.0f, 123.0f, 125.0f, 127.0f, 80.0f, 82.0f, 84.0f, 86.0f, 96.0f, 98.0f,
100.0f, 102.0f, 112.0f, 114.0f, 116.0f, 118.0f, 128.0f, 130.0f, 132.0f, 134.0f,
}),
},
};
return result;
}
template<typename T>
std::vector<ReorgYoloParams<T>> generateInvalidParams() {
static const std::vector<ReorgYoloParams<T>> result = {
{ // Feature < stride*stride
ov::PartialShape{1, 3, 4, 4},
getValues<T>({
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f,
}),
2,
getValues<T>({}),
},
{ // Height % stride != 0
ov::PartialShape{1, 4, 5, 4},
getValues<T>({
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f,
}),
2,
getValues<T>({}),
},
{ // Width % stride != 0
ov::PartialShape{1, 4, 4, 5},
getValues<T>({
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f,
}),
2,
getValues<T>({}),
},
};
return result;
}
struct PrintToStringParamName {
template<class T>
std::string operator()(const testing::TestParamInfo<ReorgYoloParamsWithLayout<T> > &param) {
std::stringstream buf;
ReorgYoloParams<T> p;
format::type target_format;
bool should_fail;
std::tie(p, target_format, should_fail) = param.param;
buf << "InputTensor=" << to_string(p.inputTensor)
<< ".stride=" << p.stride
<< ".TargetLayout=" << fmt_to_str(target_format);
return buf.str();
}
};
}; // namespace
template<typename T>
struct reorg_yolo_test
: public ::testing::TestWithParam<ReorgYoloParamsWithLayout<T> > {
public:
void test() {
ReorgYoloParams<T> params;
format::type target_format;
bool should_fail;
std::tie(params, target_format, should_fail) = this->GetParam();
if (should_fail) {
ASSERT_THROW(run_test(params, target_format), std::invalid_argument);
} else {
ASSERT_NO_FATAL_FAILURE(run_test(params, target_format));
}
}
private:
void run_test(const ReorgYoloParams<T>& params, const format::type target_format) {
const auto data_type = type_to_data_type<T>::value;
const format::type plain_format = format::bfyx;
auto& engine = get_test_engine();
auto input = engine.allocate_memory({params.inputTensor, data_type, plain_format});
set_values(input, params.input);
topology topology;
topology.add(input_layout("input", input->get_layout()));
topology.add(reorder("input_reordered", "input", target_format, data_type));
topology.add(reorg_yolo("reorg_yolo", "input_reordered", params.stride));
topology.add(reorder("reorg_yolo_reordered", "reorg_yolo", plain_format, data_type));
network network(engine, topology);
network.set_input_data("input", input);
const auto result = network.execute();
auto out_mem = result.at("reorg_yolo_reordered").get_memory();
cldnn::mem_lock<T> out_ptr(out_mem, get_test_stream());
ASSERT_EQ(params.expected.size(), out_ptr.size());
for (size_t i = 0; i < params.expected.size(); ++i) {
EXPECT_NEAR(params.expected[i], out_ptr[i], getError<T>()) << "format=" << target_format << ", i= " << i;
}
}
};
using test_f32 = reorg_yolo_test<float>;
using test_f16 = reorg_yolo_test<half_t>;
TEST_P(test_f32, basic) {
test();
}
TEST_P(test_f16, basic) {
test();
}
INSTANTIATE_TEST_SUITE_P(reorg_yolo_f32,
test_f32,
::testing::Combine(
::testing::Values(generateParams<float>()[0]),
::testing::ValuesIn(dataFormats),
::testing::Values(false)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(reorg_yolo_f16,
test_f16,
::testing::Combine(
::testing::ValuesIn(generateParams<half_t>()),
::testing::ValuesIn(dataFormats),
::testing::Values(false)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(reorg_yolo_invalid_input,
test_f32,
::testing::Combine(
::testing::ValuesIn(generateInvalidParams<float>()),
::testing::Values(format::bfyx),
::testing::Values(true)),
PrintToStringParamName());