[GPU] Functional fixes for nvidia (#17735)

This commit is contained in:
Vladimir Paramuzov
2023-06-01 09:45:30 +04:00
committed by GitHub
parent 2dd0b75529
commit ac26216869
13 changed files with 49 additions and 9 deletions

View File

@@ -94,13 +94,6 @@ bool GemmKernelBase::Validate(const Params& p, const optional_params&) const {
return true;
}
DeviceFeaturesKey GemmKernelBase::get_required_device_features_key(const Params& params, const optional_params& options) const {
auto k = get_common_subgroups_device_features_key(params, options);
k.requires_subgroup_shuffle();
return k;
}
Datatype GemmKernelBase::GetActivationType(const gemm_params& params) const {
if (params.quantization != QuantizationType::NONE)
return Datatype::F32;

View File

@@ -53,8 +53,6 @@ protected:
virtual JitConstants GetFusedPrimitivesJitConstants(const gemm_params& params, const DispatchData& dispatchData) const;
Datatype GetActivationType(const gemm_params& params) const;
// --Fused ops
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
bool Validate(const Params& p, const optional_params&) const override;
};
} // namespace kernel_selector

View File

@@ -32,6 +32,13 @@ ParamsKey GemmKernelMMADint8::GetSupportedKey() const {
return k;
}
DeviceFeaturesKey GemmKernelMMADint8::get_required_device_features_key(const Params& params, const optional_params& options) const {
auto k = get_common_subgroups_device_features_key(params, options);
k.requires_subgroup_shuffle();
return k;
}
JitConstants GemmKernelMMADint8::GetJitConstants(const gemm_params& params) const {
JitConstants jit = Parent::GetJitConstants(params);
GemmTuningData td = SetTuningParams(params);

View File

@@ -41,5 +41,6 @@ protected:
GemmTuningData SetTuningParams(const gemm_params& params) const;
size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
bool HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const;
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
};
} // namespace kernel_selector

View File

@@ -32,6 +32,13 @@ ParamsKey GemmKernelMMADslmInt8::GetSupportedKey() const {
return k;
}
DeviceFeaturesKey GemmKernelMMADslmInt8::get_required_device_features_key(const Params& params, const optional_params& options) const {
auto k = get_common_subgroups_device_features_key(params, options);
k.requires_subgroup_shuffle();
return k;
}
JitConstants GemmKernelMMADslmInt8::GetJitConstants(const gemm_params& params) const {
JitConstants jit = Parent::GetJitConstants(params);
GemmTuningData td = SetTuningParams(params);

View File

@@ -43,5 +43,6 @@ protected:
GemmTuningData SetTuningParams(const gemm_params& params) const;
size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
bool HasLeftovers(const GemmTuningData& tuning_data) const;
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
};
} // namespace kernel_selector

View File

@@ -31,6 +31,10 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
return k;
}
DeviceFeaturesKey GemmKernelRef::get_required_device_features_key(const Params& params, const optional_params& options) const {
return DeviceFeaturesKey();
}
JitConstants GemmKernelRef::GetJitConstants(const gemm_params& params) const {
JitConstants jit = Parent::GetJitConstants(params);

View File

@@ -25,5 +25,6 @@ protected:
}
bool Validate(const Params& params, const optional_params& options) const override;
JitConstants GetJitConstants(const gemm_params& params) const override;
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
};
} // namespace kernel_selector

View File

@@ -30,6 +30,13 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
return k;
}
DeviceFeaturesKey GemmKernelTiledOpt::get_required_device_features_key(const Params& params, const optional_params& options) const {
auto k = get_common_subgroups_device_features_key(params, options);
k.requires_subgroup_shuffle();
return k;
}
GemmKernelBase::DispatchData GemmKernelTiledOpt::SetDefault(const gemm_params& params) const {
const auto& output = params.outputs[0];

View File

@@ -35,5 +35,6 @@ protected:
DispatchData SetDefault(const gemm_params& params) const override;
JitConstants GetJitConstants(const gemm_params& params) const override;
GemmTuningData SetTuningParams(const gemm_params& params) const;
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
};
} // namespace kernel_selector

View File

@@ -29,6 +29,15 @@ ParamsKey SoftmaxKernel_bf::GetSupportedKey() const {
return k;
}
DeviceFeaturesKey SoftmaxKernel_bf::get_required_device_features_key(const Params& params, const optional_params& options) const {
DeviceFeaturesKey k;
k.requires_subgroups();
k.requires_subgroup_reduce();
k.requires_reqd_subgroup_size();
return k;
}
SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softmax_params& params) const {
auto dispatchData = Parent::SetDefault(params);

View File

@@ -20,6 +20,7 @@ public:
protected:
DispatchData SetDefault(const softmax_params& params) const override;
JitConstants GetJitConstants(const softmax_params& params, DispatchData dispatchData) const override;
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
std::vector<KernelBase::FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE };
}

View File

@@ -166,6 +166,16 @@ device_info init_device_info(const cl::Device& device) {
info.max_work_group_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
// For some reason nvidia runtime throws an exception (CL_INVALID_KERNEL_ARGS) for WG as follows:
// global: < 1 x 32 x 5184 >
// local: < 1 x 1 x 576 >
// While local < 1 x 1 x 36 > works fine
// So below we limit max WG size by 64 which was selected based on few experiments.
constexpr int nvidia_vendor_id = 0x10DE;
if (info.vendor_id == nvidia_vendor_id) {
info.max_work_group_size = 64;
}
info.max_local_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_LOCAL_MEM_SIZE>());
info.max_global_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_GLOBAL_MEM_SIZE>());
info.max_alloc_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_MAX_MEM_ALLOC_SIZE>());