[GPU] Make generic logic to find formats from meme::desc (#13730)

* [GPU] Make generic logic to find formats from meme::desc
+ Added test-cases

Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
Min, Byungil 2023-01-06 18:12:51 +09:00 committed by GitHub
parent c1f6da31b6
commit c4cd3e152b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 477 additions and 309 deletions

View File

@ -332,6 +332,8 @@ struct format {
static format adjust_to_rank(format fmt, size_t new_rank);
static const std::vector<std::pair<size_t, int>> per_axis_block_size(format fmt);
/// @brief Checks if @p format is of grouped type
static bool is_grouped(type fmt) { return group_num(fmt) != 0; }
/// @brief Checks if @p format is of image type

View File

@ -309,314 +309,162 @@ dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_t
}
}
static bool isSame(dnnl::memory::desc desc, dnnl::memory::format_tag fmt) {
dnnl::memory::desc refDesc(desc.dims(), desc.data_type(), fmt);
static void get_identical_order(std::vector<std::vector<size_t>>& orders, std::vector<size_t> order,
size_t first, size_t depth) {
if (depth == 0)
return;
if (desc.data.ndims != refDesc.data.ndims)
return false;
for (size_t idx = first; idx <= first + depth ; idx++) {
std::swap(order[first], order[idx]);
if (first != idx)
orders.push_back(order);
if (desc.data.format_kind != dnnl_blocked || refDesc.data.format_kind != dnnl_blocked)
throw std::runtime_error("dnnlMemoryDesc::isSame is not implemented for non blocked memory format");
auto actualBlkDesc = desc.data.format_desc.blocking;
auto refBlkDesc = refDesc.data.format_desc.blocking;
if (actualBlkDesc.inner_nblks != refBlkDesc.inner_nblks)
return false;
for (int i = 0; i < actualBlkDesc.inner_nblks; ++i)
if (actualBlkDesc.inner_blks[i] != refBlkDesc.inner_blks[i])
return false;
for (int i = 0; i < actualBlkDesc.inner_nblks; ++i)
if (actualBlkDesc.inner_idxs[i] != refBlkDesc.inner_idxs[i])
return false;
auto actualStrides = desc.data.format_desc.blocking.strides;
auto refStrides = refDesc.data.format_desc.blocking.strides;
std::vector<size_t> actualOrder(desc.data.ndims);
std::iota(actualOrder.begin(), actualOrder.end(), 0);
std::sort(actualOrder.begin(), actualOrder.end(),
[&actualStrides] (size_t ind_l, size_t ind_r) {
return actualStrides[ind_l] > actualStrides[ind_r];
});
std::vector<size_t> refOrder(refDesc.data.ndims);
std::iota(refOrder.begin(), refOrder.end(), 0);
std::sort(refOrder.begin(), refOrder.end(),
[&refStrides] (size_t ind_l, size_t ind_r) {
return refStrides[ind_l] > refStrides[ind_r];
});
if (actualOrder != refOrder) {
return false;
get_identical_order(orders, order, first+1, depth-1);
std::swap(order[first], order[idx]);
}
return true;
}
dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc) {
// TODO [OneDNN]: Previously it was a field of tdesc, but now the brute
// force search here. Please avoid of using this method.
const auto ndims = desc.dims().size();
// There are no suitable format_tag for this
if (ndims == 0 || ndims > 6)
return dnnl::memory::format_tag::undef;
for (const auto fmt : form_tags_by_ndims.at(static_cast<int>(ndims))) {
if (isSame(desc, fmt))
return fmt;
}
return dnnl::memory::format_tag::undef;
}
static std::vector<size_t> get_order(dnnl::memory::desc desc) {
auto blk = desc.data.format_desc.blocking;
auto strides = blk.strides;
// Get candidate orders calculated by stride value of dnnl::memory::descriptor could be multiple
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc) {
std::vector<std::vector<size_t>> orders;
auto strides = desc.data.format_desc.blocking.strides;
std::vector<size_t> order(desc.data.ndims);
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(),
[&strides] (size_t ind_l, size_t ind_r) {
return (strides[ind_l] > strides[ind_r]);
});
return order;
orders.push_back(order);
// Orders of those axes which have a same stride in memory::desc can be changed.
// If y and x axes have same, then it can be bfyx or bfxy.
for (size_t idx = 0 ; idx+1 < order.size() ; idx++) {
size_t depth = 0;
for (size_t next = idx+1 ; next < order.size() ; next++) {
if (strides[order[idx]] == strides[order[next]]) {
depth++;
} else {
break;
}
}
// mutiple axes can have a same stride value of mem descriptor
get_identical_order(orders, order, idx, depth);
idx += depth;
}
return orders;
}
static bool compare_strides(std::vector<size_t> a, std::vector<size_t> b) {
return std::equal(a.begin(), a.end(), b.begin());
static bool compare_orders(std::vector<std::vector<size_t>> a, std::vector<size_t> b) {
for (size_t idx = 0 ; idx < a.size() ; idx++) {
if (std::equal(a[idx].begin(), a[idx].end(), b.begin()))
return true;
}
return false;
}
cldnn::format find_data_format(dnnl::memory::desc desc) {
auto onednn_desc = get_format_by_desc(desc);
auto blk = desc.data.format_desc.blocking;
auto order = get_candidate_orders(desc);
if (onednn_desc != dnnl::memory::format_tag::undef) {
return convert_data_format(onednn_desc);
} else {
auto blk = desc.data.format_desc.blocking;
auto order = get_order(desc);
for (int32_t fmt_idx = format::bfyx ; fmt_idx < format::format_num ; fmt_idx++) {
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
if (desc.data.ndims == static_cast<int>(candidate_trait._order.size())
&& blk.inner_nblks == static_cast<int>(candidate_trait.block_sizes.size())
&& compare_strides(order, candidate_trait._order)) {
bool is_match = true;
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
if (blk.inner_blks[idx] != static_cast<int>(candidate_trait.block_sizes[idx].second)
|| blk.inner_idxs[idx] != static_cast<int>(candidate_trait.block_sizes[idx].first)) {
for (int32_t fmt_idx = format::bfyx ; fmt_idx < format::oiyx ; fmt_idx++) {
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
if (desc.data.ndims == static_cast<int>(candidate_trait._order.size())
&& blk.inner_nblks == static_cast<int>(candidate_trait.block_sizes.size())
&& compare_orders(order, candidate_trait._order)) {
bool is_match = true;
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
if (blk.inner_blks[idx] != static_cast<int>(candidate_trait.block_sizes[idx].second)
|| blk.inner_idxs[idx] != static_cast<int>(candidate_trait.block_sizes[idx].first)) {
is_match = false;
break;
}
}
if (is_match)
return static_cast<format::type>(fmt_idx);
}
}
std::stringstream msg;
msg << "Unsupported onednn dnnl::memory::desc find_data_format. "
<< "ndims: " << desc.data.ndims
<< ", inner_nblks: " << blk.inner_nblks
<< ", inner_blks: ";
for (int i = 0; i < blk.inner_nblks; i++)
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
throw std::runtime_error(msg.str());
}
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped) {
auto blk = desc.data.format_desc.blocking;
auto orders = get_candidate_orders(desc);
format start_format = format::oiyx;
if (is_grouped)
start_format = format::goiyx;
for (int32_t fmt_idx = start_format ; fmt_idx < format::format_num ; fmt_idx++) {
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
if (static_cast<size_t>(desc.data.ndims) == candidate_trait._order.size()
&& static_cast<size_t>(blk.inner_nblks) == candidate_trait.block_sizes.size()
&& compare_orders(orders, candidate_trait._order)) {
// Compare all pairs of dimension number and block size to format_traits_map of all formats
bool is_match = true;
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
auto block_idx = static_cast<dnnl_dim_t>(candidate_trait.block_sizes[idx].first);
auto block_size = static_cast<dnnl_dim_t>(candidate_trait.block_sizes[idx].second);
if (is_grouped && candidate_trait.is_group_char(candidate_trait.internal_order[block_idx])) {
// inner_idx gets the index of group dimension in mem::desc when blocked axis is group
auto inner_idx = candidate_trait.order.find_first_of(candidate_trait.internal_order[block_idx]);
if (blk.inner_blks[idx] != block_size ||
blk.inner_idxs[idx] != static_cast<dnnl_dim_t>(inner_idx)) {
is_match = false;
break;
}
} else if (is_grouped) {
// g,o,i from cldnn formats are matching to a,b,c of dnnl. But g is at the end of internal order.
if (blk.inner_blks[idx] != block_size ||
(blk.inner_idxs[idx] - static_cast<dnnl_dim_t>(candidate_trait.group_num)) != block_idx) {
is_match = false;
break;
}
} else {
if (blk.inner_blks[idx] != block_size ||
blk.inner_idxs[idx] != block_idx) {
is_match = false;
break;
}
}
if (is_match)
return static_cast<format::type>(fmt_idx);
}
}
std::stringstream msg;
msg << "Unsupported onednn dnnl::memory::desc find_data_format. "
<< "ndims: " << desc.data.ndims
<< ", inner_nblks: " << blk.inner_nblks
<< ", inner_blks: ";
for (int i = 0; i < blk.inner_nblks; i++)
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
throw std::runtime_error(msg.str());
}
}
// onednn -> cldnn
static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
if (is_grouped) {
switch (fmt) {
case dnnl::memory::format_tag::abcde: return cldnn::format::goiyx;
case dnnl::memory::format_tag::abcdef: return cldnn::format::goizyx;
case dnnl::memory::format_tag::Abcdef16a: return cldnn::format::gs_oizyx_gsv16;
case dnnl::memory::format_tag::Abcde16a: return cldnn::format::gs_oiyx_gsv16;
case dnnl::memory::format_tag::Abcde32a: return cldnn::format::gs_oiyx_gsv32;
case dnnl::memory::format_tag::Abcdef32a: return cldnn::format::gs_oizyx_gsv32;
case dnnl::memory::format_tag::aCBde16c16b: return cldnn::format::g_is_os_yx_isv16_osv16;
case dnnl::memory::format_tag::aBCde2b8c8b2c: return cldnn::format::g_os_is_yx_osa2_isa8_osv8_isv2;
case dnnl::memory::format_tag::aBCde4b8c8b4c: return cldnn::format::g_os_is_yx_osa4_isa8_osv8_isv4;
case dnnl::memory::format_tag::aBCde4b8c8b2c: return cldnn::format::g_os_is_yx_osa4_isa8_osv8_isv2;
case dnnl::memory::format_tag::aBCde8b2c: return cldnn::format::g_os_is_yx_osv8_isv2;
case dnnl::memory::format_tag::aBCde8b4c: return cldnn::format::g_os_is_yx_osv8_isv4;
case dnnl::memory::format_tag::aBcde8b: return cldnn::format::g_os_iyx_osv8;
case dnnl::memory::format_tag::aBCd2b8c16b4c: return cldnn::format::g_os_is_yx_osa2_isa8_osv16_isv4;
case dnnl::memory::format_tag::aBCd2b8c16b2c: return cldnn::format::g_os_is_yx_osa2_isa8_osv16_isv2;
case dnnl::memory::format_tag::aBCdef16c16b: return cldnn::format::g_os_is_zyx_isv16_osv16;
case dnnl::memory::format_tag::aBCdef4b8c8b2c: return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv2;
case dnnl::memory::format_tag::aBCdef4b8c8b4c: return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv4;
default: throw std::runtime_error(std::string("Unsupported grouped onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
}
} else {
switch (fmt) {
case dnnl::memory::format_tag::ab: return cldnn::format::oiyx;
case dnnl::memory::format_tag::abcd: return cldnn::format::oiyx;
case dnnl::memory::format_tag::bacd: return cldnn::format::ioyx;
case dnnl::memory::format_tag::bcda: return cldnn::format::iyxo;
case dnnl::memory::format_tag::BAcd16b16a: return cldnn::format::is_os_yx_isv16_osv16;
case dnnl::memory::format_tag::ABcd16b16a: return cldnn::format::os_is_yx_isv16_osv16;
case dnnl::memory::format_tag::abcde: return cldnn::format::oizyx;
case dnnl::memory::format_tag::ABcd4a8b8a4b: return cldnn::format::os_is_yx_osa4_isa8_osv8_isv4;
case dnnl::memory::format_tag::ABcd4a8b8a2b: return cldnn::format::os_is_yx_osa4_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcde4a8b8a2b: return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcde4a8b8a4b: return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv4;
case dnnl::memory::format_tag::ABcd8a4b: return cldnn::format::os_is_yx_osv8_isv4;
case dnnl::memory::format_tag::ABcde8a4b: return cldnn::format::os_is_zyx_osv8_isv4;
case dnnl::memory::format_tag::ABcde8a2b: return cldnn::format::os_is_zyx_osv8_isv2;
case dnnl::memory::format_tag::ABcd8a2b: return cldnn::format::os_is_yx_osv8_isv2;
case dnnl::memory::format_tag::Acdb16a: return cldnn::format::os_yxi_osv16;
case dnnl::memory::format_tag::Acdeb16a: return cldnn::format::os_zyxi_osv16;
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
case dnnl::memory::format_tag::Abcd16a: return cldnn::format::os_iyx_osv16;
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
case dnnl::memory::format_tag::BAcd4b8a8b4a: return cldnn::format::is_os_yx_isa4_osa8_isv8_osv4;
default: throw std::runtime_error(std::string("Unsupported onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
if (is_match)
return static_cast<format::type>(fmt_idx);
}
}
}
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped) {
auto onednn_desc = get_format_by_desc(desc);
if (onednn_desc != dnnl::memory::format_tag::undef) {
return convert_format(onednn_desc, is_grouped);
} else {
auto blk = desc.data.format_desc.blocking;
auto order = get_order(desc);
if (is_grouped) {
if (desc.data.ndims == 5 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
&& compare_strides(order, {0, 1, 2, 3, 4})) {
return cldnn::format::g_os_is_yx_isa8_osv8_isv2;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
&& compare_strides(order, {0, 1, 2, 3, 4})) {
return cldnn::format::g_os_is_yx_isa8_osv8_isv4;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
if (compare_strides(order, {0, 1, 3, 4, 2})) return cldnn::format::g_os_yx_is_osv8_isv2;
else if (compare_strides(order, {0, 1, 3, 2, 4})) return cldnn::format::g_os_y_is_x_osv8_isv2;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
if (compare_strides(order, {0, 1, 3, 4, 2})) return cldnn::format::g_os_yx_is_osv8_isv4;
else if (compare_strides(order, {0, 1, 3, 2, 4})) return cldnn::format::g_os_y_is_x_osv8_isv4;
} else if (desc.data.ndims == 6 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
if (compare_strides(order, {0, 1, 3, 4, 5, 2})) return cldnn::format::g_os_zyx_is_osv8_isv2;
else if (compare_strides(order, {0, 1, 3, 4, 2, 5})) return cldnn::format::g_os_zy_is_x_osv8_isv2;
} else if (desc.data.ndims == 6 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
if (compare_strides(order, {0, 1, 3, 4, 5, 2})) return cldnn::format::g_os_zyx_is_osv8_isv4;
else if (compare_strides(order, {0, 1, 3, 4, 2, 5})) return cldnn::format::g_os_zy_is_x_osv8_isv4;
} else if (desc.data.ndims == 6 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
&& compare_strides(order, {0, 1, 2, 3, 4, 5})) {
return cldnn::format::g_os_is_zyx_isa8_osv8_isv2;
} else if (desc.data.ndims == 6 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
&& compare_strides(order, {0, 1, 2, 3, 4, 5})) {
return cldnn::format::g_os_is_zyx_isa8_osv8_isv4;
}
} else {
if (desc.data.ndims == 4 && blk.inner_nblks == 4
&& blk.inner_blks[0] == 4 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 4
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1
&& compare_strides(order, {1, 0, 2, 3})) {
return cldnn::format::is_os_yx_osa4_isa8_osv8_isv4;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 4
&& blk.inner_blks[0] == 2 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1 && blk.inner_idxs[3] == 0
&& compare_strides(order, {1, 0, 2, 3})) {
return cldnn::format::is_os_yx_isa2_osa8_isv8_osv2;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 16 && blk.inner_blks[1] == 4 && blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1
&& compare_strides(order, {0, 1, 2, 3})) {
return cldnn::format::os_is_yx_osv16_isv4;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 16 && blk.inner_blks[1] == 8 && blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0
&& compare_strides(order, {0, 1, 2, 3})) {
return cldnn::format::is_os_yx_isv16_osv8;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
if (compare_strides(order, {0, 1, 2, 3})) return cldnn::format::os_is_yx_isa8_osv8_isv2;
else if (compare_strides(order, {1, 0, 2, 3})) return cldnn::format::is_os_yx_isa8_osv8_isv2;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
if (compare_strides(order, {0, 1, 2, 3})) return cldnn::format::os_is_yx_isa8_osv8_isv4;
else if (compare_strides(order, {1, 0, 2, 3})) return cldnn::format::is_os_yx_isa8_osv8_isv4;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
if (compare_strides(order, {0, 2, 1, 3})) return cldnn::format::os_y_is_x_osv8_isv2;
else if (compare_strides(order, {0, 2, 3, 1})) return cldnn::format::os_yx_is_osv8_isv2;
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
if (compare_strides(order, {0, 2, 1, 3})) return cldnn::format::os_y_is_x_osv8_isv4;
else if (compare_strides(order, {0, 2, 3, 1})) return cldnn::format::os_yx_is_osv8_isv4;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2 &&
blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2 &&
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
if (compare_strides(order, {0, 2, 3, 4, 1})) return cldnn::format::os_zyx_is_osv8_isv2;
else if (compare_strides(order, {0, 2, 3, 1, 4})) return cldnn::format::os_zy_is_x_osv8_isv2;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2 &&
blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4 &&
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
if (compare_strides(order, {0, 2, 3, 4, 1})) return cldnn::format::os_zyx_is_osv8_isv4;
else if (compare_strides(order, {0, 2, 3, 1, 4})) return cldnn::format::os_zy_is_x_osv8_isv4;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
if (compare_strides(order, {0, 1, 2, 3, 4})) return cldnn::format::os_is_zyx_isa8_osv8_isv2;
else if (compare_strides(order, {1, 0, 2, 3, 4})) return cldnn::format::is_os_zyx_isa8_osv8_isv2;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
if (compare_strides(order, {0, 1, 2, 3, 4})) return cldnn::format::os_is_zyx_isa8_osv8_isv4;
else if (compare_strides(order, {1, 0, 2, 3, 4})) return cldnn::format::is_os_zyx_isa8_osv8_isv4;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 4 &&
blk.inner_blks[0] == 2 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 2 &&
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1) {
return cldnn::format::os_is_zyx_osa2_isa8_osv8_isv2;
} else if (desc.data.ndims == 5 && blk.inner_nblks == 4 &&
blk.inner_blks[0] == 4 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 4 &&
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1) {
return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv4;
}
}
std::stringstream msg;
msg << "Unsupported " << (is_grouped ? "grouped" : "") << "onednn dnnl::memory::desc find_format. "
<< "ndims: " << desc.data.ndims
<< ", inner_nblks: " << blk.inner_nblks
<< ", inner_blks: ";
for (int i = 0; i < blk.inner_nblks; i++)
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
msg << ", strides_order: ";
std::stringstream msg;
msg << "Unsupported " << (is_grouped ? "grouped" : "") << "onednn dnnl::memory::desc find_format. "
<< "ndims: " << desc.data.ndims
<< ", inner_nblks: " << blk.inner_nblks
<< ", inner_blks: ";
for (int i = 0; i < blk.inner_nblks; i++)
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
for (auto order : orders) {
msg << ", strides_order : ";
for (const auto& value : order)
msg << value << " ";
throw std::runtime_error(msg.str());
}
msg << ", stride_value : ";
auto strides = desc.data.format_desc.blocking.strides;
for (size_t idx = 0; idx < orders[0].size() ; idx++) {
msg << strides[idx] << " ";
}
throw std::runtime_error(msg.str());
}
// Currently, usage of alpha and beta between cldnn::pow and dnnl::eltwise::pow is different : d = pow(src, a) / d = a * pow(src, b)
@ -662,7 +510,5 @@ bool is_per_tensor(cldnn::data_node& node, int32_t& zp_val) {
template bool is_per_tensor<int8_t>(cldnn::data_node& node, int32_t& zp_val);
template bool is_per_tensor<uint8_t>(cldnn::data_node& node, int32_t& zp_val);
} // namespace onednn
} // namespace cldnn

View File

@ -5,6 +5,7 @@
#pragma once
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_debug.h>
#include <intel_gpu/runtime/layout.hpp>
#include <intel_gpu/runtime/engine.hpp>
@ -32,6 +33,7 @@ cldnn::format convert_data_format(dnnl::memory::format_tag fmt);
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc);
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped = false);
cldnn::format find_data_format(dnnl::memory::desc desc);
dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc);

View File

@ -439,7 +439,8 @@ bool program_node::is_padding_supported(int axis, int padding) const {
if (fmt == format::fs_b_yx_fsv32 && (axis == 0))
return false;
for (const auto& block : fmt.block_sizes()) {
auto block_sizes_dims = format::per_axis_block_size(fmt);
for (const auto& block : block_sizes_dims) {
size_t block_axis = block.first;
int block_size = block.second;

View File

@ -87,39 +87,42 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(image_2d_weights_c4_fyx_b, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
FMT_TRAITS(image_2d_weights_c1_b_fyx, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
FMT_TRAITS(lstm_weights_dio, 1, 1, 2, 0, {0, 1, 3, 2}, "oixy", "oixy?", {}),
FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 16}, {1, 4}}),
FMT_TRAITS(os_is_yx_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}, {1, 16}}),
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 16}}),
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 16}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 16}}),
FMT_TRAITS(os_is_zyx_osa2_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 16}, {1, 2}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 16}, {1, 4}}),
FMT_TRAITS(os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(os_is_zyx_osa2_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}),
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}, {1, 32}}),
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 32}}),
FMT_TRAITS(is_os_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{0, 32}, {1, 32}}),
FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 32}, {0, 32}}),
FMT_TRAITS(is_os_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 2}, {0, 8}, {1, 8}, {0, 2}}),
FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 4}, {0, 8}, {1, 8}, {0, 4}}),
FMT_TRAITS(is_o_yx_isv32, 1, 1, 2, 0, {1, 0, 2, 3}, "oyxi", "oixy?", {{1, 32}}),
FMT_TRAITS(is_o32_yx_isv32_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
FMT_TRAITS(os_is_y_x8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
FMT_TRAITS(os_is_y_x8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
FMT_TRAITS(os_is_yx_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 16}, {1, 4}}),
FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 4}, {0, 8}}),
FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 4}, {0, 8}}),
FMT_TRAITS(os_is_yx_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 2}, {0, 8}}),
FMT_TRAITS(os_is_zyx_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 2}, {0, 8}}),
FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 4}}),
FMT_TRAITS(os_is_yx_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 2}}),
FMT_TRAITS(os_is_zyx_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 2}}),
FMT_TRAITS(os_is_zyx_osv16_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
FMT_TRAITS(os_is_yx_osv32_isv4_swizzled_by_2, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}),
FMT_TRAITS(os_is_yx_osv32_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}),
FMT_TRAITS(os_is_zyx_osv32_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 4}}),
FMT_TRAITS(os_is_yx_osv32_isv32p, 1, 1, 1, 0, {0, 1, 2, 3}, "oixy", "oixy?", {}),
FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3, 4}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3, 4}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
FMT_TRAITS(is_os_zyx_isa8_osv8_isv2, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "ioxyz", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(is_os_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "ioxyz", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
@ -133,7 +136,7 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(os_is_zyx_isv8_osv16_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 2}}),
FMT_TRAITS(os_zyxi_osv16, 1, 1, 3, 0, {0, 2, 3, 4, 1}, "ozyxi", "oixyz", {{0, 16}}),
FMT_TRAITS(os_is_yx_isv8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 16}, {1, 2}}),
FMT_TRAITS(os_is_yx_osv16_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(os_is_yx_osv16_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 16}}),
FMT_TRAITS(os_is_zyx_osv32_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 16}}),
FMT_TRAITS(os_is_zyx_osv64_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 64}, {1, 16}}),
FMT_TRAITS(os_iyx_osv32__ai32, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}}),
@ -185,15 +188,15 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(g_os_zyx_is_osv32_isv32, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "gozyxi", "oixyz?g", {{0, 32}, {1, 32}}),
FMT_TRAITS(g_os_is_yx_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(g_os_is_yx_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 16}, {1, 16}}),
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 32}}),
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 32}, {1, 32}}),
FMT_TRAITS(g_os_is_zyx_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(g_os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 16}}),
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 32}, {1, 16}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 32}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 16}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 16}, {1, 4}}),
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 16}, {1, 2}}),
FMT_TRAITS(gs_oi_yxs_gsv4_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 4}}),
FMT_TRAITS(gs_oi_yxs_gsv16_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 16}}),
FMT_TRAITS(gs_oi_yxs_gsv32_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 32}}),
@ -307,4 +310,18 @@ format format::adjust_to_rank(format fmt, size_t new_rank) {
OPENVINO_ASSERT(false, "Can't adjust format ", fmt.to_string(), " to the new rank (", new_rank, ")");
}
// First : block_idx, Second : block_size
const std::vector<std::pair<size_t, int>> format::per_axis_block_size(format fmt) {
std::vector<std::pair<size_t, int>> sizes_for_dims;
for (const auto& block : fmt.block_sizes()) {
auto it = std::find_if(sizes_for_dims.begin(), sizes_for_dims.end(),
[&block](const std::pair<size_t, int>& ele) { return ele.first == block.first; });
if (it != sizes_for_dims.end())
it->second *= block.second; // the axis is double blocked
else
sizes_for_dims.push_back({block.first, block.second});
}
return sizes_for_dims;
}
} // namespace cldnn

