[GPU] gemm batch format support (#12474)

This commit is contained in:
OlehKravchyshyn 2022-11-11 02:01:38 +02:00 committed by GitHub
parent 5eef0298d9
commit 3443079a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 586 additions and 3219 deletions

View File

@ -2,15 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gemm_inst.h"
#include "primitive_base.hpp"
#include "impls/implementation_map.hpp"
#include "kernel_selector_helper.h"
#include "gemm/gemm_kernel_selector.h"
#include "gemm/gemm_kernel_base.h"
#include "gemm/gemm_kernel_selector.h"
#include "gemm_inst.h"
#include "impls/implementation_map.hpp"
#include "intel_gpu/runtime/error_handler.hpp"
#include <algorithm>
#include "kernel_selector_helper.h"
#include "primitive_base.hpp"
namespace cldnn {
namespace ocl {
@ -142,20 +141,30 @@ public:
namespace detail {
attach_gemm_impl::attach_gemm_impl() {
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
});
const std::vector<data_types> types{data_types::f16,
data_types::f32,
data_types::i8,
data_types::u8,
data_types::i32};
const std::vector<format::type> formats {
format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bfzyx,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_zyx_bsv32_fsv16,
format::bfwzyx,
};
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, types, formats);
}
} // namespace detail

View File

@ -1516,6 +1516,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::gemm::type_id() &&
prim.type() != cldnn::tile::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id() &&
prim.type() != cldnn::gather_tree::type_id() &&
@ -1565,6 +1566,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::gemm::type_id() &&
prim.type() != cldnn::tile::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id() &&
prim.type() != cldnn::gather_tree::type_id() &&

View File

@ -13,8 +13,16 @@ inline uint FUNC(get_input0_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
#if INPUT0_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x);
#else
#if INPUT0_DIMS == 4
return INPUT0_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT0_DIMS == 5
return INPUT0_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT0_DIMS == 6
return INPUT0_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 0 format
#endif
#endif
}
inline uint FUNC(get_input0_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
@ -29,8 +37,16 @@ inline uint FUNC(get_input1_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
#if INPUT1_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x);
#else
#if INPUT1_DIMS == 4
return INPUT1_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT1_DIMS == 5
return INPUT1_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT1_DIMS == 6
return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 1 format
#endif
#endif
}
inline uint FUNC(get_input1_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
@ -46,8 +62,16 @@ inline uint FUNC(get_input2_index)(uint b, uint f, uint w, uint z, uint y, uint
#if INPUT2_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x);
#else
#if INPUT2_DIMS == 4
return INPUT2_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT2_DIMS == 5
return INPUT2_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT2_DIMS == 6
return INPUT2_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 2 format
#endif
#endif
}
#endif // INPUT2_TYPE
@ -55,8 +79,16 @@ inline uint FUNC(get_output_index)(uint b, uint f, uint w, uint z, uint y, uint
#if OUTPUT_SIMPLE
return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, y, x);
#else
#if OUTPUT_DIMS == 4
return OUTPUT_GET_INDEX(b, f, y, x);
#elif OUTPUT_DIMS == 5
return OUTPUT_GET_INDEX(b, f, z, y, x);
#elif OUTPUT_DIMS == 6
return OUTPUT_GET_INDEX(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported output format
#endif
#endif
}
KERNEL(gemm_ref)(

View File

@ -12,16 +12,14 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableOutputDataType(Datatype::INT32);
k.EnableAllInputLayout();
k.EnableAllOutputLayout();
k.EnableBatching();
k.EnableDifferentTypes();

File diff suppressed because one or more lines are too long