diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/broadcast/broadcast_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/broadcast/broadcast_kernel_base.cpp index 9f79c3a9477..a43bb855297 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/broadcast/broadcast_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/broadcast/broadcast_kernel_base.cpp @@ -87,7 +87,7 @@ KernelsData BroadcastKernelBase::GetCommonKernelsData(const Params& params, 1, 0, 1, - prim_params.outputs[0].is_dynamic()); + prim_params.inputs[0].is_dynamic() || prim_params.outputs[0].is_dynamic()); return {k_data}; } diff --git a/src/plugins/intel_gpu/tests/test_cases/broadcast_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/broadcast_gpu_test.cpp index 3b51ffdb3b2..fc48cc57f4a 100644 --- a/src/plugins/intel_gpu/tests/test_cases/broadcast_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/test_cases/broadcast_gpu_test.cpp @@ -87,7 +87,8 @@ void start_broadcast_test_dynamic(format input_format, data_types input_data_type, ov::Shape output_shape, ov::Shape input_data_shape, - ov::AxisSet broadcast_axes) { + ov::AxisSet broadcast_axes, + bool is_output_static = false) { size_t input_data_size = accumulate(input_data_shape.rbegin(), input_data_shape.rend(), (size_t)1, std::multiplies()); EXPECT_GE(input_data_size, (size_t)1); std::vector input_data = {}; @@ -110,27 +111,42 @@ void start_broadcast_test_dynamic(format input_format, auto& engine = get_test_engine(); auto input = engine.allocate_memory({ov::PartialShape(input_data_shape), input_data_type, fmt}); - - auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt); - auto target_shape_layout = layout(ov::PartialShape{input_rank}, data_types::i32, fmt); - auto target_shape_mem = engine.allocate_memory(target_shape_layout); + topology topology; - topology.add(input_layout("input", in_layout)); - topology.add(input_layout("target_shape", target_shape_layout)); - topology.add(reorder("reorder", input_info("input"), input_format, input_data_type)); - topology.add(broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes))); - topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type)); + memory::ptr target_shape_mem = nullptr; + if (is_output_static) { + auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt); + topology.add(input_layout("input", in_layout)); + topology.add(reorder("reorder", input_info("input"), input_format, input_data_type)); + topology.add(broadcast("broadcast", + input_info("reorder"), + output_shape, + ov::AxisSet(broadcast_axes))); + topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type)); + } else { + auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt); + auto target_shape_layout = layout(ov::PartialShape{input_rank}, data_types::i32, fmt); + target_shape_mem = engine.allocate_memory(target_shape_layout); + topology.add(input_layout("input", in_layout)); + topology.add(input_layout("target_shape", target_shape_layout)); + topology.add(reorder("reorder", input_info("input"), input_format, input_data_type)); + topology.add( + broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes))); + topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type)); + std::vector target_shape_data(output_shape.begin(), output_shape.end()); + set_values(target_shape_mem, target_shape_data); + } build_options bo; bo.set_option(build_option::allow_new_shape_infer(true)); set_values(input, input_data); - std::vector target_shape_data(output_shape.begin(), output_shape.end()); - set_values(target_shape_mem, target_shape_data); network network(engine, topology, bo); network.set_input_data("input", input); - network.set_input_data("target_shape", target_shape_mem); + if (!is_output_static) { + network.set_input_data("target_shape", target_shape_mem); + } auto inst = network.get_primitive("broadcast"); auto impl = inst->get_impl(); @@ -268,14 +284,26 @@ TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) { start_broadcast_test_dynamic(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1}); } +TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic_with_static_output) { + start_broadcast_test_dynamic(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1}, true); +} + TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) { start_broadcast_test_dynamic(format::bfyx, data_types::u8, {4, 5}, {1, 1}, {0, 1}); } +TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1x2_dynamic_with_static_output) { + start_broadcast_test_dynamic(format::bfyx, data_types::u8, {4, 5, 2}, {1, 1, 1}, {0, 1, 2}, true); +} + TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) { start_broadcast_test_dynamic(format::bfyx, data_types::i64, {4, 5}, {1, 1}, {0, 1}); } +TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1x2x3_dynamic_with_static_output) { + start_broadcast_test_dynamic(format::bfyx, data_types::i64, {4, 5, 2, 3}, {1, 1, 1, 1}, {0, 1, 2, 3}); +} + /* Expected golden_data = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,