[GPU] Prevent fusing for eltwise which is not broadcastable from input to fused output (#20974)
This commit is contained in:
parent
c608771e03
commit
4c40716e95
@ -131,6 +131,14 @@ void prepare_primitive_fusing_through::run(program& p) {
|
||||
auto new_prev = fuse_through_order[fuse_through_order.size() - 1];
|
||||
auto new_next = fuse_through_order[fuse_through_order.size() - 2];
|
||||
|
||||
// Check broadcastable for fused eltwise's output
|
||||
if (node->is_type<eltwise>()) {
|
||||
auto out_shape = new_prev->get_output_layout().get_partial_shape(); // new_prev's layout became node's new layout after fusing
|
||||
auto in_shape = node->get_dependency(1).get_output_layout().get_partial_shape();
|
||||
if (!broadcastable(in_shape, out_shape, true, true))
|
||||
continue;
|
||||
}
|
||||
|
||||
if (new_prev->is_type<input_layout>() ||
|
||||
new_prev->is_type<mutable_data>() ||
|
||||
new_prev->is_type<quantize>())
|
||||
|
@ -236,17 +236,24 @@ inline ov::PartialShape extend_shape_to_rank_from_begin(ov::PartialShape pshape,
|
||||
return extended_pshape;
|
||||
}
|
||||
|
||||
inline bool broadcastable(const ov::PartialShape& first_pshape, const ov::PartialShape& second_pshape, bool use_new_shape_infer) {
|
||||
inline bool broadcastable(const ov::PartialShape& first_pshape, const ov::PartialShape& second_pshape, bool use_new_shape_infer,
|
||||
bool first_to_second_only = false) {
|
||||
if (first_pshape.is_dynamic() || second_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (first_pshape.size() != second_pshape.size() && use_new_shape_infer) {
|
||||
return false;
|
||||
if (first_to_second_only) {
|
||||
if (first_pshape.size() > second_pshape.size()) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (first_pshape.size() != second_pshape.size() && use_new_shape_infer) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
size_t min_size = std::min(first_pshape.size(), second_pshape.size());
|
||||
|
||||
for (size_t i = 0; i < min_size; ++i) {
|
||||
if (!(first_pshape[i] == 1 || second_pshape[i] == 1 || first_pshape[i] == second_pshape[i])) {
|
||||
if (!(first_pshape[i] == 1 || (!first_to_second_only && second_pshape[i] == 1) || first_pshape[i] == second_pshape[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -100,6 +100,7 @@ public:
|
||||
|
||||
#define CASE_ELTWISE_FP32_5 { 1, 5, 4, 4 }, data_types::f32, data_types::f32, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
|
||||
#define CASE_ELTWISE_FP32_6 { 2, 32, 4, 8 }, data_types::f32, data_types::f32, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
|
||||
#define CASE_ELTWISE_FP32_7 { 1, 8, 16, 1 }, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfwzyx, eltwise_mode::sum
|
||||
#define CASE_ELTWISE_FP16_5 { 2, 32, 4, 8 }, data_types::f16, data_types::f16, format::b_fs_yx_fsv4, data_types::f16, format::b_fs_yx_fsv4, eltwise_mode::sum
|
||||
#define CASE_ELTWISE_FP16_6 { 1, 32, 4, 8 }, data_types::f16, data_types::f16, format::byxf, data_types::f16, format::byxf, eltwise_mode::sum
|
||||
#define CASE_ELTWISE_I8_4 { 2, 16, 4, 4 }, data_types::i8, data_types::i8, format::b_fs_yx_fsv4, data_types::f32, format::b_fs_yx_fsv4, eltwise_mode::sum
|
||||
@ -449,6 +450,28 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, eltwise_fp32_fused_prims, ::testing::Value
|
||||
eltwise_test_params{ CASE_ELTWISE_U8_4, 3, 5 },
|
||||
}));
|
||||
|
||||
class eltwise_reorder_eltwise_fp32_fused_prims : public EltwiseFusingTest {};
|
||||
TEST_P(eltwise_reorder_eltwise_fp32_fused_prims, eltwise_activation) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
data("const", get_mem(layout{ {p.input_size[2]}, p.input_type, p.input_format }, -10, 10)), // 1d const
|
||||
data("const2", get_mem(layout{ {1, 1, 1, 1, 1, 1}, p.input_type, p.default_format }, -10, 10)), // 6d const
|
||||
input_layout("input", get_input_layout(p)),
|
||||
eltwise("eltwise1", { input_info("input"), input_info("const") }, p.mode, p.input_type),
|
||||
reorder("reorder6d", input_info("eltwise1"), layout{ {p.input_size[0], p.input_size[1], 1, 1, p.input_size[2], p.input_size[3]}, p.input_type, p.default_format }),
|
||||
eltwise("eltwise2", { input_info("reorder6d"), input_info("const2") }, eltwise_mode::prod, p.default_type),
|
||||
activation("activation", input_info("eltwise2"), activation_func::abs),
|
||||
reorder("out", input_info("activation"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.input_type);
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, eltwise_reorder_eltwise_fp32_fused_prims, ::testing::ValuesIn(std::vector<eltwise_test_params>{
|
||||
eltwise_test_params{ CASE_ELTWISE_FP32_7, 3, 4 },
|
||||
}));
|
||||
|
||||
class eltwise_fp32_scale : public EltwiseFusingTest {};
|
||||
TEST_P(eltwise_fp32_scale, 6d) {
|
||||
auto p = GetParam();
|
||||
|
Loading…
Reference in New Issue
Block a user