[GPU] Make gemm_tiled_opt support outer axis (#19210)

This commit is contained in:
Min, Byungil 2023-08-16 21:43:47 +09:00 committed by GitHub
parent cabb40638a
commit ef6c8c1d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
k.EnableInputLayout(DataLayout::bfwzyx); k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx); k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableBatching(); k.EnableBatching();
k.EnableDifferentTypes(); k.EnableDifferentTypes();
k.EnableDynamicShapesSupport(); k.EnableDynamicShapesSupport();
@ -234,6 +235,13 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
return false; return false;
const auto& gmm_params = static_cast<const gemm_params&>(params); const auto& gmm_params = static_cast<const gemm_params&>(params);
for (auto input : gmm_params.inputs) {
// Only supports outer padding as first element offset
if (input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Z().pad.Total() != 0 ||
input.Feature().pad.Total() != 0)
return false;
}
bool gemm_leftovers = gmm_params.inputs[0].X().v % 16 || gmm_params.inputs[0].Y().v % 16 || bool gemm_leftovers = gmm_params.inputs[0].X().v % 16 || gmm_params.inputs[0].Y().v % 16 ||
gmm_params.inputs[1].X().v % 16 || gmm_params.inputs[1].Y().v % 16; gmm_params.inputs[1].X().v % 16 || gmm_params.inputs[1].Y().v % 16;
// If gmm_params has dynamic inputs, the correct dimension value cannot be obtained // If gmm_params has dynamic inputs, the correct dimension value cannot be obtained