[GPU] Fix for a corner case for broadcast with dynamic input and static output (#14451)
This commit is contained in:
parent
a21da85eb9
commit
a47688e593
@ -87,7 +87,7 @@ KernelsData BroadcastKernelBase::GetCommonKernelsData(const Params& params,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
prim_params.outputs[0].is_dynamic());
|
||||
prim_params.inputs[0].is_dynamic() || prim_params.outputs[0].is_dynamic());
|
||||
|
||||
return {k_data};
|
||||
}
|
||||
|
@ -87,7 +87,8 @@ void start_broadcast_test_dynamic(format input_format,
|
||||
data_types input_data_type,
|
||||
ov::Shape output_shape,
|
||||
ov::Shape input_data_shape,
|
||||
ov::AxisSet broadcast_axes) {
|
||||
ov::AxisSet broadcast_axes,
|
||||
bool is_output_static = false) {
|
||||
size_t input_data_size = accumulate(input_data_shape.rbegin(), input_data_shape.rend(), (size_t)1, std::multiplies<size_t>());
|
||||
EXPECT_GE(input_data_size, (size_t)1);
|
||||
std::vector<T> input_data = {};
|
||||
@ -111,26 +112,41 @@ void start_broadcast_test_dynamic(format input_format,
|
||||
auto& engine = get_test_engine();
|
||||
auto input = engine.allocate_memory({ov::PartialShape(input_data_shape), input_data_type, fmt});
|
||||
|
||||
topology topology;
|
||||
memory::ptr target_shape_mem = nullptr;
|
||||
if (is_output_static) {
|
||||
auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt);
|
||||
topology.add(input_layout("input", in_layout));
|
||||
topology.add(reorder("reorder", input_info("input"), input_format, input_data_type));
|
||||
topology.add(broadcast("broadcast",
|
||||
input_info("reorder"),
|
||||
output_shape,
|
||||
ov::AxisSet(broadcast_axes)));
|
||||
topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type));
|
||||
} else {
|
||||
auto in_layout = layout(ov::PartialShape::dynamic(input_rank), input_data_type, fmt);
|
||||
auto target_shape_layout = layout(ov::PartialShape{input_rank}, data_types::i32, fmt);
|
||||
auto target_shape_mem = engine.allocate_memory(target_shape_layout);
|
||||
topology topology;
|
||||
target_shape_mem = engine.allocate_memory(target_shape_layout);
|
||||
topology.add(input_layout("input", in_layout));
|
||||
topology.add(input_layout("target_shape", target_shape_layout));
|
||||
topology.add(reorder("reorder", input_info("input"), input_format, input_data_type));
|
||||
topology.add(broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes)));
|
||||
topology.add(
|
||||
broadcast("broadcast", input_info("reorder"), input_info("target_shape"), ov::AxisSet(broadcast_axes)));
|
||||
topology.add(reorder("output", input_info("broadcast"), fmt, input_data_type));
|
||||
std::vector<int32_t> target_shape_data(output_shape.begin(), output_shape.end());
|
||||
set_values<int32_t>(target_shape_mem, target_shape_data);
|
||||
}
|
||||
|
||||
build_options bo;
|
||||
bo.set_option(build_option::allow_new_shape_infer(true));
|
||||
|
||||
set_values(input, input_data);
|
||||
std::vector<int32_t> target_shape_data(output_shape.begin(), output_shape.end());
|
||||
set_values<int32_t>(target_shape_mem, target_shape_data);
|
||||
|
||||
network network(engine, topology, bo);
|
||||
network.set_input_data("input", input);
|
||||
if (!is_output_static) {
|
||||
network.set_input_data("target_shape", target_shape_mem);
|
||||
}
|
||||
|
||||
auto inst = network.get_primitive("broadcast");
|
||||
auto impl = inst->get_impl();
|
||||
@ -268,14 +284,26 @@ TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
|
||||
start_broadcast_test_dynamic<float>(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1});
|
||||
}
|
||||
|
||||
TEST(broadcast_gpu_float, bfyx_1_to_4x5_w_b_axes_0x1_dynamic_with_static_output) {
|
||||
start_broadcast_test_dynamic<float>(format::bfyx, data_types::f32, {4, 5}, {1, 1}, {0, 1}, true);
|
||||
}
|
||||
|
||||
TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
|
||||
start_broadcast_test_dynamic<uint8_t>(format::bfyx, data_types::u8, {4, 5}, {1, 1}, {0, 1});
|
||||
}
|
||||
|
||||
TEST(broadcast_gpu_uint8_t, bfyx_1_to_4x5_w_b_axes_0x1x2_dynamic_with_static_output) {
|
||||
start_broadcast_test_dynamic<uint8_t>(format::bfyx, data_types::u8, {4, 5, 2}, {1, 1, 1}, {0, 1, 2}, true);
|
||||
}
|
||||
|
||||
TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1_dynamic) {
|
||||
start_broadcast_test_dynamic<int64_t>(format::bfyx, data_types::i64, {4, 5}, {1, 1}, {0, 1});
|
||||
}
|
||||
|
||||
TEST(broadcast_gpu_int64_t, bfyx_1_to_4x5_w_b_axes_0x1x2x3_dynamic_with_static_output) {
|
||||
start_broadcast_test_dynamic<int64_t>(format::bfyx, data_types::i64, {4, 5, 2, 3}, {1, 1, 1, 1}, {0, 1, 2, 3});
|
||||
}
|
||||
|
||||
|
||||
/* Expected golden_data = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
|
Loading…
Reference in New Issue
Block a user