[IE CLDNN] Refreshment of convolution_gpu_fs_byx_fsv32 kernel (#2536)
This commit is contained in:
parent
3f55733b43
commit
dba94b1f19
@ -37,6 +37,9 @@ ParamsKey ConvolutionKernel_fs_byx_fsv32::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableInputWeightsType(WeightsType::F16);
|
||||
k.EnableInputLayout(DataLayout::fs_b_yx_fsv32);
|
||||
k.EnableOutputLayout(DataLayout::fs_b_yx_fsv32);
|
||||
@ -149,7 +152,11 @@ bool ConvolutionKernel_fs_byx_fsv32::Validate(const Params& p, const optional_pa
|
||||
JitConstants ConvolutionKernel_fs_byx_fsv32::GetJitConstants(const convolution_params& params,
|
||||
const DispatchData& kd) const {
|
||||
auto jit = ConvolutionKernelBase::GetJitConstants(params, kd);
|
||||
auto accumulator_type = GetAccumulatorType(params);
|
||||
auto activation_type = GetAccumulatorType(params);
|
||||
|
||||
jit.Merge(MakeTypeJitConstants(accumulator_type, "ACCUMULATOR"));
|
||||
jit.Merge(MakeTypeJitConstants(activation_type, "ACTIVATION"));
|
||||
jit.AddConstant(MakeJitConstant("INPUT_BLOCK_WIDTH", kd.cldnnStyle.inputBlockWidth));
|
||||
jit.AddConstant(MakeJitConstant("OUTPUT_BLOCK_WIDTH", kd.cldnnStyle.blockWidth));
|
||||
jit.AddConstant(MakeJitConstant("FSV", fsv));
|
||||
@ -157,13 +164,12 @@ JitConstants ConvolutionKernel_fs_byx_fsv32::GetJitConstants(const convolution_p
|
||||
jit.AddConstant(MakeJitConstant("FSV_PER_THREAD", fsvPerThread));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
auto input_dt = GetUnitType(params);
|
||||
FusedOpsConfiguration conf_vec_elem = {"_VEC_ELEM",
|
||||
{"b", "(fs * FSV + sglid + out_f * SUB_GROUP_SIZE)", "or", "oc + out_x"},
|
||||
"tmp_write[out_f]", input_dt, 1 };
|
||||
"tmp_write[out_f]", activation_type, 1 };
|
||||
FusedOpsConfiguration conf_scalar = {"_SCALAR",
|
||||
{"b", "(fs * FSV + sglid + out_f * SUB_GROUP_SIZE)", "or", "oc + out_x"},
|
||||
"out[out_idx]", input_dt, 1 };
|
||||
"res", activation_type, 1 };
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf_vec_elem, conf_scalar}));
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -14,7 +14,6 @@
|
||||
|
||||
#include "include/common.cl"
|
||||
#include "include/data_types.cl"
|
||||
#include "include/unit_type.cl"
|
||||
#include "include/include_all.cl"
|
||||
|
||||
#define unroll_for __attribute__((opencl_unroll_hint)) for
|
||||
@ -33,6 +32,14 @@
|
||||
|
||||
#define ALIGNED_IFM_NUM (((FILTER_IFM_NUM + FSV - 1) / FSV) * FSV)
|
||||
|
||||
#define INPUT_TYPE2 MAKE_VECTOR_TYPE(INPUT0_TYPE, 2)
|
||||
#define INPUT_TYPE4 MAKE_VECTOR_TYPE(INPUT0_TYPE, 4)
|
||||
#define BIAS_TYPE2 MAKE_VECTOR_TYPE(BIAS_TYPE, 2)
|
||||
#define FILTER_TYPE2 MAKE_VECTOR_TYPE(FILTER_TYPE, 2)
|
||||
#define ACTIVATION_TYPE2 MAKE_VECTOR_TYPE(ACTIVATION_TYPE, 2)
|
||||
#define OUTPUT_TYPE2 MAKE_VECTOR_TYPE(OUTPUT_TYPE, 2)
|
||||
#define TO_OUTPUT_TYPE2 CAT(convert_, OUTPUT_TYPE2)
|
||||
|
||||
// ======================================================================================
|
||||
// Required JIT definitions:
|
||||
// --------------------------------------------------------------------------------------
|
||||
@ -48,11 +55,11 @@
|
||||
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
|
||||
__attribute__((reqd_work_group_size(1, 1, SUB_GROUP_SIZE)))
|
||||
KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
__global UNIT_TYPE* input,
|
||||
__global UNIT_TYPE* output,
|
||||
__global UNIT_TYPE* weights,
|
||||
__global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output,
|
||||
__global FILTER_TYPE* weights,
|
||||
#if BIAS_TERM
|
||||
__global UNIT_TYPE* biases,
|
||||
__global BIAS_TYPE* biases,
|
||||
#endif
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
FUSED_OPS_DECLS,
|
||||
@ -67,14 +74,9 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
uint fs = fs_b_id / INPUT0_BATCH_NUM;
|
||||
uint b = fs_b_id - fs * INPUT0_BATCH_NUM;
|
||||
|
||||
UNIT_TYPE in[INPUT_BLOCK_WIDTH * FSV_PER_THREAD];
|
||||
UNIT_TYPE w[FSV_PER_THREAD];
|
||||
UNIT_TYPE out[OUTPUT_BLOCK_WIDTH * FSV_PER_THREAD];
|
||||
|
||||
for (uint out_i = 0; out_i < OUTPUT_BLOCK_WIDTH * FSV_PER_THREAD; ++out_i)
|
||||
{
|
||||
out[out_i] = UNIT_VAL_ZERO;
|
||||
}
|
||||
INPUT0_TYPE in[INPUT_BLOCK_WIDTH * FSV_PER_THREAD];
|
||||
FILTER_TYPE w[FSV_PER_THREAD];
|
||||
ACCUMULATOR_TYPE out[OUTPUT_BLOCK_WIDTH * FSV_PER_THREAD] = { ACCUMULATOR_VAL_ZERO };
|
||||
|
||||
// Calculate offset to first input data element
|
||||
const uint in_pitch_x = FSV;
|
||||
@ -102,7 +104,7 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
uint in_x = 0;
|
||||
unroll_for (; in_x + 2 <= INPUT_BLOCK_WIDTH; in_x += 2)
|
||||
{
|
||||
UNIT_TYPE4 tmp_read = UNIT_BLOCK_READ4(input, tmp_input_offset + in_x * FSV);
|
||||
INPUT_TYPE4 tmp_read = DT_INPUT_BLOCK_READ4(input, tmp_input_offset + in_x * FSV);
|
||||
in[in_x * FSV_PER_THREAD + 0] = tmp_read.s0;
|
||||
in[in_x * FSV_PER_THREAD + 1] = tmp_read.s1;
|
||||
in[in_x * FSV_PER_THREAD + 2] = tmp_read.s2;
|
||||
@ -110,7 +112,7 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
}
|
||||
unroll_for (; in_x < INPUT_BLOCK_WIDTH; ++in_x)
|
||||
{
|
||||
UNIT_TYPE2 tmp_read = UNIT_BLOCK_READ2(input, tmp_input_offset + in_x * FSV);
|
||||
INPUT_TYPE2 tmp_read = DT_INPUT_BLOCK_READ2(input, tmp_input_offset + in_x * FSV);
|
||||
in[in_x * FSV_PER_THREAD + 0] = tmp_read.s0;
|
||||
in[in_x * FSV_PER_THREAD + 1] = tmp_read.s1;
|
||||
}
|
||||
@ -118,7 +120,7 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
|
||||
// Move temporary input offset to next row
|
||||
tmp_input_offset += DILATION_SIZE_Y * in_pitch_y;
|
||||
|
||||
|
||||
uint tmp_weight_offset = weight_offset;
|
||||
|
||||
// ====================================================================
|
||||
@ -129,7 +131,7 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
unroll_for (uint f_x = 0; f_x < FILTER_SIZE_X; ++f_x)
|
||||
{
|
||||
// Load weights
|
||||
UNIT_TYPE2 tmp_read = UNIT_BLOCK_READ2(weights, tmp_weight_offset + f_x * FSV);
|
||||
FILTER_TYPE2 tmp_read = DT_FILTER_BLOCK_READ2(weights, tmp_weight_offset + f_x * FSV);
|
||||
w[0] = tmp_read.s0;
|
||||
w[1] = tmp_read.s1;
|
||||
|
||||
@ -137,12 +139,12 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
{
|
||||
unroll_for (uint out_f = 0; out_f < FSV_PER_THREAD; ++out_f)
|
||||
{
|
||||
UNIT_TYPE in_val = intel_sub_group_shuffle(
|
||||
INPUT0_TYPE in_val = intel_sub_group_shuffle(
|
||||
in[(out_x * STRIDE_SIZE_X + f_x * DILATION_SIZE_X) * FSV_PER_THREAD + ifii / SUB_GROUP_SIZE],
|
||||
ifii % SUB_GROUP_SIZE);
|
||||
|
||||
const uint out_idx = out_x * FSV_PER_THREAD + out_f;
|
||||
out[out_idx] = mad(in_val, w[out_f], out[out_idx]);
|
||||
out[out_idx] = mad(TO_ACCUMULATOR_TYPE(in_val), TO_ACCUMULATOR_TYPE(w[out_f]), out[out_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,30 +175,18 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
const uint bias_index = (fs * FSV + out_f * SUB_GROUP_SIZE + sglid) * OUTPUT_SIZE_X * OUTPUT_SIZE_Y +
|
||||
or * OUTPUT_SIZE_X +
|
||||
(oc + out_x);
|
||||
out[out_x * FSV_PER_THREAD + out_f] += biases[bias_index];
|
||||
out[out_x * FSV_PER_THREAD + out_f] += TO_ACCUMULATOR_TYPE(biases[bias_index]);
|
||||
}
|
||||
# else // BIAS_PER_OUTPUT
|
||||
const uint bias_index = fs * FSV;
|
||||
UNIT_TYPE2 bias_read = UNIT_BLOCK_READ2(biases, bias_index);
|
||||
out[out_x * FSV_PER_THREAD + 0] += bias_read.s0;
|
||||
out[out_x * FSV_PER_THREAD + 1] += bias_read.s1;
|
||||
BIAS_TYPE2 bias_read = DT_BIAS_BLOCK_READ2(biases, bias_index);
|
||||
out[out_x * FSV_PER_THREAD + 0] += TO_ACCUMULATOR_TYPE(bias_read.s0);
|
||||
out[out_x * FSV_PER_THREAD + 1] += TO_ACCUMULATOR_TYPE(bias_read.s1);
|
||||
# endif // BIAS_PER_OUTPUT
|
||||
}
|
||||
#endif // BIAS_TERM
|
||||
// ========================================================================
|
||||
|
||||
// ========================================================================
|
||||
// Activation
|
||||
unroll_for (uint out_x = 0; out_x < OUTPUT_BLOCK_WIDTH; ++out_x)
|
||||
{
|
||||
unroll_for (uint out_f = 0; out_f < FSV_PER_THREAD; ++out_f)
|
||||
{
|
||||
const uint out_idx = out_x * FSV_PER_THREAD + out_f;
|
||||
out[out_idx] = ACTIVATION(out[out_idx], ACTIVATION_PARAMS);
|
||||
}
|
||||
}
|
||||
// ========================================================================
|
||||
|
||||
// ========================================================================
|
||||
// Store results:
|
||||
// Calculate offset to first output element
|
||||
@ -216,20 +206,25 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
const bool full_f = OUTPUT_FEATURE_NUM % FSV == 0 || fs * FSV + FSV <= OUTPUT_FEATURE_NUM;
|
||||
const bool full_x = OUTPUT_SIZE_X % OUTPUT_BLOCK_WIDTH == 0 || oc + OUTPUT_BLOCK_WIDTH <= OUTPUT_SIZE_X;
|
||||
|
||||
ACTIVATION_TYPE res[OUTPUT_BLOCK_WIDTH * FSV_PER_THREAD] = { ACTIVATION_VAL_ZERO };
|
||||
|
||||
if (full_f && full_x)
|
||||
{
|
||||
// Case without bounds checking
|
||||
unroll_for (uint out_x = 0; out_x < OUTPUT_BLOCK_WIDTH; ++out_x)
|
||||
{
|
||||
UNIT_TYPE2 tmp_write = (UNIT_TYPE2)(out[out_x * FSV_PER_THREAD + 0],
|
||||
out[out_x * FSV_PER_THREAD + 1]);
|
||||
ACTIVATION_TYPE2 tmp_write = (ACTIVATION_TYPE2)(out[out_x * FSV_PER_THREAD + 0],
|
||||
out[out_x * FSV_PER_THREAD + 1]);
|
||||
OUTPUT_TYPE2 final_result;
|
||||
#if HAS_FUSED_OPS
|
||||
unroll_for (uint out_f = 0; out_f < 2; ++out_f)
|
||||
{
|
||||
{ FUSED_OPS_VEC_ELEM; tmp_write[out_f] = FUSED_OPS_RESULT_VEC_ELEM; }
|
||||
{ FUSED_OPS_VEC_ELEM; final_result[out_f] = FUSED_OPS_RESULT_VEC_ELEM; }
|
||||
}
|
||||
#else
|
||||
final_result = TO_OUTPUT_TYPE2(ACTIVATION(tmp_write, ACTIVATION_PARAMS));
|
||||
#endif
|
||||
UNIT_BLOCK_WRITE2(output, output_offset, tmp_write);
|
||||
DT_OUTPUT_BLOCK_WRITE2(output, output_offset, final_result);
|
||||
output_offset += FSV;
|
||||
}
|
||||
}
|
||||
@ -242,10 +237,14 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
if (oc + out_x < OUTPUT_SIZE_X && fs * FSV + sglid + out_f * SUB_GROUP_SIZE < OUTPUT_FEATURE_NUM)
|
||||
{
|
||||
const uint out_idx = out_x * FSV_PER_THREAD + out_f;
|
||||
ACTIVATION_TYPE res = TO_ACTIVATION_TYPE(out[out_idx]);
|
||||
OUTPUT_TYPE final_result;
|
||||
#if HAS_FUSED_OPS
|
||||
{ FUSED_OPS_SCALAR; out[out_idx] = FUSED_OPS_RESULT_SCALAR; }
|
||||
{ FUSED_OPS_SCALAR; final_result = FUSED_OPS_RESULT_SCALAR; }
|
||||
#else
|
||||
final_result = TO_OUTPUT_TYPE(ACTIVATION(res, ACTIVATION_PARAMS));
|
||||
#endif
|
||||
output[output_offset + sglid] = out[out_idx];
|
||||
output[output_offset + sglid] = final_result;
|
||||
}
|
||||
output_offset += SUB_GROUP_SIZE;
|
||||
}
|
||||
@ -263,3 +262,11 @@ KERNEL(convolution_gpu_fs_byx_fsv32)(
|
||||
#undef OUTPUT_SIZE_X_WITH_PADDING
|
||||
#undef OUTPUT_SIZE_Y_WITH_PADDING
|
||||
#undef OUTPUT_SIZE_B_WITH_PADDING
|
||||
|
||||
#undef INPUT_TYPE2
|
||||
#undef INPUT_TYPE4
|
||||
#undef BIAS_TYPE2
|
||||
#undef FILTER_TYPE2
|
||||
#undef ACTIVATION_TYPE2
|
||||
#undef OUTPUT_TYPE2
|
||||
#undef TO_OUTPUT_TYPE2
|
||||
|
@ -4898,6 +4898,21 @@ using TestParamType_grouped_convolution_gpu = ::testing::tuple< int, // 0 -
|
||||
format, // 11 - Input data format
|
||||
std::string>; // 12 - Implementation name
|
||||
|
||||
using TestParamType_general_convolution_gpu = ::testing::tuple< int, // 0 - Input X size
|
||||
int, // 1 - Input Y size
|
||||
int, // 2 - Input Z size
|
||||
int, // 3 - Input features
|
||||
int, // 4 - Output features
|
||||
int, // 5 - Kernel sizeX
|
||||
int, // 6 - Kernel sizeY
|
||||
int, // 7 - Kernel sizeZ
|
||||
int, // 8 - Groups number
|
||||
int, // 9 - Stride
|
||||
int, // 10 - Batch
|
||||
format, // 11 - Input data format
|
||||
std::string, // 12 - Implementation name
|
||||
bool>; // 13 - With bias
|
||||
|
||||
struct convolution_gpu : public ::testing::TestWithParam<TestParamType_convolution_gpu>
|
||||
{
|
||||
static std::string
|
||||
@ -4991,6 +5006,32 @@ struct convolution_grouped_gpu : public ::testing::TestWithParam<TestParamType_g
|
||||
}
|
||||
};
|
||||
|
||||
struct convolution_general_gpu : public ::testing::TestWithParam<TestParamType_general_convolution_gpu> {
|
||||
static std::string PrintToStringParamName(
|
||||
testing::TestParamInfo<TestParamType_general_convolution_gpu> param_info) {
|
||||
// construct a readable name
|
||||
std::string res = "in" + std::to_string(testing::get<0>(param_info.param)) + "x" +
|
||||
std::to_string(testing::get<1>(param_info.param)) + "y" +
|
||||
std::to_string(testing::get<2>(param_info.param)) + "z" +
|
||||
std::to_string(testing::get<3>(param_info.param)) + "f" + "_output" +
|
||||
std::to_string(testing::get<4>(param_info.param)) + "f" + "_filter" +
|
||||
std::to_string(testing::get<5>(param_info.param)) + "x" +
|
||||
std::to_string(testing::get<6>(param_info.param)) + "y" +
|
||||
std::to_string(testing::get<7>(param_info.param)) + "z" + "_groups" +
|
||||
std::to_string(testing::get<8>(param_info.param)) + "_stride" +
|
||||
std::to_string(testing::get<9>(param_info.param)) + "_batch" +
|
||||
std::to_string(testing::get<10>(param_info.param)) + "_format" +
|
||||
std::to_string(testing::get<11>(param_info.param)) + "_wih_bias_" +
|
||||
std::to_string(testing::get<13>(param_info.param));
|
||||
|
||||
if (testing::get<12>(param_info.param) != "") {
|
||||
res += "_impl_" + testing::get<12>(param_info.param);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(convolution_gpu_test,
|
||||
convolution_gpu_fs_byx_fsv32,
|
||||
::testing::Values(
|
||||
@ -7353,6 +7394,173 @@ TEST_P(convolution_grouped_gpu, base) {
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(conv_fp16_cases,
|
||||
convolution_general_gpu,
|
||||
::testing::Values(
|
||||
// Input X size, Input Y size, Input Z size, Input features, Output features,
|
||||
// Kernel size X, Kernel size Y, Kernel size Z, Groups number, Stride, Batch,
|
||||
// Input data format, Implementation name, WithBias
|
||||
TestParamType_general_convolution_gpu(8, 8, 1, 8, 16, 3, 3, 1, 1, 1, 16, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true),
|
||||
TestParamType_general_convolution_gpu(12, 12, 1, 4, 16, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(11, 11, 1, 96, 48, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true),
|
||||
TestParamType_general_convolution_gpu(12, 12, 1, 32, 48, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(7, 7, 1, 16, 16, 3, 3, 1, 1, 1, 16, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true),
|
||||
TestParamType_general_convolution_gpu(7, 8, 1, 20, 64, 4, 4, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(5, 5, 1, 80, 64, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(7, 7, 1, 32, 64, 4, 4, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true),
|
||||
TestParamType_general_convolution_gpu(5, 5, 1, 32, 64, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true),
|
||||
TestParamType_general_convolution_gpu(12, 10, 1, 32, 64, 5, 5, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(5, 5, 1, 32, 64, 3, 3, 1, 1, 1, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", false),
|
||||
TestParamType_general_convolution_gpu(5, 5, 1, 64, 64, 3, 3, 1, 1, 2, 2, format::fs_b_yx_fsv32, "convolution_gpu_fs_byx_fsv32", true)
|
||||
),
|
||||
convolution_general_gpu::PrintToStringParamName);
|
||||
|
||||
TEST_P(convolution_general_gpu, conv_fp16_cases) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
if (!engine.get_info().supports_fp16)
|
||||
{
|
||||
std::cout << "[ SKIPPED ] The test is skipped (cl_khr_fp16 is not supported)." << std::endl;
|
||||
EXPECT_EQ(1, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
const int input_x = testing::get<0>(GetParam()),
|
||||
input_y = testing::get<1>(GetParam()),
|
||||
input_z = testing::get<2>(GetParam()),
|
||||
input_f = testing::get<3>(GetParam()),
|
||||
output_f = testing::get<4>(GetParam()),
|
||||
filter_x = testing::get<5>(GetParam()),
|
||||
filter_y = testing::get<6>(GetParam()),
|
||||
filter_z = testing::get<7>(GetParam()),
|
||||
groups = testing::get<8>(GetParam()),
|
||||
stride = testing::get<9>(GetParam()),
|
||||
batch_num = testing::get<10>(GetParam()),
|
||||
output_padding = 0,
|
||||
input_offset_z = (filter_z - 1) / 2,
|
||||
input_offset_y = (filter_y - 1) / 2,
|
||||
input_offset_x = (filter_x - 1) / 2;
|
||||
auto input_data_format = testing::get<11>(GetParam());
|
||||
auto impl_name = testing::get<12>(GetParam());
|
||||
auto with_bias = testing::get<13>(GetParam());
|
||||
|
||||
const int output_y = 1 + (input_y + 2 * (-input_offset_y) - filter_y) / stride + 2 * output_padding;
|
||||
const int output_x = 1 + (input_x + 2 * (-input_offset_x) - filter_x) / stride + 2 * output_padding;
|
||||
|
||||
auto input_size = tensor(batch_num, input_f, input_x, input_y);
|
||||
auto input_data = generate_random_4d<FLOAT16>(batch_num, input_f, input_y, input_x, -1, 1);
|
||||
auto input_data_bfyx = flatten_4d(format::bfyx, input_data);
|
||||
auto input_mem = memory::allocate(engine, { data_types::f16, format::bfyx, input_size });
|
||||
set_values(input_mem, input_data_bfyx);
|
||||
|
||||
auto weights_size = tensor(output_f, input_f, filter_y, filter_x, 1);
|
||||
auto weights_data = generate_random_4d<FLOAT16>(output_f, input_f, filter_y, filter_x, -1, 1);
|
||||
auto weights_data_bfyx = flatten_4d(format::bfyx, weights_data);
|
||||
auto weights_mem = memory::allocate(engine, {data_types::f16, format::bfyx, weights_size});
|
||||
set_values(weights_mem, weights_data_bfyx);
|
||||
|
||||
// Will be used to store reference values calculated in branches depending on bias
|
||||
auto expected_result = VVVVF<FLOAT16>(batch_num, VVVF<FLOAT16>(output_f));
|
||||
topology topology;
|
||||
|
||||
// Calculate reference values
|
||||
if (with_bias) {
|
||||
auto biases_size = tensor(1, output_f, 1, 1);
|
||||
auto biases_data = generate_random_1d<FLOAT16>(output_f, -1, 1);
|
||||
auto biases_mem = memory::allocate(engine, {data_types::f16, format::bfyx, biases_size});
|
||||
set_values(biases_mem, biases_data);
|
||||
|
||||
for (auto bi = 0; bi < batch_num; ++bi) {
|
||||
for (auto ofi = 0; ofi < output_f; ++ofi) {
|
||||
expected_result[bi][ofi] = reference_convolve(input_data[bi], // input
|
||||
weights_data[ofi], // weights
|
||||
stride, stride, // strides
|
||||
biases_data[ofi], // bias
|
||||
1, 1, // dilation
|
||||
-input_offset_y, -input_offset_x, // input padding
|
||||
output_padding, output_padding); // output_padding
|
||||
}
|
||||
}
|
||||
|
||||
topology.add(input_layout("input", input_mem.get_layout()),
|
||||
data("weights_fsv", weights_mem),
|
||||
data("bias", biases_mem),
|
||||
reorder("input_fsv", "input", {data_types::f16, input_data_format, input_size}));
|
||||
|
||||
auto conv_fsv = convolution("conv_fsv",
|
||||
"input_fsv",
|
||||
{"weights_fsv"},
|
||||
{"bias"},
|
||||
groups,
|
||||
{1, 1, stride, stride},
|
||||
{0, 0, input_offset_x, input_offset_y});
|
||||
conv_fsv.output_padding = padding({0, 0, output_padding, output_padding}, 0.f);
|
||||
|
||||
topology.add(conv_fsv);
|
||||
} else {
|
||||
for (auto bi = 0; bi < batch_num; ++bi) {
|
||||
for (auto ofi = 0; ofi < output_f; ++ofi) {
|
||||
expected_result[bi][ofi] = reference_convolve(input_data[bi], // input
|
||||
weights_data[ofi], // weights
|
||||
stride, stride, // strides
|
||||
0, // bias
|
||||
1, 1, // dilation
|
||||
-input_offset_y, -input_offset_x, // input padding
|
||||
output_padding, output_padding); // output_padding
|
||||
}
|
||||
}
|
||||
|
||||
topology.add(input_layout("input", input_mem.get_layout()),
|
||||
data("weights_fsv", weights_mem),
|
||||
reorder("input_fsv", "input", {data_types::f16, input_data_format, input_size}));
|
||||
|
||||
auto conv_fsv = convolution("conv_fsv",
|
||||
"input_fsv",
|
||||
{"weights_fsv"},
|
||||
groups,
|
||||
{1, 1, stride, stride},
|
||||
{0, 0, input_offset_x, input_offset_y});
|
||||
conv_fsv.output_padding = padding({0, 0, output_padding, output_padding}, 0.f);
|
||||
topology.add(conv_fsv);
|
||||
}
|
||||
build_options options;
|
||||
options.set_option(build_option::optimize_data(true));
|
||||
implementation_desc conv_impl = {input_data_format, impl_name};
|
||||
options.set_option(build_option::force_implementations({{"conv_fsv", conv_impl}}));
|
||||
network network(engine, topology, options);
|
||||
|
||||
network.set_input_data("input", input_mem);
|
||||
network.execute();
|
||||
|
||||
auto out_mem = network.get_output("conv_fsv").get_memory();
|
||||
auto out_ptr = out_mem.pointer<FLOAT16>();
|
||||
auto out_lay = out_mem.get_layout();
|
||||
|
||||
ASSERT_EQ(out_mem.get_layout().format, input_data_format);
|
||||
ASSERT_EQ(out_lay.size.batch[0], expected_result.size());
|
||||
ASSERT_EQ(out_lay.size.feature[0], expected_result[0].size());
|
||||
ASSERT_EQ(out_lay.size.spatial[1], expected_result[0][0].size());
|
||||
ASSERT_EQ(out_lay.size.spatial[0], expected_result[0][0][0].size());
|
||||
|
||||
for (int bi = 0; bi < out_lay.size.batch[0]; ++bi)
|
||||
for (int ofi = 0; ofi < out_lay.size.feature[0]; ++ofi)
|
||||
for (int yi = 0; yi < out_lay.size.spatial[1]; ++yi)
|
||||
for (int xi = 0; xi < out_lay.size.spatial[0]; ++xi) {
|
||||
tensor coords = tensor(batch(bi), feature(ofi), spatial(xi, yi, 0, 0));
|
||||
auto offset = out_lay.get_linear_offset(coords);
|
||||
auto val = out_ptr[offset];
|
||||
auto val_ref = expected_result[bi][ofi][yi][xi];
|
||||
auto equal = are_equal(val_ref, val, 1);
|
||||
if (!equal) {
|
||||
std::cout << "Value at batch: " << bi << ", output_f: " << ofi
|
||||
<< ", y: " << yi << ", x: " << xi << " = " << static_cast<float>(val) << std::endl;
|
||||
std::cout << "Reference value at batch: " << bi << ", output_f: " << ofi << ", y: " << yi
|
||||
<< ", x: " << xi << " = " << static_cast<float>(val_ref) << std::endl;
|
||||
}
|
||||
EXPECT_TRUE(equal);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputT, typename WeightsT, typename OutputT>
|
||||
class convolution_test_base {
|
||||
public:
|
||||
|
@ -79,6 +79,25 @@ struct bc_test_params {
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
struct bc_force_kernel_params {
|
||||
tensor in_shape;
|
||||
tensor out_shape;
|
||||
tensor kernel;
|
||||
tensor stride;
|
||||
tensor pad;
|
||||
tensor dilation;
|
||||
uint32_t groups;
|
||||
data_types data_type;
|
||||
format input_format;
|
||||
data_types weights_type;
|
||||
format weights_format;
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
std::string kernel_name;
|
||||
};
|
||||
|
||||
struct conv_eltw_test_params {
|
||||
tensor in_shape;
|
||||
tensor out_shape;
|
||||
@ -326,26 +345,27 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class WeightsPrimitiveFusingTest : public ::BaseFusingTest<bc_test_params> {
|
||||
template <typename T>
|
||||
class WeightsPrimitiveFusingTest : public ::BaseFusingTest<T> {
|
||||
public:
|
||||
|
||||
void execute(bc_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
void execute(T& p) {
|
||||
auto input_prim = this->get_mem(get_input_layout(p));
|
||||
network network_not_fused(this->engine, this->topology_non_fused, this->bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, this->bo_fused);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
|
||||
compare(network_not_fused, network_fused, p);
|
||||
this->compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
|
||||
layout get_input_layout(bc_test_params& p) {
|
||||
layout get_input_layout(T& p) {
|
||||
auto pad = p.pad.negate();
|
||||
std::vector<int> pad_ = { 0, 0, pad.spatial[0], pad.spatial[1] };
|
||||
return layout{ p.data_type, p.input_format, p.in_shape, padding{pad_} };
|
||||
}
|
||||
|
||||
layout get_per_channel_layout(bc_test_params& p) {
|
||||
layout get_per_channel_layout(T& p) {
|
||||
return layout{ p.default_type, p.default_format, tensor{1, p.out_shape.feature[0], 1, 1} };
|
||||
}
|
||||
};
|
||||
@ -465,6 +485,7 @@ public:
|
||||
#define CASE_CONV_FP16_10 {32, 16, 4, 5, 4}, {32, 32, 2, 3, 2}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::bs_fs_zyx_bsv16_fsv16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||
#define CASE_CONV_FP16_11 {1, 32, 4, 5, 4}, {1, 16, 2, 3, 2}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::os_is_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_CONV_FP16_12 {1, 16, 4, 5, 4}, {1, 16, 2, 3, 2}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::os_is_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_CONV_FP16_13 {16, 32, 4, 5}, {16, 64, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::fs_b_yx_fsv32, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_CONV_U8S8_1 {1, 15, 4, 5}, {1, 30, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_CONV_U8S8_2 {1, 15, 5, 5}, {1, 30, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
@ -535,7 +556,7 @@ public:
|
||||
/* ---------------------------------------- FP32 convolution cases ------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ----------- NOTE: A part of tests is disabled until all FP kernels don't support fusings ------------ */
|
||||
class ConvFusingTest : public WeightsPrimitiveFusingTest {
|
||||
class ConvFusingTest : public WeightsPrimitiveFusingTest<bc_test_params> {
|
||||
public:
|
||||
void execute(bc_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
@ -1026,7 +1047,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_fp32_scale_activation_quantize_i8_eltw
|
||||
bc_test_params{CASE_CONV_FP32_3, 2, 7},
|
||||
}), );
|
||||
|
||||
class conv_fp32_activation_eltwise_in_u8_fp32 : public WeightsPrimitiveFusingTest {};
|
||||
class conv_fp32_activation_eltwise_in_u8_fp32 : public WeightsPrimitiveFusingTest<bc_test_params> {};
|
||||
TEST_P(conv_fp32_activation_eltwise_in_u8_fp32, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
@ -1119,7 +1140,6 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_scale_activation_eltwise_fp32_quantize
|
||||
conv_eltw_test_params{CASE_CONV_ELTW_FP32_8, 3, 6},
|
||||
}), );
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* -------------------------------------- binary convolution cases ------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
@ -1301,7 +1321,6 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_bin_scale_conv_dw_prelu,
|
||||
bc_test_params{CASE_BIN_CONV3, 3, 5},
|
||||
}), );
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ---------------------------------------- INT8 convolution cases ------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
@ -2309,10 +2328,81 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_i8_activation_eltwise_diff_sizes,
|
||||
conv_eltw_test_params{CASE_CONV_ELTW_i8_5, 3, 4},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ----------------------------------- Force convolution kernel cases ---------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
class ConvFusingForceKernelTest : public ::WeightsPrimitiveFusingTest<bc_force_kernel_params> {
|
||||
public:
|
||||
void execute(bc_force_kernel_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
build_options options;
|
||||
options.set_option(build_option::optimize_data(true));
|
||||
implementation_desc conv_impl = {p.input_format, p.kernel_name};
|
||||
options.set_option(build_option::force_implementations({{"conv_prim", conv_impl}}));
|
||||
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, options);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
|
||||
compare(network_not_fused, network_fused, p);
|
||||
auto find_conv = [](primitive_info& p) -> bool {
|
||||
if (p.original_id == "conv_prim")
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
auto pi_fused = network_fused.get_primitives_info();
|
||||
auto info_fused = std::find_if(pi_fused.begin(), pi_fused.end(), find_conv);
|
||||
if (info_fused != pi_fused.end())
|
||||
std::cout << "kernel: " << info_fused->kernel_id << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
class conv_fp16_activation : public ConvFusingForceKernelTest {};
|
||||
TEST_P(conv_fp16_activation, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
data("bias", get_mem(get_bias_layout(p))),
|
||||
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
|
||||
activation("activation", "conv_prim", activation_func::abs),
|
||||
reorder("reorder_bfyx", "activation", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_fp16_activation, ::testing::ValuesIn(std::vector<bc_force_kernel_params>{
|
||||
bc_force_kernel_params{CASE_CONV_FP16_13, 2, 3, "convolution_gpu_fs_byx_fsv32"},
|
||||
}), );
|
||||
|
||||
|
||||
class conv_fp16_scale : public ConvFusingForceKernelTest {};
|
||||
TEST_P(conv_fp16_scale, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
data("bias", get_mem(get_bias_layout(p))),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
|
||||
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
|
||||
scale("scale", "conv_prim", "scale_data"),
|
||||
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_fp16_scale,
|
||||
::testing::ValuesIn(std::vector<bc_force_kernel_params>{
|
||||
bc_force_kernel_params{CASE_CONV_FP16_13, 2, 3, "convolution_gpu_fs_byx_fsv32"},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ---------------------------------------- FC cases --------------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
class FCFusingTest : public WeightsPrimitiveFusingTest {};
|
||||
class FCFusingTest : public WeightsPrimitiveFusingTest<bc_test_params> {};
|
||||
class fc_fp32_activation : public FCFusingTest {};
|
||||
TEST_P(fc_fp32_activation, basic) {
|
||||
auto p = GetParam();
|
||||
@ -3810,7 +3900,7 @@ using deconv_test_params = bc_test_params;
|
||||
#define CASE_DECONV_ELTW_i8_5 {1, 16, 2, 4}, {1, 16, 4, 6}, {1, 16, 4, 1}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::b_fs_yx_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
|
||||
|
||||
class DeconvolutionFusingTest : public ::WeightsPrimitiveFusingTest {
|
||||
class DeconvolutionFusingTest : public ::WeightsPrimitiveFusingTest<bc_test_params> {
|
||||
public:
|
||||
void execute(deconv_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
|
Loading…
Reference in New Issue
Block a user