[IE CLDNN] Weights reorders optimization (#1542)
* [IE CLDNN] Fix incorrect (g)_is_os_(z)yx_isv16_osv16 layout name * [IE CLDNN] Constant blobs copying improvements * [IE CLDNN] Weights reorders optimization
This commit is contained in:
parent
4c9fe89487
commit
fe6cb54bca
@ -771,9 +771,7 @@ cldnn::primitive_id Program::CreatePrimitiveFromBlob(cldnn::topology& topology,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < bufSize; i++) {
|
||||
buf[i] = data[i];
|
||||
}
|
||||
std::memcpy(&buf[0], &data[0], bufSize);
|
||||
}
|
||||
topology.add(cldnn::data(primID, mem));
|
||||
blobMemCache[data] = primID;
|
||||
|
16
inference-engine/thirdparty/clDNN/api/tensor.hpp
vendored
16
inference-engine/thirdparty/clDNN/api/tensor.hpp
vendored
@ -132,8 +132,8 @@ struct format {
|
||||
os_zyxi_osv16, ///< format used for weights for 3D convolution
|
||||
os_is_yx_isv16_osv16, ///< format used for blocked convolution
|
||||
os_is_zyx_isv16_osv16, ///< format used for weights for blocked 3D convolution
|
||||
is_os_zyx_osv16_isv16, ///< format used for weights for blocked 3D deconvolution
|
||||
is_os_yx_osv16_isv16, ///< format used for weights for blocked deconvolution
|
||||
is_os_zyx_isv16_osv16, ///< format used for weights for blocked 3D deconvolution
|
||||
is_os_yx_isv16_osv16, ///< format used for weights for blocked deconvolution
|
||||
os_is_yx_isv8_osv16_isv2, ///< format used for weights for blocked 2D convolution
|
||||
os_is_zyx_isv8_osv16_isv2, ///< format used for weights for blocked 3D convolution
|
||||
///< os - output feature maps slice, i - input feature maps,
|
||||
@ -187,10 +187,10 @@ struct format {
|
||||
gs_oiyx_gsv16, ///< format used for weights for 2D convolution
|
||||
gs_oizyx_gsv16, ///< format used for weights for 3D convolution
|
||||
gs_oiyx_gsv32, ///< format used for weights for 2D convolution
|
||||
g_is_os_zyx_osv16_isv16, ///< format used for grouped weights for blocked 3D deconvolution
|
||||
g_is_os_zyx_isv16_osv16, ///< format used for grouped weights for blocked 3D deconvolution
|
||||
g_os_is_yx_osv16_isv4,
|
||||
g_os_is_zyx_osv16_isv16,
|
||||
g_is_os_yx_osv16_isv16,
|
||||
g_is_os_yx_isv16_osv16,
|
||||
g_os_is_zyx_isv8_osv16_isv2,
|
||||
g_os_is_yx_isv8_osv16_isv2,
|
||||
g_os_is_zyx_isv16_osv16,
|
||||
@ -273,8 +273,8 @@ struct format {
|
||||
{ os_is_yx_osv32_isv4, { 1, 1, 2, 0, 0, "bfxy", "bfxy?", {{0, 32}, {1, 4}}}},
|
||||
{ os_is_yx_osv32_isv32p, { 1, 1, 1, 0, 0, "bfxy", "bfxy?", {}}},
|
||||
{ os_is_zyx_isv16_osv16, { 1, 1, 3, 0, 0, "bfzyx", "bfxyz", {{0, 16}, {1, 16}}}},
|
||||
{ is_os_zyx_osv16_isv16, { 1, 1, 3, 0, 0, "fbzyx", "bfxyz", {{0, 16}, {1, 16}}}},
|
||||
{ is_os_yx_osv16_isv16, { 1, 1, 2, 0, 0, "fbyx", "bfxyz", {{0, 16}, {1, 16}}}},
|
||||
{ is_os_zyx_isv16_osv16, { 1, 1, 3, 0, 0, "fbzyx", "bfxyz", {{1, 16}, {0, 16}}}},
|
||||
{ is_os_yx_isv16_osv16, { 1, 1, 2, 0, 0, "fbyx", "bfxyz", {{1, 16}, {0, 16}}}},
|
||||
{ os_is_osv32_isv32_swizzled_by_4, { 1, 1, 0, 0, 0, "bfxy", "bfxy?", {{0, 32}, {1, 32}}}},
|
||||
{ os_is_zyx_isv8_osv16_isv2, { 1, 1, 3, 0, 0, "bfzyx", "bfxyz", {{1, 8}, {0, 16}, {1, 2}}}},
|
||||
{ os_zyxi_osv16, { 1, 1, 3, 0, 0, "bzyxf", "bfxyz", {{0, 16}}}},
|
||||
@ -289,8 +289,8 @@ struct format {
|
||||
{ gs_oizyx_gsv16, { 1, 1, 3, 0, 1, "gbfzyx", "bfxyz???g", {{8, 16}}}},
|
||||
{ gs_oiyx_gsv32, { 1, 1, 2, 0, 1, "gbfyx", "bfxy????g", {{8, 32}}}},
|
||||
{ gyxio, { 1, 1, 2, 0, 1, "gyxfb", "bfxy????g", {}}},
|
||||
{ g_is_os_zyx_osv16_isv16, { 1, 1, 3, 0, 1, "gfbzyx", "bfxyz???g", {{0, 16}, {1, 16}}}},
|
||||
{ g_is_os_yx_osv16_isv16, { 1, 1, 2, 0, 1, "gfbyx", "bfxy????g", {{0, 16}, {1, 16}}}},
|
||||
{ g_is_os_zyx_isv16_osv16, { 1, 1, 3, 0, 1, "gfbzyx", "bfxyz???g", {{1, 16}, {0, 16}}}},
|
||||
{ g_is_os_yx_isv16_osv16, { 1, 1, 2, 0, 1, "gfbyx", "bfxy????g", {{1, 16}, {0, 16}}}},
|
||||
{ g_os_is_zyx_isv8_osv16_isv2, { 1, 1, 3, 0, 1, "gbfzyx", "bfxyz???g", {{1, 8}, {0, 16}, {1, 2}}}},
|
||||
{ g_os_is_yx_isv8_osv16_isv2, { 1, 1, 2, 0, 1, "gbfyx", "bfxy????g", {{1, 8}, {0, 16}, {1, 2}}}},
|
||||
{ g_os_is_zyx_isv16_osv16, { 1, 1, 3, 0, 1, "gbfzyx", "bfxyz???g", {{0, 16}, {1, 16}}}},
|
||||
|
@ -104,14 +104,13 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
|
||||
{ WeightsLayout::os_is_yx_osv32_isv32p, { 0, 1, -1, 2, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_is_zyx_isv16_osv16, { 0, 1, 2, 3, 4, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_is_yx_isv16_osv16, { 0, 1, -1, 2, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::is_os_zyx_osv16_isv16, { 0, 1, 2, 4, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::is_os_yx_osv16_isv16, { 0, 1, -1, 3, 2, -1, -1, -1 } },
|
||||
{ WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_is_osv32_isv32_swizzled_by_4, { -1, -1, -1, 0, 1, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_is_zyx_isv8_osv16_isv2, { 0, 1, 2, 3, 4, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_is_yx_isv8_osv16_isv2, { 0, 1, -1, 2, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_zyxi_osv16, { 1, 2, 3, 0, 4, -1, -1, -1 } },
|
||||
{ WeightsLayout::os_i_yxs_osv4_yxsv4, { 0, 1, -1, 2, 3, -1, -1, -1 } },
|
||||
{ WeightsLayout::is_os_yx_osv16_isv16, { 0, 1, -1, 3, 2, -1, -1, -1 } },
|
||||
{ WeightsLayout::goiyx, { 0, 1, -1, 2, 3, -1, -1, 4 } },
|
||||
{ WeightsLayout::goizyx, { 0, 1, 2, 3, 4, -1, -1, 5 } },
|
||||
{ WeightsLayout::g_os_iyx_osv16, { 0, 1, -1, 2, 3, -1, -1, 4 } },
|
||||
@ -121,8 +120,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
|
||||
{ WeightsLayout::gs_oiyx_gsv32, { 0, 1, -1, 2, 3, -1, -1, 4 } },
|
||||
{ WeightsLayout::gyxio, { 2, 3, -1, 1, 0, -1, -1, 4 } },
|
||||
{ WeightsLayout::gi_yxs_os_yxsv2_osv16, { 1, 2, -1, 3, 0, -1, -1, 4 } },
|
||||
{ WeightsLayout::g_is_os_zyx_osv16_isv16, { 0, 1, 2, 4, 3, -1, -1, 5 } },
|
||||
{ WeightsLayout::g_is_os_yx_osv16_isv16, { 0, 1, -1, 3, 2, -1, -1, 4 } },
|
||||
{ WeightsLayout::g_is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1, -1, 5 } },
|
||||
{ WeightsLayout::g_is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1, -1, 4 } },
|
||||
{ WeightsLayout::g_os_is_zyx_isv8_osv16_isv2, { 0, 1, 2, 3, 4, -1, -1, 5 } },
|
||||
{ WeightsLayout::g_os_is_yx_isv8_osv16_isv2, { 0, 1, -1, 2, 3, -1, -1, 4 } },
|
||||
{ WeightsLayout::g_os_is_zyx_isv16_osv16, { 0, 1, 2, 3, 4, -1, -1, 5 } },
|
||||
@ -542,12 +541,12 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
newDims[4] = RoundUp(newDims[4], 16);
|
||||
break;
|
||||
case is_os_zyx_osv16_isv16:
|
||||
case is_os_zyx_isv16_osv16:
|
||||
assert(newDims.size() == 5);
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
newDims[4] = RoundUp(newDims[4], 16);
|
||||
break;
|
||||
case is_os_yx_osv16_isv16:
|
||||
case is_os_yx_isv16_osv16:
|
||||
assert(newDims.size() == 4);
|
||||
newDims[2] = RoundUp(newDims[2], 16);
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
@ -594,12 +593,12 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
|
||||
assert(newDims.size() == 5);
|
||||
newDims[0] = RoundUp(newDims[0], 16);
|
||||
break;
|
||||
case g_is_os_zyx_osv16_isv16:
|
||||
case g_is_os_zyx_isv16_osv16:
|
||||
assert(newDims.size() == 6);
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
newDims[4] = RoundUp(newDims[4], 16);
|
||||
break;
|
||||
case g_is_os_yx_osv16_isv16:
|
||||
case g_is_os_yx_isv16_osv16:
|
||||
assert(newDims.size() == 5);
|
||||
newDims[2] = RoundUp(newDims[2], 16);
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
|
@ -81,8 +81,8 @@ enum WeightsLayout {
|
||||
os_iyx_osv32__ai32,
|
||||
os_iyx_osv64,
|
||||
os_is_zyx_isv16_osv16,
|
||||
is_os_zyx_osv16_isv16,
|
||||
is_os_yx_osv16_isv16,
|
||||
is_os_zyx_isv16_osv16,
|
||||
is_os_yx_isv16_osv16,
|
||||
os_is_zyx_isv8_osv16_isv2,
|
||||
os_is_yx_isv8_osv16_isv2,
|
||||
os_is_yx_isv16_osv16,
|
||||
@ -138,8 +138,8 @@ enum WeightsLayout {
|
||||
gs_oiyx_gsv32,
|
||||
g_os_iyx_osv16_rotate_180,
|
||||
gi_yxs_os_yxsv2_osv16,
|
||||
g_is_os_zyx_osv16_isv16,
|
||||
g_is_os_yx_osv16_isv16,
|
||||
g_is_os_zyx_isv16_osv16,
|
||||
g_is_os_yx_isv16_osv16,
|
||||
g_os_is_zyx_isv8_osv16_isv2,
|
||||
g_os_is_yx_isv8_osv16_isv2,
|
||||
g_os_is_zyx_isv16_osv16,
|
||||
|
@ -41,9 +41,9 @@ protected:
|
||||
return WeightsLayout::os_zyxi_osv16;
|
||||
} else if (use_data_type == Datatype::F32 && params.inputs[0].Batch().v % 16 == 0) {
|
||||
if (is_3d_case)
|
||||
return (params.groups > 1) ? WeightsLayout::g_is_os_zyx_osv16_isv16 : WeightsLayout::is_os_zyx_osv16_isv16;
|
||||
return (params.groups > 1) ? WeightsLayout::g_is_os_zyx_isv16_osv16 : WeightsLayout::is_os_zyx_isv16_osv16;
|
||||
else
|
||||
return (params.groups > 1) ? WeightsLayout::g_is_os_yx_osv16_isv16 : WeightsLayout::is_os_yx_osv16_isv16;
|
||||
return (params.groups > 1) ? WeightsLayout::g_is_os_yx_isv16_osv16 : WeightsLayout::is_os_yx_isv16_osv16;
|
||||
} else if (use_data_type == Datatype::F16 && params.inputs[0].Batch().v % 32 == 0) {
|
||||
if (is_3d_case)
|
||||
return (params.groups > 1) ? WeightsLayout::g_os_is_zyx_isv8_osv16_isv2 : WeightsLayout::os_is_zyx_isv8_osv16_isv2;
|
||||
|
@ -32,9 +32,9 @@ public:
|
||||
protected:
|
||||
WeightsLayout GetPreferredWeightsLayout(const deconvolution_params& p) const override {
|
||||
if (p.output.Dimentions() == 4)
|
||||
return WeightsLayout::is_os_yx_osv16_isv16;
|
||||
return WeightsLayout::is_os_yx_isv16_osv16;
|
||||
else
|
||||
return WeightsLayout::is_os_zyx_osv16_isv16;
|
||||
return WeightsLayout::is_os_zyx_isv16_osv16;
|
||||
}
|
||||
bool Validate(const Params& p, const optional_params& o) const override;
|
||||
CommonDispatchData SetDefault(const deconvolution_params& arg) const override;
|
||||
|
@ -34,8 +34,8 @@ inline uint32_t SubGroupSize(WeightsLayout l) {
|
||||
case WeightsLayout::os_is_yx_osv32_isv32p:
|
||||
case WeightsLayout::os_is_yx_isv16_osv16:
|
||||
case WeightsLayout::os_is_zyx_isv16_osv16:
|
||||
case WeightsLayout::is_os_zyx_osv16_isv16:
|
||||
case WeightsLayout::is_os_yx_osv16_isv16:
|
||||
case WeightsLayout::is_os_zyx_isv16_osv16:
|
||||
case WeightsLayout::is_os_yx_isv16_osv16:
|
||||
case WeightsLayout::os_is_yx_isv8_osv16_isv2:
|
||||
case WeightsLayout::os_is_zyx_isv8_osv16_isv2:
|
||||
case WeightsLayout::os_zyxi_osv16:
|
||||
@ -46,8 +46,8 @@ inline uint32_t SubGroupSize(WeightsLayout l) {
|
||||
case WeightsLayout::gs_oiyx_gsv32:
|
||||
case WeightsLayout::g_os_iyx_osv16_rotate_180:
|
||||
case WeightsLayout::gi_yxs_os_yxsv2_osv16:
|
||||
case WeightsLayout::g_is_os_zyx_osv16_isv16:
|
||||
case WeightsLayout::g_is_os_yx_osv16_isv16:
|
||||
case WeightsLayout::g_is_os_zyx_isv16_osv16:
|
||||
case WeightsLayout::g_is_os_yx_isv16_osv16:
|
||||
case WeightsLayout::g_os_is_zyx_isv8_osv16_isv2:
|
||||
case WeightsLayout::g_os_is_yx_isv8_osv16_isv2:
|
||||
case WeightsLayout::g_os_is_zyx_isv16_osv16:
|
||||
@ -217,6 +217,8 @@ ReorderKernelBase::DispatchData ReorderKernelBase::SetDefault(const reorder_para
|
||||
|
||||
KernelsData ReorderKernelBase::GetCommonKernelsData(const reorder_weights_params& params, const optional_params& options, float estimated_time) const {
|
||||
assert(params.GetType() == KernelType::REORDER);
|
||||
if (!Validate(params, options))
|
||||
return {};
|
||||
|
||||
KernelData kd = KernelData::Default<reorder_weights_params>(params);
|
||||
reorder_weights_params& newParams = *static_cast<reorder_weights_params*>(kd.params.get());
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -106,6 +106,7 @@ protected:
|
||||
virtual JitConstants GetJitConstants(const reorder_params& params) const;
|
||||
virtual DispatchData SetDefault(const reorder_weights_params& params) const;
|
||||
virtual DispatchData SetDefault(const reorder_params& params) const;
|
||||
virtual bool Validate(const Params&, const optional_params&) const { return true; };
|
||||
KernelsData GetCommonKernelsData(const reorder_weights_params& params,
|
||||
const optional_params&,
|
||||
float estimated_time) const;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016-2019 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -20,6 +20,7 @@
|
||||
#include "reorder_weights_image_fyx_b_kernel.h"
|
||||
#include "reorder_weights_image_winograd_6x3_kernel.h"
|
||||
#include "reorder_weights_binary_kernel.h"
|
||||
#include "reorder_weights_opt.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
@ -30,6 +31,7 @@ ReorderWeightsKernelSelctor::ReorderWeightsKernelSelctor() {
|
||||
Attach<ReorderWeightsImage_fyx_b_Kernel>();
|
||||
Attach<ReorderWeightsImageWinograd6x3Kernel>();
|
||||
Attach<ReorderWeightsBinaryKernel>();
|
||||
Attach<ReorderWeightsOpt>();
|
||||
}
|
||||
|
||||
KernelsData ReorderWeightsKernelSelctor::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||
|
@ -0,0 +1,193 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "reorder_weights_opt.h"
|
||||
#include "kernel_selector_utils.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
static const std::vector<size_t> preferred_sizes = {8, 4, 2, 1};
|
||||
|
||||
ParamsKey ReorderWeightsOpt::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputWeightsType(WeightsType::INT8);
|
||||
k.EnableInputWeightsType(WeightsType::F16);
|
||||
k.EnableInputWeightsType(WeightsType::F32);
|
||||
k.EnableOutputWeightsType(WeightsType::INT8);
|
||||
k.EnableOutputWeightsType(WeightsType::F16);
|
||||
k.EnableOutputWeightsType(WeightsType::F32);
|
||||
k.EnableInputWeightsLayout(WeightsLayout::oiyx);
|
||||
k.EnableInputWeightsLayout(WeightsLayout::oizyx);
|
||||
k.EnableInputWeightsLayout(WeightsLayout::goiyx);
|
||||
k.EnableInputWeightsLayout(WeightsLayout::goizyx);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_is_yx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_is_zyx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_is_yx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_is_zyx_isv16_osv16);
|
||||
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_iyx_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_iyx_osv32);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_iyx_osv32__ai32);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_iyx_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_iyx_osv32);
|
||||
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::is_os_yx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::is_os_zyx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_is_os_yx_isv16_osv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_is_os_zyx_isv16_osv16);
|
||||
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_is_yx_osv16_isv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_is_zyx_osv32_isv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::os_is_zyx_osv64_isv16);
|
||||
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_zyx_is_osv16_isv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_zyx_is_osv16_isv32);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_zyx_is_osv32_isv16);
|
||||
k.EnableOutputWeightsLayout(WeightsLayout::g_os_zyx_is_osv32_isv32);
|
||||
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
return k;
|
||||
}
|
||||
|
||||
static inline std::pair<size_t, size_t> GetSliceSizes(WeightsLayout l) {
|
||||
if (l == WeightsLayout::os_is_yx_isv16_osv16 || l == WeightsLayout::os_is_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::g_os_is_yx_isv16_osv16 || l == WeightsLayout::g_os_is_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::is_os_zyx_isv16_osv16 || l == WeightsLayout::is_os_yx_isv16_osv16 ||
|
||||
l == WeightsLayout::os_is_yx_osv16_isv16 || l == WeightsLayout::g_os_zyx_is_osv16_isv16 ||
|
||||
l == WeightsLayout::g_is_os_yx_isv16_osv16 || l == WeightsLayout::g_is_os_zyx_isv16_osv16)
|
||||
return {16, 16};
|
||||
else if (l == WeightsLayout::os_iyx_osv16 || l == WeightsLayout::g_os_iyx_osv16)
|
||||
return {1, 16};
|
||||
else if (l == WeightsLayout::os_iyx_osv32 || l == WeightsLayout::g_os_iyx_osv32 || l == WeightsLayout::os_iyx_osv32__ai32)
|
||||
return {1, 32};
|
||||
else if (l == WeightsLayout::os_is_zyx_osv32_isv16 || l == WeightsLayout::g_os_zyx_is_osv32_isv16)
|
||||
return {16, 32};
|
||||
else if (l == WeightsLayout::os_is_zyx_osv64_isv16)
|
||||
return {16, 64};
|
||||
else if (l == WeightsLayout::g_os_zyx_is_osv16_isv32)
|
||||
return {32, 16};
|
||||
else if (l == WeightsLayout::g_os_zyx_is_osv32_isv32)
|
||||
return {32, 32};
|
||||
else
|
||||
return {1, 1};
|
||||
}
|
||||
|
||||
static inline bool IsOsvFirst(WeightsLayout l) {
|
||||
if (l == WeightsLayout::os_is_yx_isv16_osv16 || l == WeightsLayout::os_is_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::g_os_is_yx_isv16_osv16 || l == WeightsLayout::g_os_is_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::os_iyx_osv16 || l == WeightsLayout::g_os_iyx_osv16||
|
||||
l == WeightsLayout::os_iyx_osv32 || l == WeightsLayout::g_os_iyx_osv32 ||
|
||||
l == WeightsLayout::os_iyx_osv32__ai32 || l == WeightsLayout::is_os_yx_isv16_osv16 ||
|
||||
l == WeightsLayout::is_os_zyx_isv16_osv16 || l == WeightsLayout::g_is_os_yx_isv16_osv16 ||
|
||||
l == WeightsLayout::g_is_os_zyx_isv16_osv16)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline size_t GetOptimalSize(size_t val, std::vector<size_t> optimal_sizes) {
|
||||
for (auto& s : optimal_sizes)
|
||||
if (val % s == 0)
|
||||
return s;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ReorderWeightsOpt::DispatchData ReorderWeightsOpt::SetDefault(
|
||||
const reorder_weights_params& params) const {
|
||||
DispatchData kd;
|
||||
|
||||
const auto& output = params.output;
|
||||
const auto output_layout = output.GetLayout();
|
||||
const auto subgroup_size = 16;
|
||||
const auto ifm_block_supported = (output_layout != WeightsLayout::os_iyx_osv16 &&
|
||||
output_layout != WeightsLayout::os_iyx_osv32 &&
|
||||
output_layout != WeightsLayout::g_os_iyx_osv16 &&
|
||||
output_layout != WeightsLayout::g_os_iyx_osv32 &&
|
||||
output_layout != WeightsLayout::os_iyx_osv32__ai32);
|
||||
|
||||
const auto osv_first = IsOsvFirst(output_layout);
|
||||
const auto ofm_block = (osv_first) ? subgroup_size : GetOptimalSize(output.OFM().v, preferred_sizes);
|
||||
const auto ifm_block = (osv_first) ? ifm_block_supported ? GetOptimalSize(output.IFM().v, preferred_sizes) : 1
|
||||
: subgroup_size;
|
||||
|
||||
std::vector<size_t> global;
|
||||
if (osv_first) {
|
||||
global = {output.G().v * (output.IFM().v / ifm_block), output.Z().v * output.Y().v * output.X().v, Align(output.OFM().v, ofm_block)};
|
||||
} else {
|
||||
global = {output.G().v * (output.OFM().v / ofm_block), output.Z().v * output.Y().v * output.X().v, Align(output.IFM().v, ifm_block)};
|
||||
}
|
||||
|
||||
kd.gws0 = global[0];
|
||||
kd.gws1 = global[1];
|
||||
kd.gws2 = global[2];
|
||||
|
||||
kd.lws0 = 1;
|
||||
kd.lws1 = 1;
|
||||
kd.lws2 = 16;
|
||||
|
||||
return kd;
|
||||
}
|
||||
|
||||
JitConstants ReorderWeightsOpt::GetJitConstants(const reorder_weights_params& params) const {
|
||||
auto jit = ReorderKernelBase::GetJitConstants(params);
|
||||
const auto& output = params.output;
|
||||
const auto subgroup_size = 16;
|
||||
const auto ifm_block_supported = (output.GetLayout() != WeightsLayout::os_iyx_osv16 &&
|
||||
output.GetLayout() != WeightsLayout::os_iyx_osv32 &&
|
||||
output.GetLayout() != WeightsLayout::g_os_iyx_osv16 &&
|
||||
output.GetLayout() != WeightsLayout::g_os_iyx_osv32 &&
|
||||
output.GetLayout() != WeightsLayout::os_iyx_osv32__ai32);
|
||||
|
||||
const auto slice_sizes = GetSliceSizes(output.GetLayout());
|
||||
const auto osv_first = IsOsvFirst(output.GetLayout());
|
||||
const auto leftovers = (osv_first) ? output.OFM().v % subgroup_size : output.IFM().v % subgroup_size;
|
||||
const auto ofm_block = (osv_first) ? subgroup_size : GetOptimalSize(output.OFM().v, preferred_sizes);
|
||||
const auto ifm_block = (osv_first) ? ifm_block_supported ? GetOptimalSize(output.IFM().v, preferred_sizes) : 1
|
||||
: subgroup_size;
|
||||
|
||||
jit.AddConstant(MakeJitConstant("IFM_SIZE", slice_sizes.first));
|
||||
jit.AddConstant(MakeJitConstant("OFM_SIZE", slice_sizes.second));
|
||||
jit.AddConstant(MakeJitConstant("OSV_FIRST", osv_first));
|
||||
jit.AddConstant(MakeJitConstant("IFM_BLOCK_SIZE", ifm_block));
|
||||
jit.AddConstant(MakeJitConstant("OFM_BLOCK_SIZE", ofm_block));
|
||||
|
||||
if (leftovers)
|
||||
jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS", leftovers));
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
bool ReorderWeightsOpt::Validate(const Params& params, const optional_params& /*options*/) const {
|
||||
const auto& p = static_cast<const reorder_weights_params&>(params);
|
||||
const auto& input = p.input;
|
||||
const auto& output = p.output;
|
||||
|
||||
if (input.GroupedLayout() != output.GroupedLayout()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input.GetDims().size() != output.GetDims().size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
KernelsData ReorderWeightsOpt::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
const reorder_weights_params& orgParams = static_cast<const reorder_weights_params&>(params);
|
||||
return GetCommonKernelsData(orgParams, options, FORCE_PRIORITY_5);
|
||||
}
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "reorder_kernel_base.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
class ReorderWeightsOpt : public ReorderKernelBase {
|
||||
public:
|
||||
ReorderWeightsOpt() : ReorderKernelBase("reorder_weights_opt") {}
|
||||
virtual ~ReorderWeightsOpt() {}
|
||||
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
DispatchData SetDefault(const reorder_weights_params& arg) const override;
|
||||
|
||||
protected:
|
||||
virtual bool Validate(const Params& params, const optional_params& options) const override;
|
||||
virtual JitConstants GetJitConstants(const reorder_weights_params& params) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
@ -280,7 +280,7 @@ inline uint FUNC(get_b_fs_yx_fsv_index_safe)(uint b, uint f, uint y, uint x,
|
||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
||||
)
|
||||
|
||||
#define GET_FILTER_IS_OS_ZYX_OSV16_ISV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
|
||||
#define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
(sub_group_size)*( \
|
||||
@ -292,7 +292,7 @@ inline uint FUNC(get_b_fs_yx_fsv_index_safe)(uint b, uint f, uint y, uint x,
|
||||
((i) / (sub_group_size))*CAT(prefix, _IFM_PITCH) \
|
||||
)
|
||||
|
||||
#define GET_FILTER_IS_OS_YX_OSV16_ISV16_INDEX(prefix, o, i, y, x, sub_group_size) \
|
||||
#define GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(prefix, o, i, y, x, sub_group_size) \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
(sub_group_size)*( \
|
||||
@ -1477,7 +1477,7 @@ inline uint FUNC(get_os_i_yxs_osv_yxsv4_index)(uint o, uint i, uint y, uint x, u
|
||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
||||
)
|
||||
|
||||
#define GET_FILTER_G_IS_OS_ZYX_OSV16_ISV16_INDEX(prefix, g, o, i, z, y, x, sub_group_size) \
|
||||
#define GET_FILTER_G_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, g, o, i, z, y, x, sub_group_size) \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(g)*CAT(prefix, _GROUPS_PITCH) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
@ -1490,7 +1490,7 @@ inline uint FUNC(get_os_i_yxs_osv_yxsv4_index)(uint o, uint i, uint y, uint x, u
|
||||
((i) / (sub_group_size))*CAT(prefix, _IFM_PITCH) \
|
||||
)
|
||||
|
||||
#define GET_FILTER_G_IS_OS_YX_OSV16_ISV16_INDEX(prefix, g, o, i, y, x, sub_group_size) \
|
||||
#define GET_FILTER_G_IS_OS_YX_ISV16_OSV16_INDEX(prefix, g, o, i, y, x, sub_group_size) \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(g)*CAT(prefix, _GROUPS_PITCH) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
|
@ -62,10 +62,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
|
||||
return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x);
|
||||
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined INPUT0_LAYOUT_IS_OS_ZYX_OSV16_ISV16
|
||||
return GET_FILTER_IS_OS_ZYX_OSV16_ISV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined INPUT0_LAYOUT_IS_OS_YX_OSV16_ISV16
|
||||
return GET_FILTER_IS_OS_YX_OSV16_ISV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16
|
||||
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined INPUT0_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4
|
||||
return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(INPUT0, o, i, y, x);
|
||||
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV8_OSV16_ISV2
|
||||
@ -179,10 +179,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
|
||||
return GET_FILTER_OIYX_O16(OUTPUT, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_IS_OS_ZYX_OSV16_ISV16
|
||||
return GET_FILTER_IS_OS_ZYX_OSV16_ISV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_IS_OS_YX_OSV16_ISV16
|
||||
return GET_FILTER_IS_OS_YX_OSV16_ISV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16
|
||||
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4
|
||||
return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV8_OSV16_ISV2
|
||||
@ -210,10 +210,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
|
||||
return GET_FILTER_GOIYX(OUTPUT, g, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_GI_YXS_OS_YXSV2_OSV16
|
||||
return GET_FILTER_GI_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, g, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_IS_OS_ZYX_OSV16_ISV16
|
||||
return GET_FILTER_G_IS_OS_ZYX_OSV16_ISV16_INDEX(OUTPUT, g, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_IS_OS_YX_OSV16_ISV16
|
||||
return GET_FILTER_G_IS_OS_YX_OSV16_ISV16_INDEX(OUTPUT, g, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_IS_OS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_G_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, g, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_IS_OS_YX_ISV16_OSV16
|
||||
return GET_FILTER_G_IS_OS_YX_ISV16_OSV16_INDEX(OUTPUT, g, o, i, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_OS_IS_ZYX_ISV16_OSV16
|
||||
return GET_FILTER_G_OS_IS_ZYX_ISV16_OSV16_INDEX(OUTPUT, g, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||
#elif defined OUTPUT_LAYOUT_G_OS_IS_YX_ISV8_OSV16_ISV2
|
||||
|
143
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights_opt.cl
vendored
Normal file
143
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights_opt.cl
vendored
Normal file
@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "include/common.cl"
|
||||
#include "include/data_types.cl"
|
||||
|
||||
INIT_INPUT0_INDEX_FUNC_HERE
|
||||
INIT_OUTPUT_INDEX_FUNC_HERE
|
||||
|
||||
#if OUTPUT_GROUPED
|
||||
# if OUTPUT_DIMS == 5
|
||||
# define IDX_ORDER g, o, i, y, x
|
||||
# define BLOCK_IDX_ORDER g, o_blocked, i_blocked, y, x
|
||||
# elif OUTPUT_DIMS == 6
|
||||
# define IDX_ORDER g, o, i, z, y, x
|
||||
# define BLOCK_IDX_ORDER g, o_blocked, i_blocked, z, y, x
|
||||
# endif
|
||||
#else
|
||||
# if OUTPUT_DIMS == 4
|
||||
# define IDX_ORDER o, i, y, x
|
||||
# define BLOCK_IDX_ORDER o_blocked, i_blocked, y, x
|
||||
# elif OUTPUT_DIMS == 5
|
||||
# define IDX_ORDER o, i, z, y, x
|
||||
# define BLOCK_IDX_ORDER o_blocked, i_blocked, z, y, x
|
||||
# endif
|
||||
#endif
|
||||
#define GET_INDEX(PREFIX, ORDER) CAT(PREFIX, _GET_INDEX)(ORDER)
|
||||
|
||||
#if OSV_FIRST
|
||||
# define FIRST_BLOCK_SIZE OFM_BLOCK_SIZE
|
||||
# define SECOND_BLOCK_SIZE IFM_BLOCK_SIZE
|
||||
# define PITCH INPUT0_IFM_PITCH
|
||||
# define SECOND_SIZE IFM_SIZE
|
||||
#else
|
||||
# define FIRST_BLOCK_SIZE IFM_BLOCK_SIZE
|
||||
# define SECOND_BLOCK_SIZE OFM_BLOCK_SIZE
|
||||
# define PITCH INPUT0_OFM_PITCH
|
||||
# define SECOND_SIZE OFM_SIZE
|
||||
#endif
|
||||
|
||||
#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, SECOND_BLOCK_SIZE)
|
||||
#define OUTPUT_BLOCK_WRITE(ptr, offset, val) BLOCK_WRITEN(OUTPUT_TYPE, SECOND_BLOCK_SIZE, ptr, offset, val)
|
||||
|
||||
__attribute__((intel_reqd_sub_group_size(FIRST_BLOCK_SIZE)))
|
||||
__attribute__((reqd_work_group_size(1, 1, FIRST_BLOCK_SIZE)))
|
||||
KERNEL(reorder_weights_blocked_opt)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output)
|
||||
{
|
||||
const int lid = get_sub_group_local_id();
|
||||
const int g_io = get_global_id(0);
|
||||
#if OSV_FIRST
|
||||
#if OUTPUT_GROUPED
|
||||
const int i = (g_io % (OUTPUT_IFM_NUM / SECOND_BLOCK_SIZE)) * SECOND_BLOCK_SIZE;
|
||||
const int g = (g_io / (OUTPUT_IFM_NUM / SECOND_BLOCK_SIZE));
|
||||
#else
|
||||
const int i = g_io * SECOND_BLOCK_SIZE;
|
||||
#endif // OUTPUT_GROUPED
|
||||
const int o_blocked = (int)get_group_id(2) * FIRST_BLOCK_SIZE;
|
||||
const int o = o_blocked + lid;
|
||||
const int i_blocked = i;
|
||||
#else // OSV_FIRST
|
||||
#if OUTPUT_GROUPED
|
||||
const int o = (g_io % (OUTPUT_OFM_NUM / SECOND_BLOCK_SIZE)) * SECOND_BLOCK_SIZE;
|
||||
const int g = (g_io / (OUTPUT_OFM_NUM / SECOND_BLOCK_SIZE));
|
||||
#else
|
||||
const int o = g_io * SECOND_BLOCK_SIZE;
|
||||
#endif // OUTPUT_GROUPED
|
||||
const int i_blocked = (int)get_group_id(2) * FIRST_BLOCK_SIZE;
|
||||
const int i = i_blocked + lid;
|
||||
const int o_blocked = o;
|
||||
#endif // OSV_FIRST
|
||||
|
||||
const int zyx = get_global_id(1);
|
||||
const int x = zyx % OUTPUT_SIZE_X;
|
||||
#if (OUTPUT_DIMS - OUTPUT_GROUPED) == 5
|
||||
const int y = zyx / OUTPUT_SIZE_X % OUTPUT_SIZE_Y;
|
||||
const int z = zyx / OUTPUT_SIZE_X / OUTPUT_SIZE_Y;
|
||||
#else
|
||||
const int y = zyx / OUTPUT_SIZE_X;
|
||||
#endif // (OUTPUT_DIMS - OUTPUT_GROUPED) == 5
|
||||
|
||||
int input_idx = GET_INDEX(INPUT0, IDX_ORDER);
|
||||
const int output_idx = GET_INDEX(OUTPUT, BLOCK_IDX_ORDER);
|
||||
|
||||
#if SECOND_BLOCK_SIZE == 1
|
||||
const OUTPUT_TYPE val = TO_OUTPUT_TYPE(input[input_idx]);
|
||||
#else
|
||||
OUTPUT_VEC_TYPE val = 0;
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for (int b = 0; b < SECOND_BLOCK_SIZE; b++) {
|
||||
val[b] = TO_OUTPUT_TYPE(input[input_idx]);
|
||||
input_idx += PITCH;
|
||||
}
|
||||
#endif // SECOND_BLOCK_SIZE == 1
|
||||
#if OUTPUT_LEFTOVERS
|
||||
#if OSV_FIRST
|
||||
const bool doWrite = o < OUTPUT_OFM_NUM;
|
||||
if (o_blocked >= OUTPUT_OFM_NUM - FIRST_BLOCK_SIZE) {
|
||||
#else
|
||||
const bool doWrite = i < OUTPUT_IFM_NUM;
|
||||
if (i_blocked >= OUTPUT_IFM_NUM - FIRST_BLOCK_SIZE) {
|
||||
#endif // OSV_FIRST
|
||||
#if SECOND_BLOCK_SIZE > 1
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for (int b = 0; b < SECOND_BLOCK_SIZE; b++)
|
||||
if (doWrite)
|
||||
output[output_idx + b * SECOND_SIZE + lid] = val[b];
|
||||
#else
|
||||
if (doWrite)
|
||||
output[output_idx + lid] = val;
|
||||
#endif // SECOND_BLOCK_SIZE > 1
|
||||
}
|
||||
else
|
||||
#endif // OUTPUT_LEFTOVERS
|
||||
{
|
||||
OUTPUT_BLOCK_WRITE(output, output_idx, val);
|
||||
}
|
||||
}
|
||||
|
||||
#undef OUTPUT_VEC_TYPE
|
||||
#undef OSV_FIRST
|
||||
#undef FIRST_BLOCK_SIZE
|
||||
#undef SECOND_BLOCK_SIZE
|
||||
#undef PITCH
|
||||
#undef SECOND_SIZE
|
||||
#undef OUTPUT_BLOCK_WRITE8
|
||||
#undef OUTPUT_BLOCK_WRITE4
|
||||
#undef OUTPUT_BLOCK_WRITE2
|
||||
#undef OUTPUT_BLOCK_WRITE1
|
||||
#undef OUTPUT_BLOCK_WRITE
|
||||
#undef GET_INDEX
|
||||
#undef BLOCK_IDX_ORDER
|
||||
#undef IDX_ORDER
|
@ -471,6 +471,220 @@ std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const Data
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class WeightTensorJitConstant : public TensorBaseTJitConstant<WeightsType, WeightsLayout> {
|
||||
const WeightsTensor _tensor;
|
||||
struct WeightIndexFuncDesc {
|
||||
std::string macroName;
|
||||
std::string macroBody;
|
||||
std::string calcFunction;
|
||||
|
||||
WeightIndexFuncDesc() = default;
|
||||
WeightIndexFuncDesc(const WeightsLayout l) {
|
||||
using args = std::initializer_list<std::string>;
|
||||
if (l == WeightsLayout::oiyx || l == WeightsLayout::oizyx || l == WeightsLayout::goiyx ||
|
||||
l == WeightsLayout::goizyx) {
|
||||
args macroNameArgs = {"prefix", "g", "o", "i", "z", "y", "x"};
|
||||
const auto name = toString(l);
|
||||
this->calcFunction = FuncBody(name);
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->macroBody = R"V0G0N( \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(x)*CAT(prefix, _X_PITCH) + \
|
||||
(y)*CAT(prefix, _Y_PITCH) + \
|
||||
(z)*CAT(prefix, _Z_PITCH) + \
|
||||
(i)*CAT(prefix, _IFM_PITCH) + \
|
||||
(o)*CAT(prefix, _OFM_PITCH) + \
|
||||
(g)*CAT(prefix, _GROUPS_PITCH)
|
||||
)V0G0N";
|
||||
} else if (l == WeightsLayout::os_is_yx_isv16_osv16 || l == WeightsLayout::os_is_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::g_os_is_yx_isv16_osv16 || l == WeightsLayout::g_os_is_zyx_isv16_osv16) {
|
||||
args macroNameArgs = {"prefix", "g", "o", "i", "z", "y", "x", "sub_group_size"};
|
||||
const auto name = toString(l);
|
||||
this->calcFunction = FuncBody(name);
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->macroBody = R"V0G0N( \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(g)*CAT(prefix, _GROUPS_PITCH) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
(sub_group_size)*( \
|
||||
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \
|
||||
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \
|
||||
(z)*(sub_group_size)*CAT(prefix, _Z_PITCH) + \
|
||||
((i) % (sub_group_size)) + \
|
||||
((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \
|
||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
||||
)
|
||||
)V0G0N";
|
||||
} else if (l == WeightsLayout::os_iyx_osv16 || l == WeightsLayout::os_iyx_osv32 ||
|
||||
l == WeightsLayout::os_iyx_osv32__ai32 || l == WeightsLayout::g_os_iyx_osv16 ||
|
||||
l == WeightsLayout::g_os_iyx_osv32) {
|
||||
args macroNameArgs = {"prefix", "g", "o", "i", "y", "x", "sub_group_size"};
|
||||
const auto name = toString(l);
|
||||
this->calcFunction = FuncBody(name);
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->macroBody = R"V0G0N( \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(g * CAT(prefix, _GROUPS_PITCH)) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
(sub_group_size)*( \
|
||||
(x)*CAT(prefix, _X_PITCH) + \
|
||||
(y)*CAT(prefix, _Y_PITCH) + \
|
||||
(i)*CAT(prefix, _IFM_PITCH) + \
|
||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
||||
)
|
||||
)V0G0N";
|
||||
} else if (l == WeightsLayout::is_os_yx_isv16_osv16 || l == WeightsLayout::is_os_zyx_isv16_osv16 ||
|
||||
l == WeightsLayout::g_is_os_yx_isv16_osv16 || l == WeightsLayout::g_is_os_zyx_isv16_osv16) {
|
||||
args macroNameArgs = {"prefix", "g", "o", "i", "z", "y", "x", "sub_group_size"};
|
||||
const auto name = toString(l);
|
||||
this->calcFunction = FuncBody(name);
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->macroBody = R"V0G0N( \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(g)*CAT(prefix, _GROUPS_PITCH) + \
|
||||
((o) % (sub_group_size)) + \
|
||||
(sub_group_size)*( \
|
||||
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \
|
||||
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \
|
||||
(z)*(sub_group_size)*CAT(prefix, _Z_PITCH) + \
|
||||
((i) % (sub_group_size)) + \
|
||||
((o) / (sub_group_size))*(sub_group_size)*CAT(prefix, _OFM_PITCH) + \
|
||||
((i) / (sub_group_size))*CAT(prefix, _IFM_PITCH) \
|
||||
)
|
||||
)V0G0N";
|
||||
} else if (l == WeightsLayout::os_is_yx_osv16_isv16 || l == WeightsLayout::os_is_zyx_osv32_isv16 ||
|
||||
l == WeightsLayout::os_is_zyx_osv64_isv16) {
|
||||
args macroNameArgs = {"prefix", "o", "i", "z", "y", "x"};
|
||||
args funcArgs = {"o", "i", "z", "y", "x", "x_size", "y_size", "z_size", "i_size", "o_size", "osv_size", "isv_size"};
|
||||
const auto name = toString(l);
|
||||
const auto body = R"V0G0N( \
|
||||
const uint isv = i % isv_size; \
|
||||
const uint osv = o % osv_size; \
|
||||
const uint is = i / isv_size; \
|
||||
const uint os = o / osv_size; \
|
||||
const uint x_pitch = osv_size * isv_size; \
|
||||
const uint y_pitch = x_pitch * x_size; \
|
||||
const uint z_pitch = y_pitch * y_size; \
|
||||
const uint is_pitch = z_pitch * z_size; \
|
||||
const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size); \
|
||||
const uint output_offset = \
|
||||
isv + \
|
||||
osv * isv_size + \
|
||||
x * x_pitch + \
|
||||
y * y_pitch + \
|
||||
z * z_pitch + \
|
||||
is * is_pitch + \
|
||||
os * os_pitch; \
|
||||
return output_offset; \
|
||||
)V0G0N";
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->calcFunction = FuncBody(name, funcArgs, body);
|
||||
if (l == WeightsLayout::os_is_yx_osv16_isv16)
|
||||
this->macroBody = FuncCall(name, {"o", "i", "0", "y", "x", Cat("_SIZE_X"), Cat("_SIZE_Y"), "1", Cat("_IFM_NUM"), Cat("_OFM_NUM"), "16", "16"});
|
||||
else if (l == WeightsLayout::os_is_zyx_osv32_isv16)
|
||||
this->macroBody = FuncCall(name, {"o", "i", "z", "y", "x", Cat("_SIZE_X"), Cat("_SIZE_Y"), Cat("_SIZE_Z"), Cat("_IFM_NUM"), Cat("_OFM_NUM"), "32", "16"});
|
||||
else if (l == WeightsLayout::os_is_zyx_osv64_isv16)
|
||||
this->macroBody = FuncCall(name, {"o", "i", "z", "y", "x", Cat("_SIZE_X"), Cat("_SIZE_Y"), Cat("_SIZE_Z"), Cat("_IFM_NUM"), Cat("_OFM_NUM"), "64", "16"});
|
||||
} else if (l == WeightsLayout::g_os_zyx_is_osv16_isv16 || l == WeightsLayout::g_os_zyx_is_osv16_isv32 ||
|
||||
l == WeightsLayout::g_os_zyx_is_osv32_isv16 || l == WeightsLayout::g_os_zyx_is_osv32_isv32) {
|
||||
args macroNameArgs = {"prefix", "g", "o", "i", "z", "y", "x"};
|
||||
args funcArgs = {"g", "o", "i", "z", "y", "x", "g_size", "o_size", "i_size", "z_size", "y_size", "x_size", "osv", "isv"};
|
||||
const auto name = toString(l);
|
||||
const auto body = R"V0G0N( \
|
||||
uint is_size = (i_size + isv - 1) / isv; \
|
||||
uint os_size = (o_size + osv - 1) / osv; \
|
||||
uint isv_index = i % isv; \
|
||||
uint osv_index = o % osv; \
|
||||
uint is_index = i / isv; \
|
||||
uint os_index = o / osv; \
|
||||
uint isv_pitch = 1; \
|
||||
uint osv_pitch = isv_pitch * isv; \
|
||||
uint is_pitch = osv_pitch * osv; \
|
||||
uint x_pitch = is_pitch * is_size; \
|
||||
uint y_pitch = x_pitch * x_size; \
|
||||
uint z_pitch = y_pitch * y_size; \
|
||||
uint os_pitch = z_pitch * z_size; \
|
||||
uint g_pitch = os_pitch * os_size; \
|
||||
uint index = 0; \
|
||||
index += isv_index * isv_pitch; \
|
||||
index += osv_index * osv_pitch; \
|
||||
index += is_index * is_pitch; \
|
||||
index += x * x_pitch; \
|
||||
index += y * y_pitch; \
|
||||
index += z * z_pitch; \
|
||||
index += os_index * os_pitch; \
|
||||
index += g * g_pitch; \
|
||||
return index; \
|
||||
)V0G0N";
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->calcFunction = FuncBody(name, funcArgs, body);
|
||||
std::string osv = "16", isv = "16";
|
||||
if (l == WeightsLayout::g_os_zyx_is_osv16_isv16) {
|
||||
osv = "16"; isv = "16";
|
||||
} else if (l == WeightsLayout::g_os_zyx_is_osv16_isv32) {
|
||||
osv = "16"; isv = "32";
|
||||
} else if (l == WeightsLayout::g_os_zyx_is_osv32_isv16) {
|
||||
osv = "32"; isv = "16";
|
||||
} else if (l == WeightsLayout::g_os_zyx_is_osv32_isv32) {
|
||||
osv = "32"; isv = "32";
|
||||
}
|
||||
this->macroBody = FuncCall(name, {"g", "o", "i", "z", "y", "x", Cat("_GROUPS_NUM"), Cat("_OFM_NUM"), Cat("_IFM_NUM"), Cat("_SIZE_Z"),
|
||||
Cat("_SIZE_Y"), Cat("_SIZE_X"), osv, isv});
|
||||
} else if (l == WeightsLayout::os_is_yx_osv16_isv4 || l == WeightsLayout::os_is_yx_osv32_isv4) {
|
||||
args macroNameArgs = {"prefix", "o", "i", "y", "x"};
|
||||
args funcArgs = {"o", "i", "y", "x", "i_size", "o_size", "x_size", "otd"};
|
||||
const auto name = toString(l);
|
||||
const auto body = R"V0G0N( \
|
||||
uint out_depth_tile = o / otd; \
|
||||
uint od = o - out_depth_tile * otd; \
|
||||
const uint tile = 4; \
|
||||
uint id_tile = i / tile; \
|
||||
uint id = i - id_tile * tile; \
|
||||
uint idx = out_depth_tile * (o_size / tile) * otd * tile \
|
||||
+ id_tile * i_size * otd * tile \
|
||||
+ y * x_size * otd * tile \
|
||||
+ x * otd * tile \
|
||||
+ od * tile \
|
||||
+ id; \
|
||||
return idx; \
|
||||
)V0G0N";
|
||||
this->macroName = MacroName(name, macroNameArgs);
|
||||
this->calcFunction = FuncBody(name, funcArgs, body);
|
||||
if (l == WeightsLayout::os_is_yx_osv16_isv4)
|
||||
this->macroBody = FuncCall(name, {"o", "i", "y", "x", Cat("_IFM_PITCH"), Cat("_OFM_PITCH"), Cat("_SIZE_X"), "16"});
|
||||
else if (l == WeightsLayout::os_is_yx_osv32_isv4)
|
||||
this->macroBody = FuncCall(name, {"o", "i", "y", "x", Cat("_IFM_PITCH"), Cat("_OFM_PITCH"), Cat("_SIZE_X"), "32"});
|
||||
} else {
|
||||
// throw error?
|
||||
}
|
||||
}
|
||||
|
||||
static const std::string Cat(std::string name, std::string prefix = "prefix") {
|
||||
return "CAT(" + prefix + ", " + name + ")";
|
||||
}
|
||||
|
||||
static const std::string FuncCall(std::string name, std::initializer_list<std::string> args) {
|
||||
std::string args_str = "";
|
||||
size_t counter = 0;
|
||||
for (auto& arg : args)
|
||||
args_str += (++counter == args.size()) ? (arg) : (arg + ", ");
|
||||
return "FUNC_CALL(" + name + ")(" + args_str + ")";
|
||||
}
|
||||
|
||||
static const std::string MacroName(std::string name, std::initializer_list<std::string> args) {
|
||||
std::string args_str = "";
|
||||
size_t counter = 0;
|
||||
for (auto& arg : args)
|
||||
args_str += (++counter == args.size()) ? (arg) : (arg + ", ");
|
||||
return "GET_WEIGHTS_" + name + "_INDEX(" + args_str + ")";
|
||||
}
|
||||
|
||||
static const std::string FuncBody(std::string name, std::initializer_list<std::string> args = {}, std::string body = "return 0;") {
|
||||
std::string args_str = "";
|
||||
size_t counter = 0;
|
||||
for (auto& arg : args)
|
||||
args_str += (++counter == args.size()) ? (arg) : (arg + ", ");
|
||||
return "inline uint FUNC(" + name + ")(" + args_str + "){" + body + "}";
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
WeightTensorJitConstant(const std::string& name, const WeightsTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
|
||||
@ -498,6 +712,130 @@ JitDefinitions WeightTensorJitConstant::GetDefinitions() const {
|
||||
|
||||
definitions.insert(definitions.end(), baseDefinitions.begin(), baseDefinitions.end());
|
||||
|
||||
auto is_common_nd_layout = [](std::vector<Tensor::WeightsChannelName> common_channels, WeightsLayout l) -> bool {
|
||||
for (size_t c = 0; c < static_cast<size_t>(Tensor::WeightsChannelName::COUNT); c++) {
|
||||
auto channel = static_cast<Tensor::WeightsChannelName>(c);
|
||||
if (WeightsTensor::Channelndex(l, channel) != -1) {
|
||||
if (std::find(common_channels.begin(), common_channels.end(), channel) == common_channels.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::string index_func_name = _name + "_INDEX_FUNC";
|
||||
std::string index_macro_name;
|
||||
std::string index_func_val;
|
||||
|
||||
auto layout = _tensor.GetLayout();
|
||||
WeightIndexFuncDesc indexFuncDesc {layout};
|
||||
if (WeightsTensor::DoesGroupDimExist(layout)) {
|
||||
if (WeightsTensor::ChannelsCount(layout) <= 5) {
|
||||
std::vector<Tensor::WeightsChannelName> grouped_4d_channels = {
|
||||
Tensor::WeightsChannelName::G,
|
||||
Tensor::WeightsChannelName::OFM,
|
||||
Tensor::WeightsChannelName::IFM,
|
||||
Tensor::WeightsChannelName::Y,
|
||||
Tensor::WeightsChannelName::X,
|
||||
};
|
||||
bool is_grouped_4d_layout = is_common_nd_layout(grouped_4d_channels, layout);
|
||||
if (is_grouped_4d_layout) {
|
||||
index_macro_name = _name + "_GET_INDEX(g, o, i, y, x)";
|
||||
auto layout_str = toString(layout);
|
||||
if (layout == WeightsLayout::goiyx)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, 0, y, x)";
|
||||
else if (layout == WeightsLayout::g_os_is_yx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, 0, y, x, 16)";
|
||||
else if (layout == WeightsLayout::g_os_iyx_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, y, x, 16)";
|
||||
else if (layout == WeightsLayout::g_is_os_yx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, 0, y, x, 16)";
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
} else if (WeightsTensor::ChannelsCount(layout) == 6) {
|
||||
std::vector<Tensor::WeightsChannelName> grouped_5d_channels = {
|
||||
Tensor::WeightsChannelName::G,
|
||||
Tensor::WeightsChannelName::OFM,
|
||||
Tensor::WeightsChannelName::IFM,
|
||||
Tensor::WeightsChannelName::Z,
|
||||
Tensor::WeightsChannelName::Y,
|
||||
Tensor::WeightsChannelName::X,
|
||||
};
|
||||
bool is_grouped_5d_layout = is_common_nd_layout(grouped_5d_channels, layout);
|
||||
if (is_grouped_5d_layout) {
|
||||
index_macro_name = _name + "_GET_INDEX(g, o, i, z, y, x)";
|
||||
auto layout_str = toString(layout);
|
||||
if (layout == WeightsLayout::goizyx)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, z, y, x)";
|
||||
else if (layout == WeightsLayout::g_os_is_zyx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, z, y, x, 16)";
|
||||
else if (layout == WeightsLayout::g_is_os_zyx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", g, o, i, z, y, x, 16)";
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (WeightsTensor::ChannelsCount(layout) <= 4) {
|
||||
std::vector<Tensor::WeightsChannelName> base_4d_channels = {
|
||||
Tensor::WeightsChannelName::OFM,
|
||||
Tensor::WeightsChannelName::IFM,
|
||||
Tensor::WeightsChannelName::Y,
|
||||
Tensor::WeightsChannelName::X,
|
||||
};
|
||||
bool is_common_4d_layout = is_common_nd_layout(base_4d_channels, layout);
|
||||
if (is_common_4d_layout) {
|
||||
index_macro_name = _name + "_GET_INDEX(o, i, y, x)";
|
||||
auto layout_str = toString(layout);
|
||||
if (layout == WeightsLayout::oiyx)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, 0, y, x)";
|
||||
else if (layout == WeightsLayout::os_is_yx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, 0, y, x, 16)";
|
||||
else if (layout == WeightsLayout::os_iyx_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, y, x, 16)";
|
||||
else if (layout == WeightsLayout::os_iyx_osv32 || layout == WeightsLayout::os_iyx_osv32__ai32)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, y, x, 32)";
|
||||
else if (layout == WeightsLayout::is_os_yx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, 0, y, x, 16)";
|
||||
else if (layout == WeightsLayout::os_is_yx_osv16_isv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", o, i, 0, y, x)";
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
} else if (WeightsTensor::ChannelsCount(layout) == 5) {
|
||||
std::vector<Tensor::WeightsChannelName> base_5d_channels = {
|
||||
Tensor::WeightsChannelName::OFM,
|
||||
Tensor::WeightsChannelName::IFM,
|
||||
Tensor::WeightsChannelName::Z,
|
||||
Tensor::WeightsChannelName::Y,
|
||||
Tensor::WeightsChannelName::X,
|
||||
};
|
||||
bool is_common_5d_layout = is_common_nd_layout(base_5d_channels, layout);
|
||||
if (is_common_5d_layout) {
|
||||
index_macro_name = _name + "_GET_INDEX(o, i, z, y, x)";
|
||||
auto layout_str = toString(layout);
|
||||
if (layout == WeightsLayout::oizyx)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, z, y, x)";
|
||||
else if (layout == WeightsLayout::os_is_zyx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, z, y, x, 16)";
|
||||
else if (layout == WeightsLayout::is_os_zyx_isv16_osv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", 0, o, i, z, y, x, 16)";
|
||||
else if (layout == WeightsLayout::os_is_zyx_osv32_isv16 || layout == WeightsLayout::os_is_zyx_osv64_isv16)
|
||||
index_func_val = "GET_WEIGHTS_" + layout_str + "_INDEX(" + _name + ", o, i, z, y, x)";
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!indexFuncDesc.macroName.empty()) {
|
||||
definitions.push_back({ index_func_name, indexFuncDesc.calcFunction });
|
||||
definitions.push_back({ "INIT_" + index_func_name + "_HERE", index_func_name });
|
||||
definitions.push_back({ indexFuncDesc.macroName, indexFuncDesc.macroBody });
|
||||
definitions.push_back({ index_macro_name, index_func_val });
|
||||
}
|
||||
return definitions;
|
||||
}
|
||||
|
||||
|
@ -344,8 +344,8 @@ std::string toString(WeightsLayout layout) {
|
||||
case WeightsLayout::os_is_yx_osv32_isv32p: return "OS_IS_YX_OSV32_ISV32P";
|
||||
case WeightsLayout::oizyx: return "OIZYX";
|
||||
case WeightsLayout::os_is_zyx_isv16_osv16: return "OS_IS_ZYX_ISV16_OSV16";
|
||||
case WeightsLayout::is_os_zyx_osv16_isv16: return "IS_OS_ZYX_OSV16_ISV16";
|
||||
case WeightsLayout::is_os_yx_osv16_isv16: return "IS_OS_YX_OSV16_ISV16";
|
||||
case WeightsLayout::is_os_zyx_isv16_osv16: return "IS_OS_ZYX_ISV16_OSV16";
|
||||
case WeightsLayout::is_os_yx_isv16_osv16: return "IS_OS_YX_ISV16_OSV16";
|
||||
case WeightsLayout::os_is_zyx_isv8_osv16_isv2: return "OS_IS_ZYX_ISV8_OSV16_ISV2";
|
||||
case WeightsLayout::os_zyxi_osv16: return "OS_ZYXI_OSV16";
|
||||
case WeightsLayout::os_is_yx_isv8_osv16_isv2: return "OS_IS_YX_ISV8_OSV16_ISV2";
|
||||
@ -360,8 +360,8 @@ std::string toString(WeightsLayout layout) {
|
||||
case WeightsLayout::gs_oizyx_gsv16: return "GS_OIZYX_GSV16";
|
||||
case WeightsLayout::gs_oiyx_gsv32: return "GS_OIYX_GSV32";
|
||||
case WeightsLayout::gi_yxs_os_yxsv2_osv16: return "GI_YXS_OS_YXSV2_OSV16";
|
||||
case WeightsLayout::g_is_os_zyx_osv16_isv16: return "G_IS_OS_ZYX_OSV16_ISV16";
|
||||
case WeightsLayout::g_is_os_yx_osv16_isv16: return "G_IS_OS_YX_OSV16_ISV16";
|
||||
case WeightsLayout::g_is_os_zyx_isv16_osv16: return "G_IS_OS_ZYX_ISV16_OSV16";
|
||||
case WeightsLayout::g_is_os_yx_isv16_osv16: return "G_IS_OS_YX_ISV16_OSV16";
|
||||
case WeightsLayout::g_os_is_zyx_isv8_osv16_isv2: return "G_OS_IS_ZYX_ISV8_OSV16_ISV2";
|
||||
case WeightsLayout::g_os_is_yx_isv8_osv16_isv2: return "G_OS_IS_YX_ISV8_OSV16_ISV2";
|
||||
case WeightsLayout::g_os_is_zyx_isv16_osv16: return "G_OS_IS_ZYX_ISV16_OSV16";
|
||||
|
@ -170,10 +170,10 @@ inline std::string fmt_to_str(format fmt) {
|
||||
return "os_is_yx_osv32_isv32p";
|
||||
case format::os_is_zyx_isv16_osv16:
|
||||
return "os_is_zyx_isv16_osv16";
|
||||
case format::is_os_zyx_osv16_isv16:
|
||||
return "is_os_zyx_osv16_isv16";
|
||||
case format::is_os_yx_osv16_isv16:
|
||||
return "is_os_yx_osv16_isv16";
|
||||
case format::is_os_zyx_isv16_osv16:
|
||||
return "is_os_zyx_isv16_osv16";
|
||||
case format::is_os_yx_isv16_osv16:
|
||||
return "is_os_yx_isv16_osv16";
|
||||
case format::os_is_osv32_isv32_swizzled_by_4:
|
||||
return "os_is_osv32_isv32_swizzled_by_4";
|
||||
case format::os_is_zyx_isv8_osv16_isv2:
|
||||
@ -193,10 +193,10 @@ inline std::string fmt_to_str(format fmt) {
|
||||
return "gs_oiyx_gsv16";
|
||||
case format::gs_oiyx_gsv32:
|
||||
return "gs_oiyx_gsv32";
|
||||
case format::g_is_os_zyx_osv16_isv16:
|
||||
return "g_is_os_zyx_osv16_isv16";
|
||||
case format::g_is_os_yx_osv16_isv16:
|
||||
return "g_is_os_yx_osv16_isv16";
|
||||
case format::g_is_os_zyx_isv16_osv16:
|
||||
return "g_is_os_zyx_isv16_osv16";
|
||||
case format::g_is_os_yx_isv16_osv16:
|
||||
return "g_is_os_yx_isv16_osv16";
|
||||
case format::g_os_is_zyx_isv8_osv16_isv2:
|
||||
return "g_os_is_zyx_isv8_osv16_isv2";
|
||||
case format::g_os_is_yx_isv8_osv16_isv2:
|
||||
|
@ -289,10 +289,10 @@ kernel_selector::weights_layout to_weights_layout(format f) {
|
||||
return kernel_selector::weights_layout::os_i_osv16;
|
||||
case format::os_is_zyx_isv16_osv16:
|
||||
return kernel_selector::weights_layout::os_is_zyx_isv16_osv16;
|
||||
case format::is_os_zyx_osv16_isv16:
|
||||
return kernel_selector::weights_layout::is_os_zyx_osv16_isv16;
|
||||
case format::is_os_yx_osv16_isv16:
|
||||
return kernel_selector::weights_layout::is_os_yx_osv16_isv16;
|
||||
case format::is_os_zyx_isv16_osv16:
|
||||
return kernel_selector::weights_layout::is_os_zyx_isv16_osv16;
|
||||
case format::is_os_yx_isv16_osv16:
|
||||
return kernel_selector::weights_layout::is_os_yx_isv16_osv16;
|
||||
case format::os_is_osv32_isv32_swizzled_by_4:
|
||||
return kernel_selector::weights_layout::os_is_osv32_isv32_swizzled_by_4;
|
||||
case format::os_is_zyx_isv8_osv16_isv2:
|
||||
@ -315,10 +315,10 @@ kernel_selector::weights_layout to_weights_layout(format f) {
|
||||
return kernel_selector::weights_layout::gs_oiyx_gsv32;
|
||||
case format::gyxio:
|
||||
return kernel_selector::weights_layout::gyxio;
|
||||
case format::g_is_os_zyx_osv16_isv16:
|
||||
return kernel_selector::weights_layout::g_is_os_zyx_osv16_isv16;
|
||||
case format::g_is_os_yx_osv16_isv16:
|
||||
return kernel_selector::weights_layout::g_is_os_yx_osv16_isv16;
|
||||
case format::g_is_os_zyx_isv16_osv16:
|
||||
return kernel_selector::weights_layout::g_is_os_zyx_isv16_osv16;
|
||||
case format::g_is_os_yx_isv16_osv16:
|
||||
return kernel_selector::weights_layout::g_is_os_yx_isv16_osv16;
|
||||
case format::g_os_is_zyx_isv8_osv16_isv2:
|
||||
return kernel_selector::weights_layout::g_os_is_zyx_isv8_osv16_isv2;
|
||||
case format::g_os_is_yx_isv8_osv16_isv2:
|
||||
@ -418,10 +418,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
|
||||
return cldnn::format::bfzyx;
|
||||
case kernel_selector::weights_layout::os_is_zyx_isv16_osv16:
|
||||
return cldnn::format::os_is_zyx_isv16_osv16;
|
||||
case kernel_selector::weights_layout::is_os_zyx_osv16_isv16:
|
||||
return cldnn::format::is_os_zyx_osv16_isv16;
|
||||
case kernel_selector::weights_layout::is_os_yx_osv16_isv16:
|
||||
return cldnn::format::is_os_yx_osv16_isv16;
|
||||
case kernel_selector::weights_layout::is_os_zyx_isv16_osv16:
|
||||
return cldnn::format::is_os_zyx_isv16_osv16;
|
||||
case kernel_selector::weights_layout::is_os_yx_isv16_osv16:
|
||||
return cldnn::format::is_os_yx_isv16_osv16;
|
||||
case kernel_selector::weights_layout::os_is_zyx_isv8_osv16_isv2:
|
||||
return cldnn::format::os_is_zyx_isv8_osv16_isv2;
|
||||
case kernel_selector::weights_layout::os_zyxi_osv16:
|
||||
@ -442,10 +442,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
|
||||
return cldnn::format::gs_oiyx_gsv32;
|
||||
case kernel_selector::weights_layout::gyxio:
|
||||
return cldnn::format::gyxio;
|
||||
case kernel_selector::weights_layout::g_is_os_zyx_osv16_isv16:
|
||||
return cldnn::format::g_is_os_zyx_osv16_isv16;
|
||||
case kernel_selector::weights_layout::g_is_os_yx_osv16_isv16:
|
||||
return cldnn::format::g_is_os_yx_osv16_isv16;
|
||||
case kernel_selector::weights_layout::g_is_os_zyx_isv16_osv16:
|
||||
return cldnn::format::g_is_os_zyx_isv16_osv16;
|
||||
case kernel_selector::weights_layout::g_is_os_yx_isv16_osv16:
|
||||
return cldnn::format::g_is_os_yx_isv16_osv16;
|
||||
case kernel_selector::weights_layout::g_os_is_zyx_isv8_osv16_isv2:
|
||||
return cldnn::format::g_os_is_zyx_isv8_osv16_isv2;
|
||||
case kernel_selector::weights_layout::g_os_is_yx_isv8_osv16_isv2:
|
||||
|
@ -186,6 +186,8 @@ memory_impl::ptr primitive_inst::allocate_output() {
|
||||
_node.get_memory_dependencies(),
|
||||
alloc_type,
|
||||
false);
|
||||
} else if (_network.is_internal() && _node.is_output() && _node.is_type<generic_layer>()) {
|
||||
return engine.allocate_memory(layout, allocation_type::usm_device, net_id);
|
||||
} else if (_network.is_internal() || (!_node.can_share_buffer()) || _node.can_be_optimized() || _node.is_output()) {
|
||||
return engine.allocate_memory(layout, alloc_type, net_id);
|
||||
}
|
||||
|
@ -3730,21 +3730,21 @@ using deconv_test_params = bc_test_params;
|
||||
|
||||
// in_shape; out_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format;
|
||||
#define CASE_DECONV_FP32_1 {1, 15, 4, 5}, {1, 30, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_2 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_2 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_4 {1, 32, 4, 5}, {1, 32, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 32, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::gs_oiyx_gsv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_5 {1, 15, 4, 5}, {1, 30, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_6 {1, 16, 4, 5}, {1, 32, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_7 {1, 16, 4, 5}, {1, 32, 7, 9}, {1, 1, 1, 1}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_6 {1, 16, 4, 5}, {1, 32, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_7 {1, 16, 4, 5}, {1, 32, 7, 9}, {1, 1, 1, 1}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_FP32_8 {1, 32, 4, 5}, {1, 32, 7, 9}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 32, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::gs_oiyx_gsv16, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_DECONV_FP16_1 {1, 15, 4, 5}, {1, 30, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_2 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_osv16_isv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_osv16_isv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_2 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_isv16_osv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_isv16_osv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_4 {1, 32, 4, 5}, {1, 32, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 32, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::gs_oiyx_gsv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_5 {1, 15, 4, 5}, {1, 30, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_6 {1, 16, 4, 5}, {1, 32, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_osv16_isv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_7 {1, 16, 4, 5}, {1, 32, 7, 9}, {1, 1, 1, 1}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_osv16_isv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_6 {1, 16, 4, 5}, {1, 32, 9, 11}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_isv16_osv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_7 {1, 16, 4, 5}, {1, 32, 7, 9}, {1, 1, 1, 1}, tensor{1, 1, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::is_os_yx_isv16_osv16, data_types::f16, format::bfyx
|
||||
#define CASE_DECONV_FP16_8 {1, 32, 4, 5}, {1, 32, 7, 9}, {1, 1, 3, 3}, tensor{1, 1, 2, 2}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 32, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::gs_oiyx_gsv16, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_DECONV_S8S8_1 {1, 15, 4, 5}, {1, 30, 6, 7}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
|
||||
@ -3768,24 +3768,24 @@ using deconv_test_params = bc_test_params;
|
||||
// 3D
|
||||
// in_shape; out_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format;
|
||||
#define CASE_DECONV_FP32_3D_1 {1, 15, 4, 5, 3}, {1, 30, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfzyx, data_types::f32, format::oizyx, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_2 {1, 16, 4, 5, 3}, {1, 32, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_osv16_isv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_3 {1, 16, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_osv16_isv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_2 {1, 16, 4, 5, 3}, {1, 32, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_3 {1, 16, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_4 {1, 32, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, -1}, tensor{1}, 32, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::gs_oizyx_gsv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_5 {1, 15, 4, 5, 3}, {1, 30, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfzyx, data_types::f32, format::oizyx, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_6 {1, 16, 4, 5, 3}, {1, 32, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_osv16_isv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_7 {1, 16, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_osv16_isv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_6 {1, 16, 4, 5, 3}, {1, 32, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_7 {1, 16, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::is_os_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_8 {1, 32, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0, 0, -1, -1, -1}, tensor{1}, 32, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::gs_oizyx_gsv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_9 {16, 16, 4, 5, 3}, {16, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, data_types::f32, format::is_os_zyx_osv16_isv16, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_FP32_3D_9 {16, 16, 4, 5, 3}, {16, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, data_types::f32, format::is_os_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
|
||||
#define CASE_DECONV_FP16_3D_1 {1, 15, 4, 5, 3}, {1, 30, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::bfzyx, data_types::f16, format::oizyx, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_2 {1, 16, 4, 5, 3}, {1, 32, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_osv16_isv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_3 {1, 16, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_osv16_isv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_2 {1, 16, 4, 5, 3}, {1, 32, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_3 {1, 16, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_4 {1, 32, 4, 5, 3}, {1, 32, 4, 5, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, -1}, tensor{1}, 32, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::gs_oizyx_gsv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_5 {1, 15, 4, 5, 3}, {1, 30, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::bfzyx, data_types::f16, format::oizyx, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_6 {1, 16, 4, 5, 3}, {1, 32, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_osv16_isv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_7 {1, 16, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_osv16_isv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_6 {1, 16, 4, 5, 3}, {1, 32, 9, 11, 7}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_7 {1, 16, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::is_os_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_8 {1, 32, 4, 5, 3}, {1, 32, 7, 9, 5}, {1, 1, 3, 3, 3}, tensor{1, 1, 2, 2, 2}, tensor{0, 0, -1, -1, -1}, tensor{1}, 32, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::gs_oizyx_gsv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_9 {16, 16, 4, 5, 3}, {16, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::bs_fs_zyx_bsv16_fsv16, data_types::f16, format::is_os_zyx_osv16_isv16, data_types::f16, format::bfzyx
|
||||
#define CASE_DECONV_FP16_3D_9 {16, 16, 4, 5, 3}, {16, 32, 7, 9, 5}, {1, 1, 1, 1, 1}, tensor{1, 1, 2, 2, 2}, tensor{0}, tensor{1}, 1, data_types::f16, format::bs_fs_zyx_bsv16_fsv16, data_types::f16, format::is_os_zyx_isv16_osv16, data_types::f16, format::bfzyx
|
||||
|
||||
#define CASE_DECONV_S8S8_3D_1 {1, 15, 4, 5, 3}, {1, 30, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfzyx, data_types::i8, format::oizyx, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_S8S8_3D_2 {1, 16, 4, 5, 3}, {1, 32, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::b_fs_zyx_fsv16, data_types::i8, format::oizyx, data_types::f32, format::bfzyx
|
||||
@ -3807,7 +3807,7 @@ using deconv_test_params = bc_test_params;
|
||||
|
||||
#define CASE_DECONV_ELTW_FP32_1 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 32, 1, 1}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::oiyx, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_ELTW_FP32_2 {1, 16, 4, 5}, {1, 32, 6, 7}, {1, 1, 1, 1}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::os_is_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_ELTW_FP32_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_osv16_isv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_ELTW_FP32_3 {1, 16, 4, 5}, {1, 32, 4, 5}, {1, 1, 1, 1}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::is_os_yx_isv16_osv16, data_types::f32, format::bfyx
|
||||
#define CASE_DECONV_ELTW_FP32_4 {1, 15, 4, 5, 3}, {1, 30, 6, 7, 5}, {1, 1, 6, 7, 5}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfzyx, data_types::f32, format::oizyx, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_ELTW_FP32_5 {1, 15, 4, 5, 4}, {1, 30, 6, 7, 6}, {1, 30, 6, 1, 6}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::bfzyx, data_types::f32, format::oizyx, data_types::f32, format::bfzyx
|
||||
#define CASE_DECONV_ELTW_FP32_6 {1, 32, 2, 2, 2}, {1, 16, 4, 4, 4}, {1, 16, 1, 4, 1}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::os_is_zyx_isv16_osv16, data_types::f32, format::bfzyx
|
||||
|
Loading…
Reference in New Issue
Block a user