[GPU] Prevent fusing for eltwise which is not broadcastable from input to fused output (#20974)

This commit is contained in:
Kelvin Choi 2023-12-01 16:11:59 +09:00 committed by GitHub
parent c608771e03
commit 4c40716e95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 4 deletions

View File

@ -131,6 +131,14 @@ void prepare_primitive_fusing_through::run(program& p) {
auto new_prev = fuse_through_order[fuse_through_order.size() - 1];
auto new_next = fuse_through_order[fuse_through_order.size() - 2];
// Check broadcastable for fused eltwise's output
if (node->is_type<eltwise>()) {
auto out_shape = new_prev->get_output_layout().get_partial_shape(); // new_prev's layout became node's new layout after fusing
auto in_shape = node->get_dependency(1).get_output_layout().get_partial_shape();
if (!broadcastable(in_shape, out_shape, true, true))
continue;
}
if (new_prev->is_type<input_layout>() ||
new_prev->is_type<mutable_data>() ||
new_prev->is_type<quantize>())

View File

@ -236,17 +236,24 @@ inline ov::PartialShape extend_shape_to_rank_from_begin(ov::PartialShape pshape,
return extended_pshape;
}
inline bool broadcastable(const ov::PartialShape& first_pshape, const ov::PartialShape& second_pshape, bool use_new_shape_infer) {
inline bool broadcastable(const ov::PartialShape& first_pshape, const ov::PartialShape& second_pshape, bool use_new_shape_infer,
bool first_to_second_only = false) {
if (first_pshape.is_dynamic() || second_pshape.is_dynamic()) {
return false;
}
if (first_to_second_only) {
if (first_pshape.size() > second_pshape.size()) {
return false;
}
} else {
if (first_pshape.size() != second_pshape.size() && use_new_shape_infer) {
return false;
}
}
size_t min_size = std::min(first_pshape.size(), second_pshape.size());
for (size_t i = 0; i < min_size; ++i) {
if (!(first_pshape[i] == 1 || second_pshape[i] == 1 || first_pshape[i] == second_pshape[i])) {
if (!(first_pshape[i] == 1 || (!first_to_second_only && second_pshape[i] == 1) || first_pshape[i] == second_pshape[i])) {
return false;
}
}

View File

@ -100,6 +100,7 @@ public:
#define CASE_ELTWISE_FP32_5 { 1, 5, 4, 4 }, data_types::f32, data_types::f32, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
#define CASE_ELTWISE_FP32_6 { 2, 32, 4, 8 }, data_types::f32, data_types::f32, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
#define CASE_ELTWISE_FP32_7 { 1, 8, 16, 1 }, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfwzyx, eltwise_mode::sum
#define CASE_ELTWISE_FP16_5 { 2, 32, 4, 8 }, data_types::f16, data_types::f16, format::b_fs_yx_fsv4, data_types::f16, format::b_fs_yx_fsv4, eltwise_mode::sum
#define CASE_ELTWISE_FP16_6 { 1, 32, 4, 8 }, data_types::f16, data_types::f16, format::byxf, data_types::f16, format::byxf, eltwise_mode::sum
#define CASE_ELTWISE_I8_4 { 2, 16, 4, 4 }, data_types::i8, data_types::i8, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
@ -449,6 +450,28 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, eltwise_fp32_fused_prims, ::testing::Value
eltwise_test_params{ CASE_ELTWISE_U8_4, 3, 5 },
}));
class eltwise_reorder_eltwise_fp32_fused_prims : public EltwiseFusingTest {};
TEST_P(eltwise_reorder_eltwise_fp32_fused_prims, eltwise_activation) {
auto p = GetParam();
create_topologies(
data("const", get_mem(layout{ {p.input_size[2]}, p.input_type, p.input_format }, -10, 10)), // 1d const
data("const2", get_mem(layout{ {1, 1, 1, 1, 1, 1}, p.input_type, p.default_format }, -10, 10)), // 6d const
input_layout("input", get_input_layout(p)),
eltwise("eltwise1", { input_info("input"), input_info("const") }, p.mode, p.input_type),
reorder("reorder6d", input_info("eltwise1"), layout{ {p.input_size[0], p.input_size[1], 1, 1, p.input_size[2], p.input_size[3]}, p.input_type, p.default_format }),
eltwise("eltwise2", { input_info("reorder6d"), input_info("const2") }, eltwise_mode::prod, p.default_type),
activation("activation", input_info("eltwise2"), activation_func::abs),
reorder("out", input_info("activation"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.input_type);
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, eltwise_reorder_eltwise_fp32_fused_prims, ::testing::ValuesIn(std::vector<eltwise_test_params>{
eltwise_test_params{ CASE_ELTWISE_FP32_7, 3, 4 },
}));
class eltwise_fp32_scale : public EltwiseFusingTest {};
TEST_P(eltwise_fp32_scale, 6d) {
auto p = GetParam();