[IE CLDNN] Add a byxf format for RegionYolo op. (#4451)

This commit is contained in:
Sungeun Kim 2021-03-09 20:48:00 +09:00 committed by GitHub
parent 3757c079c1
commit e40a44202e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 370 additions and 93 deletions

View File

@ -46,21 +46,46 @@ JitConstants RegionYoloKernelRef::GetJitConstants(const region_yolo_params& ry)
return jit; return jit;
} }
bool RegionYoloKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType:: REGION_YOLO || o.GetType() != KernelType::REGION_YOLO) {
return false;
}
const region_yolo_params& params = static_cast<const region_yolo_params&>(p);
const size_t expected_feature_size =
params.do_softmax ? params.inputs[0].X().v * params.inputs[0].Y().v * params.inputs[0].Feature().v : params.inputs[0].Feature().v;
if (expected_feature_size != params.output.Feature().v) {
return false;
}
return true;
}
RegionYoloKernelRef::DispatchData SetDefault(const region_yolo_params& params) { RegionYoloKernelRef::DispatchData SetDefault(const region_yolo_params& params) {
RegionYoloKernelRef::DispatchData dispatchData; RegionYoloKernelRef::DispatchData dispatchData;
const auto& input = params.inputs[0]; const auto& input = params.inputs[0];
if (input.GetLayout() == DataLayout::bfyx) {
dispatchData.gws = {input.X().v * input.Y().v, 1, 1}; switch (input.GetLayout()) {
} else { case DataLayout::bfyx:
dispatchData.gws = {input.Feature().v * input.Batch().v, input.X().v, input.Y().v}; case DataLayout::byxf: {
uint32_t region_num = params.do_softmax ? params.num : params.mask_size;
dispatchData.gws = {input.X().v * input.Y().v, region_num, input.Batch().v};
} break;
default:
throw std::invalid_argument("Unsupported DataLayout");
} }
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
return dispatchData; return dispatchData;
} }
KernelsData RegionYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { KernelsData RegionYoloKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
assert(params.GetType() == KernelType::REGION_YOLO); if (!Validate(params, options)) {
return {};
}
const region_yolo_params& orgParams = static_cast<const region_yolo_params&>(params); const region_yolo_params& orgParams = static_cast<const region_yolo_params&>(params);
DispatchData dispatchData = SetDefault(orgParams); DispatchData dispatchData = SetDefault(orgParams);

View File

@ -61,5 +61,6 @@ public:
protected: protected:
virtual JitConstants GetJitConstants(const region_yolo_params& params) const; virtual JitConstants GetJitConstants(const region_yolo_params& params) const;
bool Validate(const Params& p, const optional_params& o) const override;
}; };
} // namespace kernel_selector } // namespace kernel_selector

View File

