diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_selector.cpp index de780caae1a..5c3b23ec799 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_selector.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_selector.cpp @@ -16,12 +16,14 @@ #include "permute_kernel_selector.h" #include "permute_kernel_ref.h" #include "permute_kernel_tile_8x8_4x4.h" +#include "permute_kernel_tile_8x8_4x4_fsv.h" namespace kernel_selector { permute_kernel_selector::permute_kernel_selector() { Attach(); Attach(); + Attach(); } KernelsData permute_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const { diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4.cpp index 3e87c66aa03..b6dfdefdbce 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4.cpp @@ -15,11 +15,11 @@ #include "permute_kernel_tile_8x8_4x4.h" #include "kernel_selector_utils.h" +#include "common_tools.h" #include #include #include -#define CEIL_DIV(A, B) ((A + B - 1)/(B)) // Tile size : 4x4 or 8x8 #define MIN_TILE_SIZE 4 #define DEFAULT_TILE_SIZE 8 @@ -128,7 +128,7 @@ JitConstants PermuteKernel_tile_8x8_4x4::GetJitConstants(const permute_params& p jit.AddConstant(MakeJitConstant("TILE_SIZE", tile_size)); jit.AddConstant(MakeJitConstant("N_VECTORS_IN_TILE", tile_size / vector_width)); jit.AddConstant(MakeJitConstant("LWS", total_lws)); - jit.AddConstant(MakeJitConstant("NFEATURE_TILES", CEIL_DIV(params.inputs[0].Feature().v, tile_size))); + jit.AddConstant(MakeJitConstant("NFEATURE_TILES", CeilDiv(params.inputs[0].Feature().v, tile_size))); std::string normal_tile_cond = "true"; std::string x_remainder_cond = "true"; @@ -137,7 +137,7 @@ JitConstants PermuteKernel_tile_8x8_4x4::GetJitConstants(const permute_params& p if (params.inputs[0].X().v % tile_size) { jit.AddConstant(MakeJitConstant("X_REMAINDER_ITEM", params.inputs[0].X().v / tile_size)); jit.AddConstant(MakeJitConstant("X_REMAINDER_SIZE", params.inputs[0].X().v % tile_size)); - jit.AddConstant(MakeJitConstant("X_REMAINDER_SIZE_AS_VECTOR", CEIL_DIV(params.inputs[0].X().v % tile_size, vector_width))); + jit.AddConstant(MakeJitConstant("X_REMAINDER_SIZE_AS_VECTOR", CeilDiv(params.inputs[0].X().v % tile_size, vector_width))); normal_tile_cond += " && (x < X_REMAINDER_ITEM)"; x_remainder_cond += " && (x == X_REMAINDER_ITEM)"; f_remainder_cond += " && (x < X_REMAINDER_ITEM)"; @@ -145,7 +145,7 @@ JitConstants PermuteKernel_tile_8x8_4x4::GetJitConstants(const permute_params& p if (params.inputs[0].Feature().v % tile_size) { jit.AddConstant(MakeJitConstant("F_REMAINDER_ITEM", params.inputs[0].Feature().v / tile_size)); jit.AddConstant(MakeJitConstant("F_REMAINDER_SIZE", params.inputs[0].Feature().v % tile_size)); - jit.AddConstant(MakeJitConstant("F_REMAINDER_SIZE_AS_VECTOR", CEIL_DIV(params.inputs[0].Feature().v % tile_size, vector_width))); + jit.AddConstant(MakeJitConstant("F_REMAINDER_SIZE_AS_VECTOR", CeilDiv(params.inputs[0].Feature().v % tile_size, vector_width))); normal_tile_cond += " && (f < F_REMAINDER_ITEM)"; x_remainder_cond += " && (f < F_REMAINDER_ITEM)"; f_remainder_cond += " && (f == F_REMAINDER_ITEM)"; @@ -176,7 +176,7 @@ static std::vector GetBestLwsFromGws(const permute_params& params, const std::vector dims{0, 2, 1}; // SLM size: elemsize * tile_size * tile_size * work_items <= 64K - size_t elem_size = sizeof(params.output.GetDType()); + size_t elem_size = params.output.ElementSize(); size_t max_local_mem_size = params.engineInfo.maxLocalMemSize; size_t max_num_work_items = std::min((size_t)256, (size_t)max_local_mem_size / (elem_size * tile_size * tile_size)); @@ -205,13 +205,13 @@ CommonDispatchData PermuteKernel_tile_8x8_4x4::SetDefault(const permute_params& size_t tile_size = GetTileSize(params); switch (in.GetLayout()) { case DataLayout::bfyx: - dispatchData.gws = {CEIL_DIV(in.X().v , tile_size), in.Y().v, CEIL_DIV(in.Feature().v, tile_size) * in.Batch().v}; + dispatchData.gws = {CeilDiv(in.X().v , tile_size), in.Y().v, CeilDiv(in.Feature().v, tile_size) * in.Batch().v}; break; case DataLayout::bfzyx: - dispatchData.gws = {CEIL_DIV(in.X().v , tile_size), in.Y().v * in.Z().v, CEIL_DIV(in.Feature().v, tile_size) * in.Batch().v}; + dispatchData.gws = {CeilDiv(in.X().v , tile_size), in.Y().v * in.Z().v, CeilDiv(in.Feature().v, tile_size) * in.Batch().v}; break; case DataLayout::bfwzyx: - dispatchData.gws = {CEIL_DIV(in.X().v , tile_size), in.Y().v * in.Z().v * in.W().v, CEIL_DIV(in.Feature().v, tile_size) * in.Batch().v}; + dispatchData.gws = {CeilDiv(in.X().v , tile_size), in.Y().v * in.Z().v * in.W().v, CeilDiv(in.Feature().v, tile_size) * in.Batch().v}; break; default: throw std::runtime_error("Unsupported combination\n"); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.cpp new file mode 100644 index 00000000000..b554ebd728c --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.cpp @@ -0,0 +1,316 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "permute_kernel_tile_8x8_4x4_fsv.h" +#include "kernel_selector_utils.h" +#include "common_tools.h" +#include +#include +#include + +// Tile size : 4x4 or 8x8 +#define MIN_TILE_SIZE 4 +#define DEFAULT_TILE_SIZE 8 + +namespace kernel_selector { + +ParamsKey PermuteKernel_tile_8x8_4x4_fsv::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); + k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::INT64); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); + k.EnableOutputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::INT64); + k.EnableDifferentTypes(); + k.EnableInputLayout(DataLayout::b_fs_yx_fsv16); + k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16); + k.EnableInputLayout(DataLayout::b_fs_zyx_fsv16); + k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv16); + k.EnableInputLayout(DataLayout::b_fs_yx_fsv32); + k.EnableOutputLayout(DataLayout::b_fs_yx_fsv32); + k.EnableInputLayout(DataLayout::b_fs_zyx_fsv32); + k.EnableOutputLayout(DataLayout::b_fs_zyx_fsv32); + k.EnableInputLayout(DataLayout::b_fs_yx_fsv4); + k.EnableOutputLayout(DataLayout::b_fs_yx_fsv4); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + return k; +} + +static inline std::vector GetFusedOpOrderVector(size_t size) { + std::vector ret; + switch (size) { + case 4: + ret = {"b", "y + lh", "x", "f + lw"}; + break; + case 5: + ret = {"b", "z + lh", "y", "x", "f + lw"}; + break; + default : throw std::runtime_error("Unsupported combination\n"); + } + return ret; +} + +static inline std::string GetTiledOutputOrder(size_t size) { + std::string order_str = ""; + switch (size) { + case 4 : + order_str = "b, y, x, f + lw"; + break; + case 5 : + order_str = "b, z, y, x, f + lw"; + break; + default : throw std::runtime_error("Unsupported combination\n"); + } + return order_str; +} + +static inline std::string GetTiledInputOrder(size_t size) { + std::string order_str = ""; + switch (size) { + case 4 : + order_str = "b, f, y + lh, x"; + break; + case 5 : + order_str = "b, f, z + lh, y, x"; + break; + default : throw std::runtime_error("Unsupported combination\n"); + } + return order_str; +} + +static inline size_t GetFsvAlignment(const permute_params& params) { + const auto& in = params.inputs[0]; + int fsv_alignment = -1; + switch (in.GetLayout()) { + case DataLayout::b_fs_yx_fsv16: + case DataLayout::b_fs_zyx_fsv16: + fsv_alignment = 16; + break; + case DataLayout::b_fs_yx_fsv32: + case DataLayout::b_fs_zyx_fsv32: + fsv_alignment = 32; + break; + case DataLayout::b_fs_yx_fsv4: + fsv_alignment = 4; + break; + default: + throw std::runtime_error("Unsupported combination\n"); + } + return fsv_alignment; +} + +static inline size_t GetTileSize(const permute_params& params) { + const Datatype input_type = params.inputs[0].GetDType(); + const Datatype output_type = params.output.GetDType(); + + // i64 only supports tile size 4 + if ((input_type == Datatype::INT64) || (output_type == Datatype::INT64)) { + return MIN_TILE_SIZE; + } + + // supports 4x4 or 8x8 tiling + size_t rotating_dim = 0; + if (params.inputs[0].GetDims().size() == 4) { + rotating_dim = params.inputs[0].Y().v; + } else if (params.inputs[0].GetDims().size() == 5) { + rotating_dim = params.inputs[0].Z().v; + } + + if (rotating_dim < DEFAULT_TILE_SIZE || params.inputs[0].Feature().v < DEFAULT_TILE_SIZE) { + return MIN_TILE_SIZE; + } + + return DEFAULT_TILE_SIZE; +} + +JitConstants PermuteKernel_tile_8x8_4x4_fsv::GetJitConstants(const permute_params& params, const CommonDispatchData& dispatchData) const { + auto jit = Parent::GetJitConstants(params, dispatchData); + const size_t f = params.inputs[0].Feature().v; + const size_t z = params.inputs[0].Z().v; + const size_t y = params.inputs[0].Y().v; + const size_t tile_size = GetTileSize(params); + const uint64_t total_lws = dispatchData.lws[0] * dispatchData.lws[1] * dispatchData.lws[2]; + const size_t input_ndims = params.inputs[0].GetDims().size(); + const size_t output_ndims = params.output.GetDims().size(); + const size_t fsv_alignment = GetFsvAlignment(params); + + jit.AddConstant(MakeJitConstant("INPUT0_TILED_ORDER", GetTiledInputOrder(input_ndims))); + jit.AddConstant(MakeJitConstant("OUTPUT_TILED_ORDER", GetTiledOutputOrder(output_ndims))); + jit.AddConstant(MakeJitConstant("INPUT0_FEATURE_SLICE_NUM", CeilDiv(f, fsv_alignment))); + jit.AddConstant(MakeJitConstant("TILE_SIZE", tile_size)); + jit.AddConstant(MakeJitConstant("FSV_ALIGNMENT", fsv_alignment)); + jit.AddConstant(MakeJitConstant("TRANS_BUF_SIZE", tile_size * total_lws)); + + // whether F is tile_size-aligned + if (f % tile_size != 0) { + jit.AddConstant(MakeJitConstant("F_REMAINDER_SIZE", f % tile_size)); + jit.AddConstant(MakeJitConstant("F_REMAINDER_CONDITION", "((INPUT0_FEATURE_NUM - F_REMAINDER_SIZE) <= f) && (f < INPUT0_FEATURE_NUM)")); + jit.AddConstant(MakeJitConstant("F_NO_REMAINDER_CONDITION", "(f < (INPUT0_FEATURE_NUM - F_REMAINDER_SIZE))")); + } else { + jit.AddConstant(MakeJitConstant("F_NO_REMAINDER_CONDITION", "(f < INPUT0_FEATURE_NUM)")); + } + + // whether y (or z if b_fs_zyx_fsv16) is tile_size-aligned + if ((input_ndims == 4) && (y % tile_size != 0)) { + jit.AddConstant(MakeJitConstant("YZ_REMAINDER_SIZE", y % tile_size)); + jit.AddConstant(MakeJitConstant("YZ_NO_REMAINDER_CONDITION", "y < (INPUT0_SIZE_Y - YZ_REMAINDER_SIZE)")); + jit.AddConstant(MakeJitConstant("YZ_REMAINDER_CONDITION", "((INPUT0_SIZE_Y - YZ_REMAINDER_SIZE) <= y) && (y < INPUT0_SIZE_Y)")); + } else if ((input_ndims == 5) && (z % tile_size != 0)) { + jit.AddConstant(MakeJitConstant("YZ_REMAINDER_SIZE", z % tile_size)); + jit.AddConstant(MakeJitConstant("YZ_NO_REMAINDER_CONDITION", "z < (INPUT0_SIZE_Z - YZ_REMAINDER_SIZE)")); + jit.AddConstant(MakeJitConstant("YZ_REMAINDER_CONDITION", "((INPUT0_SIZE_Z - YZ_REMAINDER_SIZE) <= z) && (z < INPUT0_SIZE_Z)")); + } + + if (!params.fused_ops.empty()) { + std::vector output_order = GetFusedOpOrderVector(output_ndims); + FusedOpsConfiguration conf = {"", output_order, "input_var", params.inputs[0].GetDType(), 1}; + jit.Merge(MakeFusedOpsJitConstants(params, {conf})); + } + return jit; +} + +static std::vector GetBestLwsFromGws(const permute_params& params, const std::vector& gws, const size_t tile_width, const size_t tile_size) { + std::vector lws{1, 1, 1}; + std::vector dims{0, 1, 2}; + + // SLM size: elemsize * tile_width * tile_width * work_items <= 64K + const size_t elem_size = params.output.ElementSize(); + const size_t max_local_mem_size = params.engineInfo.maxLocalMemSize; + const size_t max_work_group_size = params.engineInfo.maxWorkGroupSize; + size_t max_num_work_items = std::min(max_work_group_size, max_local_mem_size / (elem_size * tile_width * tile_size)); + + for (size_t i = 0; i < dims.size(); ++i) { + size_t dim = dims[i]; + size_t max_divider = static_cast(std::sqrt(gws[dim]) + 1); + for (size_t divider = 1; divider <= max_divider; ++divider) { + if (gws[dim] % divider == 0) { + const size_t lws0 = gws[dim] / divider; + if (lws0 <= max_num_work_items) { + lws[dim] = std::max(lws[dim], lws0); + } + if (divider <= max_num_work_items) { + lws[dim] = std::max(lws[dim], divider); + } + } + } + max_num_work_items /= lws[dim]; + } + return lws; +} + +static inline std::vector GetGWS(const permute_params& params) { + const auto& in = params.inputs[0]; + const size_t tile_size = GetTileSize(params); + const size_t fsv_alignment = GetFsvAlignment(params); + std::vector gws; + switch (in.GetLayout()) { + case DataLayout::b_fs_yx_fsv16: + case DataLayout::b_fs_yx_fsv32: + case DataLayout::b_fs_yx_fsv4: + gws = {CeilDiv(fsv_alignment, tile_size), + CeilDiv(in.Y().v, tile_size) * in.X().v, + in.Batch().v * CeilDiv(in.Feature().v, fsv_alignment)}; + break; + case DataLayout::b_fs_zyx_fsv16: + case DataLayout::b_fs_zyx_fsv32: + gws = {CeilDiv(fsv_alignment, tile_size), + CeilDiv(in.Z().v, tile_size) * in.X().v * in.Y().v, + in.Batch().v * CeilDiv(in.Feature().v, fsv_alignment)}; + break; + default: + throw std::runtime_error("Unsupported combination\n"); + } + return gws; +} + +CommonDispatchData PermuteKernel_tile_8x8_4x4_fsv::SetDefault(const permute_params& params) const { + CommonDispatchData dispatchData; + const size_t tile_size = GetTileSize(params); + dispatchData.gws = GetGWS(params); + dispatchData.lws = GetBestLwsFromGws(params, dispatchData.gws, tile_size, tile_size); + return dispatchData; +} + +// Validate is the same as permute_kernel_tile_8x8_4x4 +bool PermuteKernel_tile_8x8_4x4_fsv::Validate(const Params& p, const optional_params& o) const { + if (!Parent::Validate(p, o)) return false; + + std::function&)> is_rotating_except_batch = [](const std::vector& order) { + // Target transform: Rotate feature dim to back to be taken as inner-most axis + // ex) 0(b), 4(f), 1(z), 2(y), 3(x) + // ex) 0(b), 3(f), 1(y), 2(x) + if ((int32_t) order[1] != order.size() - 1) return false; + if ((int32_t) order[0] != 0) return false; + for (int32_t i = 2; i < (int32_t) order.size(); ++i) { + if ((int32_t)order[i] != (i - 1)) return false; + } + return true; + }; + + const permute_params& params = static_cast(p); + + if (params.inputs[0].GetDims().size() != params.output.GetDims().size()) { + return false; + } + + if (!is_rotating_except_batch(params.order)) { + return false; + } + + return true; +} + +KernelsPriority PermuteKernel_tile_8x8_4x4_fsv::GetKernelsPriority(const Params& params, const optional_params& /*options*/) const { + KernelData kd = KernelData::Default(params); + permute_params& newParams = *static_cast(kd.params.get()); + + // calculate number of working groups + const size_t tile_size = GetTileSize(newParams); + + std::vector gws = GetGWS(newParams); + std::vector lws = GetBestLwsFromGws(newParams, gws, tile_size, tile_size); + size_t num_working_groups = 1; + for (size_t i=0; i < gws.size(); ++i) { + num_working_groups *= gws.at(i) / lws.at(i); + } + + const size_t feature = newParams.inputs[0].Feature().v; + size_t rotating_dim = 0; + if (newParams.inputs[0].GetDims().size() == 4) { + rotating_dim = newParams.inputs[0].Y().v; + } else if (newParams.inputs[0].GetDims().size() == 5) { + rotating_dim = newParams.inputs[0].Z().v; + } + + if (num_working_groups == 1) { + return DONT_USE_IF_HAVE_SOMETHING_ELSE; + } else if ((rotating_dim >= DEFAULT_TILE_SIZE) && (feature >= DEFAULT_TILE_SIZE)) { + return FORCE_PRIORITY_1; + } else if ((rotating_dim >= DEFAULT_TILE_SIZE) || (feature >= DEFAULT_TILE_SIZE)) { + return FORCE_PRIORITY_2; + } else { + return FORCE_PRIORITY_3; + } +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.h new file mode 100644 index 00000000000..cccd5e3e2c7 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/permute/permute_kernel_tile_8x8_4x4_fsv.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "permute_kernel_base.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// PermuteKernel_tile_8x8_4x4 +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class PermuteKernel_tile_8x8_4x4_fsv : public PermuteKernelBase { +public: + using Parent = PermuteKernelBase; + using Parent::Parent; + PermuteKernel_tile_8x8_4x4_fsv() : PermuteKernelBase("permute_tile_8x8_4x4_fsv") {} + virtual ~PermuteKernel_tile_8x8_4x4_fsv() {} + + bool Validate(const Params& p, const optional_params& o) const override; + KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const; + ParamsKey GetSupportedKey() const override; +protected: + JitConstants GetJitConstants(const permute_params& params, const CommonDispatchData& dispatchData) const; + CommonDispatchData SetDefault(const permute_params& params) const; + std::vector GetSupportedFusedOps() const override { + return { + FusedOpType::ACTIVATION, + FusedOpType::QUANTIZE, + FusedOpType::ELTWISE, + FusedOpType::SCALE + }; + } +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/permute_tile_8x8_4x4_fsv.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/permute_tile_8x8_4x4_fsv.cl new file mode 100644 index 00000000000..db01a408864 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/permute_tile_8x8_4x4_fsv.cl @@ -0,0 +1,165 @@ +// Copyright (c) 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "include/fetch.cl" +#include "include/common.cl" +#include "include/data_types.cl" + +#define unroll_for __attribute__((opencl_unroll_hint)) for +#define CEIL_DIV(A, B) (((A) + (B) - 1) / (B)) +#define INPUT0_GET_TILED_INDEX(ORDER) INPUT0_GET_INDEX(ORDER) +#define OUTPUT_GET_TILED_INDEX(ORDER) OUTPUT_GET_INDEX(ORDER) +#define YZ_REMAINDER_LESS_THAN_TILE_SIZE ((YZ_REMAINDER_CONDITION) && (YZ_REMAINDER_SIZE < ( TILE_SIZE /2))) +#define YZ_REMAINDER_MORE_THAN_TILE_SIZE ((YZ_REMAINDER_CONDITION) && (YZ_REMAINDER_SIZE >= ( TILE_SIZE /2))) + +#define INPUTVTYPE CAT(INPUT0_TYPE, TILE_SIZE) +#define OUTPUTVTYPE CAT(OUTPUT_TYPE, TILE_SIZE) +#define VLOAD CAT(vload, TILE_SIZE) +#define VSTORE CAT(vstore, TILE_SIZE) +#define AS_INPUTVTYPE CAT(as_, INPUTVTYPE) + +#define GET_GLOBAL_ID(IDX) ((uint)get_global_id(IDX)) +#define GET_LOCAL_ID(IDX) ((uint)get_local_id(IDX)) +#define GET_LOCAL_SIZE(IDX) ((uint)get_local_size(IDX)) + +KERNEL (permute_tile_8x8_4x4_fsv)( + const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif + ) +{ +#if INPUT0_DIMS == 4 + const uint y = (GET_GLOBAL_ID(1) / INPUT0_SIZE_X) * TILE_SIZE; + const uint x = (GET_GLOBAL_ID(1)) % INPUT0_SIZE_X; +#elif INPUT0_DIMS == 5 + const uint z = (GET_GLOBAL_ID(1)/ (INPUT0_SIZE_X * INPUT0_SIZE_Y)) * TILE_SIZE; + const uint yx = GET_GLOBAL_ID(1) % (INPUT0_SIZE_X * INPUT0_SIZE_Y); + const uint y = yx / INPUT0_SIZE_X ; + const uint x = yx % INPUT0_SIZE_X; +#endif + const uint fsv = GET_GLOBAL_ID(0) * TILE_SIZE; + const uint fs = GET_GLOBAL_ID(2) % INPUT0_FEATURE_SLICE_NUM; + const uint b = GET_GLOBAL_ID(2) / INPUT0_FEATURE_SLICE_NUM; + const uint f = fsv + fs * FSV_ALIGNMENT; + + __local OUTPUTVTYPE transpose_buf[TRANS_BUF_SIZE]; + const uint local_id = GET_LOCAL_ID(0) * GET_LOCAL_SIZE(2) * GET_LOCAL_SIZE(1) + + GET_LOCAL_ID(1) * GET_LOCAL_SIZE(2) + + GET_LOCAL_ID(2); + const uint local_buf_offset = local_id * TILE_SIZE; + + if (F_NO_REMAINDER_CONDITION) { + // read and transpose + unroll_for (uint lh = 0; lh < TILE_SIZE; ++lh) { + const uint input_idx = INPUT0_GET_TILED_INDEX(INPUT0_TILED_ORDER); + INPUTVTYPE read_data = AS_INPUTVTYPE(VLOAD(0, input + input_idx)); + + unroll_for (uint lw = 0; lw < TILE_SIZE; ++lw) { + const uint dst = local_buf_offset + lw; +#if HAS_FUSED_OPS + INPUT0_TYPE input_var = read_data[lw]; + FUSED_OPS; + transpose_buf[dst][lh] = FUSED_OPS_RESULT; +#else + transpose_buf[dst][lh] = ACTIVATION(read_data[lw], ACTIVATION_PARAMS); +#endif + } + } + // write to ddr +#ifdef YZ_REMAINDER_CONDITION + if (YZ_REMAINDER_LESS_THAN_TILE_SIZE) { + // copy one by one when z % TILE_SIZE < TILE_SIZE/2 + unroll_for (uint lw = 0; lw < TILE_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + unroll_for (uint lh = 0; lh < YZ_REMAINDER_SIZE; ++lh) { + output[output_idx + lh] = transpose_buf[local_buf_offset + lw][lh]; + } + } + } else if (YZ_REMAINDER_MORE_THAN_TILE_SIZE) { + // use vstore and fill zero when z % TILE_SIZE > TILE_SIZE/2 + unroll_for (uint lw = 0; lw < TILE_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + unroll_for (uint lh = YZ_REMAINDER_SIZE; lh < TILE_SIZE; ++lh) { + output[output_idx + lh] = 0.f; + } + } + } else if (YZ_NO_REMAINDER_CONDITION) { + // use vstore when z % TILE_SIZE == 0 + unroll_for (uint lw = 0; lw < TILE_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + } + } +#else + unroll_for (uint lw = 0; lw < TILE_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + } +#endif + } +#ifdef F_REMAINDER_CONDITION + else if (F_REMAINDER_CONDITION) { + // read and transpose + unroll_for (uint lh = 0; lh < TILE_SIZE; ++lh) { + const uint input_idx = INPUT0_GET_TILED_INDEX(INPUT0_TILED_ORDER); + INPUTVTYPE read_data = AS_INPUTVTYPE(VLOAD(0, input + input_idx)); + unroll_for (uint lw = 0; lw < F_REMAINDER_SIZE; ++lw) { + uint dst = local_buf_offset + lw; + #if HAS_FUSED_OPS + INPUT0_TYPE input_var = read_data[lw]; + FUSED_OPS; + transpose_buf[dst][lh] = FUSED_OPS_RESULT; + #else + transpose_buf[dst][lh] = ACTIVATION(read_data[lw], ACTIVATION_PARAMS); + #endif + } + } + // write to ddr +#ifdef YZ_REMAINDER_CONDITION + if (YZ_REMAINDER_LESS_THAN_TILE_SIZE) { + // copy one by one when z % TILE_SIZE < TILE_SIZE/2 + unroll_for (uint lw = 0; lw < F_REMAINDER_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + unroll_for (uint lh = 0; lh < YZ_REMAINDER_SIZE; ++lh) { + output[output_idx + lh] = transpose_buf[local_buf_offset + lw][lh]; + } + } + } else if (YZ_REMAINDER_MORE_THAN_TILE_SIZE) { + // use vstore and fill zero when z % TILE_SIZE > TILE_SIZE/2 + unroll_for (uint lw = 0; lw < F_REMAINDER_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + // zero fill for unaligned + unroll_for (uint lh = YZ_REMAINDER_SIZE; lh < TILE_SIZE; ++lh) { + output[output_idx + lh] = 0.f; + } + } + } else if (YZ_NO_REMAINDER_CONDITION) { + // use vstore when z % TILE_SIZE == 0 + unroll_for (uint lw = 0; lw < F_REMAINDER_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + } + } +#else + unroll_for (uint lw = 0; lw < F_REMAINDER_SIZE; ++lw) { + const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER); + VSTORE(transpose_buf[local_buf_offset + lw], 0, output + output_idx); + } +#endif + } +#endif +} diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index bde4906ff88..a1b2323bf21 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -6062,6 +6062,15 @@ struct permute_params { #define CASE_PERMUTE_TILE_8x8_4x4_6D_2 {1, 8, 5, 2, 2, 2}, {1, 2, 8, 5, 2, 2}, {0, 5, 1, 2, 3, 4}, tensor{0}, data_types::f32, format::bfwzyx, data_types::f32, format::bfwzyx #define CASE_PERMUTE_TILE_8x8_4x4_6D_3 {1, 5, 5, 2, 2, 2}, {1, 2, 5, 5, 2, 2}, {0, 5, 1, 2, 3, 4}, tensor{0}, data_types::f32, format::bfwzyx, data_types::f32, format::bfwzyx +// permute_tile_8x8_4x4_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_0 {1, 16, 16, 2}, {1, 2, 16, 16}, {0, 3, 1, 2}, tensor{0}, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_1 {1, 15, 16, 2}, {1, 2, 15, 16}, {0, 3, 1, 2}, tensor{0}, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_2 {1, 16, 3, 2}, {1, 2, 16, 3}, {0, 3, 1, 2}, tensor{0}, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_3 {1, 5, 7, 2}, {1, 2, 5, 7}, {0, 3, 1, 2}, tensor{0}, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_0 {1, 16, 16, 2, 2}, {1, 2, 16, 16, 2}, {0, 4, 1, 2, 3}, tensor{0}, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::b_fs_zyx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_1 {1, 15, 16, 2, 2}, {1, 2, 15, 16, 2}, {0, 4, 1, 2, 3}, tensor{0}, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::b_fs_zyx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_2 {1, 16, 3, 2, 2}, {1, 2, 16, 3, 2}, {0, 4, 1, 2, 3}, tensor{0}, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::b_fs_zyx_fsv16 +#define CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_3 {1, 5, 7, 2, 2}, {1, 2, 5, 7, 2}, {0, 4, 1, 2, 3}, tensor{0}, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::b_fs_zyx_fsv16 class PermuteFusingTest : public ::BaseFusingTest { public: @@ -6168,6 +6177,16 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_1, 2, 5}, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_2, 2, 5}, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_3, 2, 5}, + + // Fusing tests for permute_tile_8x8_4x4_fsv16 + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_0, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_1, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_2, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_3, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_0, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_1, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_2, 2, 5}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_3, 2, 5}, }), ); class permute_quant_u8: public PermuteFusingTest {}; @@ -6361,6 +6380,16 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_actv_scale_actv, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_1, 2, 7}, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_2, 2, 7}, permute_params{CASE_PERMUTE_TILE_8x8_4x4_6D_3, 2, 7}, + + // Fusing tests for permute_tile_8x8_4x4_fsv16 + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_0, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_1, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_2, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_4D_3, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_0, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_1, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_2, 2, 7}, + permute_params{CASE_PERMUTE_TILE_8x8_4x4_FSV16_5D_3, 2, 7}, }), ); /* ------------------------------------------------------------------------------------------------------------ */ diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/permute_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/permute_gpu_test.cpp index 1b25c4e542c..311a2ef8320 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/permute_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/permute_gpu_test.cpp @@ -31,6 +31,7 @@ #include #include #include +#include using namespace cldnn; using namespace tests; @@ -1621,3 +1622,244 @@ TEST(permute_gpu_f32_tile_8x8_4x4, xf_remainder_bfwzyx_0_5_4_1_2_3) { EXPECT_FLOAT_EQ(answers[i], output_ptr[i]); } } + +struct TiledPermuteParam { + std::vector sizes; + cldnn::format format_fsv; +}; + +class TiledPermuteTest : public ::testing::TestWithParam { +public: + const cldnn::engine engine; + TiledPermuteTest(): engine(get_test_engine()) { } + + template + void compare_value(T a, T b) const { + EXPECT_EQ(a, b); + } + + template + void set_random_values(const cldnn::memory& mem) const { + tests::set_random_values(mem); + } + + template + void run_test(const std::vector& sizes, cldnn::format format_fsv); +}; + +template<> +void TiledPermuteTest::compare_value(float a, float b) const { + EXPECT_FLOAT_EQ(a, b); +} + +// f16 format +template<> +void TiledPermuteTest::compare_value(FLOAT16 a, FLOAT16 b) const { + EXPECT_FLOAT_EQ(static_cast(a), static_cast(b)); +} + +template<> +void TiledPermuteTest::set_random_values(const cldnn::memory& mem) const { + // tests::set_random_values() is not supported + std::mt19937 gen; + static std::uniform_int_distribution uid(std::numeric_limits::min(), std::numeric_limits::max()); + auto ptr = mem.pointer(); + for (auto it = ptr.begin(); it != ptr.end(); ++it) { + *it = static_cast(uid(gen)); + } +} + +template +void TiledPermuteTest::run_test(const std::vector& sizes, cldnn::format format_fsv) +{ + // convert half_t to FLOAT16 + using type_ = typename data_type_to_type::type; + using type = typename std::conditional::value, FLOAT16, type_>::type; + + size_t input_size = 1; + for (size_t i = 0; i internal_sizes(sizes); + std::swap(internal_sizes.at(2), internal_sizes.back()); + cldnn::tensor tensor(internal_sizes); + + cldnn::format format = sizes.size() == 4?cldnn::format::bfyx:cldnn::format::bfzyx; + + std::vector order{0, static_cast(sizes.size()-1)}; + for (uint16_t i = 1; i<(sizes.size()-1); ++i) { + order.push_back(i); + } + + auto input = memory::allocate(engine, {Data_Type, format, tensor}); + set_random_values(input); + + topology topology_ref = topology( + input_layout("input", input.get_layout()), + reorder("reorder", "input", {Data_Type, format_fsv, tensor}), + permute("output", "reorder", order ) + ); + + // run with permute_ref + cldnn::build_options options_ref; + cldnn::implementation_desc permute_ref = { format_fsv, "permute_ref" }; + options_ref.set_option(cldnn::build_option::force_implementations({ {"output", permute_ref} })); + + cldnn::network network_ref(engine, topology_ref, options_ref); + network_ref.set_input_data("input", input); + auto outputs_ref = network_ref.execute(); + auto output_ref = outputs_ref.begin()->second.get_memory(); + auto output_ref_ptr = output_ref.pointer(); + + // run with permute_tile_8x8_4x4_fsv16 + cldnn::build_options options_tile; + cldnn::implementation_desc permute_tile_8x8_4x4_fsv = { format_fsv, "permute_tile_8x8_4x4_fsv" }; + options_tile.set_option(cldnn::build_option::force_implementations({ {"output", permute_tile_8x8_4x4_fsv} })); + + cldnn::network network_tile(engine, topology_ref, options_tile); + network_tile.set_input_data("input", input); + auto outputs_tile = network_tile.execute(); + auto output_tile = outputs_tile.begin()->second.get_memory(); + auto output_tile_ptr = output_tile.pointer(); + + // compare results + const size_t output_size= output_ref.get_layout().get_linear_size(); + for (size_t i = 0; i < output_size; i++) + { + compare_value(output_ref_ptr[i], output_tile_ptr[i]); + } +} + +class permute_tile_fsv_4d: public TiledPermuteTest {}; + +INSTANTIATE_TEST_CASE_P(, permute_tile_fsv_4d, + ::testing::ValuesIn(std::vector { + // b_fs_yx_fsv16 + // normal cases + {{1, 16, 16, 3}, format::b_fs_yx_fsv16}, + // f_not_aligned + {{1, 16 - 7, 16, 2}, format::b_fs_yx_fsv16}, + {{1, 16 - 15, 16, 2}, format::b_fs_yx_fsv16}, + // y_not_aligned + {{1, 16, 16 - 1, 2}, format::b_fs_yx_fsv16}, + {{1, 16, 16 - 9, 2}, format::b_fs_yx_fsv16}, + // fy_not_aligned + {{1, 16 - 15, 16 - 1, 2}, format::b_fs_yx_fsv16}, + {{1, 16 - 1, 16 - 7, 2}, format::b_fs_yx_fsv16}, + {{1, 16 - 7, 16 - 9, 2}, format::b_fs_yx_fsv16}, + {{1, 16 - 9, 16 - 15, 2}, format::b_fs_yx_fsv16}, + + // b_fs_yx_fsv32 + // normal cases + {{1, 32, 32, 3}, format::b_fs_yx_fsv32}, + // f_not_aligned + {{1, 32 - 7, 32, 2}, format::b_fs_yx_fsv32}, + {{1, 32 - 15, 32, 2}, format::b_fs_yx_fsv32}, + // y_not_aligned + {{1, 32, 32 - 1, 2}, format::b_fs_yx_fsv32}, + {{1, 32, 32 - 9, 2}, format::b_fs_yx_fsv32}, + // fy_not_aligned + {{1, 32 - 15, 32 - 1, 2}, format::b_fs_yx_fsv32}, + {{1, 32 - 1, 32 - 7, 2}, format::b_fs_yx_fsv32}, + {{1, 32 - 7, 32 - 9, 2}, format::b_fs_yx_fsv32}, + {{1, 32 - 9, 32 - 15, 2}, format::b_fs_yx_fsv32}, + + // b_fs_yx_fsv4 + // normal cases + {{1, 4, 4, 2}, format::b_fs_yx_fsv4}, + // f_not_aligned + {{1, 4 - 1, 4, 2}, format::b_fs_yx_fsv4}, + {{1, 4 - 3, 4, 2}, format::b_fs_yx_fsv4}, + // y_not_aligned + {{1, 4, 4 - 1, 2}, format::b_fs_yx_fsv4}, + {{1, 4, 4 - 3, 2}, format::b_fs_yx_fsv4}, + // fy_not_aligned + {{1, 4 - 3, 4 - 1, 2}, format::b_fs_yx_fsv4}, + {{1, 4 - 1, 4 - 3, 2}, format::b_fs_yx_fsv4}, + }),); + +TEST_P(permute_tile_fsv_4d, f16) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_4d, f32) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_4d, i8) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_4d, i32) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_4d, i64) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +class permute_tile_fsv_5d: public TiledPermuteTest {}; + +INSTANTIATE_TEST_CASE_P(, permute_tile_fsv_5d, + ::testing::ValuesIn(std::vector { + // b_fs_zyx_fsv16 + // normal cases + {{1, 16, 16, 3, 2}, format::b_fs_zyx_fsv16}, + // f_not_aligned + {{1, 16 - 7, 16, 2, 2}, format::b_fs_zyx_fsv16}, + {{1, 16 - 15, 16, 2, 2}, format::b_fs_zyx_fsv16}, + // z_not_aligned + {{1, 16, 16 - 1, 2, 2}, format::b_fs_zyx_fsv16}, + {{1, 16, 16 - 9, 2, 2}, format::b_fs_zyx_fsv16}, + // fz_not_aligned + {{1, 16 - 15, 16 - 1, 2, 2}, format::b_fs_zyx_fsv16}, + {{1, 16 - 1, 16 - 7, 2, 2}, format::b_fs_zyx_fsv16}, + {{1, 16 - 7, 16 - 9, 2, 2}, format::b_fs_zyx_fsv16}, + {{1, 16 - 9, 16 - 15, 2, 2}, format::b_fs_zyx_fsv16}, + + // b_fs_zyx_fsv32 + // normal cases + {{1, 32, 32, 3, 2}, format::b_fs_zyx_fsv32}, + // f_not_aligned + {{1, 32 - 7, 32, 2, 2}, format::b_fs_zyx_fsv32}, + {{1, 32 - 15, 32, 2, 2}, format::b_fs_zyx_fsv32}, + // z_not_aligned + {{1, 32, 32 - 1, 2, 2}, format::b_fs_zyx_fsv32}, + {{1, 32, 32 - 9, 2, 2}, format::b_fs_zyx_fsv32}, + // fz_not_aligned + {{1, 32 - 15, 32 - 1, 2, 2}, format::b_fs_zyx_fsv32}, + {{1, 32 - 1, 32 - 7, 2, 2}, format::b_fs_zyx_fsv32}, + {{1, 32 - 7, 32 - 9, 2, 2}, format::b_fs_zyx_fsv32}, + {{1, 32 - 9, 32 - 15, 2, 2}, format::b_fs_zyx_fsv32}, + }),); + +TEST_P(permute_tile_fsv_5d, f16) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_5d, f32) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_5d, i8) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_5d, i32) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +} + +TEST_P(permute_tile_fsv_5d, i64) { + auto p = GetParam(); + run_test(p.sizes, p.format_fsv); +}