[GPU] Fix for a corner case for broadcast with dynamic input and static output (#14451)

This commit is contained in:
Taylor Yeonbok Lee 2022-12-06 21:13:15 -08:00 committed by GitHub
parent a21da85eb9
commit a47688e593
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 14 deletions

View File

@ -87,7 +87,7 @@ KernelsData BroadcastKernelBase::GetCommonKernelsData(const Params& params,
1,
0,
1,
prim_params.outputs[0].is_dynamic());
prim_params.inputs[0].is_dynamic() || prim_params.outputs[0].is_dynamic());
return {k_data};
}

View File

@ -87,7 +87,8 @@ void start_broadcast_test_dynamic(format input_format,
data_types input_data_type,
ov::Shape output_shape,
ov::Shape input_data_shape,
ov::AxisSet broadcast_axes) {
ov::AxisSet broadcast_axes,
bool is_output_static = false) {
size_t input_data_size = accumulate(input_data_shape.rbegin(), input_data_shape.rend(), (size_t)1, std::multiplies<size_t>());
EXPECT_GE(input_data_size, (size_t)1);
std::vector<T> input_data = {};
@ -110,27 +111,42 @@ void start_broadcast_test_dynamic(format input_format,
auto& engine = get_test_engine();
auto input = engine.allocate_memory({ov::PartialShape(input_data_shape), input_data_type, fmt});
auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt);
auto target_shape_layout = layout(ov::PartialShape{input_rank}, data_types::i32, fmt);
auto target_shape_mem = engine.allocate_memory(target_shape_layout);
topology topology;
topology.add(input_layout("input", in_layout));
topology.add(input_layout("target_shape", target_shape_layout));
topology.add(reorder("reorder", input_info("input"), input_format, input_data_type));
topology.add(broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes)));
topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type));
memory::ptr target_shape_mem = nullptr;
if (is_output_static) {
auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt);
topology.add(input_layout("input", in_layout));
topology.add(reorder("reorder", input_info("input"), input_format, input_data_type));
topology.add(broadcast("broadcast",
input_info("reorder"),
output_shape,
ov::AxisSet(broadcast_axes)));
topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type));
} else {
auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt);
auto target_shape_layout = layout(ov::PartialShape{input_rank}, data_types::i32, fmt);
target_shape_mem = engine.allocate_memory(target_shape_layout);
topology.add(input_layout("input", in_layout));
topology.add(input_layout("target_shape", target_shape_layout));
topology.add(reorder("reorder", input_info("input"), input_format, input_data_type));
topology.add(
broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes)));
topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type));
std::vector<int32_t> target_shape_data(output_shape.begin(), output_shape.end());
set_values<int32_t>(target_shape_mem, target_shape_data);
}
build_options bo;
bo.set_option(build_option::allow_new_shape_infer(true));
set_values(input, input_data);
std::vector<int32_t> target_shape_data(output_shape.begin(), output_shape.end());
set_values<int32_t>(target_shape_mem, target_shape_data);
network network(engine, topology, bo);
network.set_input_data("input", input);
network.set_input_data("target_shape", target_shape_mem);
if (!is_output_static) {
network.set_input_data("target_shape", target_shape_mem);
}
auto inst = network.get_primitive("broadcast");
auto impl = inst->get_impl();
@ -268,14 +284,26 @@ TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
start_broadcast_test_dynamic<float>(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1});
}
TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic_with_static_output) {
start_broadcast_test_dynamic<float>(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1}, true);
}
TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
start_broadcast_test_dynamic<uint8_t>(format::bfyx, data_types::u8, {4, 5}, {1, 1}, {0, 1});
}
TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1x2_dynamic_with_static_output) {
start_broadcast_test_dynamic<uint8_t>(format::bfyx, data_types::u8, {4, 5, 2}, {1, 1, 1}, {0, 1, 2}, true);
}
TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
start_broadcast_test_dynamic<int64_t>(format::bfyx, data_types::i64, {4, 5}, {1, 1}, {0, 1});
}
TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1x2x3_dynamic_with_static_output) {
start_broadcast_test_dynamic<int64_t>(format::bfyx, data_types::i64, {4, 5, 2, 3}, {1, 1, 1, 1}, {0, 1, 2, 3});
}
/* Expected golden_data = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,