@ -12,94 +12,79 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "include/common.cl" #include "include/fetch.cl"
#include "include/data_types.cl"
#define IW INPUT0_SIZES[0] inline INPUT0_TYPE FUNC(logistic_activate)(INPUT0_TYPE x) {
#define IH INPUT0_SIZES[1]
#define IC INPUT0_SIZES[2]
#define IB INPUT0_SIZES[3]
inline UNIT_TYPE FUNC(logistic_activate)(UNIT_TYPE x) {
return 1. / (1. + exp(-x)); return 1. / (1. + exp(-x));
} }
inline int FUNC(entry_index)(int width, int height, int coords, int classes, inline int FUNC(output_index)(int batch, int region_num, int x, int y, int xy, int feature_offset) {
int outputs, int batch, int location,
int entry) {
int n = location / (width * height);
int loc = location % (width * height);
return batch * outputs + n * width * height * (coords + classes + 1) +
entry * width * height + loc;
}
#if DO_SOFTMAX #if DO_SOFTMAX
inline void FUNC(softmax_generic)(const __global UNIT_TYPE* src_data, __global UNIT_TYPE* dst_data, return OUTPUT_GET_INDEX(batch, feature_offset * INPUT0_SIZE_X * INPUT0_SIZE_Y + xy, 1, 1);
int B, int C, int W, int H, int i)
{
for (int b = 0; b < B; b++) {
UNIT_TYPE max = src_data[b*C*H*W + i];
for (int c = 0; c < C; c++) {
UNIT_TYPE val = src_data[b*C*H*W + c*H*W + i];
if (val > max) max = val;
}
UNIT_TYPE expSum = 0;
for (int c = 0; c < C; c++) {
dst_data[b*C*H*W + c*H*W + i] = exp(src_data[b*C*H*W + c*H*W + i] - max);
expSum += dst_data[b*C*H*W + c*H*W + i];
}
for (int c = 0; c < C; c++) {
dst_data[b*C*H*W + c*H*W + i] = dst_data[b*C*H*W + c*H*W + i] / expSum;
}
}
}
#endif
KERNEL (region_yolo_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output)
{
int x = get_global_id(0);
#if DO_SOFTMAX
#define ACTUAL_NUM (NUM)
#define CONF_CLASSES (1)
#else #else
#define ACTUAL_NUM (MASK_SIZE) return OUTPUT_GET_INDEX(batch, feature_offset, y, x);
#define CONF_CLASSES (CLASSES+1) #endif
#endif }
#define INPUTS_COUNT (IH * IW * ACTUAL_NUM * (CLASSES + COORDS + 1))
KERNEL (region_yolo_ref)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output)
for (int b = 0; b < IB; b++) { {
for (int n = 0; n < ACTUAL_NUM; n++) { int xy = get_global_id(0);
// coords: x/y int region_num = get_global_id(1);
int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 0); int batch = get_global_id(2);
int i = index + 2 * x; int x_index = xy % INPUT0_SIZE_X;
output[i] = FUNC_CALL(logistic_activate)(input[i]); int y_index = (xy / INPUT0_SIZE_X) % (INPUT0_SIZE_Y);
output[i+1] = FUNC_CALL(logistic_activate)(input[i+1]);
/// [x, y, width, height, objectness score, class score]
// coords: w/h: directly copy? /// x,y
index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, 2); int region_offset = region_num * (COORDS + CLASSES + 1);
i = index + 2 * x; int in_i = INPUT0_GET_INDEX(batch, 0 + region_offset, y_index, x_index);
output[i] = input[i]; int out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 0 + region_offset);
output[i+1] = input[i+1]; output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
// confidence in_i = INPUT0_GET_INDEX(batch, 1 + region_offset, y_index, x_index);
index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, b, n * IW * IH, COORDS); out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 1 + region_offset);
for (int j = 0; j < CONF_CLASSES; j++) output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
{
i = index + x + j*IH*IW; /// width,height
output[i] = FUNC_CALL(logistic_activate)(input[i]); in_i = INPUT0_GET_INDEX(batch, 2 + region_offset, y_index, x_index);
} out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 2 + region_offset);
} output[out_i] = input[in_i];
}
in_i = INPUT0_GET_INDEX(batch, 3 + region_offset, y_index, x_index);
#if DO_SOFTMAX out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, 3 + region_offset);
// the probability of classes output[out_i] = input[in_i];
int index = FUNC_CALL(entry_index)(IW, IH, COORDS, CLASSES, INPUTS_COUNT, 0, 0, COORDS + 1);
int batch_offset = INPUTS_COUNT / NUM; /// objectness score
for (int b = 0; b < IB * NUM; b++) in_i = INPUT0_GET_INDEX(batch, COORDS + region_offset, y_index, x_index);
FUNC_CALL(softmax_generic)(input + index + b * batch_offset, output + index + b * batch_offset, out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + region_offset);
1, CLASSES, IH, IW, x); output[out_i] = FUNC_CALL(logistic_activate)(input[in_i]);
/// class score(confidence)
#if DO_SOFTMAX
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + region_offset, y_index, x_index);
INPUT0_TYPE max_value = input[in_i];
for (int j = 1; j < CLASSES; j++) {
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
max_value = max(max_value, input[in_i]);
}
OUTPUT_TYPE expSum = 0;
for (int j = 0; j < CLASSES; j++) {
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset);
output[out_i] = exp(input[in_i] - max_value);
expSum += output[out_i];
}
for (int j = 0; j < CLASSES; j++) {
out_i = FUNC_CALL(output_index)(batch, region_num, x_index, y_index, xy, COORDS + 1 + j + region_offset);
output[out_i] /= expSum;
}
#else
for (int j = 0; j < CLASSES; j++)
{
in_i = INPUT0_GET_INDEX(batch, COORDS + 1 + j + region_offset, y_index, x_index);
output[in_i] = FUNC_CALL(logistic_activate)(input[in_i]);
}
#endif #endif
} }

