[GPU] Fix levit-128s accuracy issue (#17136)

* [GPU] Fix levit-128s accuracy issue

Wrong batch dims for fused eltwise of gemm.
-> The issue is getting incorrect batch size of fused eltwise used by gemm.
     Its rank is different from src tensor. Eltwise tensor rank was reduced by mistake.
     It is only reproduce in batch 1 and full tensor. 
     The batch size in here means all of non spatial dims, but previous implementation was default batch dim role.

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim 2023-04-24 18:16:00 +09:00 committed by GitHub
parent 6ff0cad127
commit 63f5c2f0e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 24 deletions

View File

@ -75,7 +75,7 @@ dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batche
res.erase(res.begin(), res.begin() + dims - 4);
}
if (res.size() == 4 && batched_dims_can_be_removed) {
res.erase(res.begin());
res.erase(res.begin(), res.begin() + 2);
}
return res;
}

View File

@ -991,7 +991,8 @@ void program_node::init_onednn_primitive_attributes() {
mem_desc.get_dims(), mem_desc.get_data_type());
} else if (is_type<gemm>()) {
size_t rank = cldnn::format::dimension(in.format);
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in.batch() == 1);
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);
dnnl::memory::data_type dt = onednn::convert_data_type(in.data_type);
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims);
post_ops.append_binary(alg, dnnl::memory::desc(dims, dt, fmt));

View File

@ -19,6 +19,11 @@ using namespace ::details;
using namespace ::tests;
namespace {
enum class broadcast_kinds {
none,
batch,
feature
};
struct gemm_test_params {
std::vector<ov::PartialShape> in_shapes;
ov::PartialShape out_shape;
@ -31,7 +36,7 @@ struct gemm_test_params {
size_t expected_fused_primitives;
size_t expected_not_fused_primitives;
std::string kernel_name;
dim_vec_kind broadcast_kind;
broadcast_kinds broadcast_kind;
eltwise_mode eltwise_m;
};
@ -282,9 +287,8 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_scale, ::testing::ValuesIn(std::v
gemm_test_params{ CASE_GEMM_2IN_U8U8_3, 3, 4 },
}));
class gemm_2in_add : public GemmFusingTest {};
TEST_P(gemm_2in_add, eltwise_postop) {
TEST_P(gemm_2in_add, eltwise_postop_static) {
auto p = GetParam();
if (engine.get_device_info().supports_immad) {
@ -294,9 +298,9 @@ TEST_P(gemm_2in_add, eltwise_postop) {
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
if (p.broadcast_kind == broadcast_kinds::batch)
add_data_size[0] = 1;
else
else if (p.broadcast_kind == broadcast_kinds::feature)
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
@ -306,7 +310,7 @@ TEST_P(gemm_2in_add, eltwise_postop) {
create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 0.5f)),
data("add_data", get_mem(add_data_layout, 0.5f)), // TODO: Meanless setting, iGPU failed in CASE_GEMM_2IN_FP16_5D_1 with get_mem(add_data_layout, 0, 10)
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
@ -327,9 +331,9 @@ TEST_P(gemm_2in_add, eltwise_postop_dynamic) {
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
if (p.broadcast_kind == broadcast_kinds::batch)
add_data_size[0] = 1;
else
else if (p.broadcast_kind == broadcast_kinds::feature)
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
@ -362,9 +366,9 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
if (p.broadcast_kind == broadcast_kinds::batch)
add_data_size[0] = 1;
else
else if (p.broadcast_kind == broadcast_kinds::feature)
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
@ -385,18 +389,20 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
// gemm_test_params{ CASE_GEMM_2IN_FP16_3, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum }, // TODO: check why failed in eltwise_postop_dynamic
gemm_test_params{ CASE_GEMM_2IN_FP16_4, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", broadcast_kinds::feature, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", broadcast_kinds::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", broadcast_kinds::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", broadcast_kinds::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", broadcast_kinds::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", broadcast_kinds::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", broadcast_kinds::feature, eltwise_mode::sub },
}));
class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};