[IE CLDNN] Gemm fp16/fp32 optimized kernel (#1646)

This commit is contained in:
Ilya Znamenskiy
2020-08-11 09:54:00 +03:00
committed by GitHub
parent 2d2a6dbfd8
commit 6cccbcf28a
5 changed files with 937 additions and 97 deletions

View File

@@ -16,12 +16,14 @@
#include "gemm_kernel_selector.h"
#include "gemm_kernel_ref.h"
#include "gemm_kernel_tiled_opt.h"
#include "gemm_kernel_mmad_int8.h"
#include "gemm_kernel_mmad_int8_slm.h"
namespace kernel_selector {
gemm_kernel_selector::gemm_kernel_selector() {
Attach<GemmKernelRef>();
Attach<GemmKernelTiledOpt>();
Attach<GemmKernelMMADint8>();
Attach<GemmKernelMMADslmInt8>();
}

View File

@@ -0,0 +1,195 @@
/*
// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "gemm_kernel_tiled_opt.h"
#include <iostream>
namespace kernel_selector {
ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableBatching();
return k;
}
GemmKernelBase::DispatchData GemmKernelTiledOpt::SetDefault(const gemm_params& params) const {
const auto& output = params.output;
DispatchData kd;
GemmTuningData td = SetTuningParams(params);
auto total_batches = output.LogicalSize() / (output.X().v * output.Y().v);
std::vector<size_t> global = { output.X().v, output.Y().v, total_batches };
kd.gws0 = Align(global[0], td.tile_n_size) / (td.tile_n_size / td.simd_size);
kd.gws1 = Align(global[1], td.tile_m_size) / td.tile_m_size;
kd.gws2 = global[2];
kd.lws0 = td.simd_size;
kd.lws1 = 1;
kd.lws2 = 1;
return kd;
}
GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gemm_params& params) const {
const auto& output = params.output;
auto m_size = output.Y().v;
auto n_size = output.X().v;
auto k_size = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v;
auto total_batches = output.LogicalSize() / (output.X().v * output.Y().v);
tuning_data.simd_size = 8;
if (n_size >= 8) {
tuning_data.tile_n_size = tuning_data.simd_size;
while (tuning_data.tile_n_size < 64 && n_size / (tuning_data.tile_n_size * 2) >= 1) {
tuning_data.tile_n_size *= 2;
}
}
// tuning_data.tile_k_size must be the same as simd_size when k % tile_k != 0
tuning_data.tile_k_size = tuning_data.simd_size;
tuning_data.tile_m_size = 8;
bool leftovers = m_size % tuning_data.tile_m_size || k_size % tuning_data.tile_k_size || n_size % tuning_data.tile_n_size;
if (leftovers || total_batches > 1 || params.transpose_input0 || params.transpose_input1) {
tuning_data.simd_size = 16;
tuning_data.tile_n_size = tuning_data.simd_size;
tuning_data.tile_k_size = tuning_data.simd_size;
tuning_data.tile_m_size = 16;
}
return tuning_data;
}
JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) const {
JitConstants jit = Parent::GetJitConstants(params);
const auto& output = params.output;
auto m_size = output.Y().v;
auto n_size = output.X().v;
auto k_size = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v;
auto leftover_m = m_size % tuning_data.tile_m_size;
auto leftover_n = n_size % tuning_data.tile_n_size;
auto leftover_k = k_size % tuning_data.tile_k_size;
auto b_vec_size = tuning_data.tile_n_size / tuning_data.simd_size;
jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType(), "ACCUMULATOR"));
jit.AddConstants({
MakeJitConstant("M", m_size),
MakeJitConstant("K", k_size),
MakeJitConstant("N", n_size),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
MakeJitConstant("TILE_N", tuning_data.tile_n_size),
MakeJitConstant("K_FULL_ITERATIONS", k_size / tuning_data.tile_k_size),
MakeJitConstant("TILE_M_NOT_DIVISIBLE", leftover_m != 0),
MakeJitConstant("TILE_K_NOT_DIVISIBLE", leftover_k != 0),
MakeJitConstant("TILE_N_NOT_DIVISIBLE", leftover_n != 0),
MakeJitConstant("TILE_M_LEFTOVER", leftover_m),
MakeJitConstant("TILE_K_LEFTOVER", leftover_k),
MakeJitConstant("TILE_N_LEFTOVER", leftover_n),
});
if (tuning_data.tile_k_size > tuning_data.simd_size) {
jit.AddConstants({
MakeJitConstant("A_VEC_SIZE", tuning_data.tile_k_size / tuning_data.simd_size),
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE") + std::to_string(tuning_data.tile_k_size / tuning_data.simd_size)),
});
}
else {
jit.AddConstants({
MakeJitConstant("A_VEC_SIZE", 1),
MakeJitConstant("A_FLOATN", std::string("UNIT_TYPE")),
});
}
if (tuning_data.tile_n_size > tuning_data.simd_size) {
jit.AddConstants({
MakeJitConstant("B_VEC_SIZE", b_vec_size),
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE") + std::to_string(b_vec_size)),
});
}
else {
b_vec_size = 1;
jit.AddConstants({
MakeJitConstant("B_VEC_SIZE", 1),
MakeJitConstant("B_FLOATN", std::string("UNIT_TYPE")),
});
}
if (!params.fused_ops.empty()) {
auto input_dt = GetActivationType(params);
FusedOpsConfiguration conf_vec = { "_VEC", {"b", "f", "(y + write_id)", "x"},
"dequantized",
input_dt,
b_vec_size,
LoadType::LT_ALIGNED_READ,
BoundaryCheck::ENABLED,
IndexType::TENSOR_COORD,
Tensor::DataChannelName::Y };
FusedOpsConfiguration conf_scalar = { "_SCALAR", {"b", "f", "(y + write_id)", "x"},
"dequantized",
input_dt,
1,
LoadType::LT_ALIGNED_READ,
BoundaryCheck::ENABLED,
IndexType::TENSOR_COORD,
Tensor::DataChannelName::Y };
jit.Merge(MakeFusedOpsJitConstants(params, { conf_vec, conf_scalar }));
}
return jit;
}
KernelsData GemmKernelTiledOpt::GetKernelsData(const Params& params, const optional_params& options) const {
const auto& gmm_params = static_cast<const gemm_params&>(params);
return GetCommonKernelsData(params, options, gmm_params.transpose_input0 || gmm_params.transpose_input1 ? FORCE_PRIORITY_6 : FORCE_PRIORITY_3);
}
bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& options) const {
if (!Parent::Validate(params, options))
return false;
const auto& gmm_params = static_cast<const gemm_params&>(params);
bool gemm_leftovers = gmm_params.inputs[0].X().v % 16 || gmm_params.inputs[0].Y().v % 16 ||
gmm_params.inputs[1].X().v % 16 || gmm_params.inputs[1].Y().v % 16;
if ((gmm_params.transpose_input0 || gmm_params.transpose_input1) && gemm_leftovers)
return false;
return true;
}
} // namespace kernel_selector

View File

@@ -0,0 +1,49 @@
// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gemm_kernel_base.h"
#include <vector>
namespace kernel_selector {
class GemmKernelTiledOpt : public GemmKernelBase {
public:
using Parent = GemmKernelBase;
mutable struct GemmTuningData {
size_t simd_size = 8;
size_t tile_m_size = 1;
size_t tile_k_size = 1;
size_t tile_n_size = 1;
} tuning_data;
GemmKernelTiledOpt() : GemmKernelBase("gemm_tiled_opt") {}
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
FusedOpType::ACTIVATION,
FusedOpType::SCALE,
FusedOpType::ELTWISE };
}
bool Validate(const Params& params, const optional_params& options) const override;
DispatchData SetDefault(const gemm_params& params) const override;
JitConstants GetJitConstants(const gemm_params& params) const override;
GemmTuningData SetTuningParams(const gemm_params& params) const;
};
} // namespace kernel_selector

View File

@@ -0,0 +1,359 @@
// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/common.cl"
#include "include/fetch.cl"
#include "include/unit_type.cl"
#define unroll_for __attribute__((opencl_unroll_hint)) for
#if INPUT0_TYPE_SIZE == 4
#define BLOCK_SHUFFLE intel_sub_group_shuffle
#else // INPUT0_TYPE_SIZE == 4
#define BLOCK_SHUFFLE(data, sg_lid) as_half16(intel_sub_group_shuffle(as_short16(data), sg_lid))
#endif // INPUT0_TYPE_SIZE == 4
#if TILE_K > SIMD_WIDTH
#define BLOCK_READ_A(ptr, offset) CAT(UNIT_BLOCK_READ, A_VEC_SIZE)(ptr, offset)
#else // TILE_K > SIMD_WIDTH
#define BLOCK_READ_A(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
#endif // TILE_K > SIMD_WIDTH
#if TILE_N > SIMD_WIDTH
#define BLOCK_READ_B(ptr, offset) CAT(UNIT_BLOCK_READ, B_VEC_SIZE)(ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) CAT(UNIT_BLOCK_WRITE, B_VEC_SIZE)(ptr, offset, data)
#else // TILE_N > SIMD_WIDTH
#define BLOCK_READ_B(ptr, offset) UNIT_BLOCK_READ(ptr, offset)
#define BLOCK_WRITE_C(ptr, offset, data) UNIT_BLOCK_WRITE(ptr, offset, data)
#endif // TILE_N > SIMD_WIDTH
inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) {
#if INPUT0_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, 0, 0);
#else // INPUT0_SIMPLE
# error gemm_nn_tiled.cl : Unsupported input 0 format
#endif // INPUT0_SIMPLE
}
inline uint FUNC(get_input1_batch_offset)(uint b, uint f, uint w, uint z) {
#if INPUT1_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, 0, 0);
#else // INPUT1_SIMPLE
# error gemm_nn_tiled.cl : Unsupported input 1 format
#endif // INPUT1_SIMPLE
}
#ifdef INPUT2_TYPE
inline uint FUNC(get_input2_batch_offset)(uint b, uint f, uint w, uint z) {
#if INPUT2_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, 0, 0);
#else // INPUT2_SIMPLE
# error gemm_nn_tiled.cl : Unsupported input 2 format
#endif // INPUT2_SIMPLE
}
#endif // INPUT2_TYPE
inline uint FUNC(get_output_batch_offset)(uint b, uint f, uint w, uint z) {
#if OUTPUT_SIMPLE
return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, 0, 0);
#else // OUTPUT_SIMPLE
# error gemm_nn_tiled.cl : Unsupported output format
#endif // OUTPUT_SIMPLE
}
// Optimized gemm kernel for fp16/fp32 inputs
__attribute__((intel_reqd_sub_group_size(SIMD_WIDTH)))
__attribute__((reqd_work_group_size(SIMD_WIDTH, 1, 1)))
KERNEL(gemm_tiled_opt)(
const __global INPUT0_TYPE* input0,
const __global INPUT1_TYPE* input1,
#ifdef INPUT2_TYPE
const __global INPUT2_TYPE* input2,
#endif // INPUT2_TYPE
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif // HAS_FUSED_OPS_DECLS
)
{
const uint tile_n_num = (uint)get_group_id(0);
const uint tile_m_num = (uint)get_group_id(1);
const uint tile_m_size = (uint)get_global_size(1);
const uint tile_m_offset = tile_m_num * TILE_M;
const uint tile_n_offset = tile_n_num * TILE_N;
uint batch_number = (uint)get_global_id(2);
const uint sglid = (uint)get_sub_group_local_id();
// Setting x and y for fusings indexing
#if B_VEC_SIZE == 1
const uint x = (uint)get_global_id(0);
#else // B_VEC_SIZE == 1
const uint x = tile_n_num * SIMD_WIDTH * B_VEC_SIZE;
#endif // B_VEC_SIZE == 1
uint y = tile_m_offset;
#if TILE_M_NOT_DIVISIBLE
const uint tile_m_iterations = tile_m_num == (tile_m_size - 1) ? TILE_M_LEFTOVER : TILE_M;
#else // TILE_M_NOT_DIVISIBLE
const uint tile_m_iterations = TILE_M;
#endif // TILE_M_NOT_DIVISIBLE
const uint z = batch_number % OUTPUT_SIZE_Z;
batch_number /= OUTPUT_SIZE_Z;
const uint w = batch_number % OUTPUT_SIZE_W;
batch_number /= OUTPUT_SIZE_W;
const uint f = batch_number % OUTPUT_FEATURE_NUM;
batch_number /= OUTPUT_FEATURE_NUM;
const uint b = batch_number % OUTPUT_BATCH_NUM;
// Batch offsets
const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z);
const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z);
#ifdef INPUT2_TYPE
const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z);
#endif // INPUT2_TYPE
const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z);
// Start pointers offsets
#if !TRANSPOSE_INPUT0
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0 + tile_m_offset * K;
#else // !TRANSPOSE_INPUT0
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0 + tile_m_offset;
#endif // !TRANSPOSE_INPUT0
#if !TRANSPOSE_INPUT1
const __global INPUT1_TYPE* b_ptr = input1 + batch_offset_input1 + tile_n_offset;
#else // !TRANSPOSE_INPUT1
const __global INPUT1_TYPE* b_ptr = input1 + batch_offset_input1 + tile_n_offset * K;
#endif // !TRANSPOSE_INPUT1
#ifdef INPUT2_TYPE
const __global INPUT2_TYPE* c_ptr = input2 + batch_offset_input2 + tile_m_offset * N + tile_n_offset;
#endif // INPUT2_TYPE
__global OUTPUT_TYPE* d_ptr = output + batch_offset_output + tile_m_offset * N + tile_n_offset;
const uint b_raw_global_id = tile_n_offset + sglid;
#if TRANSPOSE_INPUT0
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile;
#endif // TRANSPOSE_INPUT0
#if !TRANSPOSE_INPUT1
B_FLOATN b_tile[TILE_K];
#else // !TRANSPOSE_INPUT1
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile;
#endif // !TRANSPOSE_INPUT1
B_FLOATN c_tile[TILE_M];
unroll_for (uint i = 0; i < TILE_M; i++) {
c_tile[i] = (B_FLOATN)(ACCUMULATOR_VAL_ZERO);
}
// Full tile calculation
for (uint k = 0; k < K_FULL_ITERATIONS; k++) {
// Loading B tile
unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) {
#if TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
#endif // TILE_N_NOT_DIVISIBLE
#if !TRANSPOSE_INPUT1
b_ptr += N;
#else // !TRANSPOSE_INPUT1
b_ptr += K;
#endif // !TRANSPOSE_INPUT1
} // Loading B tile end
#if TRANSPOSE_INPUT1
b_ptr -= K * SIMD_WIDTH - SIMD_WIDTH;
// B tile shuffling for NT, TT cases
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col0 = BLOCK_SHUFFLE(b_tile, 0);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col1 = BLOCK_SHUFFLE(b_tile, 1);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col2 = BLOCK_SHUFFLE(b_tile, 2);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col3 = BLOCK_SHUFFLE(b_tile, 3);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col4 = BLOCK_SHUFFLE(b_tile, 4);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col5 = BLOCK_SHUFFLE(b_tile, 5);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col6 = BLOCK_SHUFFLE(b_tile, 6);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col7 = BLOCK_SHUFFLE(b_tile, 7);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col8 = BLOCK_SHUFFLE(b_tile, 8);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col9 = BLOCK_SHUFFLE(b_tile, 9);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col10 = BLOCK_SHUFFLE(b_tile, 10);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col11 = BLOCK_SHUFFLE(b_tile, 11);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col12 = BLOCK_SHUFFLE(b_tile, 12);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col13 = BLOCK_SHUFFLE(b_tile, 13);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col14 = BLOCK_SHUFFLE(b_tile, 14);
MAKE_VECTOR_TYPE(INPUT1_TYPE, SIMD_WIDTH) b_tile_col15 = BLOCK_SHUFFLE(b_tile, 15);
b_tile.s0 = b_tile_col0[sglid]; b_tile.s1 = b_tile_col1[sglid];
b_tile.s2 = b_tile_col2[sglid]; b_tile.s3 = b_tile_col3[sglid];
b_tile.s4 = b_tile_col4[sglid]; b_tile.s5 = b_tile_col5[sglid];
b_tile.s6 = b_tile_col6[sglid]; b_tile.s7 = b_tile_col7[sglid];
b_tile.s8 = b_tile_col8[sglid]; b_tile.s9 = b_tile_col9[sglid];
b_tile.sa = b_tile_col10[sglid]; b_tile.sb = b_tile_col11[sglid];
b_tile.sc = b_tile_col12[sglid]; b_tile.sd = b_tile_col13[sglid];
b_tile.se = b_tile_col14[sglid]; b_tile.sf = b_tile_col15[sglid];
#endif // TRANSPOSE_INPUT1
// Loading A tile and tile C calculation
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if !TRANSPOSE_INPUT0
#if TILE_K_NOT_DIVISIBLE
A_FLOATN a_read = a_ptr[dot_id * K + sglid];
#else // TILE_K_NOT_DIVISIBLE
A_FLOATN a_read = BLOCK_READ_A(a_ptr, dot_id * K);
#endif // TILE_K_NOT_DIVISIBLE
unroll_for (uint subtile_k_id = 0; subtile_k_id < TILE_K / SIMD_WIDTH; subtile_k_id++) {
unroll_for (uint simd_local_id = 0; simd_local_id < SIMD_WIDTH; simd_local_id++) {
#if TILE_K > SIMD_WIDTH
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read[subtile_k_id], simd_local_id)),
b_tile[subtile_k_id * SIMD_WIDTH + simd_local_id], c_tile[dot_id]);
#else // TILE_K > SIMD_WIDTH
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
#endif // TILE_K > SIMD_WIDTH
}
}
#else // !TRANSPOSE_INPUT0
a_tile[dot_id] = BLOCK_READ_A(a_ptr, dot_id * M);
#endif // !TRANSPOSE_INPUT0
} // Loading A tile and tile C calculation end
#if !TRANSPOSE_INPUT0
a_ptr += TILE_K;
#else // !TRANSPOSE_INPUT0
a_ptr += TILE_K * M;
// A tile shuffling for TN, TT cases
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col0 = BLOCK_SHUFFLE(a_tile, 0);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col1 = BLOCK_SHUFFLE(a_tile, 1);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col2 = BLOCK_SHUFFLE(a_tile, 2);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col3 = BLOCK_SHUFFLE(a_tile, 3);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col4 = BLOCK_SHUFFLE(a_tile, 4);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col5 = BLOCK_SHUFFLE(a_tile, 5);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col6 = BLOCK_SHUFFLE(a_tile, 6);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col7 = BLOCK_SHUFFLE(a_tile, 7);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col8 = BLOCK_SHUFFLE(a_tile, 8);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col9 = BLOCK_SHUFFLE(a_tile, 9);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col10 = BLOCK_SHUFFLE(a_tile, 10);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col11 = BLOCK_SHUFFLE(a_tile, 11);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col12 = BLOCK_SHUFFLE(a_tile, 12);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col13 = BLOCK_SHUFFLE(a_tile, 13);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col14 = BLOCK_SHUFFLE(a_tile, 14);
MAKE_VECTOR_TYPE(INPUT0_TYPE, SIMD_WIDTH) a_tile_col15 = BLOCK_SHUFFLE(a_tile, 15);
a_tile.s0 = a_tile_col0[sglid]; a_tile.s1 = a_tile_col1[sglid];
a_tile.s2 = a_tile_col2[sglid]; a_tile.s3 = a_tile_col3[sglid];
a_tile.s4 = a_tile_col4[sglid]; a_tile.s5 = a_tile_col5[sglid];
a_tile.s6 = a_tile_col6[sglid]; a_tile.s7 = a_tile_col7[sglid];
a_tile.s8 = a_tile_col8[sglid]; a_tile.s9 = a_tile_col9[sglid];
a_tile.sa = a_tile_col10[sglid]; a_tile.sb = a_tile_col11[sglid];
a_tile.sc = a_tile_col12[sglid]; a_tile.sd = a_tile_col13[sglid];
a_tile.se = a_tile_col14[sglid]; a_tile.sf = a_tile_col15[sglid];
// Tile C calculation for TN, TT cases
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
unroll_for (uint simd_local_id = 0; simd_local_id < SIMD_WIDTH; simd_local_id++) {
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_tile[dot_id], simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
}
} // Tile C calculation for TN, TT cases end
#endif // !TRANSPOSE_INPUT0
} // Full tile calculation end
#if TILE_K_NOT_DIVISIBLE
// Loading leftovers of the matrix B
unroll_for (uint b_load_id = 0; b_load_id < TILE_K_LEFTOVER; b_load_id++) {
#if TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
#endif // TILE_N_NOT_DIVISIBLE
b_ptr += N;
} // Loading leftovers of the matrix B end
// Loading leftovers of the matrix A and and tile C calculation
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
INPUT0_TYPE a_read = a_ptr[dot_id * K + sglid];
unroll_for (uint simd_id = 0; simd_id < TILE_K_LEFTOVER; simd_id++) {
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_id)), b_tile[simd_id], c_tile[dot_id]);
}
} // Loading leftovers of the matrix A and and tile C calculation end
#endif // TILE_K_NOT_DIVISIBLE
#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
#if TILE_N_NOT_DIVISIBLE
FUSED_OPS_PRELOAD_SCALAR;
#else // TILE_N_NOT_DIVISIBLE
FUSED_OPS_PRELOAD_VEC;
#endif // TILE_N_NOT_DIVISIBLE
#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
// Writing result in the global memory
unroll_for (uint write_id = 0; write_id < tile_m_iterations; write_id++) {
#if TILE_N_NOT_DIVISIBLE
if (b_raw_global_id < N) {
#ifdef INPUT2_TYPE
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
#else // INPUT2_TYPE
OUTPUT_TYPE dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
#endif // INPUT2_TYPE
#if HAS_FUSED_OPS
#if FUSED_OPS_CAN_USE_PRELOAD
FUSED_OPS_CALC_SCALAR;
#else // FUSED_OPS_CAN_USE_PRELOAD
FUSED_OPS_SCALAR;
#endif // FUSED_OPS_CAN_USE_PRELOAD
OUTPUT_TYPE res = FUSED_OPS_RESULT_SCALAR;
d_ptr[sglid] = res;
#else // HAS_FUSED_OPS
d_ptr[sglid] = dequantized;
#endif // HAS_FUSED_OPS
}
#else // TILE_N_NOT_DIVISIBLE
#ifdef INPUT2_TYPE
B_FLOATN c_val = BLOCK_READ_B(c_ptr, 0);
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_val;
#else // INPUT2_TYPE
B_FLOATN dequantized = TO_ACCUMULATOR_TYPE(ALPHA) * c_tile[write_id];
#endif // INPUT2_TYPE
#if HAS_FUSED_OPS
#if FUSED_OPS_CAN_USE_PRELOAD
FUSED_OPS_CALC_VEC;
#else // FUSED_OPS_CAN_USE_PRELOAD
FUSED_OPS_VEC;
#endif // FUSED_OPS_CAN_USE_PRELOAD
B_FLOATN res = FUSED_OPS_RESULT_VEC;
BLOCK_WRITE_C(d_ptr, 0, res);
#else // HAS_FUSED_OPS
BLOCK_WRITE_C(d_ptr, 0, dequantized);
#endif // HAS_FUSED_OPS
#endif // TILE_N_NOT_DIVISIBLE
d_ptr += N;
#ifdef INPUT2_TYPE
c_ptr += N;
#endif // INPUT2_TYPE
} // Writing result in the global memory end
}
#undef unroll_for
#undef BLOCK_SHUFFLE
#undef BLOCK_READ_A
#undef BLOCK_READ_B
#undef BLOCK_WRITE_C

View File

@@ -3237,7 +3237,7 @@ TEST(gemm_gpu, basic_smarcink2) {
}
}
struct gemm_int8_test_params {
struct gemm_base_test_params {
size_t m_size;
size_t n_size;
size_t k_size;
@@ -3253,44 +3253,169 @@ struct gemm_int8_test_params {
bool transpose_input1;
float alpha;
float beta;
cldnn::data_types allocate0_type;
cldnn::data_types allocate1_type;
cldnn::data_types allocate2_type;
cldnn::data_types output_type;
std::vector <int> range0;
std::vector <int> range1;
std::vector <int> range2;
std::string kernel_name;
};
#define CASE_GEMM_INT8_NN_TRANSPOSITION 64, 64, 64, 1, 2, 1, 2, 1, 2, 1, 2, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_NT_TRANSPOSITION 32, 64, 32, 2, 1, 2, 1, 2, 1, 2, 1, false, true, 1.7f, 1.3f
#define CASE_GEMM_INT8_TN_TRANSPOSITION 128, 64, 32, 2, 2, 2, 2, 2, 2, 2, 2, true, false, 1.0f, 0.0f
#define CASE_GEMM_INT8_TT_TRANSPOSITION 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.2f, 0.5f
#define CASE_GEMM_INT8_NN_TRANSPOSITION 64, 64, 64, 1, 2, 1, 2, 1, 2, 1, 2, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_NT_TRANSPOSITION 32, 64, 32, 2, 1, 2, 1, 2, 1, 2, 1, false, true, \
1.7f, 1.3f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_TN_TRANSPOSITION 128, 64, 32, 2, 2, 2, 2, 2, 2, 2, 2, true, false, \
1.0f, 0.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_TT_TRANSPOSITION 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.2f, 0.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_BROADCAST_1 32, 32, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_BROADCAST_2 32, 32, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, 1.7f, 1.3f
#define CASE_GEMM_INT8_BROADCAST_3 64, 32, 32, 1, 2, 2, 1, 1, 2, 2, 2, false, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_BROADCAST_4 32, 64, 32, 1, 1, 2, 2, 2, 2, 2, 2, false, false, 1.2f, 0.5f
#define CASE_GEMM_INT8_BROADCAST_1 32, 32, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_BROADCAST_2 32, 32, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, \
1.7f, 1.3f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_BROADCAST_3 64, 32, 32, 1, 2, 2, 1, 1, 2, 2, 2, false, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_BROADCAST_4 32, 64, 32, 1, 1, 2, 2, 2, 2, 2, 2, false, false, \
1.2f, 0.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_1 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_LEFTOVERS_2 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
#define CASE_GEMM_INT8_LEFTOVERS_3 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_LEFTOVERS_4 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
#define CASE_GEMM_INT8_LEFTOVERS_5 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_LEFTOVERS_6 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
#define CASE_GEMM_INT8_LEFTOVERS_7 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_LEFTOVERS_8 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
#define CASE_GEMM_INT8_LEFTOVERS_9 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_LEFTOVERS_10 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
#define CASE_GEMM_INT8_LEFTOVERS_11 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_LEFTOVERS_12 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
#define CASE_GEMM_INT8_LEFTOVERS_1 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_2 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.6f, 1.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_3 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_4 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.7f, 1.3f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_5 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_6 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.6f, 1.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_7 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_8 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.7f, 1.3f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_9 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_10 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.6f, 1.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_11 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_LEFTOVERS_12 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.7f, 1.3f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_COMBO_1 8, 8, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_COMBO_2 16, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, true, 1.7f, 0.0f
#define CASE_GEMM_INT8_COMBO_3 11, 31, 21, 7, 15, 7, 15, 7, 15, 7, 15, true, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_COMBO_4 32, 32, 32, 3, 6, 3, 6, 3, 6, 3, 6, true, true, 1.2f, 4.0f
#define CASE_GEMM_INT8_COMBO_1 8, 8, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_COMBO_2 16, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, true, \
1.7f, 0.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_COMBO_3 11, 31, 21, 7, 15, 7, 15, 7, 15, 7, 15, true, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_COMBO_4 32, 32, 32, 3, 6, 3, 6, 3, 6, 3, 6, true, true, \
1.2f, 4.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_SLM_COMBO_1 64, 64, 64, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
#define CASE_GEMM_INT8_SLM_COMBO_2 384, 384, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, 1.7f, 0.0f
#define CASE_GEMM_INT8_SLM_COMBO_3 128, 128, 64, 2, 3, 2, 3, 2, 3, 2, 3, false, false, 1.0f, 1.5f
#define CASE_GEMM_INT8_SLM_COMBO_4 256, 64, 64, 3, 6, 3, 6, 3, 6, 3, 6, false, false, 1.2f, 4.0f
#define CASE_GEMM_INT8_SLM_COMBO_1 64, 64, 64, 1, 2, 1, 1, 1, 1, 1, 2, false, false, \
1.5f, 2.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_SLM_COMBO_2 384, 384, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, \
1.7f, 0.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_SLM_COMBO_3 128, 128, 64, 2, 3, 2, 3, 2, 3, 2, 3, false, false, \
1.0f, 1.5f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
#define CASE_GEMM_INT8_SLM_COMBO_4 256, 64, 64, 3, 6, 3, 6, 3, 6, 3, 6, false, false, \
1.2f, 4.0f, data_types::i8, data_types::u8, data_types::f32, data_types::f32, {-128, 127, 1}, {0, 255, 1}, {-10, 10, 8}
template <typename T>
class GemmInt8Test : public ::testing::TestWithParam<T> {
#define CASE_GEMM_FP32_TILED_NN_1 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.5f, 2.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_2 64, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.7f, 0.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_3 31, 47, 65, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.0f, 1.5f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_4 65, 31, 47, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.2f, 4.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NT_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.5f, 2.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NT_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.7f, 0.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NT_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.0f, 1.5f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NT_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.2f, 4.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TN_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.5f, 2.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TN_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.7f, 0.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TN_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.0f, 1.5f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TN_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.2f, 4.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TT_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.5f, 2.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TT_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.7f, 0.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TT_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.0f, 1.5f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_TT_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.2f, 4.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_BROADCAST_1 64, 96, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, \
1.5f, 2.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_BROADCAST_2 32, 16, 16, 2, 1, 1, 1, 1, 1, 2, 1, false, false, \
1.7f, 0.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_BROADCAST_3 5, 1, 3, 1, 2, 2, 1, 1, 2, 2, 2, false, false, \
1.0f, 1.5f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP32_TILED_NN_BROADCAST_4 64, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, false, false, \
1.2f, 4.0f, data_types::f32, data_types::f32, data_types::f32, data_types::f32, {-10, 10, 8}, {-10, 10, 8}, {-10, 10, 8}
#define CASE_GEMM_FP16_TILED_NN_1 64, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.5f, 2.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_2 128, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.7f, 0.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_3 131, 17, 15, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.0f, 1.5f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_4 33, 17, 17, 1, 1, 1, 1, 1, 1, 1, 1, false, false, \
1.2f, 4.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NT_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.5f, 2.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NT_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.7f, 0.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NT_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.0f, 1.5f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NT_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, false, true, \
1.2f, 4.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TN_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.5f, 2.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TN_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.7f, 0.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TN_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.0f, 1.5f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TN_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, false, \
1.2f, 4.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TT_1 16, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.5f, 2.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TT_2 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.7f, 0.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TT_3 64, 32, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.0f, 1.5f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_TT_4 16, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, true, true, \
1.2f, 4.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_BROADCAST_1 64, 96, 128, 1, 2, 1, 1, 1, 1, 1, 2, false, false, \
1.5f, 2.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_BROADCAST_2 64, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, \
1.7f, 0.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_BROADCAST_3 1, 2, 3, 1, 2, 2, 1, 1, 2, 2, 2, false, false, \
1.0f, 1.5f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
#define CASE_GEMM_FP16_TILED_NN_BROADCAST_4 8, 8, 8, 1, 1, 2, 2, 2, 2, 2, 2, false, false, \
1.2f, 4.0f, data_types::f16, data_types::f16, data_types::f16, data_types::f16, {-1, 1, 1}, {-1, 1, 1}, {-1, 1, 1}
template <typename gemm_params, typename input0_type, typename input1_type, typename input2_type, typename output_type, typename accumulator_type>
class GemmBaseTest : public ::testing::TestWithParam<gemm_params> {
public:
inline size_t getGemmIndex(size_t x, size_t y, size_t f, size_t b, size_t x_size, size_t y_size, size_t f_num, size_t b_num,
@@ -3298,7 +3423,7 @@ public:
return (x % x_size) * x_pitch + (y % y_size) * y_pitch + (f % f_num) * f_pitch + (b % b_num) * b_pitch;
}
void execute(T& p) {
void execute(gemm_params& p) {
const auto& engine = get_test_engine();
auto y0_size = p.m_size;
@@ -3344,46 +3469,46 @@ public:
}
auto input0_size = tensor((int)p.b0_num, (int)p.f0_num, (int)x0_size, (int)y0_size);
auto input0_data = generate_random_4d<int8_t>(p.b0_num, p.f0_num, x0_size, y0_size, -128, 127, 1);
VVVVF<input0_type> input0_data = generate_random_4d<input0_type>(p.b0_num, p.f0_num, x0_size, y0_size, p.range0[0], p.range0[1], p.range0[2]);
auto input0_data_bfyx = flatten_4d(format::bfyx, input0_data);
auto input0_mem = memory::allocate(engine, { data_types::i8, format::bfyx, input0_size });
auto input0_mem = memory::allocate(engine, { p.allocate0_type, format::bfyx, input0_size });
set_values(input0_mem, input0_data_bfyx);
auto input1_size = tensor((int)p.b1_num, (int)p.f1_num, (int)x1_size, (int)y1_size);
auto input1_data = generate_random_4d<uint8_t>(p.b1_num, p.f1_num, x1_size, y1_size, 0, 255, 1);
VVVVF<input1_type> input1_data = generate_random_4d<input1_type>(p.b1_num, p.f1_num, x1_size, y1_size, p.range1[0], p.range1[1], p.range1[2]);
auto input1_data_bfyx = flatten_4d(format::bfyx, input1_data);
auto input1_mem = memory::allocate(engine, { data_types::u8, format::bfyx, input1_size });
auto input1_mem = memory::allocate(engine, { p.allocate1_type, format::bfyx, input1_size });
set_values(input1_mem, input1_data_bfyx);
auto input2_size = tensor((int)p.b2_num, (int)p.f2_num, (int)x2_size, (int)y2_size);
auto input2_data = generate_random_4d<float>(p.b2_num, p.f2_num, x2_size, y2_size, -10, 10);
VVVVF<input2_type> input2_data = generate_random_4d<input2_type>(p.b2_num, p.f2_num, x2_size, y2_size, p.range2[0], p.range2[1], p.range2[2]);
auto input2_data_bfyx = flatten_4d(format::bfyx, input2_data);
auto input2_mem = memory::allocate(engine, { data_types::f32, format::bfyx, input2_size });
auto input2_mem = memory::allocate(engine, { p.allocate2_type, format::bfyx, input2_size });
set_values(input2_mem, input2_data_bfyx);
std::vector<float> out_data(p.b_out_num * p.f_out_num * p.m_size * p.n_size);
std::vector<output_type> out_data(p.b_out_num * p.f_out_num * p.m_size * p.n_size);
for (size_t b = 0; b < p.b_out_num; ++b) {
for (size_t f = 0; f < p.f_out_num; ++f) {
for (size_t i = 0; i < p.m_size; ++i) {
for (size_t j = 0; j < p.n_size; ++j) {
size_t input2_data_index = getGemmIndex(j, i, f, b, x2_size, y2_size, p.f2_num, p.b2_num, x2_pitch, y2_pitch, f2_pitch, b2_pitch);
size_t out_data_index = getGemmIndex(j, i, f, b, x_out_size, y_out_size, p.f_out_num, p.b_out_num,
x_out_pitch, y_out_pitch, f_out_pitch, b_out_pitch);
int32_t acc = 0;
for (size_t y = 0; y < p.m_size; ++y) {
for (size_t x = 0; x < p.n_size; ++x) {
size_t input2_data_index = getGemmIndex(x, y, f, b, x2_size, y2_size, p.f2_num, p.b2_num, x2_pitch, y2_pitch, f2_pitch, b2_pitch);
size_t out_data_index = getGemmIndex(x, y, f, b, x_out_size, y_out_size, p.f_out_num, p.b_out_num,
x_out_pitch, y_out_pitch, f_out_pitch, b_out_pitch);
accumulator_type acc = 0;
for (size_t k = 0; k < p.k_size; ++k) {
size_t input0_data_index = getGemmIndex(k * (!p.transpose_input0) + i * p.transpose_input0, i * (!p.transpose_input0) +
size_t input0_data_index = getGemmIndex(k * (!p.transpose_input0) + y * p.transpose_input0, y * (!p.transpose_input0) +
k * p.transpose_input0, f, b, x0_size, y0_size, p.f0_num, p.b0_num, x0_pitch, y0_pitch, f0_pitch, b0_pitch);
size_t input1_data_index = getGemmIndex(j * (!p.transpose_input1) + k * p.transpose_input1, k * (!p.transpose_input1) +
j * p.transpose_input1, f, b, x1_size, y1_size, p.f1_num, p.b1_num, x1_pitch, y1_pitch, f1_pitch, b1_pitch);
size_t input1_data_index = getGemmIndex(x * (!p.transpose_input1) + k * p.transpose_input1, k * (!p.transpose_input1) +
x * p.transpose_input1, f, b, x1_size, y1_size, p.f1_num, p.b1_num, x1_pitch, y1_pitch, f1_pitch, b1_pitch);
acc += input0_data_bfyx[input0_data_index] * input1_data_bfyx[input1_data_index];
acc += (accumulator_type)input0_data_bfyx[input0_data_index] * (accumulator_type)input1_data_bfyx[input1_data_index];
}
out_data[out_data_index] = (float)acc;
out_data[out_data_index] *= p.alpha;
out_data[out_data_index] += p.beta * input2_data_bfyx[input2_data_index];
out_data[out_data_index] = (output_type)acc;
out_data[out_data_index] *= (output_type)p.alpha;
out_data[out_data_index] += (output_type)p.beta * (output_type)input2_data_bfyx[input2_data_index];
}
}
}
@@ -3393,82 +3518,192 @@ public:
topology.add(input_layout("input0", input0_mem.get_layout()));
topology.add(input_layout("input1", input1_mem.get_layout()));
topology.add(input_layout("input2", input2_mem.get_layout()));
topology.add(gemm("output", { "input0", "input1", "input2" }, data_types::f32, p.transpose_input0, p.transpose_input1, p.alpha, p.beta));
topology.add(gemm("gemm_bfyx", { "input0", "input1", "input2" }, p.output_type, p.transpose_input0, p.transpose_input1, p.alpha, p.beta));
topology.add(reorder("reorder_bfyx", "gemm_bfyx", format::bfyx, data_types::f32));
build_options options;
implementation_desc gemm_int8_impl = { format::bfyx, p.kernel_name };
options.set_option(build_option::force_implementations({ {"output", gemm_int8_impl} }));
implementation_desc gemm_impl = { format::bfyx, p.kernel_name };
options.set_option(build_option::force_implementations({ {"gemm_bfyx", gemm_impl} }));
network network(engine, topology, options);
network.set_input_data("input0", input0_mem);
network.set_input_data("input1", input1_mem);
network.set_input_data("input2", input2_mem);
auto outputs = network.execute();
auto output = outputs.at("output").get_memory();
auto output = outputs.at("reorder_bfyx").get_memory();
auto output_ptr = output.pointer<float>();
EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
for (size_t i = 0; i < out_data.size(); ++i) {
EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
if (sizeof(input0_type) == 1) {
for (size_t i = 0; i < out_data.size(); ++i) {
EXPECT_FLOAT_EQ(float(output_ptr[i]), float(out_data[i])) << "index = " << i;
}
} else if (sizeof(input0_type) == 2) {
for (size_t i = 0; i < out_data.size(); ++i) {
EXPECT_NEAR(float(output_ptr[i]), float(out_data[i]), 1e-1) << "index = " << i;
}
} else {
for (size_t i = 0; i < out_data.size(); ++i) {
EXPECT_NEAR(float(output_ptr[i]), float(out_data[i]), 1e-4) << "index = " << i;
}
}
}
};
class gemm_int8_transposition_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
class gemm_int8_transposition_tests : public ::GemmBaseTest<gemm_base_test_params, int8_t, uint8_t, float, float, int32_t> {};
TEST_P(gemm_int8_transposition_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_transposition_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
gemm_int8_test_params{ CASE_GEMM_INT8_NN_TRANSPOSITION, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_NT_TRANSPOSITION, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_TN_TRANSPOSITION, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_TT_TRANSPOSITION, "gemm_mmad_int8" },
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_transposition_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_INT8_NN_TRANSPOSITION, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_NT_TRANSPOSITION, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_TN_TRANSPOSITION, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_TT_TRANSPOSITION, "gemm_mmad_int8" },
}), );
class gemm_int8_broadcast_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
class gemm_int8_broadcast_tests : public ::GemmBaseTest<gemm_base_test_params, int8_t, uint8_t, float, float, int32_t> {};
TEST_P(gemm_int8_broadcast_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_1, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_2, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_3, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_4, "gemm_mmad_int8" },
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_INT8_BROADCAST_1, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_BROADCAST_2, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_BROADCAST_3, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_BROADCAST_4, "gemm_mmad_int8" },
}), );
class gemm_int8_leftovers_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
class gemm_int8_leftovers_tests : public ::GemmBaseTest<gemm_base_test_params, int8_t, uint8_t, float, float, int32_t> {};
TEST_P(gemm_int8_leftovers_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_leftovers_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_1, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_2, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_3, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_4, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_5, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_6, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_7, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_8, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_9, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_10, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_11, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_12, "gemm_mmad_int8" },
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_leftovers_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_1, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_2, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_3, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_4, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_5, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_6, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_7, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_8, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_9, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_10, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_11, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_LEFTOVERS_12, "gemm_mmad_int8" },
}), );
class gemm_int8_combo_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
class gemm_int8_combo_tests : public ::GemmBaseTest<gemm_base_test_params, int8_t, uint8_t, float, float, int32_t> {};
TEST_P(gemm_int8_combo_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_combo_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_1, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_2, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_3, "gemm_mmad_int8" },
gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_4, "gemm_mmad_int8" },
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_combo_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_INT8_COMBO_1, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_COMBO_2, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_COMBO_3, "gemm_mmad_int8" },
gemm_base_test_params{ CASE_GEMM_INT8_COMBO_4, "gemm_mmad_int8" },
}), );
class gemm_int8_slm_combo_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
class gemm_int8_slm_combo_tests : public ::GemmBaseTest<gemm_base_test_params, int8_t, uint8_t, float, float, int32_t> {};
TEST_P(gemm_int8_slm_combo_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_slm_combo_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
gemm_int8_test_params{ CASE_GEMM_INT8_SLM_COMBO_1, "gemm_mmad_int8_slm" },
gemm_int8_test_params{ CASE_GEMM_INT8_SLM_COMBO_2, "gemm_mmad_int8_slm" },
gemm_int8_test_params{ CASE_GEMM_INT8_SLM_COMBO_3, "gemm_mmad_int8_slm" },
gemm_int8_test_params{ CASE_GEMM_INT8_SLM_COMBO_4, "gemm_mmad_int8_slm" },
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_slm_combo_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_INT8_SLM_COMBO_1, "gemm_mmad_int8_slm" },
gemm_base_test_params{ CASE_GEMM_INT8_SLM_COMBO_2, "gemm_mmad_int8_slm" },
gemm_base_test_params{ CASE_GEMM_INT8_SLM_COMBO_3, "gemm_mmad_int8_slm" },
gemm_base_test_params{ CASE_GEMM_INT8_SLM_COMBO_4, "gemm_mmad_int8_slm" },
}), );
class gemm_fp32_tiled_nn_tests : public ::GemmBaseTest<gemm_base_test_params, float, float, float, float, float> {};
TEST_P(gemm_fp32_tiled_nn_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp32_tiled_nn_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_4, "gemm_tiled_opt" },
}), );
class gemm_fp32_tiled_nt_tests : public ::GemmBaseTest<gemm_base_test_params, float, float, float, float, float> {};
TEST_P(gemm_fp32_tiled_nt_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp32_tiled_nt_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NT_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NT_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NT_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NT_4, "gemm_tiled_opt" },
}), );
class gemm_fp32_tiled_tn_tests : public ::GemmBaseTest<gemm_base_test_params, float, float, float, float, float> {};
TEST_P(gemm_fp32_tiled_tn_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp32_tiled_tn_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TN_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TN_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TN_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TN_4, "gemm_tiled_opt" },
}), );
class gemm_fp32_tiled_tt_tests : public ::GemmBaseTest<gemm_base_test_params, float, float, float, float, float> {};
TEST_P(gemm_fp32_tiled_tt_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp32_tiled_tt_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TT_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TT_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TT_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_TT_4, "gemm_tiled_opt" },
}), );
class gemm_fp32_tiled_nn_broadcast_tests : public ::GemmBaseTest<gemm_base_test_params, float, float, float, float, float> {};
TEST_P(gemm_fp32_tiled_nn_broadcast_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp32_tiled_nn_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_BROADCAST_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_BROADCAST_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_BROADCAST_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP32_TILED_NN_BROADCAST_4, "gemm_tiled_opt" },
}), );
class gemm_fp16_tiled_nn_tests : public ::GemmBaseTest<gemm_base_test_params, FLOAT16, FLOAT16, FLOAT16, FLOAT16, FLOAT16> {};
TEST_P(gemm_fp16_tiled_nn_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp16_tiled_nn_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_4, "gemm_tiled_opt" },
}), );
class gemm_fp16_tiled_nt_tests : public ::GemmBaseTest<gemm_base_test_params, FLOAT16, FLOAT16, FLOAT16, FLOAT16, FLOAT16> {};
TEST_P(gemm_fp16_tiled_nt_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp16_tiled_nt_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NT_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NT_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NT_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NT_4, "gemm_tiled_opt" },
}), );
class gemm_fp16_tiled_tn_tests : public ::GemmBaseTest<gemm_base_test_params, FLOAT16, FLOAT16, FLOAT16, FLOAT16, FLOAT16> {};
TEST_P(gemm_fp16_tiled_tn_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp16_tiled_tn_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TN_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TN_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TN_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TN_4, "gemm_tiled_opt" },
}), );
class gemm_fp16_tiled_tt_tests : public ::GemmBaseTest<gemm_base_test_params, FLOAT16, FLOAT16, FLOAT16, FLOAT16, FLOAT16> {};
TEST_P(gemm_fp16_tiled_tt_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp16_tiled_tt_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TT_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TT_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TT_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_TT_4, "gemm_tiled_opt" },
}), );
class gemm_fp16_tiled_nn_broadcast_tests : public ::GemmBaseTest<gemm_base_test_params, FLOAT16, FLOAT16, FLOAT16, FLOAT16, FLOAT16> {};
TEST_P(gemm_fp16_tiled_nn_broadcast_tests, basic) { auto p = GetParam(); execute(p); }
INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_fp16_tiled_nn_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_base_test_params> {
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_BROADCAST_1, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_BROADCAST_2, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_BROADCAST_3, "gemm_tiled_opt" },
gemm_base_test_params{ CASE_GEMM_FP16_TILED_NN_BROADCAST_4, "gemm_tiled_opt" },
}), );