[IE CLDNN] Add b_fs_fsv16 concat optimizations (#1452)

1. Add fsv16 int8 support to optimized kernel
2. Optimize fsv16 concat kernel
3. Add graph optimization to improve concat alignment

Issue: CVS-28494
This commit is contained in:
Konrad Dobros 2020-07-27 13:49:22 +02:00 committed by GitHub
parent 3632dde431
commit 0846f2050e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 581 additions and 22 deletions

View File

@ -18,12 +18,63 @@
#include "kernel_selector_utils.h" #include "kernel_selector_utils.h"
namespace kernel_selector { namespace kernel_selector {
namespace {
size_t getTileXY(const concatenation_params& params) {
auto& input = params.inputs[0];
size_t tileXY = 1;
if (params.isAligned) {
switch (input.GetDType()) {
case Datatype::F16:
case Datatype::INT8:
case Datatype::UINT8:
tileXY = 4;
break;
default:
return 1;
}
} else {
switch (input.GetDType()) {
case Datatype::F32:
tileXY = 2;
break;
case Datatype::F16:
tileXY = 4;
break;
case Datatype::INT8:
case Datatype::UINT8:
tileXY = 8;
break;
default:
return 1;
}
}
auto tileXYMultiple = input.X().v;
bool noInputPad = input.X().pad.Total() == 0;
bool noOutputPad = params.output.X().pad.Total() == 0;
if (noInputPad && noOutputPad)
tileXYMultiple = input.X().v * input.Y().v;
while (tileXYMultiple % tileXY != 0)
tileXY /= 2;
return tileXY;
}
} // namespace
ParamsKey ConcatenationKernel_b_fs_yx_fsv16::GetSupportedKey() const { ParamsKey ConcatenationKernel_b_fs_yx_fsv16::GetSupportedKey() const {
ParamsKey k; ParamsKey k;
k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16); k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16); k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableTensorOffset(); k.EnableTensorOffset();
@ -60,10 +111,13 @@ bool ConcatenationKernel_b_fs_yx_fsv16::Validate(const Params& p, const optional
ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefault(const concatenation_params& params) const { ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefault(const concatenation_params& params) const {
DispatchData runInfo = ConcatenationKernelBase::SetDefault(params); DispatchData runInfo = ConcatenationKernelBase::SetDefault(params);
const auto& input = params.inputs[0]; const auto& input = params.inputs[0];
auto tileXY = getTileXY(params);
runInfo.gws0 = input.Batch().v; size_t tileF = params.misalignment == 0 ? 1 : 2;
runInfo.gws1 = Align(input.Feature().v, 16);
runInfo.gws2 = input.X().v * input.Y().v; runInfo.gws0 = CeilDiv(input.X().v * input.Y().v, tileXY);
runInfo.gws1 = Align(input.Feature().v, 16 * tileF) / tileF;
runInfo.gws2 = input.Batch().v;
runInfo.lws0 = 1; runInfo.lws0 = 1;
runInfo.lws1 = 16; runInfo.lws1 = 16;
@ -77,7 +131,9 @@ ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefa
JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenation_params& params) const { JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenation_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params); JitConstants jit = MakeBaseParamsJitConstants(params);
jit.AddConstant(MakeJitConstant("ALIGNED", params.isAligned)); jit.AddConstant(MakeJitConstant("ALIGNED", params.misalignment == 0));
jit.AddConstant(MakeJitConstant("MISALIGNMENT", params.misalignment));
jit.AddConstant(MakeJitConstant("TILE_XY", getTileXY(params)));
return jit; return jit;
} }
@ -85,4 +141,8 @@ JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenat
KernelsData ConcatenationKernel_b_fs_yx_fsv16::GetKernelsData(const Params& params, const optional_params& optParams) const { KernelsData ConcatenationKernel_b_fs_yx_fsv16::GetKernelsData(const Params& params, const optional_params& optParams) const {
return GetCommonKernelsData(params, optParams); return GetCommonKernelsData(params, optParams);
} }
size_t ConcatenationKernel_b_fs_yx_fsv16::GetAlignment(const concatenation_params& /*params*/) const {
return 16;
}
} // namespace kernel_selector } // namespace kernel_selector

View File

@ -28,5 +28,6 @@ public:
DispatchData SetDefault(const concatenation_params& params) const override; DispatchData SetDefault(const concatenation_params& params) const override;
JitConstants GetJitConstants(const concatenation_params& params) const override; JitConstants GetJitConstants(const concatenation_params& params) const override;
bool Validate(const Params& p, const optional_params& o) const override; bool Validate(const Params& p, const optional_params& o) const override;
size_t GetAlignment(const concatenation_params& params) const override;
}; };
} // namespace kernel_selector } // namespace kernel_selector

