[GPU] Add gemm_tiled_opt i8/u8 output support (#9202)

This commit is contained in:
Sergey Shlyapnikov
2021-12-16 15:20:28 +03:00
committed by GitHub
parent d10e8005c0
commit 2514c0ef38
3 changed files with 57 additions and 17 deletions

View File

@@ -13,6 +13,8 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
@@ -21,6 +23,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}
@@ -117,25 +120,29 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
if (tuning_data.tile_k_size > tuning_data.simd_size) {
jit.AddConstants({
MakeJitConstant("A_VEC_SIZE", tuning_data.tile_k_size / tuning_data.simd_size),
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size)),
MakeJitConstant("A_FLOATN", std::string("CAT(INPUT0_TYPE, ") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size) + ")"),
});
} else {
jit.AddConstants({
MakeJitConstant("A_VEC_SIZE", 1),
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE")),
MakeJitConstant("A_FLOATN", std::string("INPUT0_TYPE")),
});
}
if (tuning_data.tile_n_size > tuning_data.simd_size) {
jit.AddConstants({
MakeJitConstant("B_VEC_SIZE", b_vec_size),
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE") + toCodeString(b_vec_size)),
MakeJitConstant("B_FLOATN", std::string("CAT(INPUT1_TYPE, ") + toCodeString(b_vec_size) + ")"),
MakeJitConstant("OUTPUT_TYPE_VEC", std::string("CAT(OUTPUT_TYPE, ") + toCodeString(b_vec_size) + ")"),
MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("CAT(ACCUMULATOR_TYPE, ") + toCodeString(b_vec_size) + ")"),
});
} else {
b_vec_size = 1;
jit.AddConstants({
MakeJitConstant("B_VEC_SIZE", 1),
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE")),
MakeJitConstant("B_VEC_SIZE", b_vec_size),
MakeJitConstant("B_FLOATN", std::string("INPUT1_TYPE")),
MakeJitConstant("OUTPUT_TYPE_VEC", std::string("OUTPUT_TYPE")),
MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("ACCUMULATOR_TYPE")),
});
}
@@ -183,6 +190,10 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
if ((gmm_params.transpose_input0 || gmm_params.transpose_input1) && gemm_leftovers)
return false;
for (size_t i = 1; i < gmm_params.inputs.size(); i++)
if (gmm_params.inputs[0].GetDType() != gmm_params.inputs[i].GetDType())
return false;
return true;
}
} // namespace kernel_selector

View File

@@ -3,7 +3,7 @@
//
#include "include/batch_headers/fetch_data.cl"
#include "include/unit_type.cl"
#include "include/batch_headers/data_types.cl"
#define unroll_for __attribute__((opencl_unroll_hint)) for
@@ -14,17 +14,17 @@
#endif // INPUT0_TYPE_SIZE == 4
#if TILE_K > SIMD_WIDTH
#define BLOCK_READ_A(ptr, offset) CAT(UNIT_BLOCK_READ, A_VEC_SIZE)(ptr, offset)
#define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, A_VEC_SIZE, ptr, offset)
#else // TILE_K > SIMD_WIDTH
#define BLOCK_READ_A(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
#define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
#endif // TILE_K > SIMD_WIDTH
#if TILE_N > SIMD_WIDTH
#define BLOCK_READ_B(ptr, offset) CAT(UNIT_BLOCK_READ, B_VEC_SIZE)(ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) CAT(UNIT_BLOCK_WRITE, B_VEC_SIZE)(ptr, offset, data)
#define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, B_VEC_SIZE, ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, B_VEC_SIZE, ptr, offset, data)
#else // TILE_N > SIMD_WIDTH
#define BLOCK_READ_B(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) UNIT_BLOCK_WRITE(ptr, offset, data)
#define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, 1, ptr, offset, data)
#endif // TILE_N > SIMD_WIDTH
inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) {
@@ -294,9 +294,9 @@ KERNEL(gemm_tiled_opt)(
#if TILE_N_NOT_DIVISIBLE
if (b_raw_global_id < N) {
#ifdef INPUT2_TYPE
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
#else // INPUT2_TYPE
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
#endif // INPUT2_TYPE
#if HAS_FUSED_OPS
@@ -316,9 +316,9 @@ KERNEL(gemm_tiled_opt)(
#ifdef INPUT2_TYPE
B_FLOATN c_val = BLOCK_READ_B(c_ptr, 0);
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val;
ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val;
#else // INPUT2_TYPE
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
#endif // INPUT2_TYPE
#if HAS_FUSED_OPS
@@ -327,7 +327,7 @@ KERNEL(gemm_tiled_opt)(
#else // FUSED_OPS_CAN_USE_PRELOAD
FUSED_OPS_VEC;
#endif // FUSED_OPS_CAN_USE_PRELOAD
B_FLOATN res = FUSED_OPS_RESULT_VEC;
OUTPUT_TYPE_VEC res = FUSED_OPS_RESULT_VEC;
BLOCK_WRITE_C(d_ptr, 0, res);
#else // HAS_FUSED_OPS
BLOCK_WRITE_C(d_ptr, 0, dequantized);

View File

@@ -3264,6 +3264,35 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_u8,
//gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
}));
class gemm_2in_quantize_float_in : public GemmFusingTest {};
TEST_P(gemm_2in_quantize_float_in, basic) {
auto p = GetParam();
create_topologies(input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("in_lo", get_mem(get_per_channel_layout(p), 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), 0)),
data("out_hi", get_mem(get_single_element_layout(p), 255)),
gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
quantize("quantize", "gemm_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
implementation_desc gemm_impl = { format::bfyx, "gemm_tiled_opt" };
bo_fused.set_option(build_option::force_implementations({ {"gemm_prim", gemm_impl} }));
tolerance = 1.0f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_float_in,
::testing::ValuesIn(std::vector<gemm_test_params>{
gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 },
gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 4 },
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 4 },
}));
class gemm_2in_scale : public GemmFusingTest {};
TEST_P(gemm_2in_scale, basic) {
auto p = GetParam();