Optimize permute gemm onednn (#17621)

* [GPU] Optimized out permute in permute-gemm(onednn) pattern.

Permute can be optimized out when permute's in and out are compatible and onednn gemm.

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim
2023-06-07 16:20:59 +09:00
committed by GitHub
parent 3a1326fb58
commit 13028397b7
15 changed files with 243 additions and 17 deletions

View File

@@ -82,6 +82,8 @@ struct format {
byxf, ///< used in bitmaps, input from user i.e b images of RGB format
fyxb, ///< format not used inside clDNN, but supported in reorder as extension
bzyxf,
byfx, ///< To be used when onednn gemm allows permute fusing in transformer network. Not for normal use from cldnn.
bxfy, ///< To be used when onednn gemm allows permute fusing in transformer network. Not for normal use from cldnn.
///< for user provided formats.
b_fs_yx_fsv2,
b_fs_zyx_fsv2,
@@ -129,6 +131,8 @@ struct format {
iozyx, ///< 3D weights format for deconvolutions
iyxo,
oyxi,
oyix,
oxiy,
os_iyx_osv16, ///< format used only for convolution weights
o_is_yx_isv16, ///< format used only for convolution weights
os_yxi_osv16, ///< format used only for convolution weights
@@ -331,6 +335,7 @@ struct format {
/// @brief Checks if @p format is simple data format
static bool is_simple_data_format(type fmt) {
return (fmt == yxfb || fmt == byxf ||
fmt == byfx || fmt == bxfy ||
fmt == bfyx || fmt == fyxb ||
fmt == bfzyx || fmt == bfwzyx ||
fmt == bfuwzyx || fmt == bfvuwzyx);

View File

@@ -207,6 +207,10 @@ kernel_selector::data_layout to_data_layout(format f) {
return kernel_selector::data_layout::yxfb;
case format::byxf:
return kernel_selector::data_layout::byxf;
case format::byfx:
return kernel_selector::data_layout::byfx;
case format::bxfy:
return kernel_selector::data_layout::bxfy;
case format::fyxb:
return kernel_selector::data_layout::fyxb;
case format::b_fs_yx_fsv2:
@@ -302,6 +306,10 @@ cldnn::format from_data_layout(kernel_selector::data_layout l) {
return cldnn::format::yxfb;
case kernel_selector::data_layout::byxf:
return cldnn::format::byxf;
case kernel_selector::data_layout::byfx:
return cldnn::format::byfx;
case kernel_selector::data_layout::bxfy:
return cldnn::format::bxfy;
case kernel_selector::data_layout::fyxb:
return cldnn::format::fyxb;
case kernel_selector::data_layout::b_fs_yx_fsv2:
@@ -381,6 +389,10 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::iyxo;
case format::byxf:
return kernel_selector::weights_layout::oyxi;
case format::byfx:
return kernel_selector::weights_layout::oyix;
case format::bxfy:
return kernel_selector::weights_layout::oxiy;
case format::yxfb:
case format::yxio:
return kernel_selector::weights_layout::yxio;
@@ -648,6 +660,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::oiyx;
case kernel_selector::weights_layout::oyxi:
return cldnn::format::oyxi;
case kernel_selector::weights_layout::oyix:
return cldnn::format::oyix;
case kernel_selector::weights_layout::oxiy:
return cldnn::format::oxiy;
case kernel_selector::weights_layout::io:
case kernel_selector::weights_layout::iyxo:
return cldnn::format::iyxo;

View File

@@ -102,9 +102,9 @@ protected:
in1_dims = onednn::convert_gemm_tensor(in1_l.get_tensor(), rank, batched_dims_can_be_removed);
out_dims = onednn::convert_gemm_tensor(out_l.get_tensor(), rank, batched_dims_can_be_removed);
in0_fmt = onednn::convert_gemm_data_format(in0_dims);
in1_fmt = onednn::convert_gemm_data_format(in1_dims);
out_fmt = onednn::convert_gemm_data_format(out_dims);
in0_fmt = onednn::convert_gemm_data_format(in0_dims, in0_l.format);
in1_fmt = onednn::convert_gemm_data_format(in1_dims, in1_l.format);
out_fmt = onednn::convert_gemm_data_format(out_dims, out_l.format);
if (prim->transpose_input0) {
in0_fmt = transpose_format(in0_fmt);
@@ -121,7 +121,7 @@ protected:
auto bias_rank = cldnn::format::dimension(bias_l.format);
bias_dt = onednn::convert_data_type(bias_l.data_type);
bias_dims = onednn::convert_gemm_tensor(bias_l.get_tensor(), bias_rank, batched_dims_can_be_removed);
bias_fmt = onednn::convert_gemm_data_format(bias_dims);
bias_fmt = onednn::convert_gemm_data_format(bias_dims, bias_l.format);
}
}
@@ -321,6 +321,9 @@ attach_gemm_onednn::attach_gemm_onednn() {
};
std::vector<format::type> fmt = {
format::bfyx,
format::byxf,
format::byfx,
format::bxfy,
format::bfzyx,
format::bfwzyx,
};

View File

@@ -80,16 +80,24 @@ dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batche
return res;
}
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
switch (dims.size()) {
case 2: return dnnl::memory::format_tag::ab;
case 3: return dnnl::memory::format_tag::abc;
case 4: return dnnl::memory::format_tag::abcd;
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims, format target) {
if (dims.size() == target.dimension()) {
auto tag = convert_data_format(target);
if (tag != dnnl::memory::format_tag::undef) {
return tag;
} else {
throw std::invalid_argument("[clDNN] Unsupported conversion from "+ target.to_string() + " to onednn format_tag");
}
} else {
switch (dims.size()) {
case 2: return dnnl::memory::format_tag::ab;
case 3: return dnnl::memory::format_tag::abc;
case 4: return dnnl::memory::format_tag::abcd;
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
}
}
}
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims) {
auto spatials = t.spatial;
dnnl::memory::dims res(dims);
@@ -118,6 +126,9 @@ std::vector<std::pair<cldnn::format, dnnl::memory::format_tag>> format_map = {
{ cldnn::format::bfyx, dnnl::memory::format_tag::nchw },
{ cldnn::format::bfzyx, dnnl::memory::format_tag::ncdhw },
{ cldnn::format::byxf, dnnl::memory::format_tag::nhwc },
{ cldnn::format::byfx, dnnl::memory::format_tag::acbd },
{ cldnn::format::bxfy, dnnl::memory::format_tag::adbc },
{ cldnn::format::fyxb, dnnl::memory::format_tag::bcda },
{ cldnn::format::bzyxf, dnnl::memory::format_tag::ndhwc },
{ cldnn::format::b_fs_yx_fsv2, dnnl::memory::format_tag::undef },
{ cldnn::format::b_fs_yx_fsv4, dnnl::memory::format_tag::aBcd4b },

View File

@@ -29,7 +29,7 @@ dnnl::memory::dims flatten_tensor(cldnn::tensor t);
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
cldnn::format convert_data_format(dnnl::memory::format_tag fmt);
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims, format target);
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
std::vector<std::vector<size_t>> get_candidate_orders(dnnl::memory::desc desc);

View File

@@ -1873,9 +1873,70 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(target_format) << " --> " << fmt_to_str(target_format)
<< " For index : " << idx << std::endl;
}
// Optimized out permute from permute-gemm pattern. i.e. permute -> gemm
if (node.is_type<gemm>()) {
// Only the formats below support permute opt out in gemm and permute pattern. For other formats, need to check the gemm performance.
std::vector<format> gemm_in_foramt_white_list = {
format::bfyx,
format::fyxb,
format::byfx,
format::bxfy,
};
for (size_t idx = 0 ; idx < node.get_dependencies().size() ; idx++) {
if (node.get_dependency(idx).is_type<permute>()) {
auto& pnode = node.get_dependency(idx);
if (pnode.has_fused_primitives()) {
continue;
}
auto input_lay = pnode.get_dependency(0).get_output_layout();
auto output_lay = pnode.get_output_layout();
if (input_lay.compatible(output_lay)) {
for (auto candidate : gemm_in_foramt_white_list) {
auto impl_param = pnode.get_kernel_impl_params();
auto desc = impl_param->typed_desc<permute>();
auto permute_order = desc->permute_order;
std::vector<size_t> l_permute_order(std::begin(permute_order), std::end(permute_order));
if (format::traits(static_cast<format::type>(candidate))._order == l_permute_order) {
pnode.init_preferred_fmt(1, 1);
pnode.set_preferred_output_fmt(0, format(static_cast<format::type>(candidate)));
pnode.can_be_optimized(true);
node.set_preferred_input_fmt(idx, format(static_cast<format::type>(candidate)));
break;
}
}
}
}
}
// gemm -> permute
if (node.get_users().size() == 1 && node.get_users().front()->is_type<permute>() && !node.has_fused_primitives()) {
std::vector<format> gemm_out_format_white_list = {
format::bfyx,
format::fyxb,
format::byfx,
};
auto& pnode = node.get_users().front()->as<permute>();
if (!pnode.has_fused_primitives()) {
auto input_lay = pnode.get_dependency(0).get_output_layout();
auto output_lay = pnode.get_output_layout();
if (input_lay.compatible(output_lay)) {
for (auto candidate : gemm_out_format_white_list) {
auto impl_param = pnode.get_kernel_impl_params();
auto desc = impl_param->typed_desc<permute>();
auto permute_order = desc->permute_order;
std::vector<size_t> l_permute_order(std::begin(permute_order), std::end(permute_order));
if (format::traits(static_cast<format::type>(candidate))._order == l_permute_order) {
node.set_preferred_output_fmt(0, format(static_cast<format::type>(candidate)));
pnode.init_preferred_fmt(1, 1);
pnode.set_preferred_input_fmt(0, format(static_cast<format::type>(candidate)));
pnode.can_be_optimized(true);
break;
}
}
}
}
}
}
}
return;
}
#endif // ENABLE_ONEDNN_FOR_GPU

View File

@@ -1000,7 +1000,7 @@ void program_node::init_onednn_primitive_attributes() {
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);
dnnl::memory::data_type dt = onednn::convert_data_type(in.data_type);
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims);
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims, in.format);
post_ops.append_binary(alg, dnnl::memory::desc(dims, dt, fmt));
update_onednn_post_op_list(op_type, dep_idx, fmt, false, dims, dt);
} else {

View File

@@ -90,6 +90,8 @@ std::string toString(DataLayout l) {
case kernel_selector::DataLayout::bfyx: return "BFYX";
case kernel_selector::DataLayout::yxfb: return "YXFB";
case kernel_selector::DataLayout::byxf: return "BYXF";
case kernel_selector::DataLayout::byfx: return "BYFX";
case kernel_selector::DataLayout::bxfy: return "BXFY";
case kernel_selector::DataLayout::fyxb: return "FYXB";
case kernel_selector::DataLayout::b_fs_yx_fsv2: return "B_FS_YX_FSV2";
case kernel_selector::DataLayout::b_fs_yx_fsv4: return "B_FS_YX_FSV4";
@@ -297,6 +299,8 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::oiyx: return "OIYX";
case WeightsLayout::ioyx: return "IOYX";
case WeightsLayout::oyxi: return "OYXI";
case WeightsLayout::oyix: return "OYIX";
case WeightsLayout::oxiy: return "OXIY";
case WeightsLayout::iyxo: return "IYXO";
case WeightsLayout::yxio: return "YXIO";
case WeightsLayout::os_is_yx_isv16_osv16: return "OS_IS_YX_ISV16_OSV16";

View File

@@ -301,6 +301,12 @@ std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws, const
case DataLayout::byxf:
layout_order = { f, x, y, b, z, w, u, v };
break;
case DataLayout::byfx:
layout_order = { x, f, y, b, z, w, u, v };
break;
case DataLayout::bxfy:
layout_order = { y, f, x, b, z, w, u, v };
break;
case DataLayout::fyxb:
layout_order = { b, x, y, f, z, w, u, v };
break;

View File

@@ -21,6 +21,8 @@ DataTensor::DataChannelArray DataTensor::dataChannelArray {{
{ DataLayout::bfyx, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::yxfb, { 2, 3, -1, -1, -1, -1, 1, 0 } },
{ DataLayout::byxf, { 1, 2, -1, -1, -1, -1, 0, 3 } },
{ DataLayout::byfx, { 0, 2, -1, -1, -1, -1, 1, 3 } },
{ DataLayout::bxfy, { 2, 0, -1, -1, -1, -1, 1, 3 } },
{ DataLayout::fyxb, { 1, 2, -1, -1, -1, -1, 3, 0 } },
{ DataLayout::b_fs_yx_fsv2, { 0, 1, -1, -1, -1, -1, 2, 3 } },
{ DataLayout::b_fs_yx_fsv4, { 0, 1, -1, -1, -1, -1, 2, 3 } },
@@ -68,9 +70,10 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::io, { -1, -1, -1, 1, 0, -1 } },
{ WeightsLayout::oiyx, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::ioyx, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::oyxi, { 1, 2, -1, 0, 3, -1 } },
{ WeightsLayout::iyxo, { 1, 2, -1, 3, 0, -1 } },
{ WeightsLayout::oyxi, { 1, 2, -1, 0, 3, -1 } },
{ WeightsLayout::oyix, { 0, 2, -1, 1, 3, -1 } },
{ WeightsLayout::oxiy, { 2, 0, -1, 1, 3, -1 } },
{ WeightsLayout::yxio, { 2, 3, -1, 1, 0, -1 } },
{ WeightsLayout::os_iyx_osv16, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_iyx_osv32, { 0, 1, -1, 2, 3, -1 } },

View File

@@ -33,6 +33,8 @@ enum DataLayout {
byxf, // 3D+batch
fyxb, // 3D+batch
bfxy, // 3D+batch
byfx,
bxfy,
b_fs_yx_fsv2,
b_fs_zyx_fsv2,
b_fs_yx_fsv4, // reordering format for swizzled input for convolution using IMAD
@@ -83,6 +85,8 @@ enum WeightsLayout {
oiyx,
ioyx,
oyxi,
oyix,
oxiy,
iyxo,
yxio,
o_is_yx_isv16,
@@ -289,6 +293,8 @@ inline bool SimpleLayout(WeightsLayout l) {
case WeightsLayout::oiyx:
case WeightsLayout::ioyx:
case WeightsLayout::oyxi:
case WeightsLayout::oyix:
case WeightsLayout::oxiy:
case WeightsLayout::iyxo:
case WeightsLayout::yxio:
case WeightsLayout::oizyx:
@@ -307,6 +313,8 @@ inline bool SimpleLayout(DataLayout l) {
case DataLayout::bfyx:
case DataLayout::yxfb:
case DataLayout::byxf:
case DataLayout::byfx:
case DataLayout::bxfy:
case DataLayout::fyxb:
case DataLayout::bfxy:
case DataLayout::bfzyx:

View File

@@ -26,6 +26,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(byxf, 1, 1, 2, 0, {0, 2, 3, 1}, "byxf", "bfxy?", {}),
FMT_TRAITS(bfyx, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {}),
FMT_TRAITS(fyxb, 1, 1, 2, 0, {1, 2, 3, 0}, "fyxb", "bfxy?", {}),
FMT_TRAITS(byfx, 1, 1, 2, 0, {0, 2, 1, 3}, "byfx", "bfxy?", {}),
FMT_TRAITS(bxfy, 1, 1, 2, 0, {0, 3, 1, 2}, "bxfy", "bfxy?", {}),
FMT_TRAITS(b_fs_yx_fsv2, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 2}}),
FMT_TRAITS(b_fs_yx_fsv4, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy?", {{1, 4}}),
FMT_TRAITS(b_fs_yx_fsv16, 1, 1, 2, 0, {0, 1, 2, 3}, "bfyx", "bfxy", {{1, 16}}),
@@ -72,6 +74,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(ioyx, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {}),
FMT_TRAITS(iyxo, 1, 1, 2, 0, {1, 2, 3, 0}, "iyxo", "oixy", {}),
FMT_TRAITS(oyxi, 1, 1, 2, 0, {0, 2, 3, 1}, "oyxi", "oixy", {}),
FMT_TRAITS(oyix, 1, 1, 2, 0, {0, 2, 1, 3}, "oyix", "oixy", {}),
FMT_TRAITS(oxiy, 1, 1, 2, 0, {0, 3, 1, 2}, "oxiy", "oixy", {}),
FMT_TRAITS(yxio, 1, 1, 2, 0, {2, 3, 1, 0}, "yxio", "oixy?", {}),
FMT_TRAITS(oizyx, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {}),
FMT_TRAITS(iozyx, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {}),

View File

@@ -142,6 +142,10 @@ static format to_weights_format(format f, bool is_grouped) {
return format::iyxo;
case format::byxf:
return format::oyxi;
case format::byfx:
return format::oyix;
case format::bxfy:
return format::oxiy;
case format::yxfb:
return format::yxio;
case format::bfzyx:

View File

@@ -9,6 +9,7 @@
#include <intel_gpu/primitives/quantize.hpp>
#include <intel_gpu/primitives/eltwise.hpp>
#include <intel_gpu/primitives/gemm.hpp>
#include <intel_gpu/primitives/permute.hpp>
#include <intel_gpu/primitives/data.hpp>
#include <intel_gpu/runtime/tensor.hpp>
@@ -526,3 +527,103 @@ INSTANTIATE_TEST_SUITE_P(
// gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_2, 3, 3, "gemm_mmad_int8_slm" }, // tolerance issue
gemm_test_params{CASE_GEMM_ELTWISE_2IN_FP16_2, 3, 3, "gemm_tiled_opt"},
}));
#ifdef ENABLE_ONEDNN_FOR_GPU
class GemmFusingTestOneDNN : public ::BaseFusingTest<gemm_test_params> {
public:
void execute(gemm_test_params& p, bool is_dynamic, bool is_caching_test = false) {
if (!engine.get_device_info().supports_immad)
return;
cfg_not_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
auto impl_forcing = cfg_fused.get_property(ov::intel_gpu::force_implementations);
auto forcing_format = p.input_format;
for (auto& forcing : impl_forcing)
if (forcing.first == "gemm_prim")
forcing_format = forcing.second.output_format;
ov::intel_gpu::ImplementationDesc gemm_impl = { forcing_format, "", impl_types::onednn };
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemm_impl } }));
cfg_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
auto input0_prim = get_mem(get_input_layout(p, 0));
auto input1_prim = get_mem(get_input_layout(p, 1));
network::ptr network_not_fused = get_network(this->engine, this->topology_non_fused, cfg_not_fused, get_test_stream_ptr(), is_caching_test);
network::ptr network_fused = get_network(this->engine, this->topology_fused, cfg_fused, get_test_stream_ptr(), is_caching_test);
network_fused->set_input_data("input0", input0_prim);
network_not_fused->set_input_data("input0", input0_prim);
network_fused->set_input_data("input1", input1_prim);
network_not_fused->set_input_data("input1", input1_prim);
if (p.in_shapes.size() > 2) {
auto input2_prim = get_mem(get_input_layout(p, 2));
network_fused->set_input_data("input2", input2_prim);
network_not_fused->set_input_data("input2", input2_prim);
}
compare(*network_not_fused, *network_fused, p);
}
layout get_input_layout(gemm_test_params& p, int in_no) {
if (in_no == 0)
return layout{ p.in_shapes.at(0), p.data_type_in0, p.input_format };
else if (in_no == 1)
return layout{ p.in_shapes.at(1), p.data_type_in1, p.input_format };
else
return layout{ p.in_shapes.at(2), p.data_type_in2, p.input_format };
}
layout get_per_channel_layout(gemm_test_params& p) {
return layout{ov::PartialShape{ 1, p.in_shapes[0][1], 1, 1 }, p.default_type, p.default_format };
}
layout get_output_layout(gemm_test_params& p) {
return layout{ p.out_shape, p.default_type, p.input_format };
}
};
class gemm_permute_2in : public GemmFusingTestOneDNN {};
TEST_P(gemm_permute_2in, gemm_permute) {
auto p = GetParam();
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f16),
permute("permute", input_info("gemm_prim"), {0, 2, 1, 3}),
reorder("reorder_bfyx", input_info("permute"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(data_types::f16);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(
fusings_gpu, gemm_permute_2in, ::testing::ValuesIn(std::vector<gemm_test_params>{
gemm_test_params{CASE_GEMM_2IN_FP16_1, 3, 4},
gemm_test_params{CASE_GEMM_2IN_FP16_2, 3, 4},
gemm_test_params{CASE_GEMM_2IN_FP16_3, 3, 4},
}));
class permute_gemm_2in : public GemmFusingTestOneDNN {};
TEST_P(permute_gemm_2in, permute_gemm) {
auto p = GetParam();
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
permute("permute0", input_info("input0"), {0, 2, 1, 3}),
permute("permute1", input_info("input1"), {1, 2, 3, 0}),
gemm("gemm_prim", { input_info("permute0"), input_info("permute1") }, data_types::f16),
reorder("reorder_bfyx", input_info("gemm_prim"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(data_types::f16);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(
fusings_gpu, permute_gemm_2in, ::testing::ValuesIn(std::vector<gemm_test_params>{
gemm_test_params{CASE_GEMM_2IN_FP16_1, 3, 5},
gemm_test_params{CASE_GEMM_2IN_FP16_2, 3, 5},
gemm_test_params{CASE_GEMM_2IN_FP16_3, 3, 5},
}));
#endif // ENABLE_ONEDNN_FOR_GPU

View File

@@ -172,7 +172,7 @@ public:
const auto params_hash = prim_inst->get_impl_params()->hash();
ASSERT_EQ(primitive_hash, 16293979194373117693UL);
ASSERT_EQ(params_hash, 12014408712579440062UL);
ASSERT_EQ(params_hash, 3866569467272213453UL);
}
void test_reshape_basic(bool is_caching_test) {