View File

@ -115,7 +115,8 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params,
newParams.inputs.resize(1); newParams.inputs.resize(1);
newParams.inputs[0] = input; newParams.inputs[0] = input;
size_t ifm = input.Feature().v; size_t ifm = input.Feature().v;
newParams.isAligned = ifm_offset % 16 == 0 && ifm % 16 == 0; newParams.isAligned = ifm_offset % GetAlignment(newParams) == 0;
newParams.misalignment = ifm_offset % GetAlignment(newParams);
ifm_offset += ifm; ifm_offset += ifm;
auto& kernel = kd.kernels[i]; auto& kernel = kd.kernels[i];
@ -127,7 +128,7 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params,
kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2}; kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2}; kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo); kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i}); kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i });
kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
ScalarDescriptor s; ScalarDescriptor s;

View File

@ -26,6 +26,7 @@ struct concatenation_params : public base_params {
ConcatAxis axis = ConcatAxis::FEATURE; ConcatAxis axis = ConcatAxis::FEATURE;
bool isAligned = true; bool isAligned = true;
size_t misalignment = 0;
virtual ParamsKey GetParamsKey() const { virtual ParamsKey GetParamsKey() const {
auto k = base_params::GetParamsKey(); auto k = base_params::GetParamsKey();
@ -71,5 +72,8 @@ protected:
KernelsData GetCommonKernelsData(const Params& params, const optional_params&) const; KernelsData GetCommonKernelsData(const Params& params, const optional_params&) const;
int32_t GetConcatChannelIndex(const concatenation_params& params) const; int32_t GetConcatChannelIndex(const concatenation_params& params) const;
Tensor::DataChannelName GetConcatChannel(const concatenation_params& params) const; Tensor::DataChannelName GetConcatChannel(const concatenation_params& params) const;
virtual size_t GetAlignment(const concatenation_params& /*params*/) const {
return 1;
}
}; };
} // namespace kernel_selector } // namespace kernel_selector

View File

