[GPU] Optimize eltwise kernel for onednn formats (#15087)

+ Bugfix of eltwise_b_fs_yx_fsv16 kernel for int satuation
+ Add optimizing for fsv32, fsv16 using vload
+ Add optimizing for double blocked format eltwise
+ Support mixed format and broadcasting
+ Add test-cases to eltwise_gpu_test

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2023-02-02 18:03:07 +09:00 committed by GitHub
parent 347cd0e180
commit 7bdc9ec36b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 580 additions and 2 deletions

View File

@ -10,6 +10,7 @@
#define OUTPUT_TYPE_BLOCK MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
#define TO_TYPE(type, val) CAT(convert_, type)(val)
#define TO_TYPE_SAT(type, val) CAT(CAT(convert_, type), _sat)(val)
#if BLOCK_SIZE != 1
#define READ_FUNC(ptr, offset) CAT(DT_INPUT_BLOCK_READ, BLOCK_SIZE)(ptr, offset)
@ -80,7 +81,12 @@ KERNEL(eltwise_b_fs_yx_fsv16)(INPUTS_DECLS
OUTPUT_TYPE_BLOCK out = TO_TYPE(MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE), FUSED_OPS_RESULT);
#else
#if BLOCK_SIZE != 1
#if OUTPUT_IS_FP
OUTPUT_TYPE_BLOCK out = ACTIVATION_TYPED(TO_TYPE(MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE), res), ACTIVATION_PARAMS_TYPED);
#else
OUTPUT_TYPE_BLOCK out = ACTIVATION_TYPED(TO_TYPE_SAT(MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE), res), ACTIVATION_PARAMS_TYPED);
#endif
#else
OUTPUT_TYPE out = ACTIVATION_TYPED(TO_OUTPUT_TYPE(res), ACTIVATION_PARAMS_TYPED);
#endif

View File

@ -0,0 +1,72 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "include/batch_headers/fetch_data.cl"
#define OUTPUT_TYPE_BLOCK MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE)
#define TO_TYPE(type, val) CAT(convert_, type)(val)
#define TO_TYPE_SAT(type, val) CAT(CAT(convert_, type), _sat)(val)
#if ELTWISE_BROADCAST
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX_SAFE)(idx_order)
#else
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
#endif
KERNEL(eltwise_blocked_opt)(INPUTS_DECLS
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint zyx = (uint)get_global_id(1);
#if OUTPUT_DIMS == 5
const uint z = zyx / (uint)XY_BLOCK;
const uint yx = zyx % XY_BLOCK;
const uint y = yx / OUTPUT_SIZE_X;
const uint x = yx % OUTPUT_SIZE_X;
#else
const uint z = 0;
const uint y = zyx / OUTPUT_SIZE_X;
const uint x = zyx % OUTPUT_SIZE_X;
#endif
const uint f_block = get_global_id(0);
const uint b = get_global_id(2);
MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, VEC_SIZE) res;
DO_ELTWISE
#if HAS_FUSED_OPS
FUSED_OPS;
OUTPUT_TYPE_BLOCK out = TO_TYPE(MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE), FUSED_OPS_RESULT);
#else
#if QUANTIZATION_TERM && !OUTPUT_IS_FP
OUTPUT_TYPE_BLOCK out = ACTIVATION_TYPED(TO_TYPE_SAT(MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE), res), ACTIVATION_PARAMS_TYPED);
#else
OUTPUT_TYPE_BLOCK out = ACTIVATION_TYPED(TO_TYPE(MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE), res), ACTIVATION_PARAMS_TYPED);
#endif
#endif
#ifdef LEFTOVERS
// Overwrite
if ((f_block*VEC_SIZE + VEC_SIZE) > OUTPUT_FEATURE_NUM) {
for (uint fp = OUTPUT_FEATURE_NUM % VEC_SIZE; fp < VEC_SIZE; fp++) {
out[fp] = OUTPUT_VAL_ZERO;
}
}
#endif
#if OUTPUT_DIMS == 5
VSTORE_N(out, 0, &output[OUTPUT_GET_INDEX(b, (f_block*VEC_SIZE), z, y, x)]);
#else
VSTORE_N(out, 0, &output[OUTPUT_GET_INDEX(b, (f_block*VEC_SIZE), y, x)]);
#endif
}
#undef OUTPUT_TYPE_BLOCK
#undef TO_TYPE
#undef TO_TYPE_SAT

