[GPU] gemm batch format support (#12474)
This commit is contained in:
parent
5eef0298d9
commit
3443079a7b
@ -2,15 +2,14 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "gemm_inst.h"
|
|
||||||
|
|
||||||
#include "primitive_base.hpp"
|
|
||||||
#include "impls/implementation_map.hpp"
|
|
||||||
#include "kernel_selector_helper.h"
|
|
||||||
#include "gemm/gemm_kernel_selector.h"
|
|
||||||
#include "gemm/gemm_kernel_base.h"
|
#include "gemm/gemm_kernel_base.h"
|
||||||
|
#include "gemm/gemm_kernel_selector.h"
|
||||||
|
#include "gemm_inst.h"
|
||||||
|
#include "impls/implementation_map.hpp"
|
||||||
#include "intel_gpu/runtime/error_handler.hpp"
|
#include "intel_gpu/runtime/error_handler.hpp"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "kernel_selector_helper.h"
|
||||||
|
#include "primitive_base.hpp"
|
||||||
|
|
||||||
namespace cldnn {
|
namespace cldnn {
|
||||||
namespace ocl {
|
namespace ocl {
|
||||||
@ -142,20 +141,30 @@ public:
|
|||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
attach_gemm_impl::attach_gemm_impl() {
|
attach_gemm_impl::attach_gemm_impl() {
|
||||||
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, {
|
const std::vector<data_types> types{data_types::f16,
|
||||||
std::make_tuple(data_types::f32, format::bfyx),
|
data_types::f32,
|
||||||
std::make_tuple(data_types::f16, format::bfyx),
|
data_types::i8,
|
||||||
std::make_tuple(data_types::i8, format::bfyx),
|
data_types::u8,
|
||||||
std::make_tuple(data_types::u8, format::bfyx),
|
data_types::i32};
|
||||||
std::make_tuple(data_types::f32, format::bfzyx),
|
|
||||||
std::make_tuple(data_types::f16, format::bfzyx),
|
const std::vector<format::type> formats {
|
||||||
std::make_tuple(data_types::i8, format::bfzyx),
|
format::bfyx,
|
||||||
std::make_tuple(data_types::u8, format::bfzyx),
|
format::b_fs_yx_fsv16,
|
||||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
format::b_fs_yx_fsv32,
|
||||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
format::bs_fs_yx_bsv16_fsv16,
|
||||||
std::make_tuple(data_types::i8, format::bfwzyx),
|
format::bs_fs_yx_bsv32_fsv16,
|
||||||
std::make_tuple(data_types::u8, format::bfwzyx),
|
format::bs_fs_yx_bsv32_fsv32,
|
||||||
});
|
|
||||||
|
format::bfzyx,
|
||||||
|
format::bs_fs_zyx_bsv16_fsv32,
|
||||||
|
format::bs_fs_zyx_bsv16_fsv16,
|
||||||
|
format::bs_fs_zyx_bsv32_fsv32,
|
||||||
|
format::bs_fs_zyx_bsv32_fsv16,
|
||||||
|
|
||||||
|
format::bfwzyx,
|
||||||
|
};
|
||||||
|
|
||||||
|
implementation_map<gemm>::add(impl_types::ocl, gemm_impl::create, types, formats);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -1516,6 +1516,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
|||||||
prim.type() != cldnn::generate_proposals::type_id() &&
|
prim.type() != cldnn::generate_proposals::type_id() &&
|
||||||
prim.type() != cldnn::reverse::type_id() &&
|
prim.type() != cldnn::reverse::type_id() &&
|
||||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||||
|
prim.type() != cldnn::gemm::type_id() &&
|
||||||
prim.type() != cldnn::tile::type_id() &&
|
prim.type() != cldnn::tile::type_id() &&
|
||||||
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
||||||
prim.type() != cldnn::gather_tree::type_id() &&
|
prim.type() != cldnn::gather_tree::type_id() &&
|
||||||
@ -1565,6 +1566,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
|||||||
prim.type() != cldnn::generate_proposals::type_id() &&
|
prim.type() != cldnn::generate_proposals::type_id() &&
|
||||||
prim.type() != cldnn::reverse::type_id() &&
|
prim.type() != cldnn::reverse::type_id() &&
|
||||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||||
|
prim.type() != cldnn::gemm::type_id() &&
|
||||||
prim.type() != cldnn::tile::type_id() &&
|
prim.type() != cldnn::tile::type_id() &&
|
||||||
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
||||||
prim.type() != cldnn::gather_tree::type_id() &&
|
prim.type() != cldnn::gather_tree::type_id() &&
|
||||||
|
@ -13,8 +13,16 @@ inline uint FUNC(get_input0_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
|
|||||||
#if INPUT0_SIMPLE
|
#if INPUT0_SIMPLE
|
||||||
return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x);
|
return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x);
|
||||||
#else
|
#else
|
||||||
|
#if INPUT0_DIMS == 4
|
||||||
|
return INPUT0_GET_INDEX_SAFE(b, f, y, x);
|
||||||
|
#elif INPUT0_DIMS == 5
|
||||||
|
return INPUT0_GET_INDEX_SAFE(b, f, z, y, x);
|
||||||
|
#elif INPUT0_DIMS == 6
|
||||||
|
return INPUT0_GET_INDEX_SAFE(b, f, w, z, y, x);
|
||||||
|
#else
|
||||||
# error gemm_ref.cl : Unsupported input 0 format
|
# error gemm_ref.cl : Unsupported input 0 format
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline uint FUNC(get_input0_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
|
inline uint FUNC(get_input0_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
|
||||||
@ -29,8 +37,16 @@ inline uint FUNC(get_input1_index_nt)(uint b, uint f, uint w, uint z, uint y, ui
|
|||||||
#if INPUT1_SIMPLE
|
#if INPUT1_SIMPLE
|
||||||
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x);
|
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x);
|
||||||
#else
|
#else
|
||||||
|
#if INPUT1_DIMS == 4
|
||||||
|
return INPUT1_GET_INDEX_SAFE(b, f, y, x);
|
||||||
|
#elif INPUT1_DIMS == 5
|
||||||
|
return INPUT1_GET_INDEX_SAFE(b, f, z, y, x);
|
||||||
|
#elif INPUT1_DIMS == 6
|
||||||
|
return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x);
|
||||||
|
#else
|
||||||
# error gemm_ref.cl : Unsupported input 1 format
|
# error gemm_ref.cl : Unsupported input 1 format
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline uint FUNC(get_input1_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
|
inline uint FUNC(get_input1_index)(uint b, uint f, uint w, uint z, uint y, uint x) {
|
||||||
@ -46,8 +62,16 @@ inline uint FUNC(get_input2_index)(uint b, uint f, uint w, uint z, uint y, uint
|
|||||||
#if INPUT2_SIMPLE
|
#if INPUT2_SIMPLE
|
||||||
return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x);
|
return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x);
|
||||||
#else
|
#else
|
||||||
|
#if INPUT2_DIMS == 4
|
||||||
|
return INPUT2_GET_INDEX_SAFE(b, f, y, x);
|
||||||
|
#elif INPUT2_DIMS == 5
|
||||||
|
return INPUT2_GET_INDEX_SAFE(b, f, z, y, x);
|
||||||
|
#elif INPUT2_DIMS == 6
|
||||||
|
return INPUT2_GET_INDEX_SAFE(b, f, w, z, y, x);
|
||||||
|
#else
|
||||||
# error gemm_ref.cl : Unsupported input 2 format
|
# error gemm_ref.cl : Unsupported input 2 format
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif // INPUT2_TYPE
|
#endif // INPUT2_TYPE
|
||||||
|
|
||||||
@ -55,8 +79,16 @@ inline uint FUNC(get_output_index)(uint b, uint f, uint w, uint z, uint y, uint
|
|||||||
#if OUTPUT_SIMPLE
|
#if OUTPUT_SIMPLE
|
||||||
return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, y, x);
|
return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, y, x);
|
||||||
#else
|
#else
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
return OUTPUT_GET_INDEX(b, f, y, x);
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
return OUTPUT_GET_INDEX(b, f, z, y, x);
|
||||||
|
#elif OUTPUT_DIMS == 6
|
||||||
|
return OUTPUT_GET_INDEX(b, f, w, z, y, x);
|
||||||
|
#else
|
||||||
# error gemm_ref.cl : Unsupported output format
|
# error gemm_ref.cl : Unsupported output format
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
KERNEL(gemm_ref)(
|
KERNEL(gemm_ref)(
|
||||||
|
@ -12,16 +12,14 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
|
|||||||
k.EnableInputDataType(Datatype::F32);
|
k.EnableInputDataType(Datatype::F32);
|
||||||
k.EnableInputDataType(Datatype::INT8);
|
k.EnableInputDataType(Datatype::INT8);
|
||||||
k.EnableInputDataType(Datatype::UINT8);
|
k.EnableInputDataType(Datatype::UINT8);
|
||||||
|
k.EnableInputDataType(Datatype::INT32);
|
||||||
k.EnableOutputDataType(Datatype::F32);
|
k.EnableOutputDataType(Datatype::F32);
|
||||||
k.EnableOutputDataType(Datatype::F16);
|
k.EnableOutputDataType(Datatype::F16);
|
||||||
k.EnableOutputDataType(Datatype::INT8);
|
k.EnableOutputDataType(Datatype::INT8);
|
||||||
k.EnableOutputDataType(Datatype::UINT8);
|
k.EnableOutputDataType(Datatype::UINT8);
|
||||||
k.EnableInputLayout(DataLayout::bfyx);
|
k.EnableOutputDataType(Datatype::INT32);
|
||||||
k.EnableOutputLayout(DataLayout::bfyx);
|
k.EnableAllInputLayout();
|
||||||
k.EnableInputLayout(DataLayout::bfzyx);
|
k.EnableAllOutputLayout();
|
||||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
|
||||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
|
||||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
|
||||||
|
|
||||||
k.EnableBatching();
|
k.EnableBatching();
|
||||||
k.EnableDifferentTypes();
|
k.EnableDifferentTypes();
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user