[GPU] Add gemm_tiled_opt i8/u8 output support (#9202)
This commit is contained in:
committed by
GitHub
parent
d10e8005c0
commit
2514c0ef38
@@ -13,6 +13,8 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
@@ -21,6 +23,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
|
||||
return k;
|
||||
}
|
||||
@@ -117,25 +120,29 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
|
||||
if (tuning_data.tile_k_size > tuning_data.simd_size) {
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("A_VEC_SIZE", tuning_data.tile_k_size / tuning_data.simd_size),
|
||||
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size)),
|
||||
MakeJitConstant("A_FLOATN", std::string("CAT(INPUT0_TYPE, ") + toCodeString(tuning_data.tile_k_size / tuning_data.simd_size) + ")"),
|
||||
});
|
||||
} else {
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("A_VEC_SIZE", 1),
|
||||
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE")),
|
||||
MakeJitConstant("A_FLOATN", std::string("INPUT0_TYPE")),
|
||||
});
|
||||
}
|
||||
|
||||
if (tuning_data.tile_n_size > tuning_data.simd_size) {
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("B_VEC_SIZE", b_vec_size),
|
||||
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE") + toCodeString(b_vec_size)),
|
||||
MakeJitConstant("B_FLOATN", std::string("CAT(INPUT1_TYPE, ") + toCodeString(b_vec_size) + ")"),
|
||||
MakeJitConstant("OUTPUT_TYPE_VEC", std::string("CAT(OUTPUT_TYPE, ") + toCodeString(b_vec_size) + ")"),
|
||||
MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("CAT(ACCUMULATOR_TYPE, ") + toCodeString(b_vec_size) + ")"),
|
||||
});
|
||||
} else {
|
||||
b_vec_size = 1;
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("B_VEC_SIZE", 1),
|
||||
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE")),
|
||||
MakeJitConstant("B_VEC_SIZE", b_vec_size),
|
||||
MakeJitConstant("B_FLOATN", std::string("INPUT1_TYPE")),
|
||||
MakeJitConstant("OUTPUT_TYPE_VEC", std::string("OUTPUT_TYPE")),
|
||||
MakeJitConstant("ACCUMULATOR_TYPE_VEC", std::string("ACCUMULATOR_TYPE")),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -183,6 +190,10 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
|
||||
if ((gmm_params.transpose_input0 || gmm_params.transpose_input1) && gemm_leftovers)
|
||||
return false;
|
||||
|
||||
for (size_t i = 1; i < gmm_params.inputs.size(); i++)
|
||||
if (gmm_params.inputs[0].GetDType() != gmm_params.inputs[i].GetDType())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "include/batch_headers/fetch_data.cl"
|
||||
#include "include/unit_type.cl"
|
||||
#include "include/batch_headers/data_types.cl"
|
||||
|
||||
#define unroll_for __attribute__((opencl_unroll_hint)) for
|
||||
|
||||
@@ -14,17 +14,17 @@
|
||||
#endif // INPUT0_TYPE_SIZE == 4
|
||||
|
||||
#if TILE_K > SIMD_WIDTH
|
||||
#define BLOCK_READ_A(ptr, offset) CAT(UNIT_BLOCK_READ, A_VEC_SIZE)(ptr, offset)
|
||||
#define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, A_VEC_SIZE, ptr, offset)
|
||||
#else // TILE_K > SIMD_WIDTH
|
||||
#define BLOCK_READ_A(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
|
||||
#define BLOCK_READ_A(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
|
||||
#endif // TILE_K > SIMD_WIDTH
|
||||
|
||||
#if TILE_N > SIMD_WIDTH
|
||||
#define BLOCK_READ_B(ptr, offset) CAT(UNIT_BLOCK_READ, B_VEC_SIZE)(ptr, offset)
|
||||
#define BLOCK_WRITE_C(ptr, offset, data) CAT(UNIT_BLOCK_WRITE, B_VEC_SIZE)(ptr, offset, data)
|
||||
#define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, B_VEC_SIZE, ptr, offset)
|
||||
#define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, B_VEC_SIZE, ptr, offset, data)
|
||||
#else // TILE_N > SIMD_WIDTH
|
||||
#define BLOCK_READ_B(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
|
||||
#define BLOCK_WRITE_C(ptr, offset, data) UNIT_BLOCK_WRITE(ptr, offset, data)
|
||||
#define BLOCK_READ_B(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset)
|
||||
#define BLOCK_WRITE_C(ptr, offset, data) BLOCK_WRITEN(OUTPUT_TYPE, 1, ptr, offset, data)
|
||||
#endif // TILE_N > SIMD_WIDTH
|
||||
|
||||
inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) {
|
||||
@@ -294,9 +294,9 @@ KERNEL(gemm_tiled_opt)(
|
||||
#if TILE_N_NOT_DIVISIBLE
|
||||
if (b_raw_global_id < N) {
|
||||
#ifdef INPUT2_TYPE
|
||||
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
|
||||
ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
|
||||
#else // INPUT2_TYPE
|
||||
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
|
||||
ACCUMULATOR_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
|
||||
#endif // INPUT2_TYPE
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
@@ -316,9 +316,9 @@ KERNEL(gemm_tiled_opt)(
|
||||
|
||||
#ifdef INPUT2_TYPE
|
||||
B_FLOATN c_val = BLOCK_READ_B(c_ptr, 0);
|
||||
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val;
|
||||
ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val;
|
||||
#else // INPUT2_TYPE
|
||||
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
|
||||
ACCUMULATOR_TYPE_VEC dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
|
||||
#endif // INPUT2_TYPE
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
@@ -327,7 +327,7 @@ KERNEL(gemm_tiled_opt)(
|
||||
#else // FUSED_OPS_CAN_USE_PRELOAD
|
||||
FUSED_OPS_VEC;
|
||||
#endif // FUSED_OPS_CAN_USE_PRELOAD
|
||||
B_FLOATN res = FUSED_OPS_RESULT_VEC;
|
||||
OUTPUT_TYPE_VEC res = FUSED_OPS_RESULT_VEC;
|
||||
BLOCK_WRITE_C(d_ptr, 0, res);
|
||||
#else // HAS_FUSED_OPS
|
||||
BLOCK_WRITE_C(d_ptr, 0, dequantized);
|
||||
|
||||
@@ -3264,6 +3264,35 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_u8,
|
||||
//gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
|
||||
}));
|
||||
|
||||
class gemm_2in_quantize_float_in : public GemmFusingTest {};
|
||||
TEST_P(gemm_2in_quantize_float_in, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
data("in_lo", get_mem(get_per_channel_layout(p), 0)),
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), 0)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 255)),
|
||||
gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
|
||||
quantize("quantize", "gemm_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
|
||||
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
implementation_desc gemm_impl = { format::bfyx, "gemm_tiled_opt" };
|
||||
bo_fused.set_option(build_option::force_implementations({ {"gemm_prim", gemm_impl} }));
|
||||
|
||||
tolerance = 1.0f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_float_in,
|
||||
::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 },
|
||||
gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 4 },
|
||||
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 4 },
|
||||
}));
|
||||
|
||||
class gemm_2in_scale : public GemmFusingTest {};
|
||||
TEST_P(gemm_2in_scale, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
Reference in New Issue
Block a user