View File

@ -270,6 +270,7 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
auto blocked_bsv_fsv_layout = output_layout == DataLayout::bs_fs_yx_bsv16_fsv2 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv2 ||
output_layout == DataLayout::bs_fs_yx_bsv16_fsv4 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv4 ||
output_layout == DataLayout::bs_fs_yx_bsv16_fsv16 || output_layout == DataLayout::bs_fs_yx_bsv16_fsv32 ||
output_layout == DataLayout::bs_fs_yx_bsv32_fsv16 || output_layout == DataLayout::bs_fs_yx_bsv32_fsv32 ||
output_layout == DataLayout::bs_fs_zyx_bsv16_fsv16 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv32 ||
output_layout == DataLayout::bs_fs_zyx_bsv32_fsv16 || output_layout == DataLayout::bs_fs_zyx_bsv32_fsv32;
@ -388,6 +389,9 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
} else if ((output_layout == DataLayout::bs_fs_yx_bsv16_fsv16 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv16) &&
(axis_by_gws[b] != axis_by_gws[f]) && (axis_by_gws[b] != unused_axis)) {
max_optimal_lws0_value = 16;
} else if ((output_layout == DataLayout::bs_fs_yx_bsv32_fsv32 || output_layout == DataLayout::bs_fs_zyx_bsv32_fsv32) &&
(axis_by_gws[b] != axis_by_gws[f]) && (axis_by_gws[b] != unused_axis)) {
max_optimal_lws0_value = 32;
}
}

View File

