[GPU] fix pr18171 regression (#18272)
This commit is contained in:
parent
88fa4b040e
commit
504f1d8237
@ -1063,11 +1063,6 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
|
||||
auto eltw_in_size = peer_node->get_output_layout();
|
||||
if (eltw_in_size.is_dynamic())
|
||||
return;
|
||||
// When input rank > 4, fused eltwise to gemm should be converted to 4 dim in init_onednn_primitive_attribute()
|
||||
// But current init_onednn_primitive_attribute() cannot handle dynamic shape case.
|
||||
auto eltw_in_rank = fused_node->get_output_layout().get_rank();
|
||||
if ((fused_node->is_type<gemm>()) && (eltw_in_rank > 4))
|
||||
return;
|
||||
}
|
||||
if (parent1.first->is_type<convolution>() && !conv_supports_fusings(parent1.first->as<convolution>()))
|
||||
return;
|
||||
|
@ -962,10 +962,15 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
auto& desc = cldnn_post_ops[idx];
|
||||
if (desc.is_type<activation>()) {
|
||||
auto fused_desc = desc.typed_desc<activation>();
|
||||
bool allow_new_shape_infer = get_program().get_config().get_property(ov::intel_gpu::allow_new_shape_infer);
|
||||
if (fused_desc->activation_function == cldnn::activation_func::relu_negative_slope
|
||||
&& !fused_desc->additional_params_input.empty()) {
|
||||
auto dep_idx = cldnn_post_ops[idx].outer_dep_start_idx;
|
||||
auto oc_dim = static_cast<int>(desc.output_layout.get_partial_shape()[1].get_max_length());
|
||||
int oc_dim = 1;
|
||||
if (allow_new_shape_infer)
|
||||
oc_dim = static_cast<int>(desc.output_layout.get_partial_shape()[1].get_max_length());
|
||||
else
|
||||
oc_dim = static_cast<int>(desc.output_layout.get_tensor().feature.size());
|
||||
post_ops.append_prelu(1 << oc_dim);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_relu, dep_idx);
|
||||
} else if (fused_desc->activation_function == cldnn::activation_func::hard_sigmoid) {
|
||||
|
@ -515,38 +515,3 @@ TEST(prepare_primitive_fusing, eltwise_fusing_residual_connection) {
|
||||
net.execute();
|
||||
ASSERT_TRUE(conv_inst->has_unfused_subgraph());
|
||||
}
|
||||
|
||||
TEST(prepare_primitive_fusing, dont_fuse_eltwise_to_onednn_gemm_dyn_rank5) {
|
||||
auto& engine = get_test_engine();
|
||||
if (!engine.get_device_info().supports_immad)
|
||||
return;
|
||||
ov::Shape input1_shape = { 2, 2, 2, 2, 2};
|
||||
ov::Shape input2_shape = { 2, 2, 2, 2, 2};
|
||||
auto input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfzyx};
|
||||
auto input2_layout = layout{ov::PartialShape::dynamic(input2_shape.size()), data_types::f32, format::bfzyx};
|
||||
auto input1 = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f32, format::bfzyx});
|
||||
auto input2 = engine.allocate_memory(layout{ov::PartialShape(input2_shape), data_types::f32, format::bfzyx});
|
||||
auto const_layout = layout{ ov::PartialShape{2, 2, 2, 2, 2}, data_types::f32, format::bfzyx };
|
||||
auto const_mem = engine.allocate_memory(const_layout);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input1", input1_layout));
|
||||
topology.add(input_layout("input2", input2_layout));
|
||||
topology.add(data("const", const_mem));
|
||||
topology.add(gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f32));
|
||||
topology.add(eltwise("add", { input_info("gemm"), input_info("const") }, eltwise_mode::sum));
|
||||
topology.add(reorder("reorder", input_info("add"), format::bfzyx, data_types::f16));
|
||||
|
||||
ExecutionConfig config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
auto prog = program::build_program(engine, topology, config, false, true);
|
||||
|
||||
layout_optimizer lo(true);
|
||||
lo.set_optimization_attribute(layout_optimizer::optimization_attributes_type::use_onednn_impls, true);
|
||||
|
||||
program_wrapper::apply_opt_pass<prepare_primitive_fusing>(*prog, lo);
|
||||
|
||||
ASSERT_NE(prog, nullptr);
|
||||
ASSERT_TRUE(has_node(*prog, "add"));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user