[GPU] Make gemm_tiled_opt support outer axis (#19210)
This commit is contained in:
parent
cabb40638a
commit
ef6c8c1d66
@ -23,6 +23,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
|
||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
|
||||
k.EnableTensorOffset();
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableDynamicShapesSupport();
|
||||
@ -234,6 +235,13 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
|
||||
return false;
|
||||
|
||||
const auto& gmm_params = static_cast<const gemm_params&>(params);
|
||||
for (auto input : gmm_params.inputs) {
|
||||
// Only supports outer padding as first element offset
|
||||
if (input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Z().pad.Total() != 0 ||
|
||||
input.Feature().pad.Total() != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool gemm_leftovers = gmm_params.inputs[0].X().v % 16 || gmm_params.inputs[0].Y().v % 16 ||
|
||||
gmm_params.inputs[1].X().v % 16 || gmm_params.inputs[1].Y().v % 16;
|
||||
// If gmm_params has dynamic inputs, the correct dimension value cannot be obtained
|
||||
|
Loading…
Reference in New Issue
Block a user