@ -14,43 +14,145 @@
#include "include/fetch.cl" #include "include/fetch.cl"
#include "include/unit_type.cl" #include "include/data_types.cl"
#define WORK_GROUP_SIZE 16 #define WORK_GROUP_SIZE 16
#define IC_BLOCK 16 #define IC_BLOCK 16
#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TILE_XY)
#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, TILE_XY)
#define TO_OUTPUT_VEC_TYPE(x) CAT(convert_, OUTPUT_VEC_TYPE)(x)
#define INPUT_BLOCK_READ(ptr, offset) MAKE_VECTOR_TYPE(DT_INPUT_BLOCK_READ, TILE_XY)(ptr, offset)
#define OUTPUT_BLOCK_WRITE(ptr, offset, val) MAKE_VECTOR_TYPE(DT_OUTPUT_BLOCK_WRITE, TILE_XY)(ptr, offset, val)
#if !ALIGNED
// For non-aligned case process two features together to mitigate misalignment
# define TILE_F 2
#else
# define TILE_F 1
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
__attribute__((reqd_work_group_size(1, WORK_GROUP_SIZE, 1))) __attribute__((reqd_work_group_size(1, WORK_GROUP_SIZE, 1)))
__attribute__((intel_reqd_sub_group_size(WORK_GROUP_SIZE))) __attribute__((intel_reqd_sub_group_size(WORK_GROUP_SIZE)))
KERNEL (concatenation_gpu_blocked)(__global UNIT_TYPE* input, __global UNIT_TYPE* output, uint output_offset_in_concat_axis) KERNEL (concatenation_gpu_blocked)(
__global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output,
uint output_offset_in_concat_axis)
{ {
const int b = get_global_id(0); const int xy = (uint)get_global_id(0) * TILE_XY;
const int f_block = get_group_id(1); const int f_block = (uint)get_group_id(1) * TILE_F;
const int xy = get_global_id(2); const int b = get_group_id(2);
const int lid = get_sub_group_local_id(); const int lid = get_sub_group_local_id();
const int x = xy % OUTPUT_SIZE_X; const int x = xy % OUTPUT_SIZE_X;
const int y = xy / OUTPUT_SIZE_X; const int y = xy / OUTPUT_SIZE_X;
const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x);
#if ALIGNED #if ALIGNED
const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x); INPUT_VEC_TYPE src = INPUT_BLOCK_READ(input, input_offset);
const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + output_offset_in_concat_axis), y, x); const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + output_offset_in_concat_axis), y, x);
UNIT_TYPE src = UNIT_BLOCK_READ(input, input_offset); bool do_block_write = (INPUT0_FEATURE_NUM % IC_BLOCK == 0)
src = ACTIVATION(src, ACTIVATION_PARAMS); || (f_block * IC_BLOCK + IC_BLOCK <= INPUT0_FEATURE_NUM);
UNIT_BLOCK_WRITE(output, dst_index, src);
if (do_block_write) {
OUTPUT_VEC_TYPE res = TO_OUTPUT_VEC_TYPE(ACTIVATION(src, ACTIVATION_PARAMS));
OUTPUT_BLOCK_WRITE(output, dst_index, res);
} else {
if (lid < INPUT0_FEATURE_NUM % IC_BLOCK) {
__attribute__((opencl_unroll_hint))
for (uint tx = 0; tx < TILE_XY; ++tx) {
OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src)[tx], ACTIVATION_PARAMS));
output[dst_index + tx * IC_BLOCK + lid] = res;
}
}
}
#else #else
if (f_block*IC_BLOCK + lid >= INPUT0_FEATURE_NUM)
return;
const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK + lid, y, x); #if TILE_F != 1
const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x); bool full_write = (INPUT0_FEATURE_NUM % (IC_BLOCK * TILE_F) == 0) || (f_block * IC_BLOCK + TILE_F * IC_BLOCK <= INPUT0_FEATURE_NUM);
if (full_write) {
INPUT_VEC_TYPE src0 = INPUT_BLOCK_READ(input, input_offset + 0 * INPUT0_FEATURE_PITCH * IC_BLOCK);
INPUT_VEC_TYPE src1 = INPUT_BLOCK_READ(input, input_offset + 1 * INPUT0_FEATURE_PITCH * IC_BLOCK);
#if TILE_F == 4
INPUT_VEC_TYPE src2 = INPUT_BLOCK_READ(input, input_offset + 2 * INPUT0_FEATURE_PITCH * IC_BLOCK);
INPUT_VEC_TYPE src3 = INPUT_BLOCK_READ(input, input_offset + 3 * INPUT0_FEATURE_PITCH * IC_BLOCK);
#endif
UNIT_TYPE src = input[input_offset]; uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + (IC_BLOCK - MISALIGNMENT) + output_offset_in_concat_axis), y, x);
src = ACTIVATION(src, ACTIVATION_PARAMS);
output[dst_index] = src; INPUT_VEC_TYPE src_al0 = 0;
#if TILE_F == 4
INPUT_VEC_TYPE src_al1 = 0;
INPUT_VEC_TYPE src_al2 = 0;
#endif
__attribute__((opencl_unroll_hint))
for (uint tx = 0; tx < TILE_XY; ++tx) {
((INPUT0_TYPE*)&src_al0)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src0)[tx], ((INPUT0_TYPE*)&src1)[tx], (IC_BLOCK - MISALIGNMENT));
#if TILE_F == 4
((INPUT0_TYPE*)&src_al1)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src1)[tx], ((INPUT0_TYPE*)&src2)[tx], (IC_BLOCK - MISALIGNMENT));
((INPUT0_TYPE*)&src_al2)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src2)[tx], ((INPUT0_TYPE*)&src3)[tx], (IC_BLOCK - MISALIGNMENT));
#endif
}
OUTPUT_VEC_TYPE res_al0 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al0, ACTIVATION_PARAMS));
OUTPUT_BLOCK_WRITE(output, dst_index, res_al0);
#if TILE_F == 4
OUTPUT_VEC_TYPE res_al1 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al1, ACTIVATION_PARAMS));
OUTPUT_BLOCK_WRITE(output, dst_index + 1 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al1);
OUTPUT_VEC_TYPE res_al2 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al2, ACTIVATION_PARAMS));
OUTPUT_BLOCK_WRITE(output, dst_index + 2 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al2);
#endif
uint lid_f_offset = lid;
INPUT_VEC_TYPE src_unal = 0;
lid_f_offset += lid < (IC_BLOCK - MISALIGNMENT) ? 0 : IC_BLOCK * (TILE_F - 1);
#if TILE_F == 2
src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src1;
#elif TILE_F == 4
src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src3;
#endif
dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid_f_offset + output_offset_in_concat_axis), y, x);
__attribute__((opencl_unroll_hint))
for (uint tx = 0; tx < TILE_XY; ++tx) {
OUTPUT_TYPE res_unal = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src_unal)[tx], ACTIVATION_PARAMS));
output[dst_index + tx * IC_BLOCK] = res_unal;
}
} else
#endif // TILE_F != 1
{
const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x);
__attribute__((opencl_unroll_hint))
for (uint fw = 0; fw < TILE_F; ++fw) {
if (TILE_F != 1 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F != 0 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F == fw)
break;
bool do_leftover_write = INPUT0_FEATURE_NUM % IC_BLOCK == 0 || f_block * IC_BLOCK + fw * IC_BLOCK + lid < INPUT0_FEATURE_NUM;
if (do_leftover_write) {
__attribute__((opencl_unroll_hint))
for (uint tx = 0; tx < TILE_XY; ++tx) {
INPUT0_TYPE src = input[input_offset + lid + tx * IC_BLOCK + fw * INPUT0_FEATURE_PITCH * IC_BLOCK];
OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(src, ACTIVATION_PARAMS));
output[dst_index + tx * IC_BLOCK + fw * OUTPUT_FEATURE_PITCH * IC_BLOCK] = res;
}
}
}
}
#endif #endif
} }
#undef WORK_GROUP_SIZE #undef WORK_GROUP_SIZE
#undef IC_BLOCK #undef IC_BLOCK
#undef INPUT_VEC_TYPE
#undef OUTPUT_VEC_TYPE
#undef TO_OUTPUT_VEC_TYPE
#undef INPUT_BLOCK_READ
#undef OUTPUT_BLOCK_WRITE
#undef TILE_F
#undef CEIL_DIV

