diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.cpp index 8251a4252ea..57fc0500249 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.cpp @@ -18,12 +18,63 @@ #include "kernel_selector_utils.h" namespace kernel_selector { + +namespace { + +size_t getTileXY(const concatenation_params& params) { + auto& input = params.inputs[0]; + size_t tileXY = 1; + if (params.isAligned) { + switch (input.GetDType()) { + case Datatype::F16: + case Datatype::INT8: + case Datatype::UINT8: + tileXY = 4; + break; + default: + return 1; + } + } else { + switch (input.GetDType()) { + case Datatype::F32: + tileXY = 2; + break; + case Datatype::F16: + tileXY = 4; + break; + case Datatype::INT8: + case Datatype::UINT8: + tileXY = 8; + break; + default: + return 1; + } + } + + auto tileXYMultiple = input.X().v; + bool noInputPad = input.X().pad.Total() == 0; + bool noOutputPad = params.output.X().pad.Total() == 0; + if (noInputPad && noOutputPad) + tileXYMultiple = input.X().v * input.Y().v; + + while (tileXYMultiple % tileXY != 0) + tileXY /= 2; + + return tileXY; +} + +} // namespace + ParamsKey ConcatenationKernel_b_fs_yx_fsv16::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); + k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::b_fs_yx_fsv16); k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16); k.EnableTensorOffset(); @@ -60,10 +111,13 @@ bool ConcatenationKernel_b_fs_yx_fsv16::Validate(const Params& p, const optional ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefault(const concatenation_params& params) const { DispatchData runInfo = ConcatenationKernelBase::SetDefault(params); const auto& input = params.inputs[0]; + auto tileXY = getTileXY(params); - runInfo.gws0 = input.Batch().v; - runInfo.gws1 = Align(input.Feature().v, 16); - runInfo.gws2 = input.X().v * input.Y().v; + size_t tileF = params.misalignment == 0 ? 1 : 2; + + runInfo.gws0 = CeilDiv(input.X().v * input.Y().v, tileXY); + runInfo.gws1 = Align(input.Feature().v, 16 * tileF) / tileF; + runInfo.gws2 = input.Batch().v; runInfo.lws0 = 1; runInfo.lws1 = 16; @@ -77,7 +131,9 @@ ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefa JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenation_params& params) const { JitConstants jit = MakeBaseParamsJitConstants(params); - jit.AddConstant(MakeJitConstant("ALIGNED", params.isAligned)); + jit.AddConstant(MakeJitConstant("ALIGNED", params.misalignment == 0)); + jit.AddConstant(MakeJitConstant("MISALIGNMENT", params.misalignment)); + jit.AddConstant(MakeJitConstant("TILE_XY", getTileXY(params))); return jit; } @@ -85,4 +141,8 @@ JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenat KernelsData ConcatenationKernel_b_fs_yx_fsv16::GetKernelsData(const Params& params, const optional_params& optParams) const { return GetCommonKernelsData(params, optParams); } + +size_t ConcatenationKernel_b_fs_yx_fsv16::GetAlignment(const concatenation_params& /*params*/) const { + return 16; +} } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.h index 9bf2af8652d..cf8e3f92c7c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.h @@ -28,5 +28,6 @@ public: DispatchData SetDefault(const concatenation_params& params) const override; JitConstants GetJitConstants(const concatenation_params& params) const override; bool Validate(const Params& p, const optional_params& o) const override; + size_t GetAlignment(const concatenation_params& params) const override; }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.cpp index d76d5e1b4cf..0eb3fb20741 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.cpp @@ -115,7 +115,8 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params, newParams.inputs.resize(1); newParams.inputs[0] = input; size_t ifm = input.Feature().v; - newParams.isAligned = ifm_offset % 16 == 0 && ifm % 16 == 0; + newParams.isAligned = ifm_offset % GetAlignment(newParams) == 0; + newParams.misalignment = ifm_offset % GetAlignment(newParams); ifm_offset += ifm; auto& kernel = kd.kernels[i]; @@ -127,7 +128,7 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params, kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2}; kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2}; kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo); - kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i}); + kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i }); kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); ScalarDescriptor s; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.h index 1239506bd85..645bc998e71 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.h @@ -26,6 +26,7 @@ struct concatenation_params : public base_params { ConcatAxis axis = ConcatAxis::FEATURE; bool isAligned = true; + size_t misalignment = 0; virtual ParamsKey GetParamsKey() const { auto k = base_params::GetParamsKey(); @@ -71,5 +72,8 @@ protected: KernelsData GetCommonKernelsData(const Params& params, const optional_params&) const; int32_t GetConcatChannelIndex(const concatenation_params& params) const; Tensor::DataChannelName GetConcatChannel(const concatenation_params& params) const; + virtual size_t GetAlignment(const concatenation_params& /*params*/) const { + return 1; + } }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/concatenation_gpu_blocked.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/concatenation_gpu_blocked.cl index 0182baab183..0fdbe4dd4de 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/concatenation_gpu_blocked.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/concatenation_gpu_blocked.cl @@ -14,43 +14,145 @@ #include "include/fetch.cl" -#include "include/unit_type.cl" +#include "include/data_types.cl" #define WORK_GROUP_SIZE 16 #define IC_BLOCK 16 +#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TILE_XY) +#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, TILE_XY) +#define TO_OUTPUT_VEC_TYPE(x) CAT(convert_, OUTPUT_VEC_TYPE)(x) +#define INPUT_BLOCK_READ(ptr, offset) MAKE_VECTOR_TYPE(DT_INPUT_BLOCK_READ, TILE_XY)(ptr, offset) +#define OUTPUT_BLOCK_WRITE(ptr, offset, val) MAKE_VECTOR_TYPE(DT_OUTPUT_BLOCK_WRITE, TILE_XY)(ptr, offset, val) + +#if !ALIGNED +// For non-aligned case process two features together to mitigate misalignment +# define TILE_F 2 +#else +# define TILE_F 1 +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + __attribute__((reqd_work_group_size(1, WORK_GROUP_SIZE, 1))) __attribute__((intel_reqd_sub_group_size(WORK_GROUP_SIZE))) -KERNEL (concatenation_gpu_blocked)(__global UNIT_TYPE* input, __global UNIT_TYPE* output, uint output_offset_in_concat_axis) +KERNEL (concatenation_gpu_blocked)( + __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output, + uint output_offset_in_concat_axis) { - const int b = get_global_id(0); - const int f_block = get_group_id(1); - const int xy = get_global_id(2); + const int xy = (uint)get_global_id(0) * TILE_XY; + const int f_block = (uint)get_group_id(1) * TILE_F; + const int b = get_group_id(2); const int lid = get_sub_group_local_id(); const int x = xy % OUTPUT_SIZE_X; const int y = xy / OUTPUT_SIZE_X; + const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x); #if ALIGNED - const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x); + INPUT_VEC_TYPE src = INPUT_BLOCK_READ(input, input_offset); const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + output_offset_in_concat_axis), y, x); - UNIT_TYPE src = UNIT_BLOCK_READ(input, input_offset); - src = ACTIVATION(src, ACTIVATION_PARAMS); - UNIT_BLOCK_WRITE(output, dst_index, src); + bool do_block_write = (INPUT0_FEATURE_NUM % IC_BLOCK == 0) + || (f_block * IC_BLOCK + IC_BLOCK <= INPUT0_FEATURE_NUM); + + if (do_block_write) { + OUTPUT_VEC_TYPE res = TO_OUTPUT_VEC_TYPE(ACTIVATION(src, ACTIVATION_PARAMS)); + OUTPUT_BLOCK_WRITE(output, dst_index, res); + } else { + if (lid < INPUT0_FEATURE_NUM % IC_BLOCK) { + __attribute__((opencl_unroll_hint)) + for (uint tx = 0; tx < TILE_XY; ++tx) { + OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src)[tx], ACTIVATION_PARAMS)); + output[dst_index + tx * IC_BLOCK + lid] = res; + } + } + } #else - if (f_block*IC_BLOCK + lid >= INPUT0_FEATURE_NUM) - return; - const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK + lid, y, x); - const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x); +#if TILE_F != 1 + bool full_write = (INPUT0_FEATURE_NUM % (IC_BLOCK * TILE_F) == 0) || (f_block * IC_BLOCK + TILE_F * IC_BLOCK <= INPUT0_FEATURE_NUM); + if (full_write) { + INPUT_VEC_TYPE src0 = INPUT_BLOCK_READ(input, input_offset + 0 * INPUT0_FEATURE_PITCH * IC_BLOCK); + INPUT_VEC_TYPE src1 = INPUT_BLOCK_READ(input, input_offset + 1 * INPUT0_FEATURE_PITCH * IC_BLOCK); + #if TILE_F == 4 + INPUT_VEC_TYPE src2 = INPUT_BLOCK_READ(input, input_offset + 2 * INPUT0_FEATURE_PITCH * IC_BLOCK); + INPUT_VEC_TYPE src3 = INPUT_BLOCK_READ(input, input_offset + 3 * INPUT0_FEATURE_PITCH * IC_BLOCK); + #endif - UNIT_TYPE src = input[input_offset]; - src = ACTIVATION(src, ACTIVATION_PARAMS); - output[dst_index] = src; + uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + (IC_BLOCK - MISALIGNMENT) + output_offset_in_concat_axis), y, x); + + INPUT_VEC_TYPE src_al0 = 0; + #if TILE_F == 4 + INPUT_VEC_TYPE src_al1 = 0; + INPUT_VEC_TYPE src_al2 = 0; + #endif + __attribute__((opencl_unroll_hint)) + for (uint tx = 0; tx < TILE_XY; ++tx) { + ((INPUT0_TYPE*)&src_al0)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src0)[tx], ((INPUT0_TYPE*)&src1)[tx], (IC_BLOCK - MISALIGNMENT)); + #if TILE_F == 4 + ((INPUT0_TYPE*)&src_al1)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src1)[tx], ((INPUT0_TYPE*)&src2)[tx], (IC_BLOCK - MISALIGNMENT)); + ((INPUT0_TYPE*)&src_al2)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src2)[tx], ((INPUT0_TYPE*)&src3)[tx], (IC_BLOCK - MISALIGNMENT)); + #endif + } + OUTPUT_VEC_TYPE res_al0 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al0, ACTIVATION_PARAMS)); + OUTPUT_BLOCK_WRITE(output, dst_index, res_al0); + #if TILE_F == 4 + OUTPUT_VEC_TYPE res_al1 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al1, ACTIVATION_PARAMS)); + OUTPUT_BLOCK_WRITE(output, dst_index + 1 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al1); + OUTPUT_VEC_TYPE res_al2 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al2, ACTIVATION_PARAMS)); + OUTPUT_BLOCK_WRITE(output, dst_index + 2 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al2); + #endif + uint lid_f_offset = lid; + INPUT_VEC_TYPE src_unal = 0; + + lid_f_offset += lid < (IC_BLOCK - MISALIGNMENT) ? 0 : IC_BLOCK * (TILE_F - 1); + #if TILE_F == 2 + src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src1; + #elif TILE_F == 4 + src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src3; + #endif + + dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid_f_offset + output_offset_in_concat_axis), y, x); + __attribute__((opencl_unroll_hint)) + for (uint tx = 0; tx < TILE_XY; ++tx) { + OUTPUT_TYPE res_unal = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src_unal)[tx], ACTIVATION_PARAMS)); + output[dst_index + tx * IC_BLOCK] = res_unal; + } + } else +#endif // TILE_F != 1 + { + const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x); + + __attribute__((opencl_unroll_hint)) + for (uint fw = 0; fw < TILE_F; ++fw) { + if (TILE_F != 1 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F != 0 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F == fw) + break; + + bool do_leftover_write = INPUT0_FEATURE_NUM % IC_BLOCK == 0 || f_block * IC_BLOCK + fw * IC_BLOCK + lid < INPUT0_FEATURE_NUM; + if (do_leftover_write) { + __attribute__((opencl_unroll_hint)) + for (uint tx = 0; tx < TILE_XY; ++tx) { + INPUT0_TYPE src = input[input_offset + lid + tx * IC_BLOCK + fw * INPUT0_FEATURE_PITCH * IC_BLOCK]; + OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(src, ACTIVATION_PARAMS)); + output[dst_index + tx * IC_BLOCK + fw * OUTPUT_FEATURE_PITCH * IC_BLOCK] = res; + } + } + } + } #endif } #undef WORK_GROUP_SIZE #undef IC_BLOCK + +#undef INPUT_VEC_TYPE +#undef OUTPUT_VEC_TYPE +#undef TO_OUTPUT_VEC_TYPE +#undef INPUT_BLOCK_READ +#undef OUTPUT_BLOCK_WRITE + +#undef TILE_F +#undef CEIL_DIV diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp new file mode 100644 index 00000000000..85594322f72 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp @@ -0,0 +1,224 @@ +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "pass_manager.h" +#include "pooling_inst.h" +#include "convolution_inst.h" +#include "fully_connected_inst.h" +#include "data_inst.h" +#include "memory_impl.h" +#include "program_impl.h" + +#include +#include + +using namespace cldnn; + +namespace { + +using shuffle_range = std::pair; + +bool can_shuffle_features(program_node& node) { + if (node.is_type()) { + auto& conv_node = node.as(); + auto& wei_node = conv_node.weights(); + + return conv_node.get_groups() == 1 && conv_node.get_split() == 1 && + conv_node.get_deformable_groups() == 1 && !conv_node.get_transposed() && + !conv_node.activations_zero_points_term() && + wei_node.is_type() && wei_node.is_constant() && !wei_node.is_output(); + } + if (node.is_type()) { + auto& fc_node = node.as(); + auto& wei_node = fc_node.weights(); + + return wei_node.is_type() && wei_node.is_constant() && !wei_node.is_output(); + } + + bool pass_through = false; + pass_through |= node.is_type(); + pass_through |= node.is_type(); + // General conditions for pass-through layers + pass_through &= !node.is_output() && node.get_dependencies().size() == 1 && !node.has_fused_primitives(); + if (pass_through) { + // Primitives that are feature order invariant, pass-through shuffled features to users + for (auto& user : node.get_users()) { + if (!can_shuffle_features(*user)) + return false; + } + return true; + } + + return false; +} + +void shuffle_weights(data_node& node, const std::vector& ranges) { + // Correct for shuffled features by shuffling input feature dimension in weights. + // This allows to restore correct feature order on output and only changes calculation order. + auto wei_layout = node.get_output_layout(); + auto& old_weights_memory = node.get_attached_memory(); + bool need_reset = static_cast(wei_layout.data_padding) || wei_layout.format.is_blocked(); + auto new_weights_memory = old_weights_memory.get_engine()->allocate_memory(wei_layout, old_weights_memory.get_net_id(), need_reset); + + auto bytes_per_elem = data_type_traits::size_of(wei_layout.data_type); + auto old_ptr = static_cast(old_weights_memory.lock()); + auto new_ptr = static_cast(new_weights_memory->lock()); + for (int32_t ofi = 0; ofi < wei_layout.size.batch[0]; ++ofi) { + int32_t new_ifi = 0; + for (auto& range : ranges) { + for (int32_t ifi = range.first; ifi < range.second; ++ifi, ++new_ifi) { + for (int32_t wi = 0; wi < wei_layout.size.spatial[3]; ++wi) { + for (int32_t zi = 0; zi < wei_layout.size.spatial[2]; ++zi) { + for (int32_t yi = 0; yi < wei_layout.size.spatial[1]; ++yi) { + for (int32_t xi = 0; xi < wei_layout.size.spatial[0]; ++xi) { + auto old_coords = tensor(batch(ofi), feature(ifi), spatial(xi, yi, zi, wi)); + auto new_coords = tensor(batch(ofi), feature(new_ifi), spatial(xi, yi, zi, wi)); + auto old_offset = wei_layout.get_linear_offset(old_coords); + auto new_offset = wei_layout.get_linear_offset(new_coords); + for (size_t byte = 0; byte < bytes_per_elem; ++byte) { + new_ptr[new_offset * bytes_per_elem + byte] = old_ptr[old_offset * bytes_per_elem + byte]; + } + } + } + } + } + } + } + } + old_weights_memory.unlock(); + new_weights_memory->unlock(); + + node.attach_memory(*new_weights_memory, false); +} + +void shuffle_features(program_node& node, const std::vector& ranges) { + if (node.is_type()) { + auto& conv = node.as(); + shuffle_weights(conv.weights().as(), ranges); + } else if (node.is_type()) { + auto& fc = node.as(); + shuffle_weights(fc.weights().as(), ranges); + } else { + // General case for pass-through layers + for (auto& user : node.get_users()) { + shuffle_features(*user, ranges); + } + } +} + +} // namespace + +void concat_input_order::run(program_impl& p) { + for (auto node : p.get_processing_order()) { + // Check that optimization can be performed: + // 1. Not an output + // 2. Concatenation along features + // 3. Currently only fsv16 format on input/output + // 4. Not already aligned + // 5. Users can accept shuffled features + // 6. No fused primitives + if (!node->is_type() || node->is_output()) + continue; + + auto& concat_node = node->as(); + auto prim = concat_node.get_primitive(); + + bool along_f = prim->axis == concatenation::along_f; + size_t inputs_count = prim->input_size(); + bool no_fusing = !concat_node.has_fused_primitives() && concat_node.get_dependencies().size() == inputs_count; + + auto out_format = concat_node.get_output_layout().format; + bool correct_format = out_format == format::b_fs_yx_fsv16; + tensor::value_type alignment = 1; + if (out_format == format::b_fs_yx_fsv16) + alignment = 16; + + bool single_format = true; + std::vector feature_sizes; + feature_sizes.reserve(inputs_count); + for (size_t input_idx = 0; input_idx < inputs_count; ++input_idx) { + auto& dep = concat_node.get_dependency(input_idx); + auto dep_layout = dep.get_output_layout(); + single_format &= dep_layout.format == out_format; + feature_sizes.push_back(dep_layout.size.feature[0]); + } + // Alignment is not optimal if aligned input follows unaligned one + bool already_aligned = true; + for (size_t i = 1; i < feature_sizes.size(); ++i) { + bool current_aligned = feature_sizes[i] % alignment == 0; + bool previous_aligned = feature_sizes[i - 1] % alignment == 0; + already_aligned &= previous_aligned || !current_aligned; + } + // Check that we can fuse shuffling to users + bool can_shuffle_users = true; + for (auto user : concat_node.get_users()) { + can_shuffle_users &= can_shuffle_features(*user); + } + + if (!along_f || !no_fusing || !correct_format || !single_format || already_aligned || !can_shuffle_users) + continue; + + // Perform the optimization + // Calculate new input order - first inputs preserving alignment, then rest + std::vector new_order; + new_order.reserve(inputs_count); + for (size_t i = 0; i < feature_sizes.size(); ++i) { + if (feature_sizes[i] % alignment == 0) + new_order.push_back(i); + } + for (size_t i = 0; i < feature_sizes.size(); ++i) { + if (feature_sizes[i] % alignment != 0) + new_order.push_back(i); + } + // Calculate new ranges + int32_t current_offset = 0; + std::vector original_ranges; + original_ranges.reserve(inputs_count); + for (auto& feature_size : feature_sizes) { + original_ranges.emplace_back(current_offset, current_offset + feature_size); + current_offset += feature_size; + } + std::vector shuffled_ranges; + shuffled_ranges.reserve(inputs_count); + for (auto& ord : new_order) { + shuffled_ranges.push_back(original_ranges[ord]); + } + // Change input order + std::vector new_dependencies = {}; + new_dependencies.reserve(inputs_count); + for (auto& ord : new_order) { + new_dependencies.push_back(&concat_node.get_dependency(ord)); + } + // Update in place with const cast instead of replacing + auto& dependencies = concat_node.get_dependencies(); + auto& mutable_dependencies = const_cast&>(dependencies); + for (size_t i = 0; i < new_dependencies.size(); ++i) { + mutable_dependencies[i] = new_dependencies[i]; + } + std::vector new_input_ids; + new_input_ids.reserve(inputs_count); + for (auto& ord : new_order) { + new_input_ids.push_back(prim->input[ord]); + } + auto mutable_prim = std::const_pointer_cast(prim); + mutable_prim->input = new_input_ids; + // Correct users for shuffled features + for (auto& user : concat_node.get_users()) { + shuffle_features(*user, shuffled_ranges); + } + } +} + diff --git a/inference-engine/thirdparty/clDNN/src/include/pass_manager.h b/inference-engine/thirdparty/clDNN/src/include/pass_manager.h index 8bd7cfedcd2..bc620bf43da 100644 --- a/inference-engine/thirdparty/clDNN/src/include/pass_manager.h +++ b/inference-engine/thirdparty/clDNN/src/include/pass_manager.h @@ -332,6 +332,28 @@ public: void run(program_impl& p) override; }; +class concat_input_order : public base_pass { + // This optimization changes order of inputs for concatenation to provide + // better alignment for execution and allow for optimizing out in some cases. + // For example concatenation along features with inputs [13, 1024] in format fsv16 + // has only first input aligned to feature blocks, blocking performant implementation + // for second one. + // This can be fixed by chaning order to [1024, 13] and fusing reshuffling of those features + // into following layers, such as convolution or fully connected, where it can be + // implemented as compile-time weights shuffling. + // + // Requirements - may work incorrectly if not fullfiled: + // - formats are selected + // - implementations aren't selected + // + // Soft requirements - reduce applicability if not fullfiled: + // - constant primitives are reduced to data nodes + // - no fused primitives +public: + concat_input_order() : base_pass("concat_input_order") {} + void run(program_impl& p) override; +}; + class memory_dependency_pass : public base_pass { public: explicit memory_dependency_pass(const std::string& pass_name) : base_pass(pass_name) {} diff --git a/inference-engine/thirdparty/clDNN/src/program.cpp b/inference-engine/thirdparty/clDNN/src/program.cpp index 2a1733c7781..398eeaa96d4 100644 --- a/inference-engine/thirdparty/clDNN/src/program.cpp +++ b/inference-engine/thirdparty/clDNN/src/program.cpp @@ -420,6 +420,10 @@ void program_impl::pre_optimize_graph(bool is_internal) { apply_opt_pass(lo); apply_opt_pass(lo, rf); + // Ideally this should be done before fusing to simplify logic and make the pass more powerful, + // but after format selection to select correct alignment. + // Unfortunately those passes currently happen in reverse order. + apply_opt_pass(); // TODO this code should be moved to post compilation after kernel selector will support handling reorder bias apply_opt_pass(rf); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp index 535cc85e481..48f4dcad929 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp @@ -638,6 +638,10 @@ TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv32) { ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv32)); } +TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv16) { + ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16)); +} + INSTANTIATE_TEST_CASE_P(smoke_low_precision, concat_gpu_4d_i8, concat_gpu_all_params, @@ -651,3 +655,140 @@ INSTANTIATE_TEST_CASE_P(smoke_low_precision, concat_gpu_4d_u8, concat_gpu_all_params, concat_gpu::PrintToStringParamName); + +template +struct concat_id_conv_gpu_4d : public concat_gpu { +public: + + void test(format::type fmt) { + auto data_type = type_to_data_type::value; + + const auto& engine = get_test_engine(); + const size_t batch_num = testing::get<0>(GetParam()); + const std::vector in_features = testing::get<1>(GetParam()); + const size_t input_y = testing::get<2>(GetParam()); + const size_t input_x = testing::get<3>(GetParam()); + size_t output_f = 0; + for (auto& f : in_features) + output_f += f; + + topology topology; + + std::vector> in_data; + std::vector in_memory; + std::vector input_ids; + for (size_t i = 0; i < in_features.size(); i++) { + auto size = tensor(static_cast(batch_num), + static_cast(in_features[i]), + static_cast(input_x), + static_cast(input_y)); + auto data = generate_random_4d(batch_num, in_features[i], input_y, input_x, -128, 128); + auto in_lay = layout(data_type, fmt, size); + auto data_flat = std::vector(in_lay.get_linear_size(), 0); + + for (size_t bi = 0; bi < batch_num; ++bi) { + for (size_t fi = 0; fi < in_features[i]; ++fi) { + for (size_t yi = 0; yi < input_y; ++yi) { + for (size_t xi = 0; xi < input_x; ++xi) { + auto coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0)); + auto in_offset = in_lay.get_linear_offset(coords); + + data_flat[in_offset] = data[bi][fi][yi][xi]; + } + } + } + } + + auto in_mem = memory::allocate(engine, in_lay); + set_values(in_mem, data_flat); + in_memory.push_back(in_mem); + + topology.add(input_layout("input" + std::to_string(i), in_lay)); + in_data.emplace_back(std::move(data)); + input_ids.push_back("input" + std::to_string(i)); + } + + topology.add(concatenation("concat", input_ids, concatenation::concatenation_axis::along_f)); + // Add identity convolution + auto weights_lay = cldnn::layout(data_type, cldnn::format::bfyx, tensor(batch(output_f), feature(output_f))); + auto weights_mem = cldnn::memory::allocate(engine, weights_lay); + { + auto weights_ptr = weights_mem.pointer(); + for (size_t fi = 0; fi < output_f; ++fi) { + auto coords = tensor(batch(fi), feature(fi), spatial(0, 0, 0, 0)); + auto offset = weights_lay.get_linear_offset(coords); + weights_ptr[offset] = static_cast(1.f); + } + } + topology.add(data("weights", weights_mem)); + topology.add(convolution("conv", "concat", { "weights" })); + + build_options options; + options.set_option(build_option::optimize_data(true)); + auto conv_forcing = implementation_desc{ fmt, std::string() }; + options.set_option(build_option::force_implementations({ {primitive_id("conv"), conv_forcing} })); + network network(engine, topology, options); + + for (size_t i = 0; i < in_features.size(); i++) { + network.set_input_data(input_ids[i], in_memory[i]); + } + + network.execute(); + + auto out_mem = network.get_output("conv").get_memory(); + auto out_ptr = out_mem.pointer(); + + for (size_t bi = 0; bi < batch_num; bi++) { + size_t f_sum = 0; + for (size_t in_i = 0; in_i < in_features.size(); in_i++) { + for (size_t fi = 0; fi < in_features[in_i]; fi++) { + for (size_t yi = 0; yi < input_y; yi++) { + for (size_t xi = 0; xi < input_x; xi++) { + auto output_coords = tensor(batch(bi), feature(f_sum + fi), spatial(xi, yi, 0, 0)); + auto output_offset = out_mem.get_layout().get_linear_offset(output_coords); + + auto ref_val = in_data[in_i][bi][fi][yi][xi]; + auto actual_val = static_cast(out_ptr[output_offset]); + EXPECT_EQ(ref_val, actual_val) + << " b=" << bi << ", f=" << f_sum + fi << "(input " << in_i << "), y=" << yi << ", x=" << xi; + } + } + } + f_sum += in_features[in_i]; + } + } + } +}; + +using concat_id_conv_gpu_4d_f16 = concat_id_conv_gpu_4d; +using concat_id_conv_gpu_4d_i8 = concat_id_conv_gpu_4d; + +TEST_P(concat_id_conv_gpu_4d_f16, input_order_opt_b_fs_yx_fsv16) { + ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16)); +} + +INSTANTIATE_TEST_CASE_P(smoke_low_precision, + concat_id_conv_gpu_4d_f16, + ::testing::Values( + TestParamType_concat(2, { 2, 32 }, 2, 1), + TestParamType_concat(2, { 31, 64 }, 2, 2), + TestParamType_concat(2, { 15, 15, 16 }, 2, 1), + TestParamType_concat(2, { 16, 15, 16 }, 2, 2), + TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2) + ), + concat_gpu::PrintToStringParamName); + +TEST_P(concat_id_conv_gpu_4d_i8, input_order_opt_b_fs_yx_fsv16) { + ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16)); +} + +INSTANTIATE_TEST_CASE_P(smoke_low_precision, + concat_id_conv_gpu_4d_i8, + ::testing::Values( + TestParamType_concat(2, { 2, 32 }, 2, 1), + TestParamType_concat(2, { 31, 64 }, 2, 2), + TestParamType_concat(2, { 15, 15, 16 }, 2, 1), + TestParamType_concat(2, { 16, 15, 16 }, 2, 2), + TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2) + ), + concat_gpu::PrintToStringParamName);