[IE CLDNN] Support broadcasting in etlwise b_fs_yx_fsv16 (#4996)
This commit is contained in:
parent
153a20732b
commit
6c290a506f
@ -32,6 +32,13 @@ ParamsKey EltwiseKernel_b_fs_yx_fsv16::GetSupportedKey() const {
|
||||
}
|
||||
|
||||
static inline size_t GetBlockSize(const eltwise_params& params) {
|
||||
// Set blocksize 1 when broadcasting X dim
|
||||
for (size_t i = 0; i < params.inputs.size(); i++) {
|
||||
if (params.inputs[i].X().v == 1 && params.inputs[i].LogicalSize() != 1) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
size_t optimal_bs_values[] = {8, 4, 2, 1};
|
||||
|
||||
for (auto bs : optimal_bs_values) {
|
||||
@ -43,6 +50,23 @@ static inline size_t GetBlockSize(const eltwise_params& params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
static inline bool OpHasFeatureBroadcast(const eltwise_params& params, const size_t op_num) {
|
||||
const auto &ew = params.operations[op_num];
|
||||
|
||||
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
|
||||
const auto &input = ew.inputs[input_idx];
|
||||
if (input.mode == EltwiseInputMode::INPUT_BUFFER) {
|
||||
if (params.inputs[input_idx].LogicalSize() != 1
|
||||
&& params.inputs[input_idx].Feature().v == 1
|
||||
&& params.output.Feature().v != 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_params& params, bool /*useVload8*/) const {
|
||||
JitConstants jit = {};
|
||||
std::string vload_decls;
|
||||
@ -52,13 +76,13 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_par
|
||||
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
|
||||
const auto &input = ew.inputs[input_idx];
|
||||
const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
|
||||
std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
|
||||
|
||||
switch (input.mode) {
|
||||
case EltwiseInputMode::SCALAR:
|
||||
jit.AddConstant(MakeJitConstant(name, input.scalar));
|
||||
break;
|
||||
case EltwiseInputMode::INPUT_BUFFER:
|
||||
{
|
||||
if (params.inputs[input.index].LogicalSize() == params.output.Feature().v &&
|
||||
params.inputs[input.index].LogicalSize() == params.inputs[input.index].Feature().v) {
|
||||
jit.AddConstant(MakeJitConstant(name,
|
||||
@ -69,14 +93,37 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::MakeLoadJitConstants(const eltwise_par
|
||||
"input" + std::to_string(input.index) +
|
||||
"[0]"));
|
||||
} else {
|
||||
std::string block_read_str = "BLOCK_READN(INPUT" + std::to_string(input.index) + "_TYPE, " +
|
||||
"BLOCK_SIZE, " +
|
||||
"input" + std::to_string(input.index) + ", " +
|
||||
"INPUT" + std::to_string(input.index) + "_GET_INDEX(b, f_block*16, y, x))";
|
||||
jit.AddConstant(MakeJitConstant(name,
|
||||
"TO_TYPE(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE), " + block_read_str + ")"));
|
||||
const std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
|
||||
jit.AddConstant(MakeJitConstant(idx_order, "b, f_block*16, y, x"));
|
||||
|
||||
bool feature_broadcasting = (params.inputs[input_idx].Feature().v == 1 && params.output.Feature().v != 1);
|
||||
|
||||
const std::string block_read_str = "TO_TYPE(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE), BLOCK_READN(INPUT" +
|
||||
std::to_string(input.index) + "_TYPE, BLOCK_SIZE, " +
|
||||
"input" + std::to_string(input.index) + ", " +
|
||||
"GET_INDEX(INPUT, " + std::to_string(input.index) + ", " + idx_order + ")))";
|
||||
if (feature_broadcasting) {
|
||||
const std::string broadcast_name = "DO_FEATURE_BROADCAST" + std::to_string(op_num);
|
||||
std::string sub_group_broadcast;
|
||||
if (GetBlockSize(params) == 1) {
|
||||
sub_group_broadcast = "\\\n\ttmp_b" + std::to_string(op_num) +
|
||||
" = sub_group_broadcast(tmp_b" + std::to_string(op_num) + ", 0);";
|
||||
} else {
|
||||
sub_group_broadcast = "\\\n\tunroll_for (uint i = 0; i < BLOCK_SIZE; ++i) tmp_b" + std::to_string(op_num) +
|
||||
"[i] = sub_group_broadcast(tmp_b" + std::to_string(op_num) + "[i], 0);";
|
||||
}
|
||||
|
||||
std::string broadcast_value = "\\\n\tMAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, BLOCK_SIZE) tmp_b" + std::to_string(op_num) +
|
||||
" = " + block_read_str + ";" + sub_group_broadcast;
|
||||
|
||||
jit.AddConstant(MakeJitConstant(broadcast_name, broadcast_value));
|
||||
jit.AddConstant(MakeJitConstant(name, "tmp_b" + std::to_string(op_num)));
|
||||
} else {
|
||||
jit.AddConstant(MakeJitConstant(name, block_read_str));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case EltwiseInputMode::OUTPUT_BUFFER:
|
||||
jit.AddConstant(MakeJitConstant(name, "output[off]"));
|
||||
break;
|
||||
@ -107,13 +154,15 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::GetJitConstants(const eltwise_params&
|
||||
jit.AddConstant(MakeJitConstant("BLOCKS_COUNT", CeilDiv(params.output.X().v, blockSize)));
|
||||
|
||||
jit.Merge(MakeInputDeclsJitConstants(params, useVload8));
|
||||
jit.Merge(MakeIndexJitConstants(params, useVload8));
|
||||
jit.Merge(MakeLoadJitConstants(params, useVload8));
|
||||
jit.Merge(GetOperationsJitConstants(params, useVload8, blockSize));
|
||||
|
||||
std::string do_eltwise;
|
||||
auto& operations = params.operations;
|
||||
for (size_t op_num = 0; op_num < operations.size(); op_num++) {
|
||||
if (OpHasFeatureBroadcast(params, op_num)) {
|
||||
do_eltwise += "\\\n\tDO_FEATURE_BROADCAST" + std::to_string(op_num) + ";";
|
||||
}
|
||||
do_eltwise += "\\\n\tOPERATION" + std::to_string(op_num) + ";";
|
||||
}
|
||||
|
||||
@ -144,6 +193,8 @@ JitConstants EltwiseKernel_b_fs_yx_fsv16::GetJitConstants(const eltwise_params&
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
}
|
||||
|
||||
jit.AddConstant(MakeJitConstant("ELTWISE_BROADCAST", params.broadcast));
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
@ -155,17 +206,11 @@ bool EltwiseKernel_b_fs_yx_fsv16::Validate(const Params& params, const optional_
|
||||
const auto& ewParams = static_cast<const eltwise_params&>(params);
|
||||
|
||||
const auto& output = ewParams.output;
|
||||
const auto count = output.PhysicalSize();
|
||||
|
||||
if (count % 8 != 0)
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < ewParams.inputs.size(); i++) {
|
||||
// Allow the same input sizes OR per-channel operation
|
||||
if ((ewParams.inputs[i].LogicalSize() != output.LogicalSize()) &&
|
||||
(ewParams.inputs[i].LogicalSize() != output.Feature().v || ewParams.inputs[i].Feature().v != output.Feature().v) &&
|
||||
(ewParams.inputs[i].LogicalSize() != 1))
|
||||
if (ewParams.inputs[i].GetLayout() != DataLayout::b_fs_yx_fsv16 && GetBlockSize(ewParams) != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto input0 = ewParams.inputs[0];
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "include/data_types.cl"
|
||||
|
||||
#define FEATURE_SLICE_SIZE 16
|
||||
#define unroll_for __attribute__((opencl_unroll_hint())) for
|
||||
|
||||
#define OUTPUT_TYPE_BLOCK MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
|
||||
#define TO_TYPE(type, val) CAT(convert_, type)(val)
|
||||
@ -19,6 +20,12 @@
|
||||
#define WRITE_FUNC(ptr, offset, val) DT_OUTPUT_BLOCK_WRITE(ptr, offset, val)
|
||||
#endif
|
||||
|
||||
#if ELTWISE_BROADCAST
|
||||
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX_SAFE)(idx_order)
|
||||
#else
|
||||
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
|
||||
#endif
|
||||
|
||||
__attribute__((intel_reqd_sub_group_size(FEATURE_SLICE_SIZE)))
|
||||
KERNEL(eltwise_b_fs_yx_fsv16)(INPUTS_DECLS
|
||||
__global OUTPUT_TYPE* output
|
||||
@ -84,3 +91,11 @@ KERNEL(eltwise_b_fs_yx_fsv16)(INPUTS_DECLS
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#undef FEATURE_SLICE_SIZE
|
||||
#undef unroll_for
|
||||
#undef OUTPUT_TYPE_BLOCK
|
||||
#undef TO_TYPE
|
||||
#undef READ_FUNC
|
||||
#undef WRITE_FUNC
|
||||
#undef GET_INDEX
|
||||
|
@ -3342,6 +3342,8 @@ static std::vector<std::vector<std::vector<int32_t>>> inputs = {
|
||||
{{1, 16, 1, 1}, {1, 16, 8, 2}},
|
||||
{{1, 32, 1, 1}, {1, 32, 2, 2}},
|
||||
{{1, 32, 1, 1}, {8, 32, 4, 5}},
|
||||
{{1, 2, 1, 1}, {1, 1, 3, 1}},
|
||||
{{1, 2, 1, 1}, {4, 1, 3, 5}},
|
||||
|
||||
{{1, 16, 8, 2, 4}, {1, 16, 8, 2, 4}},
|
||||
{{8, 32, 4, 5, 6}, {1, 32, 1, 1, 1}},
|
||||
|
Loading…
Reference in New Issue
Block a user