[GPU] Try to use softmax_ref when types are mismatched (#19209)

* Remove support key for UINT8 and INT8
This commit is contained in:
Steve Yoo 2023-09-01 08:39:36 +09:00 committed by GitHub
parent f617cc338e
commit 05a24b1776
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 3 deletions

View File

@ -14,8 +14,6 @@ ParamsKey SoftmaxKernel_bf::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::INT8);
k.EnableInputLayout(DataLayout::bfyx); k.EnableInputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bf); k.EnableInputLayout(DataLayout::bf);
k.EnableOutputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bfyx);

View File

@ -35,7 +35,37 @@ inline static size_t GetItemClassCount(const DataTensor& input, SoftmaxDim dim)
return item_class_count; return item_class_count;
} }
ParamsKey SoftmaxKerneItemsClassOptimized::GetSupportedKey() const { return GetDefaultSupportedKey(); } ParamsKey SoftmaxKerneItemsClassOptimized::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableInputLayout(DataLayout::byxf);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::yxfb);
k.EnableInputLayout(DataLayout::bf);
k.EnableInputLayout(DataLayout::fb);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::f);
k.EnableOutputLayout(DataLayout::f);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::byxf);
k.EnableOutputLayout(DataLayout::yxfb);
k.EnableOutputLayout(DataLayout::bf);
k.EnableOutputLayout(DataLayout::fb);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableSoftmaxDim(SoftmaxDim::X);
k.EnableSoftmaxDim(SoftmaxDim::Y);
k.EnableSoftmaxDim(SoftmaxDim::Z);
k.EnableSoftmaxDim(SoftmaxDim::FEATURE);
k.EnableSoftmaxDim(SoftmaxDim::BATCH);
k.EnableDifferentTypes();
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
return k;
}
DeviceFeaturesKey SoftmaxKerneItemsClassOptimized::get_required_device_features_key(const Params& params, const optional_params& options) const { DeviceFeaturesKey SoftmaxKerneItemsClassOptimized::get_required_device_features_key(const Params& params, const optional_params& options) const {
DeviceFeaturesKey k; DeviceFeaturesKey k;

View File

@ -67,9 +67,11 @@ public:
/* ----------------------------------------------------------------------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */
#define CASE_SOFTMAX_FP32_1 {1, 15, 4, 5}, data_types::f32, format::bfyx, 1, data_types::f32, format::bfyx #define CASE_SOFTMAX_FP32_1 {1, 15, 4, 5}, data_types::f32, format::bfyx, 1, data_types::f32, format::bfyx
#define CASE_SOFTMAX_FP32_2 {1, 15, 4, 5}, data_types::f32, format::bfyx, 3, data_types::f32, format::bfyx #define CASE_SOFTMAX_FP32_2 {1, 15, 4, 5}, data_types::f32, format::bfyx, 3, data_types::f32, format::bfyx
#define CASE_SOFTMAX_FP32_3 {1, 15, 1, 1}, data_types::f32, format::bfyx, 1, data_types::f32, format::bfyx
#define CASE_SOFTMAX_FP16_1 {1, 15, 4, 5}, data_types::f16, format::bfyx, 1, data_types::f16, format::bfyx #define CASE_SOFTMAX_FP16_1 {1, 15, 4, 5}, data_types::f16, format::bfyx, 1, data_types::f16, format::bfyx
#define CASE_SOFTMAX_FP16_2 {1, 15, 4, 5}, data_types::f16, format::bfyx, 3, data_types::f16, format::bfyx #define CASE_SOFTMAX_FP16_2 {1, 15, 4, 5}, data_types::f16, format::bfyx, 3, data_types::f16, format::bfyx
#define CASE_SOFTMAX_FP16_3 {1, 15, 1, 1}, data_types::f16, format::bfyx, 1, data_types::f16, format::bfyx
class softmax_quantize : public SoftmaxPrimitiveFusingTest {}; class softmax_quantize : public SoftmaxPrimitiveFusingTest {};
TEST_P(softmax_quantize, basic) { TEST_P(softmax_quantize, basic) {
@ -93,9 +95,11 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, softmax_quantize,
::testing::ValuesIn(std::vector<softmax_test_params>{ ::testing::ValuesIn(std::vector<softmax_test_params>{
softmax_test_params{ CASE_SOFTMAX_FP32_1, 2, 3 }, softmax_test_params{ CASE_SOFTMAX_FP32_1, 2, 3 },
softmax_test_params{ CASE_SOFTMAX_FP32_2, 3, 3 }, softmax_test_params{ CASE_SOFTMAX_FP32_2, 3, 3 },
softmax_test_params{ CASE_SOFTMAX_FP32_3, 2, 3 },
softmax_test_params{ CASE_SOFTMAX_FP16_1, 2, 3 }, softmax_test_params{ CASE_SOFTMAX_FP16_1, 2, 3 },
softmax_test_params{ CASE_SOFTMAX_FP16_2, 3, 3 }, softmax_test_params{ CASE_SOFTMAX_FP16_2, 3, 3 },
softmax_test_params{ CASE_SOFTMAX_FP16_3, 2, 3 },
})); }));
class softmax_quantize_fusing_through : public SoftmaxPrimitiveFusingTest {}; class softmax_quantize_fusing_through : public SoftmaxPrimitiveFusingTest {};