[GPU] Make generic logic to find formats from meme::desc (#13730)
* [GPU] Make generic logic to find formats from meme::desc + Added test-cases Signed-off-by: Min, Byungil <byungil.min@intel.com>
This commit is contained in:
parent
c1f6da31b6
commit
c4cd3e152b
@ -332,6 +332,8 @@ struct format {
|
||||
|
||||
static format adjust_to_rank(format fmt, size_t new_rank);
|
||||
|
||||
static const std::vector<std::pair<size_t, int>> per_axis_block_size(format fmt);
|
||||
|
||||
/// @brief Checks if @p format is of grouped type
|
||||
static bool is_grouped(type fmt) { return group_num(fmt) != 0; }
|
||||
/// @brief Checks if @p format is of image type
|
||||
|
@ -309,314 +309,162 @@ dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_t
|
||||
}
|
||||
}
|
||||
|
||||
static bool isSame(dnnl::memory::desc desc, dnnl::memory::format_tag fmt) {
|
||||
dnnl::memory::desc refDesc(desc.dims(), desc.data_type(), fmt);
|
||||
static void get_identical_order(std::vector<std::vector<size_t>>& orders, std::vector<size_t> order,
|
||||
size_t first, size_t depth) {
|
||||
if (depth == 0)
|
||||
return;
|
||||
|
||||
if (desc.data.ndims != refDesc.data.ndims)
|
||||
return false;
|
||||
for (size_t idx = first; idx <= first + depth ; idx++) {
|
||||
std::swap(order[first], order[idx]);
|
||||
if (first != idx)
|
||||
orders.push_back(order);
|
||||
|
||||
if (desc.data.format_kind != dnnl_blocked || refDesc.data.format_kind != dnnl_blocked)
|
||||
throw std::runtime_error("dnnlMemoryDesc::isSame is not implemented for non blocked memory format");
|
||||
|
||||
auto actualBlkDesc = desc.data.format_desc.blocking;
|
||||
auto refBlkDesc = refDesc.data.format_desc.blocking;
|
||||
if (actualBlkDesc.inner_nblks != refBlkDesc.inner_nblks)
|
||||
return false;
|
||||
|
||||
for (int i = 0; i < actualBlkDesc.inner_nblks; ++i)
|
||||
if (actualBlkDesc.inner_blks[i] != refBlkDesc.inner_blks[i])
|
||||
return false;
|
||||
|
||||
for (int i = 0; i < actualBlkDesc.inner_nblks; ++i)
|
||||
if (actualBlkDesc.inner_idxs[i] != refBlkDesc.inner_idxs[i])
|
||||
return false;
|
||||
|
||||
auto actualStrides = desc.data.format_desc.blocking.strides;
|
||||
auto refStrides = refDesc.data.format_desc.blocking.strides;
|
||||
|
||||
std::vector<size_t> actualOrder(desc.data.ndims);
|
||||
std::iota(actualOrder.begin(), actualOrder.end(), 0);
|
||||
std::sort(actualOrder.begin(), actualOrder.end(),
|
||||
[&actualStrides] (size_t ind_l, size_t ind_r) {
|
||||
return actualStrides[ind_l] > actualStrides[ind_r];
|
||||
});
|
||||
|
||||
std::vector<size_t> refOrder(refDesc.data.ndims);
|
||||
std::iota(refOrder.begin(), refOrder.end(), 0);
|
||||
std::sort(refOrder.begin(), refOrder.end(),
|
||||
[&refStrides] (size_t ind_l, size_t ind_r) {
|
||||
return refStrides[ind_l] > refStrides[ind_r];
|
||||
});
|
||||
|
||||
if (actualOrder != refOrder) {
|
||||
return false;
|
||||
get_identical_order(orders, order, first+1, depth-1);
|
||||
std::swap(order[first], order[idx]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc) {
|
||||
// TODO [OneDNN]: Previously it was a field of tdesc, but now the brute
|
||||
// force search here. Please avoid of using this method.
|
||||
const auto ndims = desc.dims().size();
|
||||
|
||||
// There are no suitable format_tag for this
|
||||
if (ndims == 0 || ndims > 6)
|
||||
return dnnl::memory::format_tag::undef;
|
||||
|
||||
for (const auto fmt : form_tags_by_ndims.at(static_cast<int>(ndims))) {
|
||||
if (isSame(desc, fmt))
|
||||
return fmt;
|
||||
}
|
||||
|
||||
return dnnl::memory::format_tag::undef;
|
||||
}
|
||||
|
||||
static std::vector<size_t> get_order(dnnl::memory::desc desc) {
|
||||
auto blk = desc.data.format_desc.blocking;
|
||||
auto strides = blk.strides;
|
||||
// Get candidate orders calculated by stride value of dnnl::memory::descriptor could be multiple
|
||||
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc) {
|
||||
std::vector<std::vector<size_t>> orders;
|
||||
auto strides = desc.data.format_desc.blocking.strides;
|
||||
std::vector<size_t> order(desc.data.ndims);
|
||||
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(),
|
||||
[&strides] (size_t ind_l, size_t ind_r) {
|
||||
return (strides[ind_l] > strides[ind_r]);
|
||||
});
|
||||
return order;
|
||||
|
||||
orders.push_back(order);
|
||||
|
||||
// Orders of those axes which have a same stride in memory::desc can be changed.
|
||||
// If y and x axes have same, then it can be bfyx or bfxy.
|
||||
for (size_t idx = 0 ; idx+1 < order.size() ; idx++) {
|
||||
size_t depth = 0;
|
||||
for (size_t next = idx+1 ; next < order.size() ; next++) {
|
||||
if (strides[order[idx]] == strides[order[next]]) {
|
||||
depth++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// mutiple axes can have a same stride value of mem descriptor
|
||||
get_identical_order(orders, order, idx, depth);
|
||||
idx += depth;
|
||||
}
|
||||
|
||||
return orders;
|
||||
}
|
||||
|
||||
static bool compare_strides(std::vector<size_t> a, std::vector<size_t> b) {
|
||||
return std::equal(a.begin(), a.end(), b.begin());
|
||||
static bool compare_orders(std::vector<std::vector<size_t>> a, std::vector<size_t> b) {
|
||||
for (size_t idx = 0 ; idx < a.size() ; idx++) {
|
||||
if (std::equal(a[idx].begin(), a[idx].end(), b.begin()))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
cldnn::format find_data_format(dnnl::memory::desc desc) {
|
||||
auto onednn_desc = get_format_by_desc(desc);
|
||||
auto blk = desc.data.format_desc.blocking;
|
||||
auto order = get_candidate_orders(desc);
|
||||
|
||||
if (onednn_desc != dnnl::memory::format_tag::undef) {
|
||||
return convert_data_format(onednn_desc);
|
||||
} else {
|
||||
auto blk = desc.data.format_desc.blocking;
|
||||
auto order = get_order(desc);
|
||||
for (int32_t fmt_idx = format::bfyx ; fmt_idx < format::format_num ; fmt_idx++) {
|
||||
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
|
||||
if (desc.data.ndims == static_cast<int>(candidate_trait._order.size())
|
||||
&& blk.inner_nblks == static_cast<int>(candidate_trait.block_sizes.size())
|
||||
&& compare_strides(order, candidate_trait._order)) {
|
||||
bool is_match = true;
|
||||
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
|
||||
if (blk.inner_blks[idx] != static_cast<int>(candidate_trait.block_sizes[idx].second)
|
||||
|| blk.inner_idxs[idx] != static_cast<int>(candidate_trait.block_sizes[idx].first)) {
|
||||
for (int32_t fmt_idx = format::bfyx ; fmt_idx < format::oiyx ; fmt_idx++) {
|
||||
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
|
||||
if (desc.data.ndims == static_cast<int>(candidate_trait._order.size())
|
||||
&& blk.inner_nblks == static_cast<int>(candidate_trait.block_sizes.size())
|
||||
&& compare_orders(order, candidate_trait._order)) {
|
||||
bool is_match = true;
|
||||
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
|
||||
if (blk.inner_blks[idx] != static_cast<int>(candidate_trait.block_sizes[idx].second)
|
||||
|| blk.inner_idxs[idx] != static_cast<int>(candidate_trait.block_sizes[idx].first)) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_match)
|
||||
return static_cast<format::type>(fmt_idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream msg;
|
||||
msg << "Unsupported onednn dnnl::memory::desc find_data_format. "
|
||||
<< "ndims: " << desc.data.ndims
|
||||
<< ", inner_nblks: " << blk.inner_nblks
|
||||
<< ", inner_blks: ";
|
||||
for (int i = 0; i < blk.inner_nblks; i++)
|
||||
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
|
||||
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped) {
|
||||
auto blk = desc.data.format_desc.blocking;
|
||||
auto orders = get_candidate_orders(desc);
|
||||
|
||||
format start_format = format::oiyx;
|
||||
if (is_grouped)
|
||||
start_format = format::goiyx;
|
||||
|
||||
for (int32_t fmt_idx = start_format ; fmt_idx < format::format_num ; fmt_idx++) {
|
||||
auto candidate_trait = format::traits(static_cast<format::type>(fmt_idx));
|
||||
if (static_cast<size_t>(desc.data.ndims) == candidate_trait._order.size()
|
||||
&& static_cast<size_t>(blk.inner_nblks) == candidate_trait.block_sizes.size()
|
||||
&& compare_orders(orders, candidate_trait._order)) {
|
||||
// Compare all pairs of dimension number and block size to format_traits_map of all formats
|
||||
bool is_match = true;
|
||||
for (size_t idx = 0 ; idx < candidate_trait.block_sizes.size() ; idx++) {
|
||||
auto block_idx = static_cast<dnnl_dim_t>(candidate_trait.block_sizes[idx].first);
|
||||
auto block_size = static_cast<dnnl_dim_t>(candidate_trait.block_sizes[idx].second);
|
||||
if (is_grouped && candidate_trait.is_group_char(candidate_trait.internal_order[block_idx])) {
|
||||
// inner_idx gets the index of group dimension in mem::desc when blocked axis is group
|
||||
auto inner_idx = candidate_trait.order.find_first_of(candidate_trait.internal_order[block_idx]);
|
||||
if (blk.inner_blks[idx] != block_size ||
|
||||
blk.inner_idxs[idx] != static_cast<dnnl_dim_t>(inner_idx)) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
} else if (is_grouped) {
|
||||
// g,o,i from cldnn formats are matching to a,b,c of dnnl. But g is at the end of internal order.
|
||||
if (blk.inner_blks[idx] != block_size ||
|
||||
(blk.inner_idxs[idx] - static_cast<dnnl_dim_t>(candidate_trait.group_num)) != block_idx) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (blk.inner_blks[idx] != block_size ||
|
||||
blk.inner_idxs[idx] != block_idx) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_match)
|
||||
return static_cast<format::type>(fmt_idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream msg;
|
||||
msg << "Unsupported onednn dnnl::memory::desc find_data_format. "
|
||||
<< "ndims: " << desc.data.ndims
|
||||
<< ", inner_nblks: " << blk.inner_nblks
|
||||
<< ", inner_blks: ";
|
||||
for (int i = 0; i < blk.inner_nblks; i++)
|
||||
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
|
||||
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// onednn -> cldnn
|
||||
static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
|
||||
if (is_grouped) {
|
||||
switch (fmt) {
|
||||
case dnnl::memory::format_tag::abcde: return cldnn::format::goiyx;
|
||||
case dnnl::memory::format_tag::abcdef: return cldnn::format::goizyx;
|
||||
case dnnl::memory::format_tag::Abcdef16a: return cldnn::format::gs_oizyx_gsv16;
|
||||
case dnnl::memory::format_tag::Abcde16a: return cldnn::format::gs_oiyx_gsv16;
|
||||
case dnnl::memory::format_tag::Abcde32a: return cldnn::format::gs_oiyx_gsv32;
|
||||
case dnnl::memory::format_tag::Abcdef32a: return cldnn::format::gs_oizyx_gsv32;
|
||||
case dnnl::memory::format_tag::aCBde16c16b: return cldnn::format::g_is_os_yx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::aBCde2b8c8b2c: return cldnn::format::g_os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::aBCde4b8c8b4c: return cldnn::format::g_os_is_yx_osa4_isa8_osv8_isv4;
|
||||
case dnnl::memory::format_tag::aBCde4b8c8b2c: return cldnn::format::g_os_is_yx_osa4_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::aBCde8b2c: return cldnn::format::g_os_is_yx_osv8_isv2;
|
||||
case dnnl::memory::format_tag::aBCde8b4c: return cldnn::format::g_os_is_yx_osv8_isv4;
|
||||
case dnnl::memory::format_tag::aBcde8b: return cldnn::format::g_os_iyx_osv8;
|
||||
case dnnl::memory::format_tag::aBCd2b8c16b4c: return cldnn::format::g_os_is_yx_osa2_isa8_osv16_isv4;
|
||||
case dnnl::memory::format_tag::aBCd2b8c16b2c: return cldnn::format::g_os_is_yx_osa2_isa8_osv16_isv2;
|
||||
case dnnl::memory::format_tag::aBCdef16c16b: return cldnn::format::g_os_is_zyx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::aBCdef4b8c8b2c: return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::aBCdef4b8c8b4c: return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv4;
|
||||
default: throw std::runtime_error(std::string("Unsupported grouped onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
|
||||
}
|
||||
} else {
|
||||
switch (fmt) {
|
||||
case dnnl::memory::format_tag::ab: return cldnn::format::oiyx;
|
||||
case dnnl::memory::format_tag::abcd: return cldnn::format::oiyx;
|
||||
case dnnl::memory::format_tag::bacd: return cldnn::format::ioyx;
|
||||
case dnnl::memory::format_tag::bcda: return cldnn::format::iyxo;
|
||||
case dnnl::memory::format_tag::BAcd16b16a: return cldnn::format::is_os_yx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::ABcd16b16a: return cldnn::format::os_is_yx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::abcde: return cldnn::format::oizyx;
|
||||
case dnnl::memory::format_tag::ABcd4a8b8a4b: return cldnn::format::os_is_yx_osa4_isa8_osv8_isv4;
|
||||
case dnnl::memory::format_tag::ABcd4a8b8a2b: return cldnn::format::os_is_yx_osa4_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcde4a8b8a2b: return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcde4a8b8a4b: return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv4;
|
||||
case dnnl::memory::format_tag::ABcd8a4b: return cldnn::format::os_is_yx_osv8_isv4;
|
||||
case dnnl::memory::format_tag::ABcde8a4b: return cldnn::format::os_is_zyx_osv8_isv4;
|
||||
case dnnl::memory::format_tag::ABcde8a2b: return cldnn::format::os_is_zyx_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcd8a2b: return cldnn::format::os_is_yx_osv8_isv2;
|
||||
case dnnl::memory::format_tag::Acdb16a: return cldnn::format::os_yxi_osv16;
|
||||
case dnnl::memory::format_tag::Acdeb16a: return cldnn::format::os_zyxi_osv16;
|
||||
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
|
||||
case dnnl::memory::format_tag::Abcd16a: return cldnn::format::os_iyx_osv16;
|
||||
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
|
||||
case dnnl::memory::format_tag::BAcd4b8a8b4a: return cldnn::format::is_os_yx_isa4_osa8_isv8_osv4;
|
||||
default: throw std::runtime_error(std::string("Unsupported onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
|
||||
if (is_match)
|
||||
return static_cast<format::type>(fmt_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped) {
|
||||
auto onednn_desc = get_format_by_desc(desc);
|
||||
|
||||
if (onednn_desc != dnnl::memory::format_tag::undef) {
|
||||
return convert_format(onednn_desc, is_grouped);
|
||||
} else {
|
||||
auto blk = desc.data.format_desc.blocking;
|
||||
auto order = get_order(desc);
|
||||
if (is_grouped) {
|
||||
if (desc.data.ndims == 5 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
|
||||
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
|
||||
&& compare_strides(order, {0, 1, 2, 3, 4})) {
|
||||
return cldnn::format::g_os_is_yx_isa8_osv8_isv2;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
|
||||
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
|
||||
&& compare_strides(order, {0, 1, 2, 3, 4})) {
|
||||
return cldnn::format::g_os_is_yx_isa8_osv8_isv4;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
|
||||
if (compare_strides(order, {0, 1, 3, 4, 2})) return cldnn::format::g_os_yx_is_osv8_isv2;
|
||||
else if (compare_strides(order, {0, 1, 3, 2, 4})) return cldnn::format::g_os_y_is_x_osv8_isv2;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
|
||||
if (compare_strides(order, {0, 1, 3, 4, 2})) return cldnn::format::g_os_yx_is_osv8_isv4;
|
||||
else if (compare_strides(order, {0, 1, 3, 2, 4})) return cldnn::format::g_os_y_is_x_osv8_isv4;
|
||||
} else if (desc.data.ndims == 6 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
|
||||
if (compare_strides(order, {0, 1, 3, 4, 5, 2})) return cldnn::format::g_os_zyx_is_osv8_isv2;
|
||||
else if (compare_strides(order, {0, 1, 3, 4, 2, 5})) return cldnn::format::g_os_zy_is_x_osv8_isv2;
|
||||
} else if (desc.data.ndims == 6 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
|
||||
if (compare_strides(order, {0, 1, 3, 4, 5, 2})) return cldnn::format::g_os_zyx_is_osv8_isv4;
|
||||
else if (compare_strides(order, {0, 1, 3, 4, 2, 5})) return cldnn::format::g_os_zy_is_x_osv8_isv4;
|
||||
} else if (desc.data.ndims == 6 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
|
||||
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
|
||||
&& compare_strides(order, {0, 1, 2, 3, 4, 5})) {
|
||||
return cldnn::format::g_os_is_zyx_isa8_osv8_isv2;
|
||||
} else if (desc.data.ndims == 6 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
|
||||
&& blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 2
|
||||
&& compare_strides(order, {0, 1, 2, 3, 4, 5})) {
|
||||
return cldnn::format::g_os_is_zyx_isa8_osv8_isv4;
|
||||
}
|
||||
} else {
|
||||
if (desc.data.ndims == 4 && blk.inner_nblks == 4
|
||||
&& blk.inner_blks[0] == 4 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 4
|
||||
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1
|
||||
&& compare_strides(order, {1, 0, 2, 3})) {
|
||||
return cldnn::format::is_os_yx_osa4_isa8_osv8_isv4;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 4
|
||||
&& blk.inner_blks[0] == 2 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 2
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1 && blk.inner_idxs[3] == 0
|
||||
&& compare_strides(order, {1, 0, 2, 3})) {
|
||||
return cldnn::format::is_os_yx_isa2_osa8_isv8_osv2;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 16 && blk.inner_blks[1] == 4 && blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1
|
||||
&& compare_strides(order, {0, 1, 2, 3})) {
|
||||
return cldnn::format::os_is_yx_osv16_isv4;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 16 && blk.inner_blks[1] == 8 && blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0
|
||||
&& compare_strides(order, {0, 1, 2, 3})) {
|
||||
return cldnn::format::is_os_yx_isv16_osv8;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
|
||||
if (compare_strides(order, {0, 1, 2, 3})) return cldnn::format::os_is_yx_isa8_osv8_isv2;
|
||||
else if (compare_strides(order, {1, 0, 2, 3})) return cldnn::format::is_os_yx_isa8_osv8_isv2;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
|
||||
if (compare_strides(order, {0, 1, 2, 3})) return cldnn::format::os_is_yx_isa8_osv8_isv4;
|
||||
else if (compare_strides(order, {1, 0, 2, 3})) return cldnn::format::is_os_yx_isa8_osv8_isv4;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2
|
||||
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
|
||||
if (compare_strides(order, {0, 2, 1, 3})) return cldnn::format::os_y_is_x_osv8_isv2;
|
||||
else if (compare_strides(order, {0, 2, 3, 1})) return cldnn::format::os_yx_is_osv8_isv2;
|
||||
} else if (desc.data.ndims == 4 && blk.inner_nblks == 2
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4
|
||||
&& blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
|
||||
if (compare_strides(order, {0, 2, 1, 3})) return cldnn::format::os_y_is_x_osv8_isv4;
|
||||
else if (compare_strides(order, {0, 2, 3, 1})) return cldnn::format::os_yx_is_osv8_isv4;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2 &&
|
||||
blk.inner_blks[0] == 8 && blk.inner_blks[1] == 2 &&
|
||||
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
|
||||
if (compare_strides(order, {0, 2, 3, 4, 1})) return cldnn::format::os_zyx_is_osv8_isv2;
|
||||
else if (compare_strides(order, {0, 2, 3, 1, 4})) return cldnn::format::os_zy_is_x_osv8_isv2;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 2 &&
|
||||
blk.inner_blks[0] == 8 && blk.inner_blks[1] == 4 &&
|
||||
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
|
||||
if (compare_strides(order, {0, 2, 3, 4, 1})) return cldnn::format::os_zyx_is_osv8_isv4;
|
||||
else if (compare_strides(order, {0, 2, 3, 1, 4})) return cldnn::format::os_zy_is_x_osv8_isv4;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 2
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
|
||||
if (compare_strides(order, {0, 1, 2, 3, 4})) return cldnn::format::os_is_zyx_isa8_osv8_isv2;
|
||||
else if (compare_strides(order, {1, 0, 2, 3, 4})) return cldnn::format::is_os_zyx_isa8_osv8_isv2;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 3
|
||||
&& blk.inner_blks[0] == 8 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 4
|
||||
&& blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0 && blk.inner_idxs[2] == 1) {
|
||||
if (compare_strides(order, {0, 1, 2, 3, 4})) return cldnn::format::os_is_zyx_isa8_osv8_isv4;
|
||||
else if (compare_strides(order, {1, 0, 2, 3, 4})) return cldnn::format::is_os_zyx_isa8_osv8_isv4;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 4 &&
|
||||
blk.inner_blks[0] == 2 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 2 &&
|
||||
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1) {
|
||||
return cldnn::format::os_is_zyx_osa2_isa8_osv8_isv2;
|
||||
} else if (desc.data.ndims == 5 && blk.inner_nblks == 4 &&
|
||||
blk.inner_blks[0] == 4 && blk.inner_blks[1] == 8 && blk.inner_blks[2] == 8 && blk.inner_blks[3] == 4 &&
|
||||
blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1 && blk.inner_idxs[2] == 0 && blk.inner_idxs[3] == 1) {
|
||||
return cldnn::format::os_is_zyx_osa4_isa8_osv8_isv4;
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream msg;
|
||||
msg << "Unsupported " << (is_grouped ? "grouped" : "") << "onednn dnnl::memory::desc find_format. "
|
||||
<< "ndims: " << desc.data.ndims
|
||||
<< ", inner_nblks: " << blk.inner_nblks
|
||||
<< ", inner_blks: ";
|
||||
for (int i = 0; i < blk.inner_nblks; i++)
|
||||
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
|
||||
msg << ", strides_order: ";
|
||||
std::stringstream msg;
|
||||
msg << "Unsupported " << (is_grouped ? "grouped" : "") << "onednn dnnl::memory::desc find_format. "
|
||||
<< "ndims: " << desc.data.ndims
|
||||
<< ", inner_nblks: " << blk.inner_nblks
|
||||
<< ", inner_blks: ";
|
||||
for (int i = 0; i < blk.inner_nblks; i++)
|
||||
msg << "(blk " << blk.inner_blks[i] << ", idx " << blk.inner_idxs[i] << ") ";
|
||||
for (auto order : orders) {
|
||||
msg << ", strides_order : ";
|
||||
for (const auto& value : order)
|
||||
msg << value << " ";
|
||||
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
msg << ", stride_value : ";
|
||||
auto strides = desc.data.format_desc.blocking.strides;
|
||||
for (size_t idx = 0; idx < orders[0].size() ; idx++) {
|
||||
msg << strides[idx] << " ";
|
||||
}
|
||||
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Currently, usage of alpha and beta between cldnn::pow and dnnl::eltwise::pow is different : d = pow(src, a) / d = a * pow(src, b)
|
||||
@ -662,7 +510,5 @@ bool is_per_tensor(cldnn::data_node& node, int32_t& zp_val) {
|
||||
|
||||
template bool is_per_tensor<int8_t>(cldnn::data_node& node, int32_t& zp_val);
|
||||
template bool is_per_tensor<uint8_t>(cldnn::data_node& node, int32_t& zp_val);
|
||||
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace cldnn
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_debug.h>
|
||||
|
||||
#include <intel_gpu/runtime/layout.hpp>
|
||||
#include <intel_gpu/runtime/engine.hpp>
|
||||
@ -32,6 +33,7 @@ cldnn::format convert_data_format(dnnl::memory::format_tag fmt);
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
|
||||
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
|
||||
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
|
||||
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc);
|
||||
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped = false);
|
||||
cldnn::format find_data_format(dnnl::memory::desc desc);
|
||||
dnnl::memory::format_tag get_format_by_desc(dnnl::memory::desc desc);
|
||||
|
@ -439,7 +439,8 @@ bool program_node::is_padding_supported(int axis, int padding) const {
|
||||
if (fmt == format::fs_b_yx_fsv32 && (axis == 0))
|
||||
return false;
|
||||
|
||||
for (const auto& block : fmt.block_sizes()) {
|
||||
auto block_sizes_dims = format::per_axis_block_size(fmt);
|
||||
for (const auto& block : block_sizes_dims) {
|
||||
size_t block_axis = block.first;
|
||||
int block_size = block.second;
|
||||
|
||||
|
@ -87,39 +87,42 @@ static const std::map<format::type, format_traits> format_traits_map {
|
||||
FMT_TRAITS(image_2d_weights_c4_fyx_b, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
|
||||
FMT_TRAITS(image_2d_weights_c1_b_fyx, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
|
||||
FMT_TRAITS(lstm_weights_dio, 1, 1, 2, 0, {0, 1, 3, 2}, "oixy", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{1, 8}, {0, 16}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_zyx_osa2_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 16}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 16}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_zyx_osa2_isa8_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy?", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(os_is_zyx_osa4_isa8_osv8_isv4_swizzled_by_4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(is_os_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 32}, {0, 32}}),
|
||||
FMT_TRAITS(is_os_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(is_os_yx_isa2_osa8_isv8_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 2}, {0, 8}, {1, 8}, {0, 2}}),
|
||||
FMT_TRAITS(is_os_yx_isa4_osa8_isv8_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "ioxy?", {{1, 4}, {0, 8}, {1, 8}, {0, 4}}),
|
||||
FMT_TRAITS(is_o_yx_isv32, 1, 1, 2, 0, {1, 0, 2, 3}, "oyxi", "oixy?", {{1, 32}}),
|
||||
FMT_TRAITS(is_o32_yx_isv32_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_y_x8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_y_x8_osv8_isv4_swizzled_by_4, 1, 1, 2, 0, {0, 1, 2, 3}, "oyxi", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_yx_osv16_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 16}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 4}, {0, 8}}),
|
||||
FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 4}, {0, 8}}),
|
||||
FMT_TRAITS(os_is_yx_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 2}, {0, 8}}),
|
||||
FMT_TRAITS(os_is_zyx_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 2}, {0, 8}}),
|
||||
FMT_TRAITS(os_is_yx_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_zyx_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_zyx_osv8_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_zyx_osv16_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_yx_osv32_isv4_swizzled_by_2, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osv32_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oixy", "oixy?", {{0, 32}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_zyx_osv32_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_yx_osv32_isv32p, 1, 1, 1, 0, {0, 1, 2, 3}, "oixy", "oixy?", {}),
|
||||
FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_zyx_isv16_osv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3, 4}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3, 4}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
|
||||
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
|
||||
FMT_TRAITS(is_os_zyx_isa8_osv8_isv2, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "ioxyz", {{1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(is_os_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "ioxyz", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
@ -133,7 +136,7 @@ static const std::map<format::type, format_traits> format_traits_map {
|
||||
FMT_TRAITS(os_is_zyx_isv8_osv16_isv2, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 2}}),
|
||||
FMT_TRAITS(os_zyxi_osv16, 1, 1, 3, 0, {0, 2, 3, 4, 1}, "ozyxi", "oixyz", {{0, 16}}),
|
||||
FMT_TRAITS(os_is_yx_isv8_osv16_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 16}, {1, 2}}),
|
||||
FMT_TRAITS(os_is_yx_osv16_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 16}, {0, 16}}),
|
||||
FMT_TRAITS(os_is_yx_osv16_isv16, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_zyx_osv32_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(os_is_zyx_osv64_isv16, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{0, 64}, {1, 16}}),
|
||||
FMT_TRAITS(os_iyx_osv32__ai32, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}}),
|
||||
@ -185,15 +188,15 @@ static const std::map<format::type, format_traits> format_traits_map {
|
||||
FMT_TRAITS(g_os_zyx_is_osv32_isv32, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "gozyxi", "oixyz?g", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(g_os_is_yx_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(g_os_is_yx_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 16}, {1, 16}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(g_os_is_zyx_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(g_os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 32}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 32}, {1, 16}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv4, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 4}, {1, 8}, {0, 8}, {1, 4}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa4_isa8_osv8_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(g_os_is_zyx_osa4_isa8_osv8_isv2, 1, 1, 3, 1, {0, 1, 2, 3, 4, 5}, "goizyx", "oixyz?g", {{0, 4}, {1, 8}, {0, 8}, {1, 2}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 16}, {1, 4}}),
|
||||
FMT_TRAITS(g_os_is_yx_osa2_isa8_osv16_isv2, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{0, 2}, {1, 8}, {0, 16}, {1, 2}}),
|
||||
FMT_TRAITS(gs_oi_yxs_gsv4_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 4}}),
|
||||
FMT_TRAITS(gs_oi_yxs_gsv16_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 16}}),
|
||||
FMT_TRAITS(gs_oi_yxs_gsv32_yxsv4, 1, 1, 2, 1, {0, 1, 2, 3, 4}, "goiyx", "oixy??g", {{6, 32}}),
|
||||
@ -307,4 +310,18 @@ format format::adjust_to_rank(format fmt, size_t new_rank) {
|
||||
OPENVINO_ASSERT(false, "Can't adjust format ", fmt.to_string(), " to the new rank (", new_rank, ")");
|
||||
}
|
||||
|
||||
// First : block_idx, Second : block_size
|
||||
const std::vector<std::pair<size_t, int>> format::per_axis_block_size(format fmt) {
|
||||
std::vector<std::pair<size_t, int>> sizes_for_dims;
|
||||
for (const auto& block : fmt.block_sizes()) {
|
||||
auto it = std::find_if(sizes_for_dims.begin(), sizes_for_dims.end(),
|
||||
[&block](const std::pair<size_t, int>& ele) { return ele.first == block.first; });
|
||||
if (it != sizes_for_dims.end())
|
||||
it->second *= block.second; // the axis is double blocked
|
||||
else
|
||||
sizes_for_dims.push_back({block.first, block.second});
|
||||
}
|
||||
|
||||
return sizes_for_dims;
|
||||
}
|
||||
} // namespace cldnn
|
||||
|
@ -22,6 +22,17 @@ file(GLOB_RECURSE SOURCES_MAIN
|
||||
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/transformations/*.cpp"
|
||||
)
|
||||
|
||||
if (NOT ENABLE_ONEDNN_FOR_GPU)
|
||||
set(EXCLUDE_DIR "/onednn/")
|
||||
foreach (SOURCE_FILE ${SOURCES_MAIN})
|
||||
string (FIND ${SOURCE_FILE} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND)
|
||||
if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
|
||||
message (Exclude : ${SOURCE_FILE})
|
||||
list (REMOVE_ITEM SOURCES_MAIN ${SOURCE_FILE})
|
||||
endif ()
|
||||
endforeach(SOURCE_FILE)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
file(GLOB SOURCES_NATVIS
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/float16.natvis"
|
||||
@ -60,6 +71,10 @@ target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
$<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
${CMAKE_HOME_DIRECTORY}/src/core/reference/include/)
|
||||
|
||||
if(ENABLE_ONEDNN_FOR_GPU)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE setupapi)
|
||||
elseif((NOT ANDROID) AND (UNIX))
|
||||
|
@ -68,3 +68,79 @@ INSTANTIATE_TEST_SUITE_P(smoke, format_adjust_test,
|
||||
{cldnn::format::oiyx, 5, cldnn::format::any},
|
||||
}),
|
||||
format_adjust_test::PrintToString);
|
||||
|
||||
struct axes_test_format_params {
|
||||
cldnn::format in_format;
|
||||
// First : block_idx, Second : block_size
|
||||
std::vector<std::pair<size_t, int>> inner_block;
|
||||
std::vector<std::pair<size_t, int>> per_axis_block;
|
||||
};
|
||||
|
||||
class axes_test_format : public testing::TestWithParam<axes_test_format_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<axes_test_format_params> param_info) {
|
||||
auto input_fmt = param_info.param.in_format.to_string();
|
||||
auto blocks = param_info.param.inner_block;
|
||||
auto per_axis_blocks = param_info.param.per_axis_block;
|
||||
std::string res = "in_fmt = " + input_fmt + " : ";
|
||||
for (auto block : blocks) {
|
||||
res += " { " + std::to_string(block.first) + ", " + std::to_string(block.second) + "}";
|
||||
}
|
||||
|
||||
res += " > ";
|
||||
for (auto block : per_axis_blocks) {
|
||||
res += " { " + std::to_string(block.first) + ", " + std::to_string(block.second) + "}";
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(axes_test_format, simple_test) {
|
||||
auto param = GetParam();
|
||||
|
||||
auto per_axis_blocks = format::per_axis_block_size(param.in_format);
|
||||
ASSERT_EQ(per_axis_blocks.size(), param.per_axis_block.size());
|
||||
for (size_t idx = 0; idx < per_axis_blocks.size(); idx++) {
|
||||
ASSERT_EQ(per_axis_blocks.at(idx).first, param.per_axis_block.at(idx).first);
|
||||
ASSERT_EQ(per_axis_blocks.at(idx).second, param.per_axis_block.at(idx).second);
|
||||
}
|
||||
|
||||
auto blocks = format::block_sizes(param.in_format);
|
||||
ASSERT_EQ(blocks.size(), param.inner_block.size());
|
||||
for (size_t idx = 0; idx < blocks.size(); idx++) {
|
||||
ASSERT_EQ(blocks.at(idx).first, param.inner_block.at(idx).first);
|
||||
ASSERT_EQ(blocks.at(idx).second, param.inner_block.at(idx).second);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, axes_test_format,
|
||||
testing::ValuesIn(std::vector<axes_test_format_params>{
|
||||
{format::os_is_yx_isa8_osv8_isv4, {{1, 8}, {0, 8}, {1, 4}}, {{1, 32}, {0, 8}}},
|
||||
{format::os_is_yx_isa8_osv16_isv4, {{1, 8}, {0, 16}, {1, 4}}, {{1, 32}, {0, 16}}},
|
||||
{format::os_is_yx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
{format::os_is_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::os_is_zyx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
{format::os_is_zyx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::os_is_yx_osa2_isa8_osv16_isv2, {{0, 2}, {1, 8}, {0, 16}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
{format::os_is_yx_osa2_isa8_osv16_isv4, {{0, 2}, {1, 8}, {0, 16}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::os_is_yx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
|
||||
{format::os_is_zyx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
|
||||
{format::os_is_zyx_isa8_osv8_isv4, {{1, 8}, {0, 8}, {1, 4}}, {{1, 32}, {0, 8}}},
|
||||
{format::os_is_zyx_isa8_osv16_isv4, {{1, 8}, {0, 16}, {1, 4}}, {{1, 32}, {0, 16}}},
|
||||
{format::is_os_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::is_os_yx_isa2_osa8_isv8_osv2, {{1, 2}, {0, 8}, {1, 8}, {0, 2}}, {{1, 16}, {0, 16}}},
|
||||
{format::is_os_yx_isa4_osa8_isv8_osv4, {{1, 4}, {0, 8}, {1, 8}, {0, 4}}, {{1, 32}, {0, 32}}},
|
||||
{format::os_is_yx_osv8_isv4, {{0, 8}, {1, 4}}, {{0, 8}, {1, 4}}},
|
||||
{format::os_is_zyx_osv8_isv4, {{0, 8}, {1, 4}}, {{0, 8}, {1, 4}}},
|
||||
{format::os_is_yx_osv8_isv2, {{0, 8}, {1, 2}}, {{0, 8}, {1, 2}}},
|
||||
{format::os_is_zyx_osv8_isv2, {{0, 8}, {1, 2}}, {{0, 8}, {1, 2}}},
|
||||
{format::g_os_is_yx_osa2_isa8_osv8_isv2, {{0, 2}, {1, 8}, {0, 8}, {1, 2}}, {{0, 16}, {1, 16}}},
|
||||
{format::g_os_is_yx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::g_os_is_zyx_osa4_isa8_osv8_isv4, {{0, 4}, {1, 8}, {0, 8}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::g_os_is_yx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
{format::g_os_is_zyx_osa4_isa8_osv8_isv2, {{0, 4}, {1, 8}, {0, 8}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
{format::g_os_is_yx_osa2_isa8_osv16_isv4, {{0, 2}, {1, 8}, {0, 16}, {1, 4}}, {{0, 32}, {1, 32}}},
|
||||
{format::g_os_is_yx_osa2_isa8_osv16_isv2, {{0, 2}, {1, 8}, {0, 16}, {1, 2}}, {{0, 32}, {1, 16}}},
|
||||
}),
|
||||
axes_test_format::PrintToString);
|
||||
|
209
src/plugins/intel_gpu/tests/onednn/utils_test.cpp
Normal file
209
src/plugins/intel_gpu/tests/onednn/utils_test.cpp
Normal file
@ -0,0 +1,209 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include "intel_gpu/runtime/format.hpp"
|
||||
#include "graph/impls/onednn/utils.hpp"
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
struct dnnl_desc_params {
|
||||
// Descriptor info to test
|
||||
dnnl::memory::dims dims;
|
||||
dnnl::memory::data_type data_type;
|
||||
dnnl::memory::dims strides; // In case of plain (non-blocked) formats the strides between dimensions
|
||||
};
|
||||
|
||||
struct desc_stride_test_params {
|
||||
dnnl_desc_params test_desc;
|
||||
std::vector<std::vector<size_t>> expected_orders;
|
||||
};
|
||||
|
||||
class format_test_stride : public testing::TestWithParam<desc_stride_test_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<desc_stride_test_params> param_info) {
|
||||
auto strides = param_info.param.test_desc.strides;
|
||||
std::string res = " stride : {";
|
||||
for (auto stride : strides) {
|
||||
res += std::to_string(stride) + " ";
|
||||
}
|
||||
res += "}";
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(format_test_stride, test_candidates_using_stride) {
|
||||
auto param = GetParam();
|
||||
|
||||
dnnl::memory::desc test_desc(param.test_desc.dims, param.test_desc.data_type, param.test_desc.strides);
|
||||
auto candidates= onednn::get_candidate_orders(test_desc);
|
||||
|
||||
ASSERT_EQ(candidates.size(), param.expected_orders.size());
|
||||
bool is_same = true;
|
||||
for (size_t idx = 0; idx < param.expected_orders.size(); idx++) {
|
||||
auto expected = param.expected_orders.at(idx);
|
||||
bool found_match = false;
|
||||
for (size_t idx = 0 ; idx < candidates.size() ; idx++) {
|
||||
if (std::equal(candidates[idx].begin(), candidates[idx].end(), expected.begin()))
|
||||
found_match = true;
|
||||
}
|
||||
|
||||
if (!found_match)
|
||||
is_same = false;
|
||||
}
|
||||
|
||||
ASSERT_TRUE(is_same);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, format_test_stride,
|
||||
testing::ValuesIn(std::vector<desc_stride_test_params>{
|
||||
{{{1, 3, 8, 8}, dnnl::memory::data_type::f16, {768, 256, 16, 1}}, {{0, 1, 2, 3}}},
|
||||
{{{6, 1, 1, 8}, dnnl::memory::data_type::f16, {96, 16, 16, 1}}, {{0, 1, 2, 3}, {0, 2, 1, 3}}},
|
||||
{{{1, 1, 1, 16}, dnnl::memory::data_type::f16, {32, 32, 32, 1}}, {{0, 1, 2, 3}, {0, 2, 1, 3}, {1, 0, 2, 3}, {2, 1, 0, 3}, {1, 2, 0, 3}, {2, 0, 1, 3}}}
|
||||
}),
|
||||
format_test_stride::PrintToString);
|
||||
|
||||
|
||||
struct format_matching_test_params {
|
||||
dnnl::memory::dims dims;
|
||||
dnnl::memory::data_type data_type;
|
||||
dnnl::memory::format_tag dnnl_format;
|
||||
cldnn::format cldnn_format;
|
||||
};
|
||||
|
||||
class data_format_test_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
|
||||
auto dnnl_format = param_info.param.dnnl_format;
|
||||
std::string res = " data format (dnnl::memory::format_tag) : " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
|
||||
res += " > " + format::traits(param_info.param.cldnn_format).str;
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(data_format_test_match_dnnl, test_match_data_format) {
|
||||
auto param = GetParam();
|
||||
|
||||
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
|
||||
auto result = onednn::find_data_format(test_desc);
|
||||
ASSERT_TRUE(result == param.cldnn_format);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, data_format_test_match_dnnl,
|
||||
testing::ValuesIn(std::vector<format_matching_test_params>{
|
||||
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nchw, cldnn::format::bfyx},
|
||||
{{1, 3, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd4b, cldnn::format::b_fs_yx_fsv4},
|
||||
{{1, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nChw16c, cldnn::format::b_fs_yx_fsv16},
|
||||
{{1, 12, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::aBcd32b, cldnn::format::b_fs_yx_fsv32},
|
||||
{{1, 3, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcde4b, cldnn::format::b_fs_zyx_fsv4},
|
||||
{{12, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::NChw16n16c, cldnn::format::bs_fs_yx_bsv16_fsv16},
|
||||
{{32, 32, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::NChw32n32c, cldnn::format::bs_fs_yx_bsv32_fsv32},
|
||||
}),
|
||||
data_format_test_match_dnnl::PrintToString);
|
||||
|
||||
class data_format_test_not_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
|
||||
auto dnnl_format = param_info.param.dnnl_format;
|
||||
std::string res = " Failed case for " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(data_format_test_not_match_dnnl, test_not_match_data_format) {
|
||||
auto param = GetParam();
|
||||
|
||||
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
|
||||
auto result = onednn::find_data_format(test_desc);
|
||||
ASSERT_FALSE(result == param.cldnn_format);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, data_format_test_not_match_dnnl,
|
||||
testing::ValuesIn(std::vector<format_matching_test_params>{
|
||||
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::nchw, cldnn::format::byxf},
|
||||
{{1, 3, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd4b, cldnn::format::b_fs_yx_fsv2},
|
||||
{{1, 12, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::nChw16c, cldnn::format::b_fs_yx_fsv32},
|
||||
{{1, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcd32b, cldnn::format::b_fs_yx_fsv16},
|
||||
{{1, 3, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::aBcde4b, cldnn::format::b_fs_zyx_fsv16},
|
||||
{{32, 32, 16, 16}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::NChw16n16c, cldnn::format::bs_fs_yx_bsv32_fsv32},
|
||||
{{12, 12, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::NChw32n32c, cldnn::format::bs_fs_yx_bsv16_fsv16},
|
||||
}),
|
||||
data_format_test_not_match_dnnl::PrintToString);
|
||||
|
||||
class weight_format_test_match_dnnl : public testing::TestWithParam<format_matching_test_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<format_matching_test_params> param_info) {
|
||||
auto dnnl_format = param_info.param.dnnl_format;
|
||||
std::string res = " weight format (dnnl::memory::format_tag) : " + std::string(dnnl_fmt_tag2str((dnnl_format_tag_t)dnnl_format));
|
||||
res += " > " + format::traits(param_info.param.cldnn_format).str;
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(weight_format_test_match_dnnl, test_match_data_format) {
|
||||
auto param = GetParam();
|
||||
|
||||
dnnl::memory::desc test_desc(param.dims, param.data_type, param.dnnl_format);
|
||||
auto result = onednn::find_format(test_desc, false);
|
||||
ASSERT_TRUE(result == param.cldnn_format);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, weight_format_test_match_dnnl,
|
||||
testing::ValuesIn(std::vector<format_matching_test_params>{
|
||||
{{1, 3, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::abcd, cldnn::format::oiyx},
|
||||
{{16, 16, 8, 8}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::ABcd16b16a, cldnn::format::os_is_yx_isv16_osv16},
|
||||
{{8, 4, 16, 16, 16}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::ABcde8a4b, cldnn::format::os_is_zyx_osv8_isv4},
|
||||
{{16, 16, 8, 8}, dnnl::memory::data_type::f16, dnnl::memory::format_tag::ABcd2a8b8a2b, cldnn::format::os_is_yx_osa2_isa8_osv8_isv2},
|
||||
{{32, 32, 8, 8}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::BAcd4b8a8b4a, cldnn::format::is_os_yx_isa4_osa8_isv8_osv4},
|
||||
}),
|
||||
weight_format_test_match_dnnl::PrintToString);
|
||||
|
||||
struct stride_matching_test_params {
|
||||
dnnl_desc_params test_desc;
|
||||
int inner_nblks; // The number of innermost blocks
|
||||
dnnl_dims_t inner_idxs; // The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i`
|
||||
dnnl_dims_t inner_blks; // The logical indices of the blocks, e.g. `{1, 0, 1}` in case of 4i16o4i
|
||||
cldnn::format cldnn_format;
|
||||
};
|
||||
|
||||
class weight_format_test_with_stride : public testing::TestWithParam<stride_matching_test_params> {
|
||||
public:
|
||||
static std::string PrintToString(testing::TestParamInfo<stride_matching_test_params> param_info) {
|
||||
auto strides = param_info.param.test_desc.strides;
|
||||
std::string res = " stride : {";
|
||||
for (auto stride : strides) {
|
||||
res += std::to_string(stride) + " ";
|
||||
}
|
||||
res += "} > " + format::traits(param_info.param.cldnn_format).str;
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(weight_format_test_with_stride, test_match_data_format) {
|
||||
auto param = GetParam();
|
||||
|
||||
dnnl::memory::desc test_desc(param.test_desc.dims, param.test_desc.data_type, param.test_desc.strides);
|
||||
test_desc.data.format_desc.blocking.inner_nblks = param.inner_nblks;
|
||||
for (auto idx = 0; idx < param.inner_nblks; idx++) {
|
||||
test_desc.data.format_desc.blocking.inner_idxs[idx] = param.inner_idxs[idx];
|
||||
test_desc.data.format_desc.blocking.inner_blks[idx] = param.inner_blks[idx];
|
||||
}
|
||||
|
||||
auto result = onednn::find_format(test_desc, false);
|
||||
ASSERT_TRUE(result == param.cldnn_format);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, weight_format_test_with_stride,
|
||||
testing::ValuesIn(std::vector<stride_matching_test_params>{
|
||||
{{{16, 16, 1, 1}, dnnl::memory::data_type::f16, {16, 256, 1, 1}}, 2, {1, 0}, {16, 16}, cldnn::format::is_os_yx_isv16_osv16},
|
||||
{{{16, 16, 1, 1}, dnnl::memory::data_type::f16, {256, 16, 1, 1}}, 2, {1, 0}, {16, 16}, cldnn::format::os_is_yx_isv16_osv16},
|
||||
}),
|
||||
weight_format_test_with_stride::PrintToString);
|
Loading…
Reference in New Issue
Block a user