View File

@ -0,0 +1,224 @@
// Copyright (c) 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////////////////////////////////
#include "pass_manager.h"
#include "pooling_inst.h"
#include "convolution_inst.h"
#include "fully_connected_inst.h"
#include "data_inst.h"
#include "memory_impl.h"
#include "program_impl.h"
#include <vector>
#include <tuple>
using namespace cldnn;
namespace {
using shuffle_range = std::pair<int32_t, int32_t>;
bool can_shuffle_features(program_node& node) {
if (node.is_type<convolution>()) {
auto& conv_node = node.as<convolution>();
auto& wei_node = conv_node.weights();
return conv_node.get_groups() == 1 && conv_node.get_split() == 1 &&
conv_node.get_deformable_groups() == 1 && !conv_node.get_transposed() &&
!conv_node.activations_zero_points_term() &&
wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
}
if (node.is_type<fully_connected>()) {
auto& fc_node = node.as<fully_connected>();
auto& wei_node = fc_node.weights();
return wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
}
bool pass_through = false;
pass_through |= node.is_type<activation>();
pass_through |= node.is_type<pooling>();
// General conditions for pass-through layers
pass_through &= !node.is_output() && node.get_dependencies().size() == 1 && !node.has_fused_primitives();
if (pass_through) {
// Primitives that are feature order invariant, pass-through shuffled features to users
for (auto& user : node.get_users()) {
if (!can_shuffle_features(*user))
return false;
}
return true;
}
return false;
}
void shuffle_weights(data_node& node, const std::vector<shuffle_range>& ranges) {
// Correct for shuffled features by shuffling input feature dimension in weights.
// This allows to restore correct feature order on output and only changes calculation order.
auto wei_layout = node.get_output_layout();
auto& old_weights_memory = node.get_attached_memory();
bool need_reset = static_cast<bool>(wei_layout.data_padding) || wei_layout.format.is_blocked();
auto new_weights_memory = old_weights_memory.get_engine()->allocate_memory(wei_layout, old_weights_memory.get_net_id(), need_reset);
auto bytes_per_elem = data_type_traits::size_of(wei_layout.data_type);
auto old_ptr = static_cast<char*>(old_weights_memory.lock());
auto new_ptr = static_cast<char*>(new_weights_memory->lock());
for (int32_t ofi = 0; ofi < wei_layout.size.batch[0]; ++ofi) {
int32_t new_ifi = 0;
for (auto& range : ranges) {
for (int32_t ifi = range.first; ifi < range.second; ++ifi, ++new_ifi) {
for (int32_t wi = 0; wi < wei_layout.size.spatial[3]; ++wi) {
for (int32_t zi = 0; zi < wei_layout.size.spatial[2]; ++zi) {
for (int32_t yi = 0; yi < wei_layout.size.spatial[1]; ++yi) {
for (int32_t xi = 0; xi < wei_layout.size.spatial[0]; ++xi) {
auto old_coords = tensor(batch(ofi), feature(ifi), spatial(xi, yi, zi, wi));
auto new_coords = tensor(batch(ofi), feature(new_ifi), spatial(xi, yi, zi, wi));
auto old_offset = wei_layout.get_linear_offset(old_coords);
auto new_offset = wei_layout.get_linear_offset(new_coords);
for (size_t byte = 0; byte < bytes_per_elem; ++byte) {
new_ptr[new_offset * bytes_per_elem + byte] = old_ptr[old_offset * bytes_per_elem + byte];
}
}
}
}
}
}
}
}
old_weights_memory.unlock();
new_weights_memory->unlock();
node.attach_memory(*new_weights_memory, false);
}
void shuffle_features(program_node& node, const std::vector<shuffle_range>& ranges) {
if (node.is_type<convolution>()) {
auto& conv = node.as<convolution>();
shuffle_weights(conv.weights().as<data>(), ranges);
} else if (node.is_type<fully_connected>()) {
auto& fc = node.as<fully_connected>();
shuffle_weights(fc.weights().as<data>(), ranges);
} else {
// General case for pass-through layers
for (auto& user : node.get_users()) {
shuffle_features(*user, ranges);
}
}
}
} // namespace
void concat_input_order::run(program_impl& p) {
for (auto node : p.get_processing_order()) {
// Check that optimization can be performed:
// 1. Not an output
// 2. Concatenation along features
// 3. Currently only fsv16 format on input/output
// 4. Not already aligned
// 5. Users can accept shuffled features
// 6. No fused primitives
if (!node->is_type<concatenation>() || node->is_output())
continue;
auto& concat_node = node->as<concatenation>();
auto prim = concat_node.get_primitive();
bool along_f = prim->axis == concatenation::along_f;
size_t inputs_count = prim->input_size();
bool no_fusing = !concat_node.has_fused_primitives() && concat_node.get_dependencies().size() == inputs_count;
auto out_format = concat_node.get_output_layout().format;
bool correct_format = out_format == format::b_fs_yx_fsv16;
tensor::value_type alignment = 1;
if (out_format == format::b_fs_yx_fsv16)
alignment = 16;
bool single_format = true;
std::vector<tensor::value_type> feature_sizes;
feature_sizes.reserve(inputs_count);
for (size_t input_idx = 0; input_idx < inputs_count; ++input_idx) {
auto& dep = concat_node.get_dependency(input_idx);
auto dep_layout = dep.get_output_layout();
single_format &= dep_layout.format == out_format;
feature_sizes.push_back(dep_layout.size.feature[0]);
}
// Alignment is not optimal if aligned input follows unaligned one
bool already_aligned = true;
for (size_t i = 1; i < feature_sizes.size(); ++i) {
bool current_aligned = feature_sizes[i] % alignment == 0;
bool previous_aligned = feature_sizes[i - 1] % alignment == 0;
already_aligned &= previous_aligned || !current_aligned;
}
// Check that we can fuse shuffling to users
bool can_shuffle_users = true;
for (auto user : concat_node.get_users()) {
can_shuffle_users &= can_shuffle_features(*user);
}
if (!along_f || !no_fusing || !correct_format || !single_format || already_aligned || !can_shuffle_users)
continue;
// Perform the optimization
// Calculate new input order - first inputs preserving alignment, then rest
std::vector<size_t> new_order;
new_order.reserve(inputs_count);
for (size_t i = 0; i < feature_sizes.size(); ++i) {
if (feature_sizes[i] % alignment == 0)
new_order.push_back(i);
}
for (size_t i = 0; i < feature_sizes.size(); ++i) {
if (feature_sizes[i] % alignment != 0)
new_order.push_back(i);
}
// Calculate new ranges
int32_t current_offset = 0;
std::vector<shuffle_range> original_ranges;
original_ranges.reserve(inputs_count);
for (auto& feature_size : feature_sizes) {
original_ranges.emplace_back(current_offset, current_offset + feature_size);
current_offset += feature_size;
}
std::vector<shuffle_range> shuffled_ranges;
shuffled_ranges.reserve(inputs_count);
for (auto& ord : new_order) {
shuffled_ranges.push_back(original_ranges[ord]);
}
// Change input order
std::vector<program_node*> new_dependencies = {};
new_dependencies.reserve(inputs_count);
for (auto& ord : new_order) {
new_dependencies.push_back(&concat_node.get_dependency(ord));
}
// Update in place with const cast instead of replacing
auto& dependencies = concat_node.get_dependencies();
auto& mutable_dependencies = const_cast<std::vector<program_node*>&>(dependencies);
for (size_t i = 0; i < new_dependencies.size(); ++i) {
mutable_dependencies[i] = new_dependencies[i];
}
std::vector<primitive_id> new_input_ids;
new_input_ids.reserve(inputs_count);
for (auto& ord : new_order) {
new_input_ids.push_back(prim->input[ord]);
}
auto mutable_prim = std::const_pointer_cast<concatenation>(prim);
mutable_prim->input = new_input_ids;
// Correct users for shuffled features
for (auto& user : concat_node.get_users()) {
shuffle_features(*user, shuffled_ranges);
}
}
}