@ -0,0 +1,388 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "eltwise_kernel_blocked_opt.h"
#include "kernel_selector_utils.h"
#include <algorithm>
#include <string>
#include <vector>
namespace kernel_selector {
static inline bool InputHasFeatureBroadcast(const eltwise_params& params, const size_t op_num, const size_t input_idx);
static inline bool IsBroadcastingPossibleInput(const DataTensor& input, const DataTensor& output);
static inline int GetFeatureBlockSizeFromFormat(const eltwise_params& params, size_t index);
ParamsKey EltwiseKernel_blocked_opt::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv4);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv4);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv32);
k.EnableInputLayout(DataLayout::b_fs_zyx_fsv32);
k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv16_fsv32);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv16_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv32_fsv32);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv32_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv32);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv32);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv32_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv32_fsv16);
k.EnableInputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
k.EnableOutputLayout(DataLayout::bs_fs_zyx_bsv16_fsv16);
k.EnableDifferentTypes();
k.EnableBatching();
k.EnableTensorPitches();
k.EnableTensorOffset();
k.EnableEltwiseBroadcast();
return k;
}
KernelsData EltwiseKernel_blocked_opt::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
KernelData kd = KernelData::Default<eltwise_params>(params);
eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
auto cldnn_jit = GetJitConstants(newParams);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
DispatchData dispatchData = SetDefault(newParams);
auto& kernel = kd.kernels[0];
kernel.params.workGroups.global = dispatchData.gws;
kernel.params.workGroups.local = dispatchData.lws;
kernel.code.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, EXE_MODE_DEFAULT);
kernel.params.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(),
false,
false,
GetFusedPrimitiveInputsCount(params));
return {kd};
}
KernelsPriority EltwiseKernel_blocked_opt::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
return FORCE_PRIORITY_1;
}
// Protected
bool EltwiseKernel_blocked_opt::Validate(const Params& params, const optional_params& o) const {
if (!EltwiseKernelBase::Validate(params, o)) {
return false;
}
const auto& ewParams = static_cast<const eltwise_params&>(params);
if (IsUnsupportedModeForVecCode(ewParams))
return false;
for (size_t i = 0; i < ewParams.inputs.size(); i++) {
if ((GetFeatureBlockSizeFromFormat(ewParams, i) == 1) &&
!IsBroadcastingPossibleInput(ewParams.inputs[i], ewParams.outputs[0])) {
return false;
}
}
const auto vec_size = GetFeatureBlockSizeFromFormat(ewParams, 0);
const auto input0 = ewParams.inputs[0];
const auto& output = ewParams.outputs[0];
// Check that padding before features doesn't mis-align the blocks
if (input0.Feature().pad.before % vec_size != 0 || output.Feature().pad.before % vec_size != 0)
return false;
auto compareTensors = [](const DataTensor& input0, const DataTensor& input1) -> bool {
// Check all parameters except DataType
auto& input0_dims = input0.GetDims();
auto& input1_dims = input1.GetDims();
bool same = input0.GetLayout() == input1.GetLayout() &&
input0.GetPaddedVal() == input1.GetPaddedVal() &&
input0.GetViewOffset() == input1.GetViewOffset() &&
input0_dims.size() == input1_dims.size();
for (size_t i = 0; i < input0_dims.size(); i++) {
same &= input0_dims[i].v == input1_dims[i].v &&
input0_dims[i].pad.before == input1_dims[i].pad.before &&
input0_dims[i].pad.after == input1_dims[i].pad.after &&
input0_dims[i].pitch == input1_dims[i].pitch;
}
return same;
};
for (size_t i = 1; i < ewParams.inputs.size(); i++) {
if (ewParams.inputs[i].LogicalSize() == input0.LogicalSize() && !(compareTensors(ewParams.inputs[i], input0)))
return false;
if (ewParams.inputs[i].Feature().pad.before % vec_size != 0) {
return false;
}
}
return true;
}
JitConstants EltwiseKernel_blocked_opt::MakeLoadJitConstants(const eltwise_params& params, bool /*use_vload*/) const {
const auto vec_size = GetFeatureBlockSizeFromFormat(params, 0);
JitConstants jit = {};
std::string vload_decls;
// Make load jit constants
for (size_t op_num = 0; op_num < params.operations.size(); op_num++) {
const std::string op_num_str = toCodeString(op_num);
const auto &ew = params.operations[op_num];
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
const auto &input = ew.inputs[input_idx];
const std::string name = "INPUT_" + op_num_str + "_" + toCodeString(input_idx);
// Get a string for a default index based on dimension
std::string default_indexing_str;
if (DataTensor::ChannelsCount(params.inputs[input_idx].GetLayout()) == 4)
default_indexing_str = "b, (f_block * " + toCodeString(vec_size) +"), y, x";
else if (DataTensor::ChannelsCount(params.inputs[input_idx].GetLayout()) == 5)
default_indexing_str = "b, (f_block * " + toCodeString(vec_size) +"), z, y, x";
else
IE_ASSERT("MakeLoadJit : Unexpected dimension for eltwise optimized kernel.");
switch (input.mode) {
case EltwiseInputMode::SCALAR:
jit.AddConstant(MakeJitConstant(name, input.scalar));
break;
case EltwiseInputMode::INPUT_BUFFER:
{
const std::string idx_order = "INPUT" + toCodeString(input.index) + "_IDX_ORDER";
jit.AddConstant(MakeJitConstant(idx_order, default_indexing_str));
if (params.inputs[input.index].LogicalSize() == 1) {
const std::string vload_name = "DO_VLOAD" + toCodeString(op_num) + "_" + toCodeString(input_idx);
const std::string vload_value = "\\\n\tMAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + ") " +
"tmp_a" + toCodeString(op_num) + "_" + toCodeString(input_idx) + " = " +
"(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + "))" +
"(input" + toCodeString(input.index) + "[0])";
jit.AddConstant(MakeJitConstant(vload_name, vload_value));
jit.AddConstant(MakeJitConstant(name, "tmp_a" + toCodeString(op_num) + "_" + toCodeString(input_idx)));
} else {
bool feature_broadcasting = (params.inputs[input_idx].Feature().v == 1 && params.outputs[0].Feature().v != 1);
if (feature_broadcasting) {
const std::string broadcast_name = "DO_FEATURE_BROADCAST" + toCodeString(op_num) + "_" + toCodeString(input_idx);
std::string broadcast_value = "\\\n\tMAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + ") tmp_b" +
toCodeString(op_num) + " = (MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + "))" +
"(input" + toCodeString(input.index) + "[GET_INDEX(INPUT, " + toCodeString(input.index) +
", " + idx_order + ")]);";
jit.AddConstant(MakeJitConstant(broadcast_name, broadcast_value));
jit.AddConstant(MakeJitConstant(name, "tmp_b" + toCodeString(op_num)));
} else {
const std::string vload_name = "DO_VLOAD" + toCodeString(op_num) + "_" + toCodeString(input_idx);
const std::string vload_value = "\\\n\tMAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + ")" +
" tmp_a" + toCodeString(op_num) + "_" + toCodeString(input_idx) +
" = TO_TYPE(MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, " + toCodeString(vec_size) + "), vload" +
toCodeString(vec_size) + "(0, &input" + toCodeString(input.index) +
"[GET_INDEX(INPUT," + toCodeString(input.index) + ", " + idx_order + ")]));";
jit.AddConstant(MakeJitConstant(vload_name, vload_value));
jit.AddConstant(MakeJitConstant(name, "tmp_a" + toCodeString(op_num) + "_" + toCodeString(input_idx)));
}
}
break;
}
case EltwiseInputMode::OUTPUT_BUFFER:
jit.AddConstant(MakeJitConstant(name, "output[off]"));
break;
case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
jit.AddConstant(MakeJitConstant(
name,
"input" + toCodeString(input.index) + "[(size_t)tmp" + toCodeString(input.tmpIndex) + "]"));
break;
case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
jit.AddConstant(MakeJitConstant(name, "tmp" + toCodeString(input.tmpIndex)));
break;
default:
break;
}
}
}
return jit;
}
JitConstants EltwiseKernel_blocked_opt::GetJitConstants(const eltwise_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);
const auto vec_size = GetFeatureBlockSizeFromFormat(params, 0);
jit.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR"));
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", vec_size));
jit.AddConstant(MakeJitConstant("XY_BLOCK", params.outputs[0].X().v * params.outputs[0].Y().v));
bool use_vload = false;
jit.Merge(MakeInputDeclsJitConstants(params, use_vload));
jit.Merge(MakeLoadJitConstants(params, use_vload));
jit.Merge(GetOperationsJitConstants(params, use_vload, vec_size));
std::string do_eltwise;
auto& operations = params.operations;
for (size_t op_num = 0; op_num < operations.size(); op_num++) {
const auto &ew = operations[op_num];
for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
const auto &input = ew.inputs[input_idx];
if (input.mode != EltwiseInputMode::INPUT_BUFFER && input.mode != EltwiseInputMode::SCALAR)
continue;
if (InputHasFeatureBroadcast(params, op_num, input_idx)) {
do_eltwise += "\\\n\tDO_FEATURE_BROADCAST" + toCodeString(op_num) + "_" + toCodeString(input_idx) + ";";
} else {
do_eltwise += "\\\n\tDO_VLOAD" + toCodeString(op_num) + "_" + toCodeString(input_idx) + ";";
}
}
do_eltwise += "\\\n\tOPERATION" + toCodeString(op_num) + ";";
}
do_eltwise += "\\\n\tres = tmp" + toCodeString(operations.size() - 1) + ";";
jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
if (params.layoutBased || params.int8_quantization || params.broadcast) {
jit.Merge(GetTensorFriendlyWorkGroupsJit(params.outputs[0]));
}
if (!params.stride.empty()) {
jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
}
jit.Merge(MakeActivationJitConstants(params.activations, params.outputs[0].GetDType(), "_TYPED"));
if (params.outputs[0].Feature().v % vec_size != 0)
jit.AddConstant(MakeJitConstant("LEFTOVERS", params.outputs[0].Feature().v % vec_size));
// Fused_ops
if (!params.fused_ops.empty()) {
kernel_selector::Datatype input_dt = GetAccumulatorType(params);
std::vector<std::string> idx_order;
if (DataTensor::ChannelsCount(params.outputs[0].GetLayout()) == 4) {
idx_order = {"b", "f_block * " + toCodeString(vec_size), "y", "x"};
} else if (DataTensor::ChannelsCount(params.outputs[0].GetLayout()) == 5) {
idx_order = {"b", "f_block * " + toCodeString(vec_size), "z", "y", "x"};
}
FusedOpsConfiguration conf = {"", idx_order, "res", input_dt, (size_t)vec_size};
conf.vec_axis = Tensor::DataChannelName::FEATURE;
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
}
jit.AddConstant(MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization));
jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size));
jit.AddConstant(MakeJitConstant("VSTORE_N", "vstore" + toCodeString(vec_size)));
if (params.broadcast) {
bool need_idx_safe = true;
for (size_t i = 0; i < params.inputs.size(); i++) {
if (params.inputs[i].LogicalSize() == 1) {
need_idx_safe = false;
break;
}
}
if (need_idx_safe)
jit.AddConstant(MakeJitConstant("ELTWISE_BROADCAST", params.broadcast));
}
return jit;
}
EltwiseKernelBase::DispatchData EltwiseKernel_blocked_opt::SetDefault(const eltwise_params& params) const {
DispatchData dispatchData;
auto in_layout = params.inputs[0].GetLayout();
auto out_layout = params.outputs[0].GetLayout();
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws = {{Tensor::DataChannelName::FEATURE},
{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
{Tensor::DataChannelName::BATCH}};
// Global workgroup size 0: feature, 1: spatial, 2: batch
dispatchData.gws[0] = CeilDiv(params.outputs[0].Feature().v, GetFeatureBlockSizeFromFormat(params, 0));
dispatchData.gws[2] = params.outputs[0].Batch().v;
if (DataTensor::ChannelsCount(params.outputs[0].GetLayout()) == 5)
dispatchData.gws[1] = params.outputs[0].X().v * params.outputs[0].Y().v * params.outputs[0].Z().v;
else if (DataTensor::ChannelsCount(params.outputs[0].GetLayout()) == 4)
dispatchData.gws[1] = params.outputs[0].X().v * params.outputs[0].Y().v;
else
IE_ASSERT("Unexpected dimension for eltwise_blocked_opt kernel.");
// Calculate local workgroup size
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, in_layout, out_layout, dims_by_gws);
if (out_layout == DataLayout::b_fs_yx_fsv4) {
dispatchData.lws[0] = 1;
dispatchData.lws[2] = 1;
}
return dispatchData;
}
// Local
static inline int GetFeatureBlockSizeFromFormat(const eltwise_params& arg, size_t index) {
auto in_layout = arg.inputs[index].GetLayout();
switch (in_layout) {
case DataLayout::b_fs_yx_fsv4:
return 4;
case DataLayout::b_fs_yx_fsv16:
case DataLayout::b_fs_yx_fsv32:
case DataLayout::b_fs_zyx_fsv16:
case DataLayout::b_fs_zyx_fsv32:
case DataLayout::bs_fs_yx_bsv32_fsv32:
case DataLayout::bs_fs_yx_bsv32_fsv16:
case DataLayout::bs_fs_yx_bsv16_fsv32:
case DataLayout::bs_fs_yx_bsv16_fsv16:
case DataLayout::bs_fs_zyx_bsv32_fsv32:
case DataLayout::bs_fs_zyx_bsv32_fsv16:
case DataLayout::bs_fs_zyx_bsv16_fsv32:
case DataLayout::bs_fs_zyx_bsv16_fsv16:
return 8;
default:
return 1;
}
}
static inline bool IsBroadcastingPossibleInput(const DataTensor& input, const DataTensor& output) {
if ((input.LogicalSize() == 1) ||
(input.LogicalSize() == output.Feature().v && input.Feature().v == output.Feature().v)) {
return true;
}
return false;
}
static inline bool InputHasFeatureBroadcast(const eltwise_params& params, const size_t op_num, const size_t input_idx) {
const auto &ew = params.operations[op_num];
const auto &input = ew.inputs[input_idx];
if (input.mode == EltwiseInputMode::INPUT_BUFFER) {
if (params.inputs[input_idx].LogicalSize() != 1
&& params.inputs[input_idx].Feature().v == 1
&& params.outputs[0].Feature().v != 1) {
return true;
}
}
return false;
}
} // namespace kernel_selector

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "eltwise_kernel_base.h"
namespace kernel_selector {
class EltwiseKernel_blocked_opt : public EltwiseKernelBase {
public:
EltwiseKernel_blocked_opt() : EltwiseKernelBase("eltwise_blocked_opt") {}
virtual ~EltwiseKernel_blocked_opt() {}
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::QUANTIZE,
FusedOpType::ACTIVATION,
FusedOpType::ELTWISE
};
}
protected:
bool Validate(const Params& p, const optional_params& o) const override;
JitConstants MakeLoadJitConstants(const eltwise_params& params, bool useVload8) const override;
JitConstants GetJitConstants(const eltwise_params& params) const override;
DispatchData SetDefault(const eltwise_params& params) const override;
void PrintWorkSize(const DispatchData& dis);
};
} // namespace kernel_selector

