From 2514c0ef38b8a46f515dbe1dd95355bd2a7652ad Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Thu, 16 Dec 2021 15:20:28 +0300 Subject: [PATCH] [GPU] Add gemm_tiled_opt i8/u8 output support (#9202) --- .../gemm/gemm_kernel_tiled_opt.cpp | 21 ++++++++++---- .../core/cl_kernels/gemm_tiled_opt.cl | 24 +++++++-------- .../tests/test_cases/fusings_gpu_test.cpp | 29 +++++++++++++++++++ 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_tiled_opt.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_tiled_opt.cpp index 93df406663c..9f77050b46d 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_tiled_opt.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_tiled_opt.cpp @@ -13,6 +13,8 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const { k.EnableInputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bfyx); k.EnableInputLayout(DataLayout::bfzyx); @@ -21,6 +23,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const { k.EnableOutputLayout(DataLayout::bfwzyx); k.EnableBatching(); + k.EnableDifferentTypes(); return k; } @@ -117,25 +120,29 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons if (tuning_data.tile_k_size > tuning_data.simd_size) { jit.AddConstants({ MakeJitConstant("A_VEC_SIZE", tuning_data.tile_k_size / tuning_data.simd_size), - MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size)), + MakeJitConstant("A_FLOATN", std::string("CAT(INPUT0_TYPE, ") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size) + ")"), }); } else { jit.AddConstants({ MakeJitConstant("A_VEC_SIZE", 1), - MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE")), + MakeJitConstant("A_FLOATN", std::string("INPUT0_TYPE")), }); } if (tuning_data.tile_n_size > tuning_data.simd_size) { jit.AddConstants({ MakeJitConstant("B_VEC_SIZE", b_vec_size), - MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE") + toCodeString(b_vec_size)), + MakeJitConstant("B_FLOATN", std::string("CAT(INPUT1_TYPE, ") + toCodeString(b_vec_size) + ")"), + MakeJitConstant("OUTPUT_TYPE_VEC", std::string("CAT(OUTPUT_TYPE, ") + toCodeString(b_vec_size) + ")"), + MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("CAT(ACCUMULATOR_TYPE, ") + toCodeString(b_vec_size) + ")"), }); } else { b_vec_size = 1; jit.AddConstants({ - MakeJitConstant("B_VEC_SIZE", 1), - MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE")), + MakeJitConstant("B_VEC_SIZE", b_vec_size), + MakeJitConstant("B_FLOATN", std::string("INPUT1_TYPE")), + MakeJitConstant("OUTPUT_TYPE_VEC", std::string("OUTPUT_TYPE")), + MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("ACCUMULATOR_TYPE")), }); } @@ -183,6 +190,10 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o if ((gmm_params.transpose_input0 || gmm_params.transpose_input1) && gemm_leftovers) return false; + for (size_t i = 1; i < gmm_params.inputs.size(); i++) + if (gmm_params.inputs[0].GetDType() != gmm_params.inputs[i].GetDType()) + return false; + return true; } } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_tiled_opt.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_tiled_opt.cl index ae79242b369..cba34cdcf8c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_tiled_opt.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_tiled_opt.cl @@ -3,7 +3,7 @@ // #include "include/batch_headers/fetch_data.cl" -#include "include/unit_type.cl" +#include "include/batch_headers/data_types.cl" #define unroll_for __attribute__((opencl_unroll_hint)) for @@ -14,17 +14,17 @@ #endif // INPUT0_TYPE_SIZE == 4 #if TILE_K > SIMD_WIDTH - #define BLOCK_READ_A(ptr, offset) CAT(UNIT_BLOCK_READ, A_VEC_SIZE)(ptr, offset) + #define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, A_VEC_SIZE, ptr, offset) #else // TILE_K > SIMD_WIDTH - #define BLOCK_READ_A(ptr, offset) UNIT_BLOCK_READ(ptr, offset) + #define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset) #endif // TILE_K > SIMD_WIDTH #if TILE_N > SIMD_WIDTH - #define BLOCK_READ_B(ptr, offset) CAT(UNIT_BLOCK_READ, B_VEC_SIZE)(ptr, offset) - #define BLOCK_WRITE_C(ptr, offset, data) CAT(UNIT_BLOCK_WRITE, B_VEC_SIZE)(ptr, offset, data) + #define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, B_VEC_SIZE, ptr, offset) + #define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, B_VEC_SIZE, ptr, offset, data) #else // TILE_N > SIMD_WIDTH - #define BLOCK_READ_B(ptr, offset) UNIT_BLOCK_READ(ptr, offset) - #define BLOCK_WRITE_C(ptr, offset, data) UNIT_BLOCK_WRITE(ptr, offset, data) + #define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset) + #define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, 1, ptr, offset, data) #endif // TILE_N > SIMD_WIDTH inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) { @@ -294,9 +294,9 @@ KERNEL(gemm_tiled_opt)( #if TILE_N_NOT_DIVISIBLE if (b_raw_global_id < N) { #ifdef INPUT2_TYPE - OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid]; + ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid]; #else // INPUT2_TYPE - OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id]; + ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id]; #endif // INPUT2_TYPE #if HAS_FUSED_OPS @@ -316,9 +316,9 @@ KERNEL(gemm_tiled_opt)( #ifdef INPUT2_TYPE B_FLOATN c_val = BLOCK_READ_B(c_ptr, 0); - B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val; + ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val; #else // INPUT2_TYPE - B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id]; + ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id]; #endif // INPUT2_TYPE #if HAS_FUSED_OPS @@ -327,7 +327,7 @@ KERNEL(gemm_tiled_opt)( #else // FUSED_OPS_CAN_USE_PRELOAD FUSED_OPS_VEC; #endif // FUSED_OPS_CAN_USE_PRELOAD - B_FLOATN res = FUSED_OPS_RESULT_VEC; + OUTPUT_TYPE_VEC res = FUSED_OPS_RESULT_VEC; BLOCK_WRITE_C(d_ptr, 0, res); #else // HAS_FUSED_OPS BLOCK_WRITE_C(d_ptr, 0, dequantized); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 962759bdc7c..35e4fe25e08 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -3264,6 +3264,35 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_u8, //gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 }, })); +class gemm_2in_quantize_float_in : public GemmFusingTest {}; +TEST_P(gemm_2in_quantize_float_in, basic) { + auto p = GetParam(); + create_topologies(input_layout("input0", get_input_layout(p, 0)), + input_layout("input1", get_input_layout(p, 1)), + data("in_lo", get_mem(get_per_channel_layout(p), 0)), + data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), 0)), + data("out_hi", get_mem(get_single_element_layout(p), 255)), + gemm("gemm_prim", { "input0", "input1" }, data_types::f32), + quantize("quantize", "gemm_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8), + reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) + ); + + implementation_desc gemm_impl = { format::bfyx, "gemm_tiled_opt" }; + bo_fused.set_option(build_option::force_implementations({ {"gemm_prim", gemm_impl} })); + + tolerance = 1.0f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_float_in, + ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 }, + gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 4 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 4 }, +})); + class gemm_2in_scale : public GemmFusingTest {}; TEST_P(gemm_2in_scale, basic) { auto p = GetParam();