[GPU] Functional fixes for nvidia (#17735)
This commit is contained in:
committed by
GitHub
parent
2dd0b75529
commit
ac26216869
@@ -94,13 +94,6 @@ bool GemmKernelBase::Validate(const Params& p, const optional_params&) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey GemmKernelBase::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
auto k = get_common_subgroups_device_features_key(params, options);
|
||||
k.requires_subgroup_shuffle();
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
Datatype GemmKernelBase::GetActivationType(const gemm_params& params) const {
|
||||
if (params.quantization != QuantizationType::NONE)
|
||||
return Datatype::F32;
|
||||
|
||||
@@ -53,8 +53,6 @@ protected:
|
||||
virtual JitConstants GetFusedPrimitivesJitConstants(const gemm_params& params, const DispatchData& dispatchData) const;
|
||||
Datatype GetActivationType(const gemm_params& params) const;
|
||||
// --Fused ops
|
||||
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
bool Validate(const Params& p, const optional_params&) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -32,6 +32,13 @@ ParamsKey GemmKernelMMADint8::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey GemmKernelMMADint8::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
auto k = get_common_subgroups_device_features_key(params, options);
|
||||
k.requires_subgroup_shuffle();
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
JitConstants GemmKernelMMADint8::GetJitConstants(const gemm_params& params) const {
|
||||
JitConstants jit = Parent::GetJitConstants(params);
|
||||
GemmTuningData td = SetTuningParams(params);
|
||||
|
||||
@@ -41,5 +41,6 @@ protected:
|
||||
GemmTuningData SetTuningParams(const gemm_params& params) const;
|
||||
size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
|
||||
bool HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const;
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -32,6 +32,13 @@ ParamsKey GemmKernelMMADslmInt8::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey GemmKernelMMADslmInt8::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
auto k = get_common_subgroups_device_features_key(params, options);
|
||||
k.requires_subgroup_shuffle();
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
JitConstants GemmKernelMMADslmInt8::GetJitConstants(const gemm_params& params) const {
|
||||
JitConstants jit = Parent::GetJitConstants(params);
|
||||
GemmTuningData td = SetTuningParams(params);
|
||||
|
||||
@@ -43,5 +43,6 @@ protected:
|
||||
GemmTuningData SetTuningParams(const gemm_params& params) const;
|
||||
size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
|
||||
bool HasLeftovers(const GemmTuningData& tuning_data) const;
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -31,6 +31,10 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey GemmKernelRef::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
return DeviceFeaturesKey();
|
||||
}
|
||||
|
||||
JitConstants GemmKernelRef::GetJitConstants(const gemm_params& params) const {
|
||||
JitConstants jit = Parent::GetJitConstants(params);
|
||||
|
||||
|
||||
@@ -25,5 +25,6 @@ protected:
|
||||
}
|
||||
bool Validate(const Params& params, const optional_params& options) const override;
|
||||
JitConstants GetJitConstants(const gemm_params& params) const override;
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -30,6 +30,13 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey GemmKernelTiledOpt::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
auto k = get_common_subgroups_device_features_key(params, options);
|
||||
k.requires_subgroup_shuffle();
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
GemmKernelBase::DispatchData GemmKernelTiledOpt::SetDefault(const gemm_params& params) const {
|
||||
const auto& output = params.outputs[0];
|
||||
|
||||
|
||||
@@ -35,5 +35,6 @@ protected:
|
||||
DispatchData SetDefault(const gemm_params& params) const override;
|
||||
JitConstants GetJitConstants(const gemm_params& params) const override;
|
||||
GemmTuningData SetTuningParams(const gemm_params& params) const;
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
||||
@@ -29,6 +29,15 @@ ParamsKey SoftmaxKernel_bf::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey SoftmaxKernel_bf::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
DeviceFeaturesKey k;
|
||||
k.requires_subgroups();
|
||||
k.requires_subgroup_reduce();
|
||||
k.requires_reqd_subgroup_size();
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
SoftmaxKernel_bf::Parent::DispatchData SoftmaxKernel_bf::SetDefault(const softmax_params& params) const {
|
||||
auto dispatchData = Parent::SetDefault(params);
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ public:
|
||||
protected:
|
||||
DispatchData SetDefault(const softmax_params& params) const override;
|
||||
JitConstants GetJitConstants(const softmax_params& params, DispatchData dispatchData) const override;
|
||||
DeviceFeaturesKey get_required_device_features_key(const Params& params, const optional_params& /*options*/) const override;
|
||||
std::vector<KernelBase::FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE };
|
||||
}
|
||||
|
||||
@@ -166,6 +166,16 @@ device_info init_device_info(const cl::Device& device) {
|
||||
|
||||
info.max_work_group_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
|
||||
|
||||
// For some reason nvidia runtime throws an exception (CL_INVALID_KERNEL_ARGS) for WG as follows:
|
||||
// global: < 1 x 32 x 5184 >
|
||||
// local: < 1 x 1 x 576 >
|
||||
// While local < 1 x 1 x 36 > works fine
|
||||
// So below we limit max WG size by 64 which was selected based on few experiments.
|
||||
constexpr int nvidia_vendor_id = 0x10DE;
|
||||
if (info.vendor_id == nvidia_vendor_id) {
|
||||
info.max_work_group_size = 64;
|
||||
}
|
||||
|
||||
info.max_local_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_LOCAL_MEM_SIZE>());
|
||||
info.max_global_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_GLOBAL_MEM_SIZE>());
|
||||
info.max_alloc_mem_size = static_cast<uint64_t>(device.getInfo<CL_DEVICE_MAX_MEM_ALLOC_SIZE>());
|
||||
|
||||
Reference in New Issue
Block a user