diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.cpp index 40900737174..18e0390ac18 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.cpp @@ -46,21 +46,46 @@ JitConstants RegionYoloKernelRef::GetJitConstants(const region_yolo_params& ry) return jit; } +bool RegionYoloKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType:: REGION_YOLO || o.GetType() != KernelType::REGION_YOLO) { + return false; + } + + const region_yolo_params& params = static_cast(p); + const size_t expected_feature_size = + params.do_softmax ? params.inputs[0].X().v * params.inputs[0].Y().v * params.inputs[0].Feature().v : params.inputs[0].Feature().v; + + if (expected_feature_size != params.output.Feature().v) { + return false; + } + + return true; +} + RegionYoloKernelRef::DispatchData SetDefault(const region_yolo_params& params) { RegionYoloKernelRef::DispatchData dispatchData; const auto& input = params.inputs[0]; - if (input.GetLayout() == DataLayout::bfyx) { - dispatchData.gws = {input.X().v * input.Y().v, 1, 1}; - } else { - dispatchData.gws = {input.Feature().v * input.Batch().v, input.X().v, input.Y().v}; + + switch (input.GetLayout()) { + case DataLayout::bfyx: + case DataLayout::byxf: { + uint32_t region_num = params.do_softmax ? params.num : params.mask_size; + dispatchData.gws = {input.X().v * input.Y().v, region_num, input.Batch().v}; + } break; + default: + throw std::invalid_argument("Unsupported DataLayout"); } dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); return dispatchData; } + KernelsData RegionYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { - assert(params.GetType() == KernelType::REGION_YOLO); + if (!Validate(params, options)) { + return {}; + } + const region_yolo_params& orgParams = static_cast(params); DispatchData dispatchData = SetDefault(orgParams); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.h index d3639e39ea0..3c9fc70c76c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/region_yolo/region_yolo_kernel_ref.h @@ -61,5 +61,6 @@ public: protected: virtual JitConstants GetJitConstants(const region_yolo_params& params) const; + bool Validate(const Params& p, const optional_params& o) const override; }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/region_yolo_gpu_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/region_yolo_gpu_ref.cl index de62864f8b8..8e905c64662 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/region_yolo_gpu_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/region_yolo_gpu_ref.cl @@ -12,94 +12,79 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "include/common.cl" -#include "include/data_types.cl" +#include "include/fetch.cl" -#define IW INPUT0_SIZES[0] -#define IH INPUT0_SIZES[1] -#define IC INPUT0_SIZES[2] -#define IB INPUT0_SIZES[3] - -inline UNIT_TYPE FUNC(logistic_activate)(UNIT_TYPE x) { +inline INPUT0_TYPE FUNC(logistic_activate)(INPUT0_TYPE x) { return 1. / (1. + exp(-x)); } -inline int FUNC(entry_index)(int width, int height, int coords, int classes, - int outputs, int batch, int location, - int entry) { - int n = location / (width * height); - int loc = location % (width * height); - return batch * outputs + n * width * height * (coords + classes + 1) + - entry * width * height + loc; -} - +inline int FUNC(output_index)(int batch, int region_num, int x, int y, int xy, int feature_offset) { #if DO_SOFTMAX -inline void FUNC(softmax_generic)(const __global UNIT_TYPE* src_data, __global UNIT_TYPE* dst_data, - int B, int C, int W, int H, int i) -{ - for (int b = 0; b < B; b++) { - UNIT_TYPE max = src_data[b*C*H*W + i]; - for (int c = 0; c < C; c++) { - UNIT_TYPE val = src_data[b*C*H*W + c*H*W + i]; - if (val > max) max = val; - } - - UNIT_TYPE expSum = 0; - for (int c = 0; c < C; c++) { - dst_data[b*C*H*W + c*H*W + i] = exp(src_data[b*C*H*W + c*H*W + i] - max); - expSum += dst_data[b*C*H*W + c*H*W + i]; - } - - for (int c = 0; c < C; c++) { - dst_data[b*C*H*W + c*H*W + i] = dst_data[b*C*H*W + c*H*W + i] / expSum; - } - } -} -#endif - -KERNEL (region_yolo_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output) -{ - int x = get_global_id(0); - -#if DO_SOFTMAX - #define ACTUAL_NUM (NUM) - #define CONF_CLASSES (1) + return OUTPUT_GET_INDEX(batch, feature_offset * INPUT0_SIZE_X * INPUT0_SIZE_Y + xy, 1, 1); #else - #define ACTUAL_NUM (MASK_SIZE) - #define CONF_CLASSES (CLASSES+1) -#endif - #define INPUTS_COUNT (IH * IW * ACTUAL_NUM * (CLASSES + COORDS + 1)) - - for (int b = 0; b < IB; b++) { - for (int n = 0; n < ACTUAL_NUM; n++) { - // coords: x/y - int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 0); - int i = index + 2 * x; - output[i] = FUNC_CALL(logistic_activate)(input[i]); - output[i+1] = FUNC_CALL(logistic_activate)(input[i+1]); - - // coords: w/h: directly copy? - index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 2); - i = index + 2 * x; - output[i] = input[i]; - output[i+1] = input[i+1]; - - // confidence - index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, COORDS); - for (int j = 0; j < CONF_CLASSES; j++) - { - i = index + x + j*IH*IW; - output[i] = FUNC_CALL(logistic_activate)(input[i]); - } - } - } - -#if DO_SOFTMAX - // the probability of classes - int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, 0, 0, COORDS + 1); - int batch_offset = INPUTS_COUNT / NUM; - for (int b = 0; b < IB * NUM; b++) - FUNC_CALL(softmax_generic)(input + index + b * batch_offset, output + index + b * batch_offset, - 1, CLASSES, IH, IW, x); + return OUTPUT_GET_INDEX(batch, feature_offset, y, x); +#endif +} + +KERNEL (region_yolo_ref)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output) +{ + int xy = get_global_id(0); + int region_num = get_global_id(1); + int batch = get_global_id(2); + int x_index = xy % INPUT0_SIZE_X; + int y_index = (xy / INPUT0_SIZE_X) % (INPUT0_SIZE_Y); + + /// [x, y, width, height, objectness score, class score] + /// x,y + int region_offset = region_num * (COORDS + CLASSES + 1); + int in_i = INPUT0_GET_INDEX(batch, 0 + region_offset, y_index, x_index); + int out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 0 + region_offset); + output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]); + + in_i = INPUT0_GET_INDEX(batch, 1 + region_offset, y_index, x_index); + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 1 + region_offset); + output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]); + + /// width,height + in_i = INPUT0_GET_INDEX(batch, 2 + region_offset, y_index, x_index); + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 2 + region_offset); + output[out_i] = input[in_i]; + + in_i = INPUT0_GET_INDEX(batch, 3 + region_offset, y_index, x_index); + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 3 + region_offset); + output[out_i] = input[in_i]; + + /// objectness score + in_i = INPUT0_GET_INDEX(batch, COORDS + region_offset, y_index, x_index); + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + region_offset); + output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]); + + /// class score(confidence) +#if DO_SOFTMAX + in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + region_offset, y_index, x_index); + INPUT0_TYPE max_value = input[in_i]; + for (int j = 1; j < CLASSES; j++) { + in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index); + max_value = max(max_value, input[in_i]); + } + + OUTPUT_TYPE expSum = 0; + for (int j = 0; j < CLASSES; j++) { + in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index); + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset); + output[out_i] = exp(input[in_i] - max_value); + expSum += output[out_i]; + } + + for (int j = 0; j < CLASSES; j++) { + out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset); + output[out_i] /= expSum; + } +#else + for (int j = 0; j < CLASSES; j++) + { + in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index); + output[in_i] = FUNC_CALL(logistic_activate)(input[in_i]); + } #endif } diff --git a/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp b/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp index 3816042ad4c..a73f5993cbb 100644 --- a/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp +++ b/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp @@ -866,10 +866,6 @@ format layout_optimizer::get_preferred_format(program_node& node) { if (input_layout.format.dimension() == 5 && (input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16)) expected = format::bfzyx; - } else if (node.is_type()) { - if (_optimization_attributes.b_fs_yx_fsv16_network) { - expected = format::bfyx; - } } return expected; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp index 12ccd310ac8..57407dcd014 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp @@ -3536,7 +3536,7 @@ public: EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size)); if (sizeof(input0_type) == 1) { for (size_t i = 0; i < out_data.size(); ++i) { - EXPECT_FLOAT_EQ(float(output_ptr[i]), float(out_data[i])) << "index = " << i; + EXPECT_NEAR(float(output_ptr[i]), float(out_data[i]), 1e-1) << "index = " << i; } } else if (sizeof(input0_type) == 2) { for (size_t i = 0; i < out_data.size(); ++i) { diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/region_yolo_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/region_yolo_gpu_test.cpp new file mode 100644 index 00000000000..1e754486c72 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/region_yolo_gpu_test.cpp @@ -0,0 +1,270 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +namespace internal +{ + static inline int entry_index(int width, + int height, + int coords, + int classes, + int outputs, + int batch, + int location, + int entry) + { + int n = location / (width * height); + int loc = location % (width * height); + return batch * outputs + n * width * height * (coords + classes + 1) + + entry * width * height + loc; + } + + template + static inline T sigmoid(float x) + { + return static_cast(1.f / (1.f + std::exp(-x))); + } + + template + static inline void softmax_generic(const T* src_data, T* dst_data, + uint32_t batches, uint32_t channels, uint32_t height, uint32_t width) + { + const uint32_t area = height * width; + for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++) + { + const int offset = batch_idx * channels * area; + for (unsigned int i = 0; i < height * width; i++) + { + T max = src_data[batch_idx * channels * area + i]; + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + T val = src_data[offset + channel_idx * area + i]; + max = std::max(max, val); + } + + T sum = 0; + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + dst_data[offset + channel_idx * area + i] = + std::exp((float)(src_data[offset + channel_idx * area + i] - max)); + sum += dst_data[offset + channel_idx * area + i]; + } + + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + dst_data[offset + channel_idx * area + i] /= sum; + } + } + } + } + + uint32_t shape_size(const std::vector& input_shape) + { + uint32_t ret = 1; + std::for_each(input_shape.begin(), input_shape.end(), [&ret](uint32_t n){ + ret *= n; + }); + + return ret; + } + + template + void region_yolo(const T* input, + T* output, + const std::vector& input_shape, + const uint32_t coords, + const uint32_t classes, + const uint32_t regions, + const bool do_softmax, + const std::vector& mask) + { + EXPECT_EQ(input_shape.size(), 4); + + const uint32_t batches = input_shape[0]; + //const uint32_t channels = input_shape[1]; + const uint32_t height = input_shape[2]; + const uint32_t width = input_shape[3]; + + const auto mask_size = mask.size(); + + std::copy(input, input + shape_size(input_shape), output); + + uint32_t num_regions = 0; + uint32_t end_index = 0; + + if (do_softmax) + { + // Region layer (Yolo v2) + num_regions = regions; + end_index = width * height; + } + else + { + // Yolo layer (Yolo v3) + num_regions = static_cast(mask_size); + end_index = width * height * (classes + 1); + } + + const uint32_t inputs_size = width * height * num_regions * (classes + coords + 1); + for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++) + { + for (unsigned int n = 0; n < num_regions; n++) + { + int index = entry_index(width, + height, + coords, + classes, + inputs_size, + batch_idx, + n * width * height, + 0); + std::transform(input + index, + input + index + 2 * width * height, + output + index, + [](T elem) { return sigmoid(elem); }); + + index = entry_index(width, + height, + coords, + classes, + inputs_size, + batch_idx, + n * width * height, + coords); + std::transform(input + index, + input + index + end_index, + output + index, + [](T elem) { return sigmoid(elem); }); + } + } + + if (do_softmax) + { + int index = + entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1); + int batch_offset = inputs_size / regions; + for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++) + { + softmax_generic(input + index + batch_idx * batch_offset, + output + index + batch_idx * batch_offset, + 1, + classes, + height, + width); + } + } + } + + struct region_yolo_test_params { + std::vector tensor; + std::vector mask; + uint32_t coords; + uint32_t classes; + uint32_t regionNum; + data_types dataType; + format fmt; + bool softMax; + }; +} + +template +static void runRegionTest(internal::region_yolo_test_params& params) +{ + engine eng; + const tensor kInputTensor(params.tensor[0], params.tensor[1], params.tensor[2], params.tensor[3]); + auto inputData = generate_random_1d(params.tensor[0] * params.tensor[1] * params.tensor[2] * params.tensor[3], -1, 1); + + auto inputPrim = memory::allocate(eng, { params.dataType, format::bfyx, kInputTensor }); + set_values(inputPrim, inputData); + + topology topology; + topology.add(input_layout("InputData", inputPrim.get_layout())); + topology.add(reorder("reorder_pre", "InputData", params.fmt, params.dataType)); + topology.add(region_yolo("region_yolo", "reorder_pre", params.coords, params.classes, + params.regionNum, static_cast(params.mask.size()), params.softMax)); + topology.add(reorder("reorder_post", "region_yolo", format::bfyx, params.dataType)); + + network network(eng, topology); + network.set_input_data("InputData", inputPrim); + + auto outputs = network.execute(); + auto output = outputs.at("reorder_post").get_memory(); + auto outputData = output.pointer(); + + /// reference value + std::vector refOutputData(inputData.size()); + internal::region_yolo(inputData.data(), refOutputData.data(), + params.tensor, params.coords, params.classes, + params.regionNum, params.softMax, params.mask); + + /// compare values + for (size_t i = 0; i < inputData.size(); ++i) { + EXPECT_NEAR(refOutputData[i], outputData[i], 0.01); + } +} + +TEST(region_yolo_gpu_fp32, bfyx) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp32, bfyx_softmax) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp32, byxf) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp32, byxf_softmax) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp16, bfyx) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp16, bfyx_softmax) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp16, byxf) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false}; + runRegionTest(params); +} + +TEST(region_yolo_gpu_fp16, byxf_softmax) { + internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true}; + runRegionTest(params); +}