[GPU] Allow softmax_bf kernel for axis=X 4d case (#20699)

This commit is contained in:
Roman Lyamin 2023-10-30 09:11:32 +04:00 committed by GitHub
parent 53c9a0f3d4
commit 5c6b7a5ed4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 3 deletions

View File

@ -92,9 +92,9 @@ bool SoftmaxKernelBaseBF::Validate(const Params& p, const optional_params& o) co
switch (params.dim) {
case SoftmaxDim::X:
return !input.Y().is_dynamic && input.Y().v == 1 &&
return ((!input.Y().is_dynamic && input.Y().v == 1) || input.GetLayout() == DataLayout::bfyx) &&
!input.Z().is_dynamic && input.Z().v == 1 &&
!input.Feature().is_dynamic && input.Feature().v == 1;
((!input.Feature().is_dynamic && input.Feature().v == 1) || input.GetLayout() == DataLayout::bfyx);
case SoftmaxDim::Y:
return !input.X().is_dynamic && input.X().v == 1 &&
!input.Z().is_dynamic && input.Z().v == 1 &&
@ -122,6 +122,10 @@ SoftmaxKernelBase::DispatchData SoftmaxKernelBaseBF::SetDefault(const softmax_pa
OPENVINO_ASSERT(input.X().v == 1, "[GPU] SoftmaxKernelBaseBF: input.X() is expected to be 1 while actual value is ", input.X().v);
dispatchData.dataSetSize = input.Y().v;
dispatchData.dataSetsCount = input.Batch().v * input.Feature().v;
} else if (params.dim == SoftmaxDim::X && (input.Feature().v > 1 || input.Y().v > 1) && input.GetLayout() == DataLayout::bfyx) {
// Flatten BFY for such case
dispatchData.dataSetSize = input.X().v;
dispatchData.dataSetsCount = input.Batch().v * input.Feature().v * input.Y().v;
} else {
auto flatten_input = input.FlattenFeatureAndSpatials();
dispatchData.dataSetSize = flatten_input.Feature().v;

View File

@ -88,13 +88,18 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(ov::AnyMap())),
SoftMax8LayerTest::getTestCaseName);
const std::vector<ov::Shape> stableDiffusionShapes = {
{16, 4096, 4096},
{2, 8, 4096, 4096}
};
INSTANTIATE_TEST_SUITE_P(
smoke_SoftMaxStableDiffusion,
SoftMax8LayerTest,
testing::Combine(testing::ValuesIn(netPrecisions),
::testing::Values(ov::element::undefined),
::testing::Values(ov::element::undefined),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({{16, 4096, 4096}})),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation(stableDiffusionShapes)),
testing::Values(-1),
testing::Values(ov::test::utils::DEVICE_GPU),
testing::Values(ov::AnyMap())),

View File

@ -151,6 +151,8 @@ TEST(softmax_gpu_dynamic_f32_test_upper_bound, input_same_values) {
format::bfyx);
auto config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
ov::intel_gpu::ImplementationDesc softmax_impl = { format::bfyx, "softmax_gpu_ref" };
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "softmax", softmax_impl } }));
network network(engine, topology(input_layout("input", in_layout),
reorder("reorder", input_info("input"), format::bfyx, data_types::f16),
softmax("softmax", input_info("reorder"), 3),