Optimize permute gemm onednn (#17621)
* [GPU] Optimized out permute in permute-gemm(onednn) pattern. Permute can be optimized out when permute's in and out are compatible and onednn gemm. Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
@@ -82,6 +82,8 @@ struct format {
|
||||
byxf, ///< used in bitmaps, input from user i.e b images of RGB format
|
||||
fyxb, ///< format not used inside clDNN, but supported in reorder as extension
|
||||
bzyxf,
|
||||
byfx, ///< To be used when onednn gemm allows permute fusing in transformer network. Not for normal use from cldnn.
|
||||
bxfy, ///< To be used when onednn gemm allows permute fusing in transformer network. Not for normal use from cldnn.
|
||||
///< for user provided formats.
|
||||
b_fs_yx_fsv2,
|
||||
b_fs_zyx_fsv2,
|
||||
@@ -129,6 +131,8 @@ struct format {
|
||||
iozyx, ///< 3D weights format for deconvolutions
|
||||
iyxo,
|
||||
oyxi,
|
||||
oyix,
|
||||
oxiy,
|
||||
os_iyx_osv16, ///< format used only for convolution weights
|
||||
o_is_yx_isv16, ///< format used only for convolution weights
|
||||
os_yxi_osv16, ///< format used only for convolution weights
|
||||
@@ -331,6 +335,7 @@ struct format {
|
||||
/// @brief Checks if @p format is simple data format
|
||||
static bool is_simple_data_format(type fmt) {
|
||||
return (fmt == yxfb || fmt == byxf ||
|
||||
fmt == byfx || fmt == bxfy ||
|
||||
fmt == bfyx || fmt == fyxb ||
|
||||
fmt == bfzyx || fmt == bfwzyx ||
|
||||
fmt == bfuwzyx || fmt == bfvuwzyx);
|
||||
|
||||
@@ -207,6 +207,10 @@ kernel_selector::data_layout to_data_layout(format f) {
|
||||
return kernel_selector::data_layout::yxfb;
|
||||
case format::byxf:
|
||||
return kernel_selector::data_layout::byxf;
|
||||
case format::byfx:
|
||||
return kernel_selector::data_layout::byfx;
|
||||
case format::bxfy:
|
||||
return kernel_selector::data_layout::bxfy;
|
||||
case format::fyxb:
|
||||
return kernel_selector::data_layout::fyxb;
|
||||
case format::b_fs_yx_fsv2:
|
||||
@@ -302,6 +306,10 @@ cldnn::format from_data_layout(kernel_selector::data_layout l) {
|
||||
return cldnn::format::yxfb;
|
||||
case kernel_selector::data_layout::byxf:
|
||||
return cldnn::format::byxf;
|
||||
case kernel_selector::data_layout::byfx:
|
||||
return cldnn::format::byfx;
|
||||
case kernel_selector::data_layout::bxfy:
|
||||
return cldnn::format::bxfy;
|
||||
case kernel_selector::data_layout::fyxb:
|
||||
return cldnn::format::fyxb;
|
||||
case kernel_selector::data_layout::b_fs_yx_fsv2:
|
||||
@@ -381,6 +389,10 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
|
||||
return kernel_selector::weights_layout::iyxo;
|
||||
case format::byxf:
|
||||
return kernel_selector::weights_layout::oyxi;
|
||||
case format::byfx:
|
||||
return kernel_selector::weights_layout::oyix;
|
||||
case format::bxfy:
|
||||
return kernel_selector::weights_layout::oxiy;
|
||||
case format::yxfb:
|
||||
case format::yxio:
|
||||
return kernel_selector::weights_layout::yxio;
|
||||
@@ -648,6 +660,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
|
||||
return cldnn::format::oiyx;
|
||||
case kernel_selector::weights_layout::oyxi:
|
||||
return cldnn::format::oyxi;
|
||||
case kernel_selector::weights_layout::oyix:
|
||||
return cldnn::format::oyix;
|
||||
case kernel_selector::weights_layout::oxiy:
|
||||
return cldnn::format::oxiy;
|
||||
case kernel_selector::weights_layout::io:
|
||||
case kernel_selector::weights_layout::iyxo:
|
||||
return cldnn::format::iyxo;
|
||||
|
||||
@@ -102,9 +102,9 @@ protected:
|
||||
in1_dims = onednn::convert_gemm_tensor(in1_l.get_tensor(), rank, batched_dims_can_be_removed);
|
||||
out_dims = onednn::convert_gemm_tensor(out_l.get_tensor(), rank, batched_dims_can_be_removed);
|
||||
|
||||
in0_fmt = onednn::convert_gemm_data_format(in0_dims);
|
||||
in1_fmt = onednn::convert_gemm_data_format(in1_dims);
|
||||
out_fmt = onednn::convert_gemm_data_format(out_dims);
|
||||
in0_fmt = onednn::convert_gemm_data_format(in0_dims, in0_l.format);
|
||||
in1_fmt = onednn::convert_gemm_data_format(in1_dims, in1_l.format);
|
||||
out_fmt = onednn::convert_gemm_data_format(out_dims, out_l.format);
|
||||
|
||||
if (prim->transpose_input0) {
|
||||
in0_fmt = transpose_format(in0_fmt);
|
||||
@@ -121,7 +121,7 @@ protected:
|
||||
auto bias_rank = cldnn::format::dimension(bias_l.format);
|
||||
bias_dt = onednn::convert_data_type(bias_l.data_type);
|
||||
bias_dims = onednn::convert_gemm_tensor(bias_l.get_tensor(), bias_rank, batched_dims_can_be_removed);
|
||||
bias_fmt = onednn::convert_gemm_data_format(bias_dims);
|
||||
bias_fmt = onednn::convert_gemm_data_format(bias_dims, bias_l.format);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,6 +321,9 @@ attach_gemm_onednn::attach_gemm_onednn() {
|
||||
};
|
||||
std::vector<format::type> fmt = {
|
||||
format::bfyx,
|
||||
format::byxf,
|
||||
format::byfx,
|
||||
format::bxfy,
|
||||
format::bfzyx,
|
||||
format::bfwzyx,
|
||||
};
|
||||
|
||||
@@ -80,16 +80,24 @@ dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batche
|
||||
return res;
|
||||
}
|
||||
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
|
||||
switch (dims.size()) {
|
||||
case 2: return dnnl::memory::format_tag::ab;
|
||||
case 3: return dnnl::memory::format_tag::abc;
|
||||
case 4: return dnnl::memory::format_tag::abcd;
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims, format target) {
|
||||
if (dims.size() == target.dimension()) {
|
||||
auto tag = convert_data_format(target);
|
||||
if (tag != dnnl::memory::format_tag::undef) {
|
||||
return tag;
|
||||
} else {
|
||||
throw std::invalid_argument("[clDNN] Unsupported conversion from "+ target.to_string() + " to onednn format_tag");
|
||||
}
|
||||
} else {
|
||||
switch (dims.size()) {
|
||||
case 2: return dnnl::memory::format_tag::ab;
|
||||
case 3: return dnnl::memory::format_tag::abc;
|
||||
case 4: return dnnl::memory::format_tag::abcd;
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims) {
|
||||
auto spatials = t.spatial;
|
||||
dnnl::memory::dims res(dims);
|
||||
@@ -118,6 +126,9 @@ std::vector<std::pair<cldnn::format, dnnl::memory::format_tag>> format_map = {
|
||||
{ cldnn::format::bfyx, dnnl::memory::format_tag::nchw },
|
||||
{ cldnn::format::bfzyx, dnnl::memory::format_tag::ncdhw },
|
||||
{ cldnn::format::byxf, dnnl::memory::format_tag::nhwc },
|
||||
{ cldnn::format::byfx, dnnl::memory::format_tag::acbd },
|
||||
{ cldnn::format::bxfy, dnnl::memory::format_tag::adbc },
|
||||
{ cldnn::format::fyxb, dnnl::memory::format_tag::bcda },
|
||||
{ cldnn::format::bzyxf, dnnl::memory::format_tag::ndhwc },
|
||||
{ cldnn::format::b_fs_yx_fsv2, dnnl::memory::format_tag::undef },
|
||||
{ cldnn::format::b_fs_yx_fsv4, dnnl::memory::format_tag::aBcd4b },
|
||||
|
||||
@@ -29,7 +29,7 @@ dnnl::memory::dims flatten_tensor(cldnn::tensor t);
|
||||
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
|
||||
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
|
||||
cldnn::format convert_data_format(dnnl::memory::format_tag fmt);
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims, format target);
|
||||
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
|
||||
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
|
||||
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc);
|
||||
|
||||
@@ -1873,9 +1873,70 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
|
||||
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(target_format) << " --> " << fmt_to_str(target_format)
|
||||
<< " For index : " << idx << std::endl;
|
||||
}
|
||||
// Optimized out permute from permute-gemm pattern. i.e. permute -> gemm
|
||||
if (node.is_type<gemm>()) {
|
||||
// Only the formats below support permute opt out in gemm and permute pattern. For other formats, need to check the gemm performance.
|
||||
std::vector<format> gemm_in_foramt_white_list = {
|
||||
format::bfyx,
|
||||
format::fyxb,
|
||||
format::byfx,
|
||||
format::bxfy,
|
||||
};
|
||||
for (size_t idx = 0 ; idx < node.get_dependencies().size() ; idx++) {
|
||||
if (node.get_dependency(idx).is_type<permute>()) {
|
||||
auto& pnode = node.get_dependency(idx);
|
||||
if (pnode.has_fused_primitives()) {
|
||||
continue;
|
||||
}
|
||||
auto input_lay = pnode.get_dependency(0).get_output_layout();
|
||||
auto output_lay = pnode.get_output_layout();
|
||||
if (input_lay.compatible(output_lay)) {
|
||||
for (auto candidate : gemm_in_foramt_white_list) {
|
||||
auto impl_param = pnode.get_kernel_impl_params();
|
||||
auto desc = impl_param->typed_desc<permute>();
|
||||
auto permute_order = desc->permute_order;
|
||||
std::vector<size_t> l_permute_order(std::begin(permute_order), std::end(permute_order));
|
||||
if (format::traits(static_cast<format::type>(candidate))._order == l_permute_order) {
|
||||
pnode.init_preferred_fmt(1, 1);
|
||||
pnode.set_preferred_output_fmt(0, format(static_cast<format::type>(candidate)));
|
||||
pnode.can_be_optimized(true);
|
||||
node.set_preferred_input_fmt(idx, format(static_cast<format::type>(candidate)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// gemm -> permute
|
||||
if (node.get_users().size() == 1 && node.get_users().front()->is_type<permute>() && !node.has_fused_primitives()) {
|
||||
std::vector<format> gemm_out_format_white_list = {
|
||||
format::bfyx,
|
||||
format::fyxb,
|
||||
format::byfx,
|
||||
};
|
||||
auto& pnode = node.get_users().front()->as<permute>();
|
||||
if (!pnode.has_fused_primitives()) {
|
||||
auto input_lay = pnode.get_dependency(0).get_output_layout();
|
||||
auto output_lay = pnode.get_output_layout();
|
||||
if (input_lay.compatible(output_lay)) {
|
||||
for (auto candidate : gemm_out_format_white_list) {
|
||||
auto impl_param = pnode.get_kernel_impl_params();
|
||||
auto desc = impl_param->typed_desc<permute>();
|
||||
auto permute_order = desc->permute_order;
|
||||
std::vector<size_t> l_permute_order(std::begin(permute_order), std::end(permute_order));
|
||||
if (format::traits(static_cast<format::type>(candidate))._order == l_permute_order) {
|
||||
node.set_preferred_output_fmt(0, format(static_cast<format::type>(candidate)));
|
||||
pnode.init_preferred_fmt(1, 1);
|
||||
pnode.set_preferred_input_fmt(0, format(static_cast<format::type>(candidate)));
|
||||
pnode.can_be_optimized(true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
|
||||
@@ -1000,7 +1000,7 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
|
||||
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);
|
||||
dnnl::memory::data_type dt = onednn::convert_data_type(in.data_type);
|
||||
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims);
|
||||
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims, in.format);
|
||||
post_ops.append_binary(alg, dnnl::memory::desc(dims, dt, fmt));
|
||||
update_onednn_post_op_list(op_type, dep_idx, fmt, false, dims, dt);
|
||||
} else {
|
||||
|
||||
@@ -90,6 +90,8 @@ std::string toString(DataLayout l) {
|
||||
case kernel_selector::DataLayout::bfyx: return "BFYX";
|
||||
case kernel_selector::DataLayout::yxfb: return "YXFB";
|
||||
case kernel_selector::DataLayout::byxf: return "BYXF";
|
||||
case kernel_selector::DataLayout::byfx: return "BYFX";
|
||||
case kernel_selector::DataLayout::bxfy: return "BXFY";
|
||||
case kernel_selector::DataLayout::fyxb: return "FYXB";
|
||||
case kernel_selector::DataLayout::b_fs_yx_fsv2: return "B_FS_YX_FSV2";
|
||||
case kernel_selector::DataLayout::b_fs_yx_fsv4: return "B_FS_YX_FSV4";
|
||||
@@ -297,6 +299,8 @@ std::string toString(WeightsLayout layout) {
|
||||
case WeightsLayout::oiyx: return "OIYX";
|
||||
case WeightsLayout::ioyx: return "IOYX";
|
||||
case WeightsLayout::oyxi: return "OYXI";
|
||||
case WeightsLayout::oyix: return "OYIX";
|
||||
case WeightsLayout::oxiy: return "OXIY";
|
||||
case WeightsLayout::iyxo: return "IYXO";
|
||||
case WeightsLayout::yxio: return "YXIO";
|
||||
case WeightsLayout::os_is_yx_isv16_osv16: return "OS_IS_YX_ISV16_OSV16";
|
||||
|
||||
@@ -301,6 +301,12 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
|
||||
case DataLayout::byxf:
|
||||
layout_order = { f, x, y, b, z, w, u, v };
|
||||
break;
|
||||
case DataLayout::byfx:
|
||||
layout_order = { x, f, y, b, z, w, u, v };
|
||||
break;
|
||||
case DataLayout::bxfy:
|
||||
layout_order = { y, f, x, b, z, w, u, v };
|
||||
break;
|
||||
case DataLayout::fyxb:
|
||||
layout_order = { b, x, y, f, z, w, u, v };
|
||||
break;
|
||||
|
||||
@@ -21,6 +21,8 @@ DataTensor::DataChannelArray DataTensor::dataChannelArray {{
|
||||
{ DataLayout::bfyx, { 0, 1, -1, -1, -1, -1, 2, 3 } },
|
||||
{ DataLayout::yxfb, { 2, 3, -1, -1, -1, -1, 1, 0 } },
|
||||
{ DataLayout::byxf, { 1, 2, -1, -1, -1, -1, 0, 3 } },
|
||||
{ DataLayout::byfx, { 0, 2, -1, -1, -1, -1, 1, 3 } },
|
||||
{ DataLayout::bxfy, { 2, 0, -1, -1, -1, -1, 1, 3 } },
|
||||
{ DataLayout::fyxb, { 1, 2, -1, -1, -1, -1, 3, 0 } },
|
||||
{ DataLayout::b_fs_yx_fsv2, { 0, 1, -1, -1, -1, -1, 2, 3 } },
|
||||
{ DataLayout::b_fs_yx_fsv4, { 0, 1, -1, -1, -1, -1, 2, 3 } },
|
||||
@@ -68,9 +70,10 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
|
||||
{ WeightsLayout::io, { -1, -1, -1, 1, 0, -1 } },
|
||||
{ WeightsLayout::oiyx, { 0, 1, -1, 2, 3, -1 } },
|
||||
{ WeightsLayout::ioyx, { 0, 1, -1, 3, 2, -1 } },
|
||||
{ WeightsLayout::oyxi, { 1, 2, -1, 0, 3, -1 } },
|
||||
{ WeightsLayout::iyxo, { 1, 2, -1, 3, 0, -1 } },
|
||||
{ WeightsLayout::oyxi, { 1, 2, -1, 0, 3, -1 } },
|
||||
{ WeightsLayout::oyix, { 0, 2, -1, 1, 3, -1 } },
|
||||
{ WeightsLayout::oxiy, { 2, 0, -1, 1, 3, -1 } },
|
||||
{ WeightsLayout::yxio, { 2, 3, -1, 1, 0, -1 } },
|
||||
{ WeightsLayout::os_iyx_osv16, { 0, 1, -1, 2, 3, -1 } },
|
||||
{ WeightsLayout::os_iyx_osv32, { 0, 1, -1, 2, 3, -1 } },
|
||||
|
||||
@@ -33,6 +33,8 @@ enum DataLayout {
|
||||
byxf, // 3D+batch
|
||||
fyxb, // 3D+batch
|
||||
bfxy, // 3D+batch
|
||||
byfx,
|
||||
bxfy,
|
||||
b_fs_yx_fsv2,
|
||||
b_fs_zyx_fsv2,
|
||||
b_fs_yx_fsv4, // reordering format for swizzled input for convolution using IMAD
|
||||
@@ -83,6 +85,8 @@ enum WeightsLayout {
|
||||
oiyx,
|
||||
ioyx,
|
||||
oyxi,
|
||||
oyix,
|
||||
oxiy,
|
||||
iyxo,
|
||||
yxio,
|
||||
o_is_yx_isv16,
|
||||
@@ -289,6 +293,8 @@ inline bool SimpleLayout(WeightsLayout l) {
|
||||
case WeightsLayout::oiyx:
|
||||
case WeightsLayout::ioyx:
|
||||
case WeightsLayout::oyxi:
|
||||
case WeightsLayout::oyix:
|
||||
case WeightsLayout::oxiy:
|
||||
case WeightsLayout::iyxo:
|
||||
case WeightsLayout::yxio:
|
||||
case WeightsLayout::oizyx:
|
||||
@@ -307,6 +313,8 @@ inline bool SimpleLayout(DataLayout l) {
|
||||
case DataLayout::bfyx:
|
||||
case DataLayout::yxfb:
|
||||
case DataLayout::byxf:
|
||||
case DataLayout::byfx:
|
||||
case DataLayout::bxfy:
|
||||
case DataLayout::fyxb:
|
||||
case DataLayout::bfxy:
|
||||
case DataLayout::bfzyx:
|
||||
|
||||
@@ -26,6 +26,8 @@ static const std::map<format::type, format_traits> format_traits_map {
|
||||
FMT_TRAITS(byxf, 1, 1, 2, 0, {0, 2, 3, 1}, "byxf", "bfxy?", {}),
|
||||
FMT_TRAITS(bfyx, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}),
|
||||
FMT_TRAITS(fyxb, 1, 1, 2, 0, {1, 2, 3, 0}, "fyxb", "bfxy?", {}),
|
||||
FMT_TRAITS(byfx, 1, 1, 2, 0, {0, 2, 1, 3}, "byfx", "bfxy?", {}),
|
||||
FMT_TRAITS(bxfy, 1, 1, 2, 0, {0, 3, 1, 2}, "bxfy", "bfxy?", {}),
|
||||
FMT_TRAITS(b_fs_yx_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 2}}),
|
||||
FMT_TRAITS(b_fs_yx_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 4}}),
|
||||
FMT_TRAITS(b_fs_yx_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 16}}),
|
||||
@@ -72,6 +74,8 @@ static const std::map<format::type, format_traits> format_traits_map {
|
||||
FMT_TRAITS(ioyx, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {}),
|
||||
FMT_TRAITS(iyxo, 1, 1, 2, 0, {1, 2, 3, 0}, "iyxo", "oixy", {}),
|
||||
FMT_TRAITS(oyxi, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {}),
|
||||
FMT_TRAITS(oyix, 1, 1, 2, 0, {0, 2, 1, 3}, "oyix", "oixy", {}),
|
||||
FMT_TRAITS(oxiy, 1, 1, 2, 0, {0, 3, 1, 2}, "oxiy", "oixy", {}),
|
||||
FMT_TRAITS(yxio, 1, 1, 2, 0, {2, 3, 1, 0}, "yxio", "oixy?", {}),
|
||||
FMT_TRAITS(oizyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {}),
|
||||
FMT_TRAITS(iozyx, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {}),
|
||||
|
||||
@@ -142,6 +142,10 @@ static format to_weights_format(format f, bool is_grouped) {
|
||||
return format::iyxo;
|
||||
case format::byxf:
|
||||
return format::oyxi;
|
||||
case format::byfx:
|
||||
return format::oyix;
|
||||
case format::bxfy:
|
||||
return format::oxiy;
|
||||
case format::yxfb:
|
||||
return format::yxio;
|
||||
case format::bfzyx:
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <intel_gpu/primitives/quantize.hpp>
|
||||
#include <intel_gpu/primitives/eltwise.hpp>
|
||||
#include <intel_gpu/primitives/gemm.hpp>
|
||||
#include <intel_gpu/primitives/permute.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
#include <intel_gpu/runtime/tensor.hpp>
|
||||
|
||||
@@ -526,3 +527,103 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
// gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_2, 3, 3, "gemm_mmad_int8_slm" }, // tolerance issue
|
||||
gemm_test_params{CASE_GEMM_ELTWISE_2IN_FP16_2, 3, 3, "gemm_tiled_opt"},
|
||||
}));
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
class GemmFusingTestOneDNN : public ::BaseFusingTest<gemm_test_params> {
|
||||
public:
|
||||
void execute(gemm_test_params& p, bool is_dynamic, bool is_caching_test = false) {
|
||||
if (!engine.get_device_info().supports_immad)
|
||||
return;
|
||||
cfg_not_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
|
||||
|
||||
auto impl_forcing = cfg_fused.get_property(ov::intel_gpu::force_implementations);
|
||||
auto forcing_format = p.input_format;
|
||||
for (auto& forcing : impl_forcing)
|
||||
if (forcing.first == "gemm_prim")
|
||||
forcing_format = forcing.second.output_format;
|
||||
ov::intel_gpu::ImplementationDesc gemm_impl = { forcing_format, "", impl_types::onednn };
|
||||
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemm_impl } }));
|
||||
cfg_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
|
||||
|
||||
auto input0_prim = get_mem(get_input_layout(p, 0));
|
||||
auto input1_prim = get_mem(get_input_layout(p, 1));
|
||||
|
||||
network::ptr network_not_fused = get_network(this->engine, this->topology_non_fused, cfg_not_fused, get_test_stream_ptr(), is_caching_test);
|
||||
network::ptr network_fused = get_network(this->engine, this->topology_fused, cfg_fused, get_test_stream_ptr(), is_caching_test);
|
||||
network_fused->set_input_data("input0", input0_prim);
|
||||
network_not_fused->set_input_data("input0", input0_prim);
|
||||
network_fused->set_input_data("input1", input1_prim);
|
||||
network_not_fused->set_input_data("input1", input1_prim);
|
||||
if (p.in_shapes.size() > 2) {
|
||||
auto input2_prim = get_mem(get_input_layout(p, 2));
|
||||
network_fused->set_input_data("input2", input2_prim);
|
||||
network_not_fused->set_input_data("input2", input2_prim);
|
||||
}
|
||||
|
||||
compare(*network_not_fused, *network_fused, p);
|
||||
}
|
||||
|
||||
layout get_input_layout(gemm_test_params& p, int in_no) {
|
||||
if (in_no == 0)
|
||||
return layout{ p.in_shapes.at(0), p.data_type_in0, p.input_format };
|
||||
else if (in_no == 1)
|
||||
return layout{ p.in_shapes.at(1), p.data_type_in1, p.input_format };
|
||||
else
|
||||
return layout{ p.in_shapes.at(2), p.data_type_in2, p.input_format };
|
||||
}
|
||||
|
||||
layout get_per_channel_layout(gemm_test_params& p) {
|
||||
return layout{ov::PartialShape{ 1, p.in_shapes[0][1], 1, 1 }, p.default_type, p.default_format };
|
||||
}
|
||||
|
||||
layout get_output_layout(gemm_test_params& p) {
|
||||
return layout{ p.out_shape, p.default_type, p.input_format };
|
||||
}
|
||||
};
|
||||
|
||||
class gemm_permute_2in : public GemmFusingTestOneDNN {};
|
||||
TEST_P(gemm_permute_2in, gemm_permute) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f16),
|
||||
permute("permute", input_info("gemm_prim"), {0, 2, 1, 3}),
|
||||
reorder("reorder_bfyx", input_info("permute"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::f16);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
fusings_gpu, gemm_permute_2in, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_1, 3, 4},
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_2, 3, 4},
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_3, 3, 4},
|
||||
}));
|
||||
|
||||
class permute_gemm_2in : public GemmFusingTestOneDNN {};
|
||||
TEST_P(permute_gemm_2in, permute_gemm) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
permute("permute0", input_info("input0"), {0, 2, 1, 3}),
|
||||
permute("permute1", input_info("input1"), {1, 2, 3, 0}),
|
||||
gemm("gemm_prim", { input_info("permute0"), input_info("permute1") }, data_types::f16),
|
||||
reorder("reorder_bfyx", input_info("gemm_prim"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::f16);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
fusings_gpu, permute_gemm_2in, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_1, 3, 5},
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_2, 3, 5},
|
||||
gemm_test_params{CASE_GEMM_2IN_FP16_3, 3, 5},
|
||||
}));
|
||||
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
@@ -172,7 +172,7 @@ public:
|
||||
const auto params_hash = prim_inst->get_impl_params()->hash();
|
||||
|
||||
ASSERT_EQ(primitive_hash, 16293979194373117693UL);
|
||||
ASSERT_EQ(params_hash, 12014408712579440062UL);
|
||||
ASSERT_EQ(params_hash, 3866569467272213453UL);
|
||||
}
|
||||
|
||||
void test_reshape_basic(bool is_caching_test) {
|
||||
|
||||
Reference in New Issue
Block a user