View File

@ -866,10 +866,6 @@ format layout_optimizer::get_preferred_format(program_node& node) {
if (input_layout.format.dimension() == 5 && if (input_layout.format.dimension() == 5 &&
(input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16)) (input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16))
expected = format::bfzyx; expected = format::bfzyx;
} else if (node.is_type<region_yolo>()) {
if (_optimization_attributes.b_fs_yx_fsv16_network) {
expected = format::bfyx;
}
} }
return expected; return expected;

View File

@ -3536,7 +3536,7 @@ public:
EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size)); EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
if (sizeof(input0_type) == 1) { if (sizeof(input0_type) == 1) {
for (size_t i = 0; i < out_data.size(); ++i) { for (size_t i = 0; i < out_data.size(); ++i) {
EXPECT_FLOAT_EQ(float(output_ptr[i]), float(out_data[i])) << "index = " << i; EXPECT_NEAR(float(output_ptr[i]), float(out_data[i]), 1e-1) << "index = " << i;
} }
} else if (sizeof(input0_type) == 2) { } else if (sizeof(input0_type) == 2) {
for (size_t i = 0; i < out_data.size(); ++i) { for (size_t i = 0; i < out_data.size(); ++i) {

View File

@ -0,0 +1,270 @@
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
#include <api/input_layout.hpp>
#include <api/memory.hpp>
#include <api/region_yolo.hpp>
#include <api/reorder.hpp>
#include <api/topology.hpp>
#include <api/network.hpp>
#include <cstddef>
#include <tests/test_utils/test_utils.h>
using namespace cldnn;
using namespace ::tests;
namespace internal
{
static inline int entry_index(int width,
int height,
int coords,
int classes,
int outputs,
int batch,
int location,
int entry)
{
int n = location / (width * height);
int loc = location % (width * height);
return batch * outputs + n * width * height * (coords + classes + 1) +
entry * width * height + loc;
}
template <typename T>
static inline T sigmoid(float x)
{
return static_cast<T>(1.f / (1.f + std::exp(-x)));
}
template <typename T>
static inline void softmax_generic(const T* src_data, T* dst_data,
uint32_t batches, uint32_t channels, uint32_t height, uint32_t width)
{
const uint32_t area = height * width;
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
{
const int offset = batch_idx * channels * area;
for (unsigned int i = 0; i < height * width; i++)
{
T max = src_data[batch_idx * channels * area + i];
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
T val = src_data[offset + channel_idx * area + i];
max = std::max(max, val);
}
T sum = 0;
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
dst_data[offset + channel_idx * area + i] =
std::exp((float)(src_data[offset + channel_idx * area + i] - max));
sum += dst_data[offset + channel_idx * area + i];
}
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
dst_data[offset + channel_idx * area + i] /= sum;
}
}
}
}
uint32_t shape_size(const std::vector<uint32_t>& input_shape)
{
uint32_t ret = 1;
std::for_each(input_shape.begin(), input_shape.end(), [&ret](uint32_t n){
ret *= n;
});
return ret;
}
template <typename T>
void region_yolo(const T* input,
T* output,
const std::vector<uint32_t>& input_shape,
const uint32_t coords,
const uint32_t classes,
const uint32_t regions,
const bool do_softmax,
const std::vector<int64_t>& mask)
{
EXPECT_EQ(input_shape.size(), 4);
const uint32_t batches = input_shape[0];
//const uint32_t channels = input_shape[1];
const uint32_t height = input_shape[2];
const uint32_t width = input_shape[3];
const auto mask_size = mask.size();
std::copy(input, input + shape_size(input_shape), output);
uint32_t num_regions = 0;
uint32_t end_index = 0;
if (do_softmax)
{
// Region layer (Yolo v2)
num_regions = regions;
end_index = width * height;
}
else
{
// Yolo layer (Yolo v3)
num_regions = static_cast<uint32_t>(mask_size);
end_index = width * height * (classes + 1);
}
const uint32_t inputs_size = width * height * num_regions * (classes + coords + 1);
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
{
for (unsigned int n = 0; n < num_regions; n++)
{
int index = entry_index(width,
height,
coords,
classes,
inputs_size,
batch_idx,
n * width * height,
0);
std::transform(input + index,
input + index + 2 * width * height,
output + index,
[](T elem) { return sigmoid<T>(elem); });
index = entry_index(width,
height,
coords,
classes,
inputs_size,
batch_idx,
n * width * height,
coords);
std::transform(input + index,
input + index + end_index,
output + index,
[](T elem) { return sigmoid<T>(elem); });
}
}
if (do_softmax)
{
int index =
entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1);
int batch_offset = inputs_size / regions;
for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++)
{
softmax_generic<T>(input + index + batch_idx * batch_offset,
output + index + batch_idx * batch_offset,
1,
classes,
height,
width);
}
}
}
struct region_yolo_test_params {
std::vector<uint32_t> tensor;
std::vector<int64_t> mask;
uint32_t coords;
uint32_t classes;
uint32_t regionNum;
data_types dataType;
format fmt;
bool softMax;
};
}
template <typename T>
static void runRegionTest(internal::region_yolo_test_params& params)
{
engine eng;
const tensor kInputTensor(params.tensor[0], params.tensor[1], params.tensor[2], params.tensor[3]);
auto inputData = generate_random_1d<T>(params.tensor[0] * params.tensor[1] * params.tensor[2] * params.tensor[3], -1, 1);
auto inputPrim = memory::allocate(eng, { params.dataType, format::bfyx, kInputTensor });
set_values(inputPrim, inputData);
topology topology;
topology.add(input_layout("InputData", inputPrim.get_layout()));
topology.add(reorder("reorder_pre", "InputData", params.fmt, params.dataType));
topology.add(region_yolo("region_yolo", "reorder_pre", params.coords, params.classes,
params.regionNum, static_cast<uint32_t>(params.mask.size()), params.softMax));
topology.add(reorder("reorder_post", "region_yolo", format::bfyx, params.dataType));
network network(eng, topology);
network.set_input_data("InputData", inputPrim);
auto outputs = network.execute();
auto output = outputs.at("reorder_post").get_memory();
auto outputData = output.pointer<T>();
/// reference value
std::vector<T> refOutputData(inputData.size());
internal::region_yolo<T>(inputData.data(), refOutputData.data(),
params.tensor, params.coords, params.classes,
params.regionNum, params.softMax, params.mask);
/// compare values
for (size_t i = 0; i < inputData.size(); ++i) {
EXPECT_NEAR(refOutputData[i], outputData[i], 0.01);
}
}
TEST(region_yolo_gpu_fp32, bfyx) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false};
runRegionTest<float>(params);
}
TEST(region_yolo_gpu_fp32, bfyx_softmax) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true};
runRegionTest<float>(params);
}
TEST(region_yolo_gpu_fp32, byxf) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false};
runRegionTest<float>(params);
}
TEST(region_yolo_gpu_fp32, byxf_softmax) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true};
runRegionTest<float>(params);
}
TEST(region_yolo_gpu_fp16, bfyx) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false};
runRegionTest<FLOAT16>(params);
}
TEST(region_yolo_gpu_fp16, bfyx_softmax) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true};
runRegionTest<FLOAT16>(params);
}
TEST(region_yolo_gpu_fp16, byxf) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false};
runRegionTest<FLOAT16>(params);
}
TEST(region_yolo_gpu_fp16, byxf_softmax) {
internal::region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true};
runRegionTest<FLOAT16>(params);
}