diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp index d50d306700b..45739c78a1e 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp @@ -1221,8 +1221,9 @@ void prepare_primitive_fusing::fuse_constant_transposes(program& p) { return format::find_format(new_order, fmt.block_sizes()); }; - auto itr = p.get_processing_order().begin(); - while (itr != p.get_processing_order().end()) { + auto& proc_order = p.get_processing_order(); + auto itr = proc_order.begin(); + while (itr != proc_order.end()) { auto& node = *itr++; if (!node->is_type()) @@ -1271,6 +1272,32 @@ void prepare_primitive_fusing::fuse_constant_transposes(program& p) { p.replace(prev_const, new_const_node); new_const_node.recalc_output_layout(false); + + // Add format reorder in case of onednn to avoid overhead during execution on weights memory allocation + if (_lo.get_preferred_impl_type(const_cast(*weightable_node), format::any /*dummy*/) == impl_types::onednn) { + auto next_node = new_const_node.get_users().front(); + bool can_be_fused = next_node->is_type() && + next_node->as().is_simple_reorder() && + next_node->get_users().size() == 1; + if (can_be_fused) { + layout reorder_layout = next_node->get_output_layout(); + reorder_layout.format = format::bfyx; + + auto new_reorder = std::make_shared(next_node->id() + "_reorder_fmt", new_const_node.id(), reorder_layout); + auto& new_reorder_node = p.get_or_create(new_reorder); + p.replace(*next_node, new_reorder_node); + new_reorder_node.recalc_output_layout(false); + itr = std::find(proc_order.begin(), proc_order.end(), &new_reorder_node); + } else { + layout reorder_layout = new_const_node.get_output_layout(); + reorder_layout.format = format::bfyx; + + auto new_reorder = std::make_shared(new_const_node.id() + "_reorder_fmt", new_const_node.id(), reorder_layout); + auto& new_reorder_node = p.get_or_create(std::move(new_reorder)); + p.add_intermediate(new_reorder_node, *new_const_node.get_users().front(), new_const_node); + new_reorder_node.recalc_output_layout(false); + } + } } } diff --git a/src/plugins/intel_gpu/tests/unit/passes/prepare_primitive_fusing_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/prepare_primitive_fusing_test.cpp index f19140d2459..a5f1d9e4706 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/prepare_primitive_fusing_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/prepare_primitive_fusing_test.cpp @@ -528,7 +528,7 @@ TEST(prepare_primitive_fusing, fuse_constant_transposes_removal_check) { input_layout("input", input->get_layout()), data("weights", weights), permute("permute", input_info("weights"), {1, 0}), - reorder("reorder_dt", input_info("permute"), format::bfyx, data_types::f16), + reorder("reorder_dt", input_info("permute"), format::fbyx, data_types::f16), fully_connected("fc", input_info("input"), { "reorder_dt" }, "", data_types::f16) ); @@ -536,13 +536,24 @@ TEST(prepare_primitive_fusing, fuse_constant_transposes_removal_check) { config.set_property(ov::intel_gpu::optimize_data(true)); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + if (engine.get_device_info().supports_immad) { + ov::intel_gpu::ImplementationDesc fc_impl = { format::bfyx, "", impl_types::onednn }; + config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc", fc_impl} })); + } + auto prog = program::build_program(engine, topology, config, false, true); layout_optimizer lo(true); + lo.set_implementation_forcing(config.get_property(ov::intel_gpu::force_implementations)); program_wrapper::apply_opt_pass(*prog, lo); ASSERT_TRUE(!has_node(*prog, "permute")); ASSERT_EQ(prog->get_node("weights").get_output_layout().format, format::fbyx); + + if (engine.get_device_info().supports_immad) { + ASSERT_TRUE(has_node(*prog, "reorder_dt")); + ASSERT_EQ(prog->get_node("reorder_dt").get_output_layout().format, format::bfyx); + } } TEST(prepare_primitive_fusing, fuse_constant_transposes_accuracy_test) {