View File

@ -332,6 +332,28 @@ public:
void run(program_impl& p) override; void run(program_impl& p) override;
}; };
class concat_input_order : public base_pass {
// This optimization changes order of inputs for concatenation to provide
// better alignment for execution and allow for optimizing out in some cases.
// For example concatenation along features with inputs [13, 1024] in format fsv16
// has only first input aligned to feature blocks, blocking performant implementation
// for second one.
// This can be fixed by chaning order to [1024, 13] and fusing reshuffling of those features
// into following layers, such as convolution or fully connected, where it can be
// implemented as compile-time weights shuffling.
//
// Requirements - may work incorrectly if not fullfiled:
// - formats are selected
// - implementations aren't selected
//
// Soft requirements - reduce applicability if not fullfiled:
// - constant primitives are reduced to data nodes
// - no fused primitives
public:
concat_input_order() : base_pass("concat_input_order") {}
void run(program_impl& p) override;
};
class memory_dependency_pass : public base_pass { class memory_dependency_pass : public base_pass {
public: public:
explicit memory_dependency_pass(const std::string& pass_name) : base_pass(pass_name) {} explicit memory_dependency_pass(const std::string& pass_name) : base_pass(pass_name) {}

View File

@ -420,6 +420,10 @@ void program_impl::pre_optimize_graph(bool is_internal) {
apply_opt_pass<prepare_primitive_fusing>(lo); apply_opt_pass<prepare_primitive_fusing>(lo);
apply_opt_pass<reorder_inputs>(lo, rf); apply_opt_pass<reorder_inputs>(lo, rf);
// Ideally this should be done before fusing to simplify logic and make the pass more powerful,
// but after format selection to select correct alignment.
// Unfortunately those passes currently happen in reverse order.
apply_opt_pass<concat_input_order>();
// TODO this code should be moved to post compilation after kernel selector will support handling reorder bias // TODO this code should be moved to post compilation after kernel selector will support handling reorder bias
apply_opt_pass<pre_optimize_bias>(rf); apply_opt_pass<pre_optimize_bias>(rf);

View File

@ -638,6 +638,10 @@ TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv32) {
ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv32)); ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv32));
} }
TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv16) {
ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
}
INSTANTIATE_TEST_CASE_P(smoke_low_precision, INSTANTIATE_TEST_CASE_P(smoke_low_precision,
concat_gpu_4d_i8, concat_gpu_4d_i8,
concat_gpu_all_params, concat_gpu_all_params,
@ -651,3 +655,140 @@ INSTANTIATE_TEST_CASE_P(smoke_low_precision,
concat_gpu_4d_u8, concat_gpu_4d_u8,
concat_gpu_all_params, concat_gpu_all_params,
concat_gpu::PrintToStringParamName); concat_gpu::PrintToStringParamName);
template <typename Type, typename OutputT>
struct concat_id_conv_gpu_4d : public concat_gpu {
public:
void test(format::type fmt) {
auto data_type = type_to_data_type<Type>::value;
const auto& engine = get_test_engine();
const size_t batch_num = testing::get<0>(GetParam());
const std::vector<size_t> in_features = testing::get<1>(GetParam());
const size_t input_y = testing::get<2>(GetParam());
const size_t input_x = testing::get<3>(GetParam());
size_t output_f = 0;
for (auto& f : in_features)
output_f += f;
topology topology;
std::vector<VVVVF<Type>> in_data;
std::vector<memory> in_memory;
std::vector<primitive_id> input_ids;
for (size_t i = 0; i < in_features.size(); i++) {
auto size = tensor(static_cast<int32_t>(batch_num),
static_cast<int32_t>(in_features[i]),
static_cast<int32_t>(input_x),
static_cast<int32_t>(input_y));
auto data = generate_random_4d<Type>(batch_num, in_features[i], input_y, input_x, -128, 128);
auto in_lay = layout(data_type, fmt, size);
auto data_flat = std::vector<Type>(in_lay.get_linear_size(), 0);
for (size_t bi = 0; bi < batch_num; ++bi) {
for (size_t fi = 0; fi < in_features[i]; ++fi) {
for (size_t yi = 0; yi < input_y; ++yi) {
for (size_t xi = 0; xi < input_x; ++xi) {
auto coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0));
auto in_offset = in_lay.get_linear_offset(coords);
data_flat[in_offset] = data[bi][fi][yi][xi];
}
}
}
}
auto in_mem = memory::allocate(engine, in_lay);
set_values(in_mem, data_flat);
in_memory.push_back(in_mem);
topology.add(input_layout("input" + std::to_string(i), in_lay));
in_data.emplace_back(std::move(data));
input_ids.push_back("input" + std::to_string(i));
}
topology.add(concatenation("concat", input_ids, concatenation::concatenation_axis::along_f));
// Add identity convolution
auto weights_lay = cldnn::layout(data_type, cldnn::format::bfyx, tensor(batch(output_f), feature(output_f)));
auto weights_mem = cldnn::memory::allocate(engine, weights_lay);
{
auto weights_ptr = weights_mem.pointer<Type>();
for (size_t fi = 0; fi < output_f; ++fi) {
auto coords = tensor(batch(fi), feature(fi), spatial(0, 0, 0, 0));
auto offset = weights_lay.get_linear_offset(coords);
weights_ptr[offset] = static_cast<Type>(1.f);
}
}
topology.add(data("weights", weights_mem));
topology.add(convolution("conv", "concat", { "weights" }));
build_options options;
options.set_option(build_option::optimize_data(true));
auto conv_forcing = implementation_desc{ fmt, std::string() };
options.set_option(build_option::force_implementations({ {primitive_id("conv"), conv_forcing} }));
network network(engine, topology, options);
for (size_t i = 0; i < in_features.size(); i++) {
network.set_input_data(input_ids[i], in_memory[i]);
}
network.execute();
auto out_mem = network.get_output("conv").get_memory();
auto out_ptr = out_mem.pointer<OutputT>();
for (size_t bi = 0; bi < batch_num; bi++) {
size_t f_sum = 0;
for (size_t in_i = 0; in_i < in_features.size(); in_i++) {
for (size_t fi = 0; fi < in_features[in_i]; fi++) {
for (size_t yi = 0; yi < input_y; yi++) {
for (size_t xi = 0; xi < input_x; xi++) {
auto output_coords = tensor(batch(bi), feature(f_sum + fi), spatial(xi, yi, 0, 0));
auto output_offset = out_mem.get_layout().get_linear_offset(output_coords);
auto ref_val = in_data[in_i][bi][fi][yi][xi];
auto actual_val = static_cast<Type>(out_ptr[output_offset]);
EXPECT_EQ(ref_val, actual_val)
<< " b=" << bi << ", f=" << f_sum + fi << "(input " << in_i << "), y=" << yi << ", x=" << xi;
}
}
}
f_sum += in_features[in_i];
}
}
}
};
using concat_id_conv_gpu_4d_f16 = concat_id_conv_gpu_4d<FLOAT16, FLOAT16>;
using concat_id_conv_gpu_4d_i8 = concat_id_conv_gpu_4d<int8_t, float>;
TEST_P(concat_id_conv_gpu_4d_f16, input_order_opt_b_fs_yx_fsv16) {
ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
}
INSTANTIATE_TEST_CASE_P(smoke_low_precision,
concat_id_conv_gpu_4d_f16,
::testing::Values(
TestParamType_concat(2, { 2, 32 }, 2, 1),
TestParamType_concat(2, { 31, 64 }, 2, 2),
TestParamType_concat(2, { 15, 15, 16 }, 2, 1),
TestParamType_concat(2, { 16, 15, 16 }, 2, 2),
TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
),
concat_gpu::PrintToStringParamName);
TEST_P(concat_id_conv_gpu_4d_i8, input_order_opt_b_fs_yx_fsv16) {
ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
}
INSTANTIATE_TEST_CASE_P(smoke_low_precision,
concat_id_conv_gpu_4d_i8,
::testing::Values(
TestParamType_concat(2, { 2, 32 }, 2, 1),
TestParamType_concat(2, { 31, 64 }, 2, 2),
TestParamType_concat(2, { 15, 15, 16 }, 2, 1),
TestParamType_concat(2, { 16, 15, 16 }, 2, 2),
TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
),
concat_gpu::PrintToStringParamName);