View File

@ -9,6 +9,8 @@
#include "eltwise_kernel_b_fs_yx_fsv16.h"
#include "eltwise_kernel_mixed_byxf_and_fs_b_yx_fsv32.h"
#include "eltwise_kernel_b_fs_yx_fsv4.h"
#include "eltwise_kernel_blocked_opt.h"
namespace kernel_selector {
eltwise_kernel_selector::eltwise_kernel_selector() {
@ -17,6 +19,7 @@ eltwise_kernel_selector::eltwise_kernel_selector() {
Attach<EltwiseKernel_fs_b_yx_fsv32>();
Attach<EltwiseKernel_mixed_byxf_and_fs_b_yx_fsv32>();
Attach<EltwiseKernel_b_fs_yx_fsv16>();
Attach<EltwiseKernel_blocked_opt>();
Attach<EltwiseKernel_b_fs_yx_fsv4>();
}

View File

@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "intel_gpu/runtime/layout.hpp"
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
@ -3932,6 +3933,14 @@ struct eltwise_layout_test_params {
#define CASE_ELTWISE_TEST9 eltwise_mode::eq, {4, 2, 4, 4}, {1, 1, 1, 1}, format::b_fs_yx_fsv16, format::bfyx, "generic_eltwise_ref"
class eltwise_layout_test : public BaseEltwiseTest<eltwise_layout_test_params> {
public:
std::string PrintToString(const eltwise_layout_test_params& params) {
std::string res;
res += " format1 (" + format::traits(params.input0_format).str + ")";
res += " format2 (" + format::traits(params.input1_format).str + ")";
return res;
}
};
class eltwise_test_mixed_layout : public eltwise_layout_test {};
@ -4027,6 +4036,15 @@ struct eltwise_random_test_params {
struct eltwise_random_test : testing::TestWithParam<eltwise_random_test_params>
{
static std::string PrintToString(const eltwise_random_test_params& params) {
std::string res = " data (" + cldnn::data_type_traits::name(params.input_type) + "), ";
res += " format (" + format::traits(params.in_format).str + ") input1 : ";
res += params.first_input_size.to_string() + " / input2 : ";
res += params.second_input_size.to_string() + "\n";
return res;
}
template <typename T>
void fill_random_typed(memory::ptr mem, int min, int max, int k) {
auto l = mem->get_layout();
@ -4188,15 +4206,17 @@ struct eltwise_random_test_param_generator : std::vector<eltwise_random_test_par
eltwise_random_test_param_generator& broadcast_params(data_types type, format::type input_format, format::type output_format) {
push_back(eltwise_random_test_params{ type, {1, 1, 48, 64}, {1, 10, 48, 64}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 16, 48, 64}, {1, 1, 48, 64}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 5, 4, 4}, {1, 1, 4, 4}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 36, 48, 64}, {1, 1, 48, 64}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 36, 4, 4}, {1, 1, 4, 4}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 8, 4, 4}, {1, 1, 1, 1}, input_format, format::bfyx, output_format, eltwise_mode::sum, false });
return *this;
}
eltwise_random_test_param_generator& simple_params(data_types type, format::type input_format, format::type output_format) {
push_back(eltwise_random_test_params{ type, {1, 10, 10, 10}, {1, 10, 10, 10}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 20, 10, 10}, {1, 20, 10, 10}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 5, 4, 4}, {1, 5, 4, 4}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 20, 16, 16}, {1, 20, 16, 16}, input_format, input_format, output_format, eltwise_mode::sum, false });
push_back(eltwise_random_test_params{ type, {1, 32, 16, 16}, {1, 32, 16, 16}, input_format, input_format, output_format, eltwise_mode::sum, false });
return *this;
}
};
@ -4227,3 +4247,45 @@ INSTANTIATE_TEST_SUITE_P(export_import,
.add(eltwise_random_test_params{ data_types::f16, {1, 1, 48, 64}, {1, 10, 48, 64}, format::b_fs_yx_fsv4,
format::b_fs_yx_fsv4, format::b_fs_yx_fsv4, eltwise_mode::sum, true })
));
INSTANTIATE_TEST_SUITE_P(eltwise_smoke_fsv16,
eltwise_random_test,
testing::ValuesIn(
eltwise_random_test_param_generator()
.broadcast_params(data_types::f32, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.broadcast_params(data_types::f16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.broadcast_params(data_types::i8, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.broadcast_params(data_types::u8, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.simple_params(data_types::f32, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.simple_params(data_types::f16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.simple_params(data_types::i8, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
.simple_params(data_types::u8, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16)
));
INSTANTIATE_TEST_SUITE_P(eltwise_smoke_fsv32,
eltwise_random_test,
testing::ValuesIn(
eltwise_random_test_param_generator()
.broadcast_params(data_types::f32, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.broadcast_params(data_types::f16, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.broadcast_params(data_types::i8, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.broadcast_params(data_types::u8, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.simple_params(data_types::f32, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.simple_params(data_types::f16, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.simple_params(data_types::i8, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
.simple_params(data_types::u8, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32)
));
INSTANTIATE_TEST_SUITE_P(eltwise_smoke_bsv_fsv,
eltwise_random_test,
testing::ValuesIn(
eltwise_random_test_param_generator()
.broadcast_params(data_types::f32, format::bs_fs_yx_bsv16_fsv16, format::bs_fs_yx_bsv16_fsv16)
.broadcast_params(data_types::f16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16)
.broadcast_params(data_types::i8, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32)
.broadcast_params(data_types::u8, format::bs_fs_yx_bsv16_fsv32, format::bs_fs_yx_bsv16_fsv32)
.simple_params(data_types::f32, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16)
.simple_params(data_types::f16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16)
.simple_params(data_types::i8, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32)
.simple_params(data_types::u8, format::bs_fs_yx_bsv16_fsv32, format::bs_fs_yx_bsv16_fsv32)
));