[GPU] gemm batch format support (#12474)

This commit is contained in:
OlehKravchyshyn 2022-11-11 02:01:38 +02:00 committed by GitHub
parent 5eef0298d9
commit 3443079a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 586 additions and 3219 deletions

View File

@ -2,15 +2,14 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "gemm_inst.h"
#include "primitive_base.hpp"
#include "impls/implementation_map.hpp"
#include "kernel_selector_helper.h"
#include "gemm/gemm_kernel_selector.h"
#include "gemm/gemm_kernel_base.h" #include "gemm/gemm_kernel_base.h"
#include "gemm/gemm_kernel_selector.h"
#include "gemm_inst.h"
#include "impls/implementation_map.hpp"
#include "intel_gpu/runtime/error_handler.hpp" #include "intel_gpu/runtime/error_handler.hpp"
#include <algorithm> #include <algorithm>
#include "kernel_selector_helper.h"
#include "primitive_base.hpp"
namespace cldnn { namespace cldnn {
namespace ocl { namespace ocl {
@ -142,20 +141,30 @@ public:
namespace detail { namespace detail {
attach_gemm_impl::attach_gemm_impl() { attach_gemm_impl::attach_gemm_impl() {
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, { const std::vector<data_types> types{data_types::f16,
std::make_tuple(data_types::f32, format::bfyx), data_types::f32,
std::make_tuple(data_types::f16, format::bfyx), data_types::i8,
std::make_tuple(data_types::i8, format::bfyx), data_types::u8,
std::make_tuple(data_types::u8, format::bfyx), data_types::i32};
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx), const std::vector<format::type> formats {
std::make_tuple(data_types::i8, format::bfzyx), format::bfyx,
std::make_tuple(data_types::u8, format::bfzyx), format::b_fs_yx_fsv16,
std::make_tuple(data_types::f32, format::bfwzyx), format::b_fs_yx_fsv32,
std::make_tuple(data_types::f16, format::bfwzyx), format::bs_fs_yx_bsv16_fsv16,
std::make_tuple(data_types::i8, format::bfwzyx), format::bs_fs_yx_bsv32_fsv16,
std::make_tuple(data_types::u8, format::bfwzyx), format::bs_fs_yx_bsv32_fsv32,
});
format::bfzyx,
format::bs_fs_zyx_bsv16_fsv32,
format::bs_fs_zyx_bsv16_fsv16,
format::bs_fs_zyx_bsv32_fsv32,
format::bs_fs_zyx_bsv32_fsv16,
format::bfwzyx,
};
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, types, formats);
} }
} // namespace detail } // namespace detail

View File

@ -1516,6 +1516,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::generate_proposals::type_id() && prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() && prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id() && prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::gemm::type_id() &&
prim.type() != cldnn::tile::type_id() && prim.type() != cldnn::tile::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id() && prim.type() != cldnn::scatter_elements_update::type_id() &&
prim.type() != cldnn::gather_tree::type_id() && prim.type() != cldnn::gather_tree::type_id() &&
@ -1565,6 +1566,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::generate_proposals::type_id() && prim.type() != cldnn::generate_proposals::type_id() &&
prim.type() != cldnn::reverse::type_id() && prim.type() != cldnn::reverse::type_id() &&
prim.type() != cldnn::reorg_yolo::type_id() && prim.type() != cldnn::reorg_yolo::type_id() &&
prim.type() != cldnn::gemm::type_id() &&
prim.type() != cldnn::tile::type_id() && prim.type() != cldnn::tile::type_id() &&
prim.type() != cldnn::scatter_elements_update::type_id() && prim.type() != cldnn::scatter_elements_update::type_id() &&
prim.type() != cldnn::gather_tree::type_id() && prim.type() != cldnn::gather_tree::type_id() &&

View File

@ -13,8 +13,16 @@ inline uint FUNC(get_input0_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
#if INPUT0_SIMPLE #if INPUT0_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x); return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x);
#else #else
#if INPUT0_DIMS == 4
return INPUT0_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT0_DIMS == 5
return INPUT0_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT0_DIMS == 6
return INPUT0_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 0 format # error gemm_ref.cl : Unsupported input 0 format
#endif #endif
#endif
} }
inline uint FUNC(get_input0_index)(uint b, uint f, uint w, uint z, uint y, uint x) { inline uint FUNC(get_input0_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
@ -29,8 +37,16 @@ inline uint FUNC(get_input1_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
#if INPUT1_SIMPLE #if INPUT1_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x);
#else #else
#if INPUT1_DIMS == 4
return INPUT1_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT1_DIMS == 5
return INPUT1_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT1_DIMS == 6
return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 1 format # error gemm_ref.cl : Unsupported input 1 format
#endif #endif
#endif
} }
inline uint FUNC(get_input1_index)(uint b, uint f, uint w, uint z, uint y, uint x) { inline uint FUNC(get_input1_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
@ -46,8 +62,16 @@ inline uint FUNC(get_input2_index)(uint b, uint f, uint w, uint z, uint y, uint
#if INPUT2_SIMPLE #if INPUT2_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x); return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x);
#else #else
#if INPUT2_DIMS == 4
return INPUT2_GET_INDEX_SAFE(b, f, y, x);
#elif INPUT2_DIMS == 5
return INPUT2_GET_INDEX_SAFE(b, f, z, y, x);
#elif INPUT2_DIMS == 6
return INPUT2_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported input 2 format # error gemm_ref.cl : Unsupported input 2 format
#endif #endif
#endif
} }
#endif // INPUT2_TYPE #endif // INPUT2_TYPE
@ -55,8 +79,16 @@ inline uint FUNC(get_output_index)(uint b, uint f, uint w, uint z, uint y, uint
#if OUTPUT_SIMPLE #if OUTPUT_SIMPLE
return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, y, x); return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, y, x);
#else #else
#if OUTPUT_DIMS == 4
return OUTPUT_GET_INDEX(b, f, y, x);
#elif OUTPUT_DIMS == 5
return OUTPUT_GET_INDEX(b, f, z, y, x);
#elif OUTPUT_DIMS == 6
return OUTPUT_GET_INDEX(b, f, w, z, y, x);
#else
# error gemm_ref.cl : Unsupported output format # error gemm_ref.cl : Unsupported output format
#endif #endif
#endif
} }
KERNEL(gemm_ref)( KERNEL(gemm_ref)(

View File

@ -12,16 +12,14 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT8); k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8); k.EnableInputDataType(Datatype::UINT8);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::INT8); k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputLayout(DataLayout::bfyx); k.EnableAllInputLayout();
k.EnableInputLayout(DataLayout::bfzyx); k.EnableAllOutputLayout();
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableBatching(); k.EnableBatching();
k.EnableDifferentTypes(); k.EnableDifferentTypes();

File diff suppressed because one or more lines are too long