[GPU][DG2] Fix gemm unit tests (#14720)
This commit is contained in:
parent
c593b34654
commit
1c99fa265b
@ -3825,6 +3825,8 @@ public:
|
||||
class onednn_binary_add_full_tensor : public EltwiseSumFusingTestOneDNN {};
|
||||
TEST_P(onednn_binary_add_full_tensor, basic) {
|
||||
auto p = GetParam();
|
||||
if (engine.get_device_info().supports_immad)
|
||||
p.expected_fused_primitives = p.expected_fused_primitives_onednn;
|
||||
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
@ -3853,15 +3855,17 @@ TEST_P(onednn_binary_add_full_tensor, basic) {
|
||||
#define CASE_CONV_ELTW_SUM_SUM_DIFF_DTYPE_1 { 1, 32, 4, 4 }, { 1, 16, 4, 4 }, { 1, 1, 3, 3 }, { 1, 1 }, { 1, 1 }, { 1, 1 }, 1, data_types::u8, format::b_fs_yx_fsv32, data_types::i8, format::bfyx, data_types::i8, format::b_fs_yx_fsv32, data_types::u8, format::b_fs_yx_fsv32, data_types::f32, format::bfyx
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(eltwise_sum_fusings_gpu, onednn_binary_add_full_tensor, ::testing::ValuesIn(std::vector<convolution_eltw_sum_test_params>{
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_BINARY_ADD_1, 2, 3, 5 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_1, 2, 3, 5 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_DIFF_DTYPE_1, 2, 3, 5 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_BINARY_ADD_1, 2, 4, 5 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_1, 2, 4, 5 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_DIFF_DTYPE_1, 2, 4, 5 },
|
||||
}));
|
||||
|
||||
|
||||
class onednn_multiple_binary_add_full_tensor : public EltwiseSumFusingTestOneDNN {};
|
||||
TEST_P(onednn_multiple_binary_add_full_tensor, basic) {
|
||||
auto p = GetParam();
|
||||
if (engine.get_device_info().supports_immad)
|
||||
p.expected_fused_primitives = p.expected_fused_primitives_onednn;
|
||||
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
@ -3889,8 +3893,8 @@ TEST_P(onednn_multiple_binary_add_full_tensor, basic) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(multiple_eltwise_sum_fusings_gpu, onednn_multiple_binary_add_full_tensor, ::testing::ValuesIn(std::vector<convolution_eltw_sum_test_params>{
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_BINARY_ADD_1, 2, 3, 7 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_1, 2, 3, 7 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_BINARY_ADD_1, 2, 4, 7 },
|
||||
convolution_eltw_sum_test_params{ CASE_CONV_ELTW_SUM_SUM_1, 2, 4, 7 },
|
||||
}));
|
||||
|
||||
struct implicit_crop_concat_convolution_test_params {
|
||||
|
@ -139,7 +139,7 @@ TEST_P(gemm_3in_quantize_i8, basic) {
|
||||
reorder("reorder_bfyx", input_info("quantize"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1.0f;
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -173,7 +173,7 @@ TEST_P(gemm_2in_quantize_u8, basic) {
|
||||
reorder("reorder_bfyx", input_info("quantize"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1.0f;
|
||||
tolerance = default_tolerance(data_types::u8);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -210,7 +210,7 @@ TEST_P(gemm_2in_quantize_float_in, basic) {
|
||||
implementation_desc gemm_impl = { format::bfyx, "gemm_tiled_opt" };
|
||||
bo_fused.set_option(build_option::force_implementations({ { "gemm_prim", gemm_impl } }));
|
||||
|
||||
tolerance = 1.0f;
|
||||
tolerance = default_tolerance(data_types::u8);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -239,7 +239,7 @@ TEST_P(gemm_2in_scale, basic) {
|
||||
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -254,7 +254,7 @@ TEST_P(gemm_2in_scale, fp16_scale_out) {
|
||||
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -291,7 +291,7 @@ TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
|
||||
reorder("reorder_bfyx", input_info("quantize"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1.0f;
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -329,7 +329,7 @@ TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
|
||||
reorder("reorder_bfyx", input_info("sum"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1.0f;
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -358,7 +358,7 @@ TEST_P(gemm_2in_act_scale_eltwise, basic) {
|
||||
if (engine.get_device_info().supports_immad && !p.kernel_name.empty())
|
||||
p.expected_fused_primitives += 2;
|
||||
|
||||
tolerance = 1e-4f;
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -379,7 +379,7 @@ TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
|
||||
if (engine.get_device_info().supports_immad && !p.kernel_name.empty())
|
||||
p.expected_fused_primitives += 2;
|
||||
|
||||
tolerance = 1e-4f;
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
@ -388,7 +388,7 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_act_scale_eltwise, ::testing::Val
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 6 },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 6 },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 6 },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_2, 3, 3 , "gemm_mmad_int8" },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_2, 3, 3, "gemm_mmad_int8" },
|
||||
// gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_2, 3, 3 , "gemm_mmad_int8_slm" }, // tolerance issue
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_2, 3, 3 , "gemm_tiled_opt" },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_2, 3, 3, "gemm_tiled_opt" },
|
||||
}));
|
||||
|
Loading…
Reference in New Issue
Block a user