[IE CLDNN] Support broadcasting in etlwise b_fs_yx_fsv16 (#4996)

This commit is contained in:
hyunback kim 2021-04-07 15:28:53 +09:00 committed by GitHub
parent 153a20732b
commit 6c290a506f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 16 deletions

View File

@ -32,6 +32,13 @@ ParamsKey EltwiseKernel_b_fs_yx_fsv16::GetSupportedKey() const {
}
static inline size_t GetBlockSize(const eltwise_params& params) {
// Set blocksize 1 when broadcasting X dim
for (size_t i = 0; i < params.inputs.size(); i++) {
if (params.inputs[i].X().v == 1 && params.inputs[i].LogicalSize() != 1) {
return 1;
}
}
size_t optimal_bs_values[] = {8, 4, 2, 1};
for (auto bs : optimal_bs_values) {
@ -43,6 +50,23 @@ static inline size_t GetBlockSize(const eltwise_params& params) {
return 1;
}
static inline bool OpHasFeatureBroadcast(const eltwise_params& params, const size_t op_num) {
const auto &ew = params.operations[op_num];
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
const auto &input = ew.inputs[input_idx];
if (input.mode == EltwiseInputMode::INPUT_BUFFER) {
if (params.inputs[input_idx].LogicalSize() != 1
&& params.inputs[input_idx].Feature().v == 1
&& params.output.Feature().v != 1) {
return true;
}
}
}
return false;
}
JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_params& params, bool /*useVload8*/) const {
JitConstants jit = {};
std::string vload_decls;
@ -52,13 +76,13 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_par
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
const auto &input = ew.inputs[input_idx];
const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
switch (input.mode) {
case EltwiseInputMode::SCALAR:
jit.AddConstant(MakeJitConstant(name, input.scalar));
break;
case EltwiseInputMode::INPUT_BUFFER:
{
if (params.inputs[input.index].LogicalSize() == params.output.Feature().v &&
params.inputs[input.index].LogicalSize() == params.inputs[input.index].Feature().v) {
jit.AddConstant(MakeJitConstant(name,
@ -69,14 +93,37 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_par
"input" + std::to_string(input.index) +
"[0]"));
} else {
std::string block_read_str = "BLOCK_READN(INPUT" + std::to_string(input.index) + "_TYPE, " +
"BLOCK_SIZE, " +
"input" + std::to_string(input.index) + ", " +
"INPUT" + std::to_string(input.index) + "_GET_INDEX(b, f_block*16, y, x))";
jit.AddConstant(MakeJitConstant(name,
"TO_TYPE(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE), " + block_read_str + ")"));
const std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
jit.AddConstant(MakeJitConstant(idx_order, "b, f_block*16, y, x"));
bool feature_broadcasting = (params.inputs[input_idx].Feature().v == 1 && params.output.Feature().v != 1);
const std::string block_read_str = "TO_TYPE(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE), BLOCK_READN(INPUT" +
std::to_string(input.index) + "_TYPE, BLOCK_SIZE, " +
"input" + std::to_string(input.index) + ", " +
"GET_INDEX(INPUT, " + std::to_string(input.index) + ", " + idx_order + ")))";
if (feature_broadcasting) {
const std::string broadcast_name = "DO_FEATURE_BROADCAST" + std::to_string(op_num);
std::string sub_group_broadcast;
if (GetBlockSize(params) == 1) {
sub_group_broadcast = "\\\n\ttmp_b" + std::to_string(op_num) +
" = sub_group_broadcast(tmp_b" + std::to_string(op_num) + ", 0);";
} else {
sub_group_broadcast = "\\\n\tunroll_for (uint i = 0; i < BLOCK_SIZE; ++i) tmp_b" + std::to_string(op_num) +
"[i] = sub_group_broadcast(tmp_b" + std::to_string(op_num) + "[i], 0);";
}
std::string broadcast_value = "\\\n\tMAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE) tmp_b" + std::to_string(op_num) +
" = " + block_read_str + ";" + sub_group_broadcast;
jit.AddConstant(MakeJitConstant(broadcast_name, broadcast_value));
jit.AddConstant(MakeJitConstant(name, "tmp_b" + std::to_string(op_num)));
} else {
jit.AddConstant(MakeJitConstant(name, block_read_str));
}
}
break;
}
case EltwiseInputMode::OUTPUT_BUFFER:
jit.AddConstant(MakeJitConstant(name, "output[off]"));
break;
@ -107,13 +154,15 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::GetJitConstants(const eltwise_params&
jit.AddConstant(MakeJitConstant("BLOCKS_COUNT", CeilDiv(params.output.X().v, blockSize)));
jit.Merge(MakeInputDeclsJitConstants(params, useVload8));
jit.Merge(MakeIndexJitConstants(params, useVload8));
jit.Merge(MakeLoadJitConstants(params, useVload8));
jit.Merge(GetOperationsJitConstants(params, useVload8, blockSize));
std::string do_eltwise;
auto& operations = params.operations;
for (size_t op_num = 0; op_num < operations.size(); op_num++) {
if (OpHasFeatureBroadcast(params, op_num)) {
do_eltwise += "\\\n\tDO_FEATURE_BROADCAST" + std::to_string(op_num) + ";";
}
do_eltwise += "\\\n\tOPERATION" + std::to_string(op_num) + ";";
}
@ -144,6 +193,8 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::GetJitConstants(const eltwise_params&
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
}
jit.AddConstant(MakeJitConstant("ELTWISE_BROADCAST", params.broadcast));
return jit;
}
@ -155,17 +206,11 @@ bool EltwiseKernel_b_fs_yx_fsv16::Validate(const Params& params, const optional_
const auto& ewParams = static_cast<const eltwise_params&>(params);
const auto& output = ewParams.output;
const auto count = output.PhysicalSize();
if (count % 8 != 0)
return false;
for (size_t i = 0; i < ewParams.inputs.size(); i++) {
// Allow the same input sizes OR per-channel operation
if ((ewParams.inputs[i].LogicalSize() != output.LogicalSize()) &&
(ewParams.inputs[i].LogicalSize() != output.Feature().v || ewParams.inputs[i].Feature().v != output.Feature().v) &&
(ewParams.inputs[i].LogicalSize() != 1))
if (ewParams.inputs[i].GetLayout() != DataLayout::b_fs_yx_fsv16 && GetBlockSize(ewParams) != 1) {
return false;
}
}
auto input0 = ewParams.inputs[0];

View File

@ -7,6 +7,7 @@
#include "include/data_types.cl"
#define FEATURE_SLICE_SIZE 16
#define unroll_for __attribute__((opencl_unroll_hint())) for
#define OUTPUT_TYPE_BLOCK MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
#define TO_TYPE(type, val) CAT(convert_, type)(val)
@ -19,6 +20,12 @@
#define WRITE_FUNC(ptr, offset, val) DT_OUTPUT_BLOCK_WRITE(ptr, offset, val)
#endif
#if ELTWISE_BROADCAST
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX_SAFE)(idx_order)
#else
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
#endif
__attribute__((intel_reqd_sub_group_size(FEATURE_SLICE_SIZE)))
KERNEL(eltwise_b_fs_yx_fsv16)(INPUTS_DECLS
__global OUTPUT_TYPE* output
@ -84,3 +91,11 @@ KERNEL(eltwise_b_fs_yx_fsv16)(INPUTS_DECLS
}
}
#undef FEATURE_SLICE_SIZE
#undef unroll_for
#undef OUTPUT_TYPE_BLOCK
#undef TO_TYPE
#undef READ_FUNC
#undef WRITE_FUNC
#undef GET_INDEX

View File

@ -3342,6 +3342,8 @@ static std::vector<std::vector<std::vector<int32_t>>> inputs = {
{{1, 16, 1, 1}, {1, 16, 8, 2}},
{{1, 32, 1, 1}, {1, 32, 2, 2}},
{{1, 32, 1, 1}, {8, 32, 4, 5}},
{{1, 2, 1, 1}, {1, 1, 3, 1}},
{{1, 2, 1, 1}, {4, 1, 3, 5}},
{{1, 16, 8, 2, 4}, {1, 16, 8, 2, 4}},
{{8, 32, 4, 5, 6}, {1, 32, 1, 1, 1}},