View File

@ -22,6 +22,17 @@ file(GLOB_RECURSE SOURCES_MAIN
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/transformations/*.cpp"
)
if (NOT ENABLE_ONEDNN_FOR_GPU)
set(EXCLUDE_DIR "/onednn/")
foreach (SOURCE_FILE ${SOURCES_MAIN})
string (FIND ${SOURCE_FILE} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND)
if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
message (Exclude : ${SOURCE_FILE})
list (REMOVE_ITEM SOURCES_MAIN ${SOURCE_FILE})
endif ()
endforeach(SOURCE_FILE)
endif()
if (MSVC)
file(GLOB SOURCES_NATVIS
"${CMAKE_CURRENT_SOURCE_DIR}/float16.natvis"
@ -60,6 +71,10 @@ target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
$<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>
${CMAKE_HOME_DIRECTORY}/src/core/reference/include/)
if(ENABLE_ONEDNN_FOR_GPU)
target_include_directories(${TARGET_NAME} PRIVATE $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
endif()
if(WIN32)
target_link_libraries(${TARGET_NAME} PRIVATE setupapi)
elseif((NOT ANDROID) AND (UNIX))

View File

@ -68,3 +68,79 @@ INSTANTIATE_TEST_SUITE_P(smoke, format_adjust_test,
{cldnn::format::oiyx, 5, cldnn::format::any},
}),
format_adjust_test::PrintToString);
struct axes_test_format_params {
cldnn::format in_format;
// First : block_idx, Second : block_size
std::vector<std::pair<size_t, int>> inner_block;
std::vector<std::pair<size_t, int>> per_axis_block;
};
class axes_test_format : public testing::TestWithParam<axes_test_format_params> {
public:
static std::string PrintToString(testing::TestParamInfo<axes_test_format_params> param_info) {
auto input_fmt = param_info.param.in_format.to_string();
auto blocks = param_info.param.inner_block;
auto per_axis_blocks = param_info.param.per_axis_block;
std::string res = "in_fmt = " + input_fmt + " : ";
for (auto block : blocks) {
res += " { " + std::to_string(block.first) + ", " + std::to_string(block.second) + "}";
}
res += " > ";
for (auto block : per_axis_blocks) {
res += " { " + std::to_string(block.first) + ", " + std::to_string(block.second) + "}";
}
return res;
}
};
TEST_P(axes_test_format, simple_test) {
auto param = GetParam();
auto per_axis_blocks = format::per_axis_block_size(param.in_format);
ASSERT_EQ(per_axis_blocks.size(), param.per_axis_block.size());
for (size_t idx = 0; idx < per_axis_blocks.size(); idx++) {
ASSERT_EQ(per_axis_blocks.at(idx).first, param.per_axis_block.at(idx).first);
ASSERT_EQ(per_axis_blocks.at(idx).second, param.per_axis_block.at(idx).second);
}
auto blocks = format::block_sizes(param.in_format);
ASSERT_EQ(blocks.size(), param.inner_block.size());
for (size_t idx = 0; idx < blocks.size(); idx++) {
ASSERT_EQ(blocks.at(idx).first, param.inner_block.at(idx).first);
ASSERT_EQ(blocks.at(idx).second, param.inner_block.at(idx).second);
}
}
INSTANTIATE_TEST_SUITE_P(smoke, axes_test_format,
testing::ValuesIn(std::vector<axes_test_format_params>{
{format::os_is_yx_isa8_osv8_isv4, {{1, 8}, {0, 8}, {1, 4}}, {{1, 32}, {0, 8}}},
{format::os_is_yx_isa8_osv16_isv4, {{1, 8}, {0, 16}, {1, 4}}, {{1, 32}, {0, 16}}},
{format::os_is_yx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
{format::os_is_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::os_is_zyx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
{format::os_is_zyx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::os_is_yx_osa2_isa8_osv16_isv2, {{0, 2}, {1, 8}, {0, 16}, {1, 2}}, {{0, 32}, {1, 16}}},
{format::os_is_yx_osa2_isa8_osv16_isv4, {{0, 2}, {1, 8}, {0, 16}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::os_is_yx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
{format::os_is_zyx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
{format::os_is_zyx_isa8_osv8_isv4, {{1, 8}, {0, 8}, {1, 4}}, {{1, 32}, {0, 8}}},
{format::os_is_zyx_isa8_osv16_isv4, {{1, 8}, {0, 16}, {1, 4}}, {{1, 32}, {0, 16}}},
{format::is_os_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::is_os_yx_isa2_osa8_isv8_osv2, {{1, 2}, {0, 8}, {1, 8}, {0, 2}}, {{1, 16}, {0, 16}}},
{format::is_os_yx_isa4_osa8_isv8_osv4, {{1, 4}, {0, 8}, {1, 8}, {0, 4}}, {{1, 32}, {0, 32}}},
{format::os_is_yx_osv8_isv4, {{0, 8}, {1, 4}}, {{0, 8}, {1, 4}}},
{format::os_is_zyx_osv8_isv4, {{0, 8}, {1, 4}}, {{0, 8}, {1, 4}}},
{format::os_is_yx_osv8_isv2, {{0, 8}, {1, 2}}, {{0, 8}, {1, 2}}},
{format::os_is_zyx_osv8_isv2, {{0, 8}, {1, 2}}, {{0, 8}, {1, 2}}},
{format::g_os_is_yx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
{format::g_os_is_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::g_os_is_zyx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::g_os_is_yx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
{format::g_os_is_zyx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
{format::g_os_is_yx_osa2_isa8_osv16_isv4, {{0, 2}, {1, 8}, {0, 16}, {1, 4}}, {{0, 32}, {1, 32}}},
{format::g_os_is_yx_osa2_isa8_osv16_isv2, {{0, 2}, {1, 8}, {0, 16}, {1, 2}}, {{0, 32}, {1, 16}}},
}),
axes_test_format::PrintToString);

View File

@ -0,0 +1,209 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <oneapi/dnnl/dnnl.hpp>
#include "test_utils.h"
#include "intel_gpu/runtime/format.hpp"
#include "graph/impls/onednn/utils.hpp"
using namespace cldnn;
struct dnnl_desc_params {
// Descriptor info to test
dnnl::memory::dims dims;
dnnl::memory::data_type data_type;
dnnl::memory::dims strides; // In case of plain (non-blocked) formats the strides between dimensions
};
struct desc_stride_test_params {
dnnl_desc_params test_desc;
std::vector<std::vector<size_t>> expected_orders;
};
class format_test_stride : public testing::TestWithParam<desc_stride_test_params> {
public:
static std::string PrintToString(testing::TestParamInfo<desc_stride_test_params> param_info) {
auto strides = param_info.param.test_desc.strides;
std::string res = " stride : {";
for (auto stride : strides) {
res += std::to_string(stride) + " ";
}
res += "}";
return res;
}
};
TEST_P(format_test_stride, test_candidates_using_stride) {
auto param = GetParam();
dnnl::memory::desc test_desc(param.test_desc.dims, param.test_desc.data_type, param.test_desc.strides);
auto candidates= onednn::get_candidate_orders(test_desc);
ASSERT_EQ(candidates.size(), param.expected_orders.size());
bool is_same = true;
for (size_t idx = 0; idx < param.expected_orders.size(); idx++) {
auto expected = param.expected_orders.at(idx);
bool found_match = false;
for (size_t idx = 0 ; idx < candidates.size() ; idx++) {
if (std::equal(candidates[idx].begin(), candidates[idx].end(), expected.begin()))
found_match = true;
}
if (!found_match)
is_same = false;
}
ASSERT_TRUE(is_same);
}
INSTANTIATE_TEST_SUITE_P(smoke, format_test_stride,
testing::ValuesIn(std::vector<desc_stride_test_params>{
{{{1, 3, 8, 8}, dnnl::memory::data_type::f16, {768, 256, 16, 1}}, {{0, 1, 2, 3}}},
{{{6, 1, 1, 8}, dnnl::memory::data_type::f16, {96, 16, 16, 1}}, {{0, 1, 2, 3}, {0, 2, 1, 3}}},
{{{1, 1, 1, 16}, dnnl::memory::data_type::f16, {32, 32, 32, 1}}, {{0, 1, 2, 3}, {0, 2, 1, 3}, {1, 0, 2, 3}, {2, 1, 0, 3}, {1, 2, 0, 3}, {2, 0, 1, 3}}}
}),
format_test_stride::PrintToString);
struct format_matching_test_params {
dnnl::memory::dims dims;
dnnl::memory::data_type data_type;
dnnl::memory::format_tag dnnl_format;
cldnn::format cldnn_format;
};
class data_format_test_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
public:
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
auto dnnl_format = param_info.param.dnnl_format;
std::string res = " data format (dnnl::memory::format_tag) : " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
res += " > " + format::traits(param_info.param.cldnn_format).str;
return res;
}
};
TEST_P(data_format_test_match_dnnl, test_match_data_format) {
auto param = GetParam();
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
auto result = onednn::find_data_format(test_desc);
ASSERT_TRUE(result == param.cldnn_format);
}
INSTANTIATE_TEST_SUITE_P(smoke, data_format_test_match_dnnl,
testing::ValuesIn(std::vector<format_matching_test_params>{
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nchw, cldnn::format::bfyx},
{{1, 3, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd4b, cldnn::format::b_fs_yx_fsv4},
{{1, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nChw16c, cldnn::format::b_fs_yx_fsv16},
{{1, 12, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::aBcd32b, cldnn::format::b_fs_yx_fsv32},
{{1, 3, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcde4b, cldnn::format::b_fs_zyx_fsv4},
{{12, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::NChw16n16c, cldnn::format::bs_fs_yx_bsv16_fsv16},
{{32, 32, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::NChw32n32c, cldnn::format::bs_fs_yx_bsv32_fsv32},
}),
data_format_test_match_dnnl::PrintToString);
class data_format_test_not_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
public:
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
auto dnnl_format = param_info.param.dnnl_format;
std::string res = " Failed case for " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
return res;
}
};
TEST_P(data_format_test_not_match_dnnl, test_not_match_data_format) {
auto param = GetParam();
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
auto result = onednn::find_data_format(test_desc);
ASSERT_FALSE(result == param.cldnn_format);
}
INSTANTIATE_TEST_SUITE_P(smoke, data_format_test_not_match_dnnl,
testing::ValuesIn(std::vector<format_matching_test_params>{
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nchw, cldnn::format::byxf},
{{1, 3, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd4b, cldnn::format::b_fs_yx_fsv2},
{{1, 12, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::nChw16c, cldnn::format::b_fs_yx_fsv32},
{{1, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd32b, cldnn::format::b_fs_yx_fsv16},
{{1, 3, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcde4b, cldnn::format::b_fs_zyx_fsv16},
{{32, 32, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::NChw16n16c, cldnn::format::bs_fs_yx_bsv32_fsv32},
{{12, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::NChw32n32c, cldnn::format::bs_fs_yx_bsv16_fsv16},
}),
data_format_test_not_match_dnnl::PrintToString);
class weight_format_test_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
public:
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
auto dnnl_format = param_info.param.dnnl_format;
std::string res = " weight format (dnnl::memory::format_tag) : " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
res += " > " + format::traits(param_info.param.cldnn_format).str;
return res;
}
};
TEST_P(weight_format_test_match_dnnl, test_match_data_format) {
auto param = GetParam();
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
auto result = onednn::find_format(test_desc, false);
ASSERT_TRUE(result == param.cldnn_format);
}
INSTANTIATE_TEST_SUITE_P(smoke, weight_format_test_match_dnnl,
testing::ValuesIn(std::vector<format_matching_test_params>{
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::abcd, cldnn::format::oiyx},
{{16, 16, 8, 8}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::ABcd16b16a, cldnn::format::os_is_yx_isv16_osv16},
{{8, 4, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::ABcde8a4b, cldnn::format::os_is_zyx_osv8_isv4},
{{16, 16, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::ABcd2a8b8a2b, cldnn::format::os_is_yx_osa2_isa8_osv8_isv2},
{{32, 32, 8, 8}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::BAcd4b8a8b4a, cldnn::format::is_os_yx_isa4_osa8_isv8_osv4},
}),
weight_format_test_match_dnnl::PrintToString);
struct stride_matching_test_params {
dnnl_desc_params test_desc;
int inner_nblks; // The number of innermost blocks
dnnl_dims_t inner_idxs; // The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i`
dnnl_dims_t inner_blks; // The logical indices of the blocks, e.g. `{1, 0, 1}` in case of 4i16o4i
cldnn::format cldnn_format;
};
class weight_format_test_with_stride : public testing::TestWithParam<stride_matching_test_params> {
public:
static std::string PrintToString(testing::TestParamInfo<stride_matching_test_params> param_info) {
auto strides = param_info.param.test_desc.strides;
std::string res = " stride : {";
for (auto stride : strides) {
res += std::to_string(stride) + " ";
}
res += "} > " + format::traits(param_info.param.cldnn_format).str;
return res;
}
};
TEST_P(weight_format_test_with_stride, test_match_data_format) {
auto param = GetParam();
dnnl::memory::desc test_desc(param.test_desc.dims, param.test_desc.data_type, param.test_desc.strides);
test_desc.data.format_desc.blocking.inner_nblks = param.inner_nblks;
for (auto idx = 0; idx < param.inner_nblks; idx++) {
test_desc.data.format_desc.blocking.inner_idxs[idx] = param.inner_idxs[idx];
test_desc.data.format_desc.blocking.inner_blks[idx] = param.inner_blks[idx];
}
auto result = onednn::find_format(test_desc, false);
ASSERT_TRUE(result == param.cldnn_format);
}
INSTANTIATE_TEST_SUITE_P(smoke, weight_format_test_with_stride,
testing::ValuesIn(std::vector<stride_matching_test_params>{
{{{16, 16, 1, 1}, dnnl::memory::data_type::f16, {16, 256, 1, 1}}, 2, {1, 0}, {16, 16}, cldnn::format::is_os_yx_isv16_osv16},
{{{16, 16, 1, 1}, dnnl::memory::data_type::f16, {256, 16, 1, 1}}, 2, {1, 0}, {16, 16}, cldnn::format::os_is_yx_isv16_osv16},
}),
weight_format_test_with_stride::PrintToString);