[GPU] Try to use softmax_ref when types are mismatched (#19209)
* Remove support key for UINT8 and INT8
This commit is contained in:
parent
f617cc338e
commit
05a24b1776
@ -14,8 +14,6 @@ ParamsKey SoftmaxKernel_bf::GetSupportedKey() const {
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bf);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
|
@ -35,7 +35,37 @@ inline static size_t GetItemClassCount(const DataTensor& input, SoftmaxDim dim)
|
||||
return item_class_count;
|
||||
}
|
||||
|
||||
ParamsKey SoftmaxKerneItemsClassOptimized::GetSupportedKey() const { return GetDefaultSupportedKey(); }
|
||||
ParamsKey SoftmaxKerneItemsClassOptimized::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableInputLayout(DataLayout::byxf);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::yxfb);
|
||||
k.EnableInputLayout(DataLayout::bf);
|
||||
k.EnableInputLayout(DataLayout::fb);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::f);
|
||||
k.EnableOutputLayout(DataLayout::f);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::byxf);
|
||||
k.EnableOutputLayout(DataLayout::yxfb);
|
||||
k.EnableOutputLayout(DataLayout::bf);
|
||||
k.EnableOutputLayout(DataLayout::fb);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::X);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::Y);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::Z);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::FEATURE);
|
||||
k.EnableSoftmaxDim(SoftmaxDim::BATCH);
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
return k;
|
||||
}
|
||||
|
||||
DeviceFeaturesKey SoftmaxKerneItemsClassOptimized::get_required_device_features_key(const Params& params, const optional_params& options) const {
|
||||
DeviceFeaturesKey k;
|
||||
|
@ -67,9 +67,11 @@ public:
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
#define CASE_SOFTMAX_FP32_1 {1, 15, 4, 5}, data_types::f32, format::bfyx, 1, data_types::f32, format::bfyx
|
||||
#define CASE_SOFTMAX_FP32_2 {1, 15, 4, 5}, data_types::f32, format::bfyx, 3, data_types::f32, format::bfyx
|
||||
#define CASE_SOFTMAX_FP32_3 {1, 15, 1, 1}, data_types::f32, format::bfyx, 1, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_SOFTMAX_FP16_1 {1, 15, 4, 5}, data_types::f16, format::bfyx, 1, data_types::f16, format::bfyx
|
||||
#define CASE_SOFTMAX_FP16_2 {1, 15, 4, 5}, data_types::f16, format::bfyx, 3, data_types::f16, format::bfyx
|
||||
#define CASE_SOFTMAX_FP16_3 {1, 15, 1, 1}, data_types::f16, format::bfyx, 1, data_types::f16, format::bfyx
|
||||
|
||||
class softmax_quantize : public SoftmaxPrimitiveFusingTest {};
|
||||
TEST_P(softmax_quantize, basic) {
|
||||
@ -93,9 +95,11 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, softmax_quantize,
|
||||
::testing::ValuesIn(std::vector<softmax_test_params>{
|
||||
softmax_test_params{ CASE_SOFTMAX_FP32_1, 2, 3 },
|
||||
softmax_test_params{ CASE_SOFTMAX_FP32_2, 3, 3 },
|
||||
softmax_test_params{ CASE_SOFTMAX_FP32_3, 2, 3 },
|
||||
|
||||
softmax_test_params{ CASE_SOFTMAX_FP16_1, 2, 3 },
|
||||
softmax_test_params{ CASE_SOFTMAX_FP16_2, 3, 3 },
|
||||
softmax_test_params{ CASE_SOFTMAX_FP16_3, 2, 3 },
|
||||
}));
|
||||
|
||||
class softmax_quantize_fusing_through : public SoftmaxPrimitiveFusingTest {};
|
||||
|
Loading…
Reference in New Issue
Block a user