* [GPU] oneDNN3.3 integration.
* Supports new formats from oneDNN3.3 requires.
* Fix Perf regression because of the wrong mvn kernel selection issue.
    modnet_webcam_portrait_matting.int8
    person-reidentification-retail-0248.int8
* support undefined onednn tag for using any tag instead.

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim 2023-10-10 12:47:55 +09:00 committed by GitHub
parent 86d0bdb2db
commit bf9bdaa671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 435 additions and 77 deletions

View File

@ -90,6 +90,8 @@ struct format {
b_fs_zyx_fsv2,
b_fs_yx_fsv4, ///< format for input for IMAD convolutions
b_fs_zyx_fsv4, ///< format for input for IMAD 3D convolutions
b_fs_yx_fsv8,
b_fs_zyx_fsv8,
b_fs_yx_fsv16, ///< format used for blocked convolution
b_fs_yx_fsv32, ///< format used for blocked int8 convolution
b_fs_zyx_fsv16, ///< format used for 3D blocked convolution (features blocked by 16)
@ -107,6 +109,8 @@ struct format {
bs_fs_zyx_bsv8_fsv2, ///< format used for 3D blocked convolution (batch and features blocked by 8 and 2)
bs_fs_yx_bsv16_fsv2, ///< format used for 2D blocked convolution (batch and features blocked by 16 and 2)
bs_fs_zyx_bsv16_fsv2, ///< format used for 3D blocked convolution (batch and features blocked by 16 and 2)
bs_fs_yx_bsv16_fsv8, ///< format used for 2D blocked convolution (batch and features blocked by 16 and 8)
bs_fs_zyx_bsv16_fsv8, ///< format used for 3D blocked convolution (batch and features blocked by 16 and 8)
bs_fs_yx_bsv4_fsv2, ///< format used for 2D blocked convolution (batch blocked by 4, features blocked by 2)
bs_fs_zyx_bsv4_fsv4, ///< format used for 3D blocked convolution (batch and features blocked by 4)
bs_fs_zyx_bsv4_fsv2, ///< format used for 3D blocked convolution (batch blocked by 4, features blocked by 2)
@ -135,7 +139,9 @@ struct format {
oyix,
oxiy,
os_iyx_osv16, ///< format used only for convolution weights
o_is_yx_isv4, ///< format used only for convolution weights
o_is_yx_isv16, ///< format used only for convolution weights
o_is_zyx_isv16, ///< format used only for convolution weights
os_yxi_osv16, ///< format used only for convolution weights
os_is_yx_osv16_isv2, ///< format used only for convolution weights
os_is_yx_osv16_isv16, ///< format used for convolution i8 weights
@ -145,6 +151,7 @@ struct format {
os_is_yx_isv16_osv16, ///< format used for blocked convolution
os_is_zyx_isv16_osv16, ///< format used for weights for blocked 3D convolution
is_os_zyx_isv16_osv16, ///< format used for weights for blocked 3D deconvolution
is_os_yx_osv8_isv4, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv16, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv8, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv4, ///< format used for weights for blocked deconvolution
@ -232,6 +239,8 @@ struct format {
os_zyx_is_osv8_isv4,
os_zy_is_x_osv8_isv2,
os_zy_is_x_osv8_isv4,
os_is_yx_osv4_isv16,
os_is_yx_osv2_isv16,
goiyx, ///< format used for weights for 2D convolution
gioyx, ///< format used for weights for 2D deconvolution
@ -241,7 +250,9 @@ struct format {
g_os_iyx_osv8, ///< format used for weights for 2D convolution
g_os_iyx_osv16, ///< format used for weights for 2D convolution
g_os_iyx_osv32, ///< format used for weights for 2D convolution
gs_oiyx_gsv8, ///< format used for weights for 2D convolution
gs_oiyx_gsv16, ///< format used for weights for 2D convolution
gs_oizyx_gsv8, ///< format used for weights for 3D convolution
gs_oizyx_gsv16, ///< format used for weights for 3D convolution
gs_oiyx_gsv32, ///< format used for weights for 2D convolution
gs_oizyx_gsv32, ///< format used for weights for 3D convolution

View File

@ -219,6 +219,8 @@ kernel_selector::data_layout to_data_layout(format f) {
return kernel_selector::data_layout::b_fs_yx_fsv2;
case format::b_fs_yx_fsv4:
return kernel_selector::data_layout::b_fs_yx_fsv4;
case format::b_fs_yx_fsv8:
return kernel_selector::data_layout::b_fs_yx_fsv8;
case format::b_fs_yx_fsv16:
return kernel_selector::data_layout::b_fs_yx_fsv16;
case format::b_fs_yx_fsv32:
@ -227,6 +229,8 @@ kernel_selector::data_layout to_data_layout(format f) {
return kernel_selector::data_layout::b_fs_zyx_fsv2;
case format::b_fs_zyx_fsv4:
return kernel_selector::data_layout::b_fs_zyx_fsv4;
case format::b_fs_zyx_fsv8:
return kernel_selector::data_layout::b_fs_zyx_fsv8;
case format::b_fs_zyx_fsv32:
return kernel_selector::data_layout::b_fs_zyx_fsv32;
case format::bs_f_bsv16:
@ -277,6 +281,10 @@ kernel_selector::data_layout to_data_layout(format f) {
return kernel_selector::data_layout::bs_fs_yx_bsv16_fsv2;
case format::bs_fs_zyx_bsv16_fsv2:
return kernel_selector::data_layout::bs_fs_zyx_bsv16_fsv2;
case format::bs_fs_yx_bsv16_fsv8:
return kernel_selector::data_layout::bs_fs_yx_bsv16_fsv8;
case format::bs_fs_zyx_bsv16_fsv8:
return kernel_selector::data_layout::bs_fs_zyx_bsv16_fsv8;
case format::bs_fs_yx_bsv8_fsv2:
return kernel_selector::data_layout::bs_fs_yx_bsv8_fsv2;
case format::bs_fs_zyx_bsv8_fsv2:
@ -320,10 +328,14 @@ cldnn::format from_data_layout(kernel_selector::data_layout l) {
return cldnn::format::b_fs_yx_fsv2;
case kernel_selector::data_layout::b_fs_yx_fsv4:
return cldnn::format::b_fs_yx_fsv4;
case kernel_selector::data_layout::b_fs_yx_fsv8:
return cldnn::format::b_fs_yx_fsv8;
case kernel_selector::data_layout::b_fs_yx_fsv16:
return cldnn::format::b_fs_yx_fsv16;
case kernel_selector::data_layout::b_fs_yx_fsv32:
return cldnn::format::b_fs_yx_fsv32;
case kernel_selector::data_layout::b_fs_zyx_fsv8:
return cldnn::format::b_fs_zyx_fsv8;
case kernel_selector::data_layout::b_fs_zyx_fsv32:
return cldnn::format::b_fs_zyx_fsv32;
case kernel_selector::data_layout::bs_f_bsv8__af8:
@ -366,6 +378,10 @@ cldnn::format from_data_layout(kernel_selector::data_layout l) {
return cldnn::format::bs_fs_yx_bsv16_fsv2;
case kernel_selector::data_layout::bs_fs_zyx_bsv16_fsv2:
return cldnn::format::bs_fs_zyx_bsv16_fsv2;
case kernel_selector::data_layout::bs_fs_yx_bsv16_fsv8:
return cldnn::format::bs_fs_yx_bsv16_fsv8;
case kernel_selector::data_layout::bs_fs_zyx_bsv16_fsv8:
return cldnn::format::bs_fs_zyx_bsv16_fsv8;
case kernel_selector::data_layout::bs_fs_yx_bsv8_fsv2:
return cldnn::format::bs_fs_yx_bsv8_fsv2;
case kernel_selector::data_layout::bs_fs_yx_bsv32_fsv32:
@ -406,8 +422,12 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::yxio;
case format::os_yxi_osv16:
return kernel_selector::weights_layout::os_yxi_osv16;
case format::o_is_yx_isv4:
return kernel_selector::weights_layout::o_is_yx_isv4;
case format::o_is_yx_isv16:
return kernel_selector::weights_layout::o_is_yx_isv16;
case format::o_is_zyx_isv16:
return kernel_selector::weights_layout::o_is_zyx_isv16;
case format::os_iyx_osv16:
return kernel_selector::weights_layout::os_iyx_osv16;
case format::os_is_yx_osv16_isv2:
@ -474,6 +494,10 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::os_is_zyx_osv8_isv2;
case format::os_is_yx_osv8_isv4:
return kernel_selector::weights_layout::os_is_yx_osv8_isv4;
case format::os_is_yx_osv4_isv16:
return kernel_selector::weights_layout::os_is_yx_osv4_isv16;
case format::os_is_yx_osv2_isv16:
return kernel_selector::weights_layout::os_is_yx_osv2_isv16;
case format::os_is_zyx_osv8_isv4:
return kernel_selector::weights_layout::os_is_zyx_osv8_isv4;
case format::os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4:
@ -527,6 +551,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::is_os_zyx_isv16_osv16;
case format::os_is_zyx_osv32_isv16:
return kernel_selector::weights_layout::os_is_zyx_osv32_isv16;
case format::is_os_yx_osv8_isv4:
return kernel_selector::weights_layout::is_os_yx_osv8_isv4;
case format::is_os_yx_isv16_osv16:
return kernel_selector::weights_layout::is_os_yx_isv16_osv16;
case format::is_os_yx_isv16_osv8:
@ -567,6 +593,10 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::g_os_iyx_osv16;
case format::g_os_iyx_osv32:
return kernel_selector::weights_layout::g_os_iyx_osv32;
case format::gs_oiyx_gsv8:
return kernel_selector::weights_layout::gs_oiyx_gsv8;
case format::gs_oizyx_gsv8:
return kernel_selector::weights_layout::gs_oizyx_gsv8;
case format::gs_oiyx_gsv16:
return kernel_selector::weights_layout::gs_oiyx_gsv16;
case format::gs_oizyx_gsv16:
@ -711,8 +741,12 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::yxio;
case kernel_selector::weights_layout::os_yxi_osv16:
return cldnn::format::os_yxi_osv16;
case kernel_selector::weights_layout::o_is_yx_isv4:
return cldnn::format::o_is_yx_isv4;
case kernel_selector::weights_layout::o_is_yx_isv16:
return cldnn::format::o_is_yx_isv16;
case kernel_selector::weights_layout::o_is_zyx_isv16:
return cldnn::format::o_is_zyx_isv16;
case kernel_selector::weights_layout::os_iyx_osv16:
return cldnn::format::os_iyx_osv16;
case kernel_selector::weights_layout::os_is_yx_isv16_osv16:
@ -821,6 +855,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::os_is_zyx_isv16_osv16;
case kernel_selector::weights_layout::is_os_zyx_isv16_osv16:
return cldnn::format::is_os_zyx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_osv8_isv4:
return cldnn::format::is_os_yx_osv8_isv4;
case kernel_selector::weights_layout::is_os_yx_isv16_osv16:
return cldnn::format::is_os_yx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_isv16_osv8:
@ -855,6 +891,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::os_is_zyx_osv8_isv2;
case kernel_selector::weights_layout::os_is_yx_osv8_isv4:
return cldnn::format::os_is_yx_osv8_isv4;
case kernel_selector::weights_layout::os_is_yx_osv4_isv16:
return cldnn::format::os_is_yx_osv4_isv16;
case kernel_selector::weights_layout::os_is_yx_osv2_isv16:
return cldnn::format::os_is_yx_osv2_isv16;
case kernel_selector::weights_layout::os_is_zyx_osv8_isv4:
return cldnn::format::os_is_zyx_osv8_isv4;
case kernel_selector::weights_layout::os_is_zyx_isv8_osv16_isv2:
@ -871,6 +911,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::g_os_iyx_osv16;
case kernel_selector::weights_layout::g_os_iyx_osv32:
return cldnn::format::g_os_iyx_osv32;
case kernel_selector::weights_layout::gs_oiyx_gsv8:
return cldnn::format::gs_oiyx_gsv8;
case kernel_selector::weights_layout::gs_oizyx_gsv8:
return cldnn::format::gs_oizyx_gsv8;
case kernel_selector::weights_layout::gs_oiyx_gsv16:
return cldnn::format::gs_oiyx_gsv16;
case kernel_selector::weights_layout::gs_oizyx_gsv16:

View File

@ -53,33 +53,40 @@ protected:
dnnl::memory::desc desc = onednn::layout_to_memory_desc(a_zp->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, a_zp->get_onednn_memory(desc)});
auto dnnl_mem = a_zp->get_onednn_memory(desc);
void *mapped_ptr = dnnl_mem.map_data();
if (mapped_ptr) {
GPU_DEBUG_TRACE_DETAIL << instance.id() << " activations_zero_points: ";
for (size_t i = 0; i < desc.get_size(); ++i) {
GPU_DEBUG_TRACE_DETAIL << static_cast<int32_t*>(mapped_ptr)[i] << " ";
GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_IF(debug_config->verbose >= static_cast<int>(ov::intel_gpu::LogLevel::TRACE_DETAIL)) {
auto dnnl_mem = a_zp->get_onednn_memory(desc);
void *mapped_ptr = dnnl_mem.map_data();
if (mapped_ptr) {
GPU_DEBUG_TRACE_DETAIL << instance.id() << " activations_zero_points: ";
for (size_t i = 0; i < desc.get_size(); ++i) {
GPU_DEBUG_TRACE_DETAIL << static_cast<int32_t*>(mapped_ptr)[i] << " ";
}
GPU_DEBUG_TRACE_DETAIL << std::endl;
dnnl_mem.unmap_data(mapped_ptr);
}
GPU_DEBUG_TRACE_DETAIL << std::endl;
dnnl_mem.unmap_data(mapped_ptr);
}
}
if (instance.weights_zero_points_term()) {
auto w_zp = instance.weights_zero_points_memory();
dnnl::memory::desc desc = onednn::layout_to_memory_desc(w_zp->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, w_zp->get_onednn_memory(desc)});
throw std::runtime_error("Convolution oneDNN primitive doesn't support asymmetric weights quantization");
// auto w_zp = instance.weights_zero_points_memory();
// dnnl::memory::desc desc = onednn::layout_to_memory_desc(w_zp->get_layout(), dnnl::memory::format_tag::a, true);
// args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, w_zp->get_onednn_memory(desc)});
auto dnnl_mem = w_zp->get_onednn_memory(desc);
void *mapped_ptr = dnnl_mem.map_data();
if (mapped_ptr) {
GPU_DEBUG_TRACE_DETAIL << instance.id() << " weights_zero_points: ";
for (size_t i = 0; i < desc.get_size(); ++i) {
GPU_DEBUG_TRACE_DETAIL << static_cast<int32_t*>(mapped_ptr)[i] << " ";
}
GPU_DEBUG_TRACE_DETAIL << std::endl;
dnnl_mem.unmap_data(mapped_ptr);
}
// GPU_DEBUG_GET_INSTANCE(debug_config);
// GPU_DEBUG_IF(debug_config->verbose >= static_cast<int>(ov::intel_gpu::LogLevel::TRACE_DETAIL)) {
// auto dnnl_mem = w_zp->get_onednn_memory(desc);
// void *mapped_ptr = dnnl_mem.map_data();
// if (mapped_ptr) {
// GPU_DEBUG_TRACE_DETAIL << instance.id() << " weights_zero_points: ";
// for (size_t i = 0; i < desc.get_size(); ++i) {
// GPU_DEBUG_TRACE_DETAIL << static_cast<int32_t*>(mapped_ptr)[i] << " ";
// }
// GPU_DEBUG_TRACE_DETAIL << std::endl;
// dnnl_mem.unmap_data(mapped_ptr);
// }
// }
}
return args;
@ -255,6 +262,8 @@ attach_convolution_onednn::attach_convolution_onednn() {
format::b_fs_zyx_fsv2,
format::b_fs_yx_fsv4,
format::b_fs_zyx_fsv4,
format::b_fs_yx_fsv8,
format::b_fs_zyx_fsv8,
format::b_fs_yx_fsv16,
format::b_fs_zyx_fsv16,
format::b_fs_zyx_fsv32,
@ -269,9 +278,11 @@ attach_convolution_onednn::attach_convolution_onednn() {
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_yx_bsv4_fsv4,
format::bs_fs_yx_bsv8_fsv4,
format::bs_fs_yx_bsv16_fsv8,
format::bs_fs_yx_bsv16_fsv4,
format::bs_fs_yx_bsv16_fsv2,
format::bs_fs_zyx_bsv8_fsv4,
format::bs_fs_zyx_bsv16_fsv8,
format::bs_fs_zyx_bsv16_fsv4,
format::bs_fs_zyx_bsv16_fsv2,
format::bs_fs_yx_bsv8_fsv2,

View File

@ -13,23 +13,6 @@
namespace cldnn {
namespace onednn {
namespace {
std::string convert_data_format_string(cldnn::format fmt) {
switch (fmt) {
case cldnn::format::b_fs_yx_fsv2: return "aBcd2b";
case cldnn::format::b_fs_zyx_fsv2: return "aBcde2b";
case cldnn::format::bs_fs_yx_bsv16_fsv2: return "ABcd16a2b";
case cldnn::format::bs_fs_zyx_bsv16_fsv2: return "ABcde16a2b";
case cldnn::format::bs_fs_yx_bsv16_fsv4: return "ABcd16a4b";
case cldnn::format::bs_fs_zyx_bsv16_fsv4: return "ABcde16a4b";
case cldnn::format::bs_fs_yx_bsv16_fsv32: return "ABcd16a32b";
case cldnn::format::bs_fs_zyx_bsv16_fsv32: return "ABcde16a32b";
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn layout string" + fmt_to_str(fmt));
}
}
} // namespace
template <typename T>
cldnn::memory::ptr convert_zp_data_to_s32(const memory::ptr zp_memory) {
auto engine = zp_memory->get_engine();
@ -132,9 +115,11 @@ std::vector<std::pair<cldnn::format, dnnl::memory::format_tag>> format_map = {
{ cldnn::format::bzyxf, dnnl::memory::format_tag::ndhwc },
{ cldnn::format::b_fs_yx_fsv2, dnnl::memory::format_tag::undef },
{ cldnn::format::b_fs_yx_fsv4, dnnl::memory::format_tag::aBcd4b },
{ cldnn::format::b_fs_yx_fsv8, dnnl::memory::format_tag::aBcd8b },
{ cldnn::format::b_fs_yx_fsv16, dnnl::memory::format_tag::nChw16c },
{ cldnn::format::b_fs_yx_fsv32, dnnl::memory::format_tag::aBcd32b },
{ cldnn::format::b_fs_zyx_fsv4, dnnl::memory::format_tag::aBcde4b },
{ cldnn::format::b_fs_zyx_fsv8, dnnl::memory::format_tag::aBcde8b },
{ cldnn::format::b_fs_zyx_fsv16, dnnl::memory::format_tag::nCdhw16c },
{ cldnn::format::b_fs_zyx_fsv32, dnnl::memory::format_tag::aBcde32b },
{ cldnn::format::bs_fs_yx_bsv16_fsv16, dnnl::memory::format_tag::NChw16n16c },
@ -157,8 +142,10 @@ dnnl::memory::format_tag convert_data_format(cldnn::format fmt) {
auto ret = std::find_if(format_map.begin(), format_map.end(),
[fmt](std::pair<cldnn::format, dnnl::memory::format_tag> &e) {
return e.first == fmt; });
if (ret == format_map.end())
return dnnl::memory::format_tag::undef;
if (ret == format_map.end()) {
GPU_DEBUG_INFO << "[clDNN] Unsupported conversion from "+ fmt.to_string() + " to onednn format_tag. Any tag will be used instead." << std::endl;
return dnnl::memory::format_tag::any;
}
return ret->second;
}
@ -233,8 +220,6 @@ dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_t
dnnl::memory::data_type dt = convert_data_type(l.data_type);
dnnl::memory::format_tag fmt = target_fmt == dnnl::memory::format_tag::undef ? convert_data_format(l.format) : target_fmt;
OPENVINO_ASSERT(fmt != dnnl::memory::format_tag::undef, "[GPU] Unexpected fmt: ", convert_data_format_string(l.format));
dnnl::memory::desc res(dims, dt, fmt);
return res;

View File

@ -1535,28 +1535,32 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
std::vector<format> onednn_optimized_formats = {
format::byxf,
format::bzyxf,
format::b_fs_zyx_fsv32,
format::b_fs_yx_fsv32,
format::b_fs_zyx_fsv16,
format::b_fs_yx_fsv8,
format::b_fs_zyx_fsv8,
format::b_fs_yx_fsv16,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_yx_bsv16_fsv32,
format::bs_fs_zyx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_yx_bsv32_fsv32,
format::bs_fs_zyx_bsv8_fsv4,
format::b_fs_zyx_fsv16,
format::b_fs_yx_fsv32,
format::b_fs_zyx_fsv32,
format::bs_fs_yx_bsv4_fsv2,
format::bs_fs_yx_bsv4_fsv4,
format::bs_fs_yx_bsv8_fsv2,
format::bs_fs_zyx_bsv8_fsv2,
format::bs_fs_yx_bsv8_fsv4,
format::bs_fs_yx_bsv16_fsv4,
format::bs_fs_zyx_bsv16_fsv4,
format::bs_fs_zyx_bsv8_fsv4,
format::bs_fs_yx_bsv16_fsv2,
format::bs_fs_zyx_bsv16_fsv2,
format::bs_fs_zyx_bsv8_fsv2,
format::bs_fs_yx_bsv8_fsv2,
format::bs_fs_yx_bsv4_fsv4,
format::bs_fs_yx_bsv4_fsv2,
format::bs_fs_yx_bsv16_fsv4,
format::bs_fs_zyx_bsv16_fsv4,
format::bs_fs_yx_bsv16_fsv8,
format::bs_fs_zyx_bsv16_fsv8,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_yx_bsv16_fsv32,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_zyx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bs_fs_zyx_bsv32_fsv32,
};
impl_types impl_candidate = impl_types::onednn;
@ -1715,9 +1719,9 @@ format layout_optimizer::get_preferred_format(program_node& node) {
expected = get_expected_format(node.as<deconvolution>());
} else if (node.is_type<mvn>()) {
auto input_layout = node.get_input_layout(0);
if (input_layout.format.dimension() == 5 &&
(input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16))
expected = format::bfzyx;
if (input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16) {
expected = format::get_default_format(input_layout.get_rank());
}
} else if (node.is_type<resample>()) {
// if the resample is in the last part of the network and there are no users using blocked format,
// it is better to reorder to bfyx before resample is done.

View File

@ -293,6 +293,38 @@ inline uint get_b_fs_yx_fsv_index_safe(uint b, uint f, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 4)
#define GET_DATA_B_FS_YX_FSV8_INDEX(prefix, b, f, y, x) \
get_b_fs_yx_fsv_index( \
b, f, y, x, \
CAT(prefix, _SIZE_X ), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_BATCH_NUM), \
CAT(prefix, _PAD_AFTER_BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8)
#define GET_DATA_B_FS_YX_FSV8_INDEX_SAFE(prefix, b, f, y, x) \
get_b_fs_yx_fsv_index_safe( \
b, f, y, x, \
CAT(prefix, _SIZE_X ), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_BATCH_NUM), \
CAT(prefix, _PAD_AFTER_BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8)
#define GET_DATA_B_FS_YX_FSV32_INDEX(prefix, b, f, y, x) \
get_b_fs_yx_fsv_index( \
b, f, y, x, \
@ -482,6 +514,38 @@ inline uint get_fs_b_yx_fsv32_index_safe(uint b, uint f, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 4)
#define GET_DATA_B_FS_ZYX_FSV8_INDEX(prefix, b, f, z, y, x) \
get_b_fs_zyx_fsv_index( \
b, f, z, y, x, \
CAT(prefix, _SIZE_X ), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8)
#define GET_DATA_B_FS_ZYX_FSV8_INDEX_SAFE(prefix, b, f, z, y, x) \
get_b_fs_zyx_fsv_index_safe( \
b, f, z, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8)
#define GET_DATA_B_FS_ZYX_FSV16_INDEX(prefix, b, f, z, y, x) \
get_b_fs_zyx_fsv_index( \
b, f, z, y, x, \
@ -775,6 +839,38 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 4)
#define GET_DATA_BS_FS_ZYX_BSV16_FSV8_INDEX(prefix, b, f, z, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \
b, f, z, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 8)
#define GET_DATA_BS_FS_YX_BSV16_FSV8_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 8)
#define GET_DATA_BS_FS_ZYX_BSV8_FSV4_INDEX(prefix, b, f, z, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \
b, f, z, y, x, \
@ -1053,6 +1149,40 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 4)
#define GET_DATA_BS_FS_YX_BSV16_FSV8_INDEX_SAFE(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 8)
#define GET_DATA_BS_FS_ZYX_BSV16_FSV8_INDEX_SAFE(prefix, b, f, z, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, z, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 8)
#define GET_DATA_BS_FS_YX_BSV8_FSV4_INDEX_SAFE(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \

View File

@ -16,6 +16,18 @@
isv \
)
#define GET_FILTER_IS_OS_YX_OSV_ISV_INDEX(prefix, o, i, y, x, osv, isv) \
get_os_is_zyx_isv_osv_index( \
i, o, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _IFM_NUM), \
isv, \
osv \
)
#define GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(prefix, o, i, y, x, osv, isv) \
get_is_os_zyx_isv_osv_index( \
o, i, 0, y, x, \
@ -1447,6 +1459,22 @@ inline uint get_g_os_is_yx_osv_isv(uint g, uint o, uint i, uint y, uint x,
x_size, y_size, 1, i_size, o_size, osv_size, isv_size);
}
#define GET_FILTER_OS_IS_YX_OSV2_ISV16_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osv_isv( \
0, o, i, y, x, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), 2, 16)
#define GET_FILTER_OS_IS_YX_OSV4_ISV16_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osv_isv( \
0, o, i, y, x, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), 4, 16)
#define GET_FILTER_OS_IS_YX_OSV8_ISV2_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osv_isv( \
0, o, i, y, x, \
@ -1826,14 +1854,15 @@ inline uint get_g_os_zyx_is_osv_isv_index(uint g, uint o, uint i, uint z, uint y
#define GET_FILTER_G_OS_ZYX_IS_OSV32_ISV16_INDEX(tensor, g, o, i, z, y, x) GET_FILTER_G_OS_ZYX_IS_OSV_ISV_INDEX(tensor, g, o, i, z, y, x, 32, 16)
#define GET_FILTER_G_OS_ZYX_IS_OSV32_ISV32_INDEX(tensor, g, o, i, z, y, x) GET_FILTER_G_OS_ZYX_IS_OSV_ISV_INDEX(tensor, g, o, i, z, y, x, 32, 32)
#define GET_FILTER_O_IS_YX_ISV16_INDEX(prefix, o, i, y, x, isv) \
CAT(prefix, _OFFSET) + \
((i) % (isv)) + \
(o)*CAT(prefix, _OFM_PITCH) + \
(isv)*( \
(x)*CAT(prefix, _X_PITCH) + \
(y)*CAT(prefix, _Y_PITCH) + \
((i) / (isv))*CAT(prefix, _IFM_PITCH) \
#define GET_FILTER_O_IS_ZYX_ISV16_INDEX(prefix, o, i, z, y, x, isv) \
CAT(prefix, _OFFSET) + \
((i) % (isv)) + \
(o)*CAT(prefix, _OFM_PITCH) + \
(isv)*( \
(x)*CAT(prefix, _X_PITCH) + \
(y)*CAT(prefix, _Y_PITCH) + \
(z)*CAT(prefix, _Z_PITCH) + \
((i) / (isv))*CAT(prefix, _IFM_PITCH) \
)
#define GET_FILTER_OS_YXI_OSV16(prefix, o, i, y, x) \

View File

@ -280,8 +280,12 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
#elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
#elif defined INPUT0_LAYOUT_O_IS_YX_ISV4
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(INPUT0, o, i, 0, y, x, 4);
#elif defined INPUT0_LAYOUT_O_IS_YX_ISV16
return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16);
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(INPUT0, o, i, 0, y, x, 16);
#elif defined INPUT0_LAYOUT_O_IS_ZYX_ISV16
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(INPUT0, o, i, z, y, x, 16);
#elif defined INPUT0_LAYOUT_IYX_OSV64
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 64);
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180
@ -320,6 +324,8 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV8
return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IS_OS_YX_OSV8_ISV4
return GET_FILTER_IS_OS_YX_OSV_ISV_INDEX(INPUT0, o, i, y, x, 8, 4);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV4
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 4);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV2
@ -384,6 +390,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_G_OS_IYX_OSV16(INPUT0, g, o, i, y, x, 16);
#elif defined INPUT0_LAYOUT_G_OS_IYX_OSV32
return GET_FILTER_G_OS_IYX_OSV16(INPUT0, g, o, i, y, x, 32);
#elif defined INPUT0_LAYOUT_GS_OIYX_GSV8
return GET_FILTER_GS_OIYX_GSV16(INPUT0, g, o, i, y, x, 8);
#elif defined INPUT0_LAYOUT_GS_OIZYX_GSV8
return GET_FILTER_GS_OIZYX_GSV16(INPUT0, g, o, i, z, y, x, 8);
#elif defined INPUT0_LAYOUT_GS_OIYX_GSV16
return GET_FILTER_GS_OIYX_GSV16(INPUT0, g, o, i, y, x, 16);
#elif defined INPUT0_LAYOUT_GS_OIZYX_GSV16
@ -426,6 +436,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_OS_IS_YX_OSV8_ISV4_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV8_ISV4
return GET_FILTER_OS_IS_ZYX_OSV8_ISV4_INDEX(INPUT0, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_YX_OSV2_ISV16
return GET_FILTER_OS_IS_YX_OSV2_ISV16_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_YX_OSV4_ISV16
return GET_FILTER_OS_IS_YX_OSV4_ISV16_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_G_OS_IS_ZYX_OSV16_ISV16
return GET_FILTER_G_OS_IS_ZYX_OSV16_ISV16_INDEX(INPUT0, g, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV32_ISV16
@ -487,8 +501,12 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV64
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 64);
#elif defined OUTPUT_LAYOUT_O_IS_YX_ISV4
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(OUTPUT, o, i, 0, y, x, 4);
#elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16
return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16);
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(OUTPUT, o, i, 0, y, x, 16);
#elif defined OUTPUT_LAYOUT_O_IS_ZYX_ISV16
return GET_FILTER_O_IS_ZYX_ISV16_INDEX(OUTPUT, o, i, z, y, x, 16);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180
return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16
@ -523,6 +541,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_OS_IS_YX_OSV8_ISV4_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSV8_ISV4
return GET_FILTER_OS_IS_ZYX_OSV8_ISV4_INDEX(OUTPUT, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV2_ISV16
return GET_FILTER_OS_IS_YX_OSV2_ISV16_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV4_ISV16
return GET_FILTER_OS_IS_YX_OSV4_ISV16_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV32_ISV4_SWIZZLED_BY_2
return GET_FILTER_OS_IS_YX_OSV32_ISV4_SWIZZLED_BY_2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV32_ISV4
@ -583,6 +605,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV8
return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_OSV8_ISV4
return GET_FILTER_IS_OS_YX_OSV_ISV_INDEX(OUTPUT, o, i, y, x, 8, 4);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV4
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 4);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV2
@ -645,6 +669,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_G_OS_IYX_OSV16(OUTPUT, g, o, i, y, x, 16);
#elif defined OUTPUT_LAYOUT_G_OS_IYX_OSV32
return GET_FILTER_G_OS_IYX_OSV16(OUTPUT, g, o, i, y, x, 32);
#elif defined OUTPUT_LAYOUT_GS_OIYX_GSV8
return GET_FILTER_GS_OIYX_GSV16(OUTPUT, g, o, i, y, x, 8);
#elif defined OUTPUT_LAYOUT_GS_OIZYX_GSV8
return GET_FILTER_GS_OIZYX_GSV16(OUTPUT, g, o, i, z, y, x, 8);
#elif defined OUTPUT_LAYOUT_GS_OIYX_GSV16
return GET_FILTER_GS_OIYX_GSV16(OUTPUT, g, o, i, y, x, 16);
#elif defined OUTPUT_LAYOUT_GS_OIZYX_GSV16

View File

@ -440,10 +440,12 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const {
layout == DataLayout::b_fs_yx_fsv32 ||
layout == DataLayout::b_fs_yx_fsv2 ||
layout == DataLayout::b_fs_yx_fsv4 ||
layout == DataLayout::b_fs_yx_fsv8 ||
layout == DataLayout::fs_b_yx_fsv32 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv16 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv32 ||
layout == DataLayout::bs_fs_yx_bsv4_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv8 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv2 ||
layout == DataLayout::bs_fs_yx_bsv8_fsv4 ||
@ -508,6 +510,10 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const {
index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX(" + _name + ", b, f, z, y, x)";
raw_index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX(" + _name + ", b, f, z, y, x)";
safe_index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX_SAFE(" + _name + ", b, f, z, y, x)";
} else if (layout == DataLayout::bs_fs_zyx_bsv16_fsv8) {
index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV8_INDEX(" + _name + ", b, f, z, y, x)";
raw_index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV8_INDEX(" + _name + ", b, f, z, y, x)";
safe_index_func_val = "GET_DATA_BS_FS_ZYX_BSV16_FSV8_INDEX_SAFE(" + _name + ", b, f, z, y, x)";
} else if (layout == DataLayout::b_fs_zyx_fsv32) {
index_func_val = "GET_DATA_B_FS_ZYX_FSV32_INDEX(" + _name + ", b, f, z, y, x)";
raw_index_func_val = "GET_DATA_B_FS_ZYX_FSV32_INDEX(" + _name + ", b, f, z, y, x)";
@ -536,6 +542,10 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const {
index_func_val = "GET_DATA_B_FS_ZYX_FSV4_INDEX(" + _name + ", b, f, z, y, x)";
raw_index_func_val = "GET_DATA_B_FS_ZYX_FSV4_INDEX(" + _name + ", b, f, z, y, x)";
safe_index_func_val = "GET_DATA_B_FS_ZYX_FSV4_INDEX_SAFE(" + _name + ", b, f, z, y, x)";
} else if (layout == DataLayout::b_fs_zyx_fsv8) {
index_func_val = "GET_DATA_B_FS_ZYX_FSV8_INDEX(" + _name + ", b, f, z, y, x)";
raw_index_func_val = "GET_DATA_B_FS_ZYX_FSV8_INDEX(" + _name + ", b, f, z, y, x)";
safe_index_func_val = "GET_DATA_B_FS_ZYX_FSV8_INDEX_SAFE(" + _name + ", b, f, z, y, x)";
} else {
index_func_val = "GET_DATA_INDEX_5D_RAW(" + _name + ", b, f, z, y, x)";
safe_index_func_val = "GET_DATA_INDEX_5D_RAW(" + _name + ", b, f, z, y, x)";

View File

@ -96,6 +96,7 @@ std::string toString(DataLayout l) {
case kernel_selector::DataLayout::fyxb: return "FYXB";
case kernel_selector::DataLayout::b_fs_yx_fsv2: return "B_FS_YX_FSV2";
case kernel_selector::DataLayout::b_fs_yx_fsv4: return "B_FS_YX_FSV4";
case kernel_selector::DataLayout::b_fs_yx_fsv8: return "B_FS_YX_FSV8";
case kernel_selector::DataLayout::b_fs_yx_fsv16: return "B_FS_YX_FSV16";
case kernel_selector::DataLayout::b_fs_yx_fsv32: return "B_FS_YX_FSV32";
case kernel_selector::DataLayout::b_fs_zyx_fsv32: return "B_FS_ZYX_FSV32";
@ -109,6 +110,7 @@ std::string toString(DataLayout l) {
case kernel_selector::DataLayout::bfwzyx: return "BFWZYX";
case kernel_selector::DataLayout::bfuwzyx: return "BFUWZYX";
case kernel_selector::DataLayout::bfvuwzyx: return "BFVUWZYX";
case kernel_selector::DataLayout::b_fs_zyx_fsv8: return "B_FS_ZYX_FSV8";
case kernel_selector::DataLayout::b_fs_zyx_fsv16: return "B_FS_ZYX_FSV16";
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv16: return "BS_FS_YX_BSV16_FSV16";
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv32: return "BS_FS_YX_BSV16_FSV32";
@ -117,6 +119,8 @@ std::string toString(DataLayout l) {
case kernel_selector::DataLayout::bs_fs_yx_bsv4_fsv4: return "BS_FS_YX_BSV4_FSV4";
case kernel_selector::DataLayout::bs_fs_yx_bsv8_fsv4: return "BS_FS_YX_BSV8_FSV4";
case kernel_selector::DataLayout::bs_fs_zyx_bsv8_fsv4: return "BS_FS_ZYX_BSV8_FSV4";
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv8: return "BS_FS_YX_BSV16_FSV8";
case kernel_selector::DataLayout::bs_fs_zyx_bsv16_fsv8: return "BS_FS_ZYX_BSV16_FSV8";
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv4: return "BS_FS_YX_BSV16_FSV4";
case kernel_selector::DataLayout::bs_fs_zyx_bsv16_fsv4: return "BS_FS_ZYX_BSV16_FSV4";
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv2: return "BS_FS_YX_BSV16_FSV2";
@ -312,7 +316,9 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::os_is_zyx_osv16_isv16: return "OS_IS_ZYX_OSV16_ISV16";
case WeightsLayout::os_is_zyx_osv32_isv16: return "OS_IS_ZYX_OSV32_ISV16";
case WeightsLayout::os_is_zyx_osv64_isv16: return "OS_IS_ZYX_OSV64_ISV16";
case WeightsLayout::o_is_yx_isv4: return "O_IS_YX_ISV4";
case WeightsLayout::o_is_yx_isv16: return "O_IS_YX_ISV16";
case WeightsLayout::o_is_zyx_isv16: return "O_IS_ZYX_ISV16";
case WeightsLayout::os_yxi_osv16: return "OS_YXI_OSV16";
case WeightsLayout::os_iyx_osv16: return "OS_IYX_OSV16";
case WeightsLayout::os_iyx_osv32: return "OS_IYX_OSV32";
@ -364,6 +370,7 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::os_is_zyx_isa8_osv8_isv2: return "OS_IS_ZYX_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa8_osv8_isv2: return "IS_OS_YX_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa8_osv8_isv4: return "IS_OS_YX_ISA8_OSV8_ISV4";
case WeightsLayout::is_os_yx_osv8_isv4: return "IS_OS_YX_OSV8_ISV4";
case WeightsLayout::is_os_yx_osa8_isv16_osv4: return "IS_OS_YX_OSA8_ISV16_OSV4";
case WeightsLayout::os_is_yx_isa8_osv8_isv2: return "OS_IS_YX_ISA8_OSV8_ISV2";
case WeightsLayout::os_is_zyx_isv8_osv16_isv2: return "OS_IS_ZYX_ISV8_OSV16_ISV2";
@ -374,6 +381,8 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::os_is_yx_osv8_isv4: return "OS_IS_YX_OSV8_ISV4";
case WeightsLayout::os_is_zyx_osv8_isv4: return "OS_IS_ZYX_OSV8_ISV4";
case WeightsLayout::os_is_yx_osv8_isv2: return "OS_IS_YX_OSV8_ISV2";
case WeightsLayout::os_is_yx_osv2_isv16: return "OS_IS_YX_OSV2_ISV16";
case WeightsLayout::os_is_yx_osv4_isv16: return "OS_IS_YX_OSV4_ISV16";
case WeightsLayout::os_is_zyx_osv8_isv2: return "OS_IS_ZYX_OSV8_ISV2";
case WeightsLayout::goiyx: return "GOIYX";
case WeightsLayout::gioyx: return "GIOYX";
@ -383,6 +392,8 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::g_os_iyx_osv8: return "G_OS_IYX_OSV8";
case WeightsLayout::g_os_iyx_osv16: return "G_OS_IYX_OSV16";
case WeightsLayout::g_os_iyx_osv32: return "G_OS_IYX_OSV32";
case WeightsLayout::gs_oiyx_gsv8: return "GS_OIYX_GSV8";
case WeightsLayout::gs_oizyx_gsv8: return "GS_OIZYX_GSV8";
case WeightsLayout::gs_oiyx_gsv16: return "GS_OIYX_GSV16";
case WeightsLayout::gs_oizyx_gsv16: return "GS_OIZYX_GSV16";
case WeightsLayout::gs_oiyx_gsv32: return "GS_OIYX_GSV32";

View File

@ -252,12 +252,14 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
auto blocked_fsv_layout = output_layout == DataLayout::b_fs_yx_fsv2 || output_layout == DataLayout::b_fs_zyx_fsv2 ||
output_layout == DataLayout::b_fs_yx_fsv4 || output_layout == DataLayout::b_fs_zyx_fsv4 ||
output_layout == DataLayout::b_fs_yx_fsv8 || output_layout == DataLayout::b_fs_zyx_fsv8 ||
output_layout == DataLayout::b_fs_yx_fsv16 || output_layout == DataLayout::b_fs_zyx_fsv16 ||
output_layout == DataLayout::b_fs_yx_fsv32 || output_layout == DataLayout::b_fs_zyx_fsv32 ||
output_layout == DataLayout::fs_b_yx_fsv32;
auto blocked_bsv_fsv_layout = output_layout == DataLayout::bs_fs_yx_bsv16_fsv2 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv2 ||
output_layout == DataLayout::bs_fs_yx_bsv16_fsv4 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv4 ||
output_layout == DataLayout::bs_fs_yx_bsv16_fsv8 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv8 ||
output_layout == DataLayout::bs_fs_yx_bsv16_fsv16 || output_layout == DataLayout::bs_fs_yx_bsv16_fsv32 ||
output_layout == DataLayout::bs_fs_yx_bsv32_fsv16 || output_layout == DataLayout::bs_fs_yx_bsv32_fsv32 ||
output_layout == DataLayout::bs_fs_zyx_bsv16_fsv16 || output_layout == DataLayout::bs_fs_zyx_bsv16_fsv32 ||
@ -318,10 +320,10 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
break;
}
} else if (blocked_fsv_layout) {
if (output_layout == DataLayout::b_fs_yx_fsv2 || output_layout == DataLayout::b_fs_yx_fsv4 ||
if (output_layout == DataLayout::b_fs_yx_fsv2 || output_layout == DataLayout::b_fs_yx_fsv4 || output_layout == DataLayout::b_fs_yx_fsv8 ||
output_layout == DataLayout::b_fs_yx_fsv16 || output_layout == DataLayout::b_fs_yx_fsv32) {
layout_order = { f, x, y, b, z, w, u, v };
} else if (output_layout == DataLayout::b_fs_zyx_fsv2 || output_layout == DataLayout::b_fs_zyx_fsv4 ||
} else if (output_layout == DataLayout::b_fs_zyx_fsv2 || output_layout == DataLayout::b_fs_zyx_fsv4 || output_layout == DataLayout::b_fs_zyx_fsv8 ||
output_layout == DataLayout::b_fs_zyx_fsv16 || output_layout == DataLayout::b_fs_zyx_fsv32) {
layout_order = { f, x, y, z, b, w, u, v };
} else { // output_layout == DataLayout::fs_b_yx_fsv32
@ -453,13 +455,18 @@ bool CheckInputsOutputNoPitchSameDims(const base_params& params) {
{DataLayout::b_fs_zyx_fsv16, {1, 16}},
{DataLayout::b_fs_yx_fsv32, {1, 32}},
{DataLayout::b_fs_zyx_fsv32, {1, 32}},
{DataLayout::bs_fs_yx_bsv16_fsv8, {16, 8}},
{DataLayout::bs_fs_yx_bsv16_fsv16, {16, 16}},
{DataLayout::bs_fs_yx_bsv16_fsv32, {16, 32}},
{DataLayout::bs_fs_zyx_bsv16_fsv8, {16, 8}},
{DataLayout::bs_fs_zyx_bsv16_fsv16, {16, 16}},
{DataLayout::bs_fs_zyx_bsv16_fsv32, {16, 32}},
{DataLayout::bs_f_bsv8__af8, {8, 8}},
{DataLayout::bs_f_bsv16__af8, {16, 8}},
{DataLayout::b_fs_yx_fsv4, {1, 4}},
{DataLayout::b_fs_zyx_fsv4, {1, 4}},
{DataLayout::b_fs_yx_fsv8, {1, 8}},
{DataLayout::b_fs_zyx_fsv8, {1, 8}},
{DataLayout::fs_b_yx_fsv32, {1, 32}},
{DataLayout::b_fs_yx_32fp, {1, 32}},
{DataLayout::bs_fs_yx_bsv32_fsv16, {32, 16}},

View File

@ -12,6 +12,8 @@
namespace kernel_selector {
inline uint32_t SubGroupSize(WeightsLayout l) {
switch (l) {
case WeightsLayout::o_is_yx_isv16:
case WeightsLayout::o_is_zyx_isv16:
case WeightsLayout::os_iyx_osv16:
case WeightsLayout::os_iyx_osv32:
case WeightsLayout::os_iyx_osv64:
@ -50,6 +52,8 @@ inline uint32_t SubGroupSize(WeightsLayout l) {
case WeightsLayout::iy_xs_os_xsv2_osv8__ao32:
case WeightsLayout::giy_xs_os_xsv2_osv8__ao32:
case WeightsLayout::g_os_iyx_osv8:
case WeightsLayout::gs_oiyx_gsv8:
case WeightsLayout::gs_oizyx_gsv8:
return 8;
default:
return 1;

View File

@ -27,16 +27,20 @@ DataTensor::DataChannelArray DataTensor::dataChannelArray {{
{ DataLayout::fyxb, { 1, 2, -1, -1, -1, -1, 3, 0 } },
{ DataLayout::b_fs_yx_fsv2, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_yx_fsv4, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_yx_fsv8, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_yx_fsv16, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_yx_fsv32, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_zyx_fsv2, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::b_fs_zyx_fsv4, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::b_fs_zyx_fsv8, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::b_fs_zyx_fsv16, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::b_fs_zyx_fsv32, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::bs_fs_yx_bsv16_fsv32, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_zyx_bsv16_fsv32, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::bs_fs_zyx_bsv16_fsv16, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::bs_fs_yx_bsv16_fsv16, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_zyx_bsv16_fsv8, { 0, 1, 2, -1, -1, -1, 3, 4 } },
{ DataLayout::bs_fs_yx_bsv16_fsv8, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv4_fsv4, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv8_fsv4, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_zyx_bsv8_fsv4, { 0, 1, 2, -1, -1, -1, 3, 4 } },
@ -82,7 +86,9 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_iyx_osv32__ai32, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_iyx_osv64, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_iyx_osv16_rotate_180, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::o_is_yx_isv4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::o_is_yx_isv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::o_is_zyx_isv16, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::os_yxi_osv16, { 1, 2, -1, 0, 3, -1 } },
{ WeightsLayout::os_i_osv8__ai8, { -1, -1, -1, 0, 1, -1 } },
{ WeightsLayout::os_i_osv16__ai8, { -1, -1, -1, 0, 1, -1 } },
@ -132,6 +138,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::is_o32_yx_isv32_swizzled_by_4, { 1, 2, -1, 0, 3, -1 } },
{ WeightsLayout::os_is_y_x8_osv8_isv4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_y_x8_osv8_isv4_swizzled_by_4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_yx_osv2_isv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_yx_osv4_isv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_yx_osv8_isv4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_zyx_osv8_isv4, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::os_is_yx_osv8_isv2, { 0, 1, -1, 2, 3, -1 } },
@ -145,6 +153,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_is_yx_osv32_isv32p, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_zyx_isv16_osv16, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::os_is_yx_isv16_osv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::is_os_yx_osv8_isv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv8, { 0, 1, -1, 3, 2, -1 } },
@ -180,6 +189,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::g_os_iyx_osv8, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::g_os_iyx_osv16, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::g_os_iyx_osv32, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::gs_oiyx_gsv8, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::gs_oizyx_gsv8, { 0, 1, 2, 3, 4, 5 } },
{ WeightsLayout::gs_oiyx_gsv16, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::gs_oizyx_gsv16, { 0, 1, 2, 3, 4, 5 } },
{ WeightsLayout::gs_oiyx_gsv32, { 0, 1, -1, 2, 3, 4 } },
@ -233,6 +244,10 @@ NDims DataTensor::GetSimpleDims(const std::vector<size_t>& d, DataLayout l) {
newDims[0] = RoundUp(newDims[0], 8);
newDims[1] = RoundUp(newDims[1], 16);
break;
case b_fs_yx_fsv8:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 8);
break;
case b_fs_yx_fsv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
@ -253,10 +268,19 @@ NDims DataTensor::GetSimpleDims(const std::vector<size_t>& d, DataLayout l) {
assert(newDims.size() == 4);
newDims[3] = RoundUp(newDims[3], 32);
break;
case b_fs_zyx_fsv8:
assert(newDims.size() == 5);
newDims[3] = RoundUp(newDims[3], 8);
break;
case b_fs_zyx_fsv16:
assert(newDims.size() == 5);
newDims[3] = RoundUp(newDims[3], 16);
break;
case bs_fs_yx_bsv16_fsv8:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 8);
newDims[3] = RoundUp(newDims[3], 16);
break;
case bs_fs_yx_bsv16_fsv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
@ -277,6 +301,11 @@ NDims DataTensor::GetSimpleDims(const std::vector<size_t>& d, DataLayout l) {
newDims[3] = RoundUp(newDims[3], 16);
newDims[4] = RoundUp(newDims[4], 16);
break;
case bs_fs_zyx_bsv16_fsv8:
assert(newDims.size() == 5);
newDims[3] = RoundUp(newDims[3], 8);
newDims[4] = RoundUp(newDims[4], 16);
break;
case bs_fs_yx_bsv4_fsv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 4);
@ -588,10 +617,18 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
// TODO: It's not the right pitches. it's here in order to calculate physical size
switch (l) {
case o_is_yx_isv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 4);
break;
case o_is_yx_isv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
break;
case o_is_zyx_isv16:
assert(newDims.size() == 5);
newDims[2] = RoundUp(newDims[3], 16);
break;
case os_iyx_osv16:
case os_yxi_osv16:
case os_iyx_osv16_rotate_180:
@ -772,6 +809,11 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
newDims[3] = RoundUp(newDims[3], 16);
newDims[4] = RoundUp(newDims[4], 16);
break;
case is_os_yx_osv8_isv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 8);
newDims[3] = RoundUp(newDims[3], 4);
break;
case is_os_yx_isv16_osv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
@ -806,6 +848,16 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
assert(newDims.size() == 5);
newDims[3] = RoundUp(newDims[0], 16);
break;
case os_is_yx_osv2_isv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 2);
break;
case os_is_yx_osv4_isv16:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 4);
break;
case os_is_yx_osv8_isv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 4);
@ -827,6 +879,14 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
assert(newDims.size() == 5);
newDims[3] = RoundUp(newDims[3], 32);
break;
case gs_oiyx_gsv8:
assert(newDims.size() == 5);
newDims[4] = RoundUp(newDims[4], 8);
break;
case gs_oizyx_gsv8:
assert(newDims.size() == 6);
newDims[5] = RoundUp(newDims[5], 8);
break;
case gs_oiyx_gsv16:
assert(newDims.size() == 5);
newDims[4] = RoundUp(newDims[4], 16);

View File

@ -40,6 +40,8 @@ enum DataLayout {
b_fs_zyx_fsv2,
b_fs_yx_fsv4, // reordering format for swizzled input for convolution using IMAD
b_fs_zyx_fsv4,
b_fs_yx_fsv8,
b_fs_zyx_fsv8,
b_fs_yx_fsv16, // 3D+batch
b_fs_zyx_fsv16, // batch, feature, 3D spatial. Blocks of 16 input channels
b_fs_yx_fsv32, // 3D+batch
@ -53,6 +55,8 @@ enum DataLayout {
bs_fs_yx_bsv8_fsv2, // batch, feature, 2D spatial. Blocks of 8 batch and 2 channels
bs_fs_zyx_bsv8_fsv4, // batch, feature, 3D spatial. Blocks of 8 batch and 4 channels
bs_fs_zyx_bsv8_fsv2, // batch, feature, 3D spatial. Blocks of 8 batch and 2 channels
bs_fs_yx_bsv16_fsv8, // batch, feature, 2D spatial. Blocks of 16 batch and 8 channels
bs_fs_zyx_bsv16_fsv8, // batch, feature, 3D spatial. Blocks of 16 batch and 8 channels
bs_fs_yx_bsv16_fsv4, // batch, feature, 2D spatial. Blocks of 16 batch and 4 channels
bs_fs_zyx_bsv16_fsv4, // batch, feature, 3D spatial. Blocks of 16 batch and 4 channels
bs_fs_yx_bsv16_fsv2, // batch, feature, 2D spatial. Blocks of 16 batch and 2 channels
@ -90,7 +94,9 @@ enum WeightsLayout {
oxiy,
iyxo,
yxio,
o_is_yx_isv4,
o_is_yx_isv16,
o_is_zyx_isv16,
os_yxi_osv16,
os_iyx_osv16,
os_iyx_osv32,
@ -156,6 +162,7 @@ enum WeightsLayout {
os_is_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv4,
is_os_yx_osv8_isv4,
is_os_yx_osa8_isv16_osv4,
is_os_yx_isa2_osa8_isv8_osv2,
g_os_is_yx_osa2_isa8_osv16_isv4,
@ -179,6 +186,8 @@ enum WeightsLayout {
os_is_yx_osv32_isv4_swizzled_by_2, // weights for bfyx -> b_fs_yx_fsv32 convolution using IMAD with swizzled ofm (0, 2, 4..), (1, 3, 5...)
os_is_yx_osv32_isv4, // weights for bfyx -> b_fs_yx_fsv{32,16} convolution using IMAD
os_is_zyx_osv32_isv4, // weights for bfzyx -> b_fs_zyx_fsv16 convolution using IMAD
os_is_yx_osv2_isv16,
os_is_yx_osv4_isv16,
oizyx,
iozyx,
os_is_yx_osv32_isv32p, // 2 blocks: 32 packed binary in channels and 32 output channels
@ -200,6 +209,8 @@ enum WeightsLayout {
g_os_iyx_osv8,
g_os_iyx_osv16,
g_os_iyx_osv32,
gs_oiyx_gsv8,
gs_oizyx_gsv8,
gs_oiyx_gsv16,
gs_oizyx_gsv16,
gs_oiyx_gsv32,

View File

@ -31,10 +31,12 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(bxfy, 1, 1, 2, 0, {0, 3, 1, 2}, "bxfy", "bfxy", {}),
FMT_TRAITS(b_fs_yx_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 2}}),
FMT_TRAITS(b_fs_yx_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 4}}),
FMT_TRAITS(b_fs_yx_fsv8, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 8}}),
FMT_TRAITS(b_fs_yx_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 16}}),
FMT_TRAITS(b_fs_yx_fsv32, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 32}}),
FMT_TRAITS(b_fs_zyx_fsv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 2}}),
FMT_TRAITS(b_fs_zyx_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 4}}),
FMT_TRAITS(b_fs_zyx_fsv8, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 8}}),
FMT_TRAITS(b_fs_zyx_fsv32, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{1, 32}}),
FMT_TRAITS(bs_fs_fsv8_bsv8, 1, 1, 0, 0, {0, 1}, "bf", "bf", {{0, 8}, {1, 8}}),
FMT_TRAITS(bs_fs_fsv8_bsv16, 1, 1, 0, 0, {0, 1}, "bf", "bf", {{0, 16}, {1, 8}}),
@ -55,6 +57,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(bs_fs_yx_bsv4_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 4 }, {1, 4}}),
FMT_TRAITS(bs_fs_yx_bsv8_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 8 }, {1, 4}}),
FMT_TRAITS(bs_fs_zyx_bsv8_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 8 }, {1, 4}}),
FMT_TRAITS(bs_fs_yx_bsv16_fsv8, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 8}}),
FMT_TRAITS(bs_fs_zyx_bsv16_fsv8, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 8}}),
FMT_TRAITS(bs_fs_yx_bsv16_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 4}}),
FMT_TRAITS(bs_fs_zyx_bsv16_fsv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "bfzyx", "bfxyz", {{0, 16 }, {1, 4}}),
FMT_TRAITS(bs_fs_yx_bsv16_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{0, 16 }, {1, 2}}),
@ -81,7 +85,9 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(oizyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {}),
FMT_TRAITS(iozyx, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {}),
FMT_TRAITS(os_is_yx_isv16_osv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(o_is_yx_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 4}}),
FMT_TRAITS(o_is_yx_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}}),
FMT_TRAITS(o_is_zyx_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 16}}),
FMT_TRAITS(os_yxi_osv16, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {{0, 16}}),
FMT_TRAITS(os_iyx_osv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}}),
FMT_TRAITS(os_iyx_osv32, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}}),
@ -114,6 +120,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(is_o32_yx_isv32_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_y_x8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {{0, 8}, {1, 4}, {2, 8}}),
FMT_TRAITS(os_is_y_x8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy", {{0, 8}, {1, 4}, {2, 8}}),
FMT_TRAITS(os_is_yx_osv2_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 16}}),
FMT_TRAITS(os_is_yx_osv4_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}, {1, 16}}),
FMT_TRAITS(os_is_yx_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 4}}),
FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 4}}),
@ -126,6 +134,7 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(os_is_yx_osv32_isv32p, 1, 1, 1, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{0, 8}, {1, 4}}),
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
FMT_TRAITS(is_os_yx_isv16_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 4}}),
@ -174,6 +183,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(g_os_iyx_osv8, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{0, 8}}),
FMT_TRAITS(g_os_iyx_osv16, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{0, 16}}),
FMT_TRAITS(g_os_iyx_osv32, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{0, 32}}),
FMT_TRAITS(gs_oiyx_gsv8, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{8, 8}}),
FMT_TRAITS(gs_oizyx_gsv8, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz???g", {{8, 8}}),
FMT_TRAITS(gs_oiyx_gsv16, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{8, 16}}),
FMT_TRAITS(gs_oizyx_gsv16, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz???g", {{8, 16}}),
FMT_TRAITS(gs_oiyx_gsv32, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy????g", {{8, 32}}),

View File

@ -145,6 +145,8 @@ static format to_weights_format(format f, bool is_grouped) {
throw std::runtime_error("Invalid conversion of data format to weights format. bfwzyx can't be non-grouped as 4D spatials are not supported");
return format::goizyx;
}
case format::b_fs_yx_fsv4:
return format::o_is_yx_isv4;
case format::b_fs_yx_fsv16:
return format::o_is_yx_isv16;
case format::bs_fs_fsv8_bsv8:

View File

@ -176,7 +176,7 @@ public:
const auto params_hash = prim_inst->get_impl_params()->hash();
ASSERT_EQ(primitive_hash, 16293979194373117693UL);
ASSERT_EQ(params_hash, 14231564068060955575UL);
ASSERT_EQ(params_hash, 15950979219660866859UL);
}
void test_reshape_basic(bool is_caching_test) {

@ -1 +1 @@
Subproject commit 4b82a66ed38ecaa993352e5cc6ed7753656b8a26
Subproject commit 284ad4574939fa784e4ddaa1f4aa577b8eb7a017