From b01986865337b8ebf214a02dbfaac422d648b8a8 Mon Sep 17 00:00:00 2001 From: Vladimir Paramuzov Date: Thu, 27 Apr 2023 20:21:17 +0400 Subject: [PATCH] [GPU] Added missing second input to splitted reshape in handle reshape pass (#17254) --- .../include/intel_gpu/graph/program.hpp | 4 +- .../graph/graph_optimizer/handle_reshape.cpp | 4 ++ .../intel_gpu/tests/passes/handle_reshape.cpp | 48 +++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp index c537b1335a7..bd5ca7ef9d3 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp @@ -189,6 +189,8 @@ public: bool connect_int_node_with_old_dep = true, bool move_usrs_of_prev_to_node = false); + void add_connection(program_node& prev, program_node& next); + // removes a node from the graph and deletes it afterwards, // prereq: node cannot be marked as output and has to have exactly one dependency // returns if 'node' has been extracted and removed successfully @@ -332,8 +334,6 @@ private: // mark if the node is constant assuming that all dependencies are marked properly void reverse_connection(program_node& dep_node, program_node& user_node); - void add_connection(program_node& prev, program_node& next); - void remove_connection(program_node& prev, program_node& next); void remove_all_connections(program_node& node); diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/handle_reshape.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/handle_reshape.cpp index ab69ea235de..1a5d9287994 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/handle_reshape.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/handle_reshape.cpp @@ -157,6 +157,10 @@ void handle_reshape::run(program& p) { auto& new_reshape_node = p.get_or_create(new_reshape); user->replace_dependency(0, input_node); p.add_intermediate(new_reshape_node, *user, 0); + if (new_reshape->input_size() == 2) { + p.add_connection(prim_node.get_dependency(1), new_reshape_node); + } + reorder_reshape_nodes.push_back(&new_reshape_node); } } diff --git a/src/plugins/intel_gpu/tests/passes/handle_reshape.cpp b/src/plugins/intel_gpu/tests/passes/handle_reshape.cpp index 60853848680..41b553fcf62 100644 --- a/src/plugins/intel_gpu/tests/passes/handle_reshape.cpp +++ b/src/plugins/intel_gpu/tests/passes/handle_reshape.cpp @@ -132,6 +132,54 @@ TEST(handle_reshape, correct_parameters_propagation) { ASSERT_EQ(out_shape1, expected_out_shape); } +TEST(handle_reshape, correct_parameters_propagation_2_inputs) { + auto& engine = get_test_engine(); + auto data0_mem = engine.allocate_memory({ ov::PartialShape{}, data_types::f16, format::bfyx }); + auto data1_mem = engine.allocate_memory({ ov::PartialShape{1, 12}, data_types::f16, format::bfyx }); + auto shape_mem = engine.allocate_memory({ ov::PartialShape{2}, data_types::i32, format::bfyx }); + auto in_layout = layout{ ov::PartialShape{1, 2, 3, 4}, data_types::f16, format::bfyx }; + set_values(shape_mem, {2, 12}); + + topology topology; + topology.add(input_layout("input", in_layout)); + topology.add(data("data0", data0_mem)); + topology.add(data("data1", data1_mem)); + topology.add(data("shape", shape_mem)); + topology.add(eltwise("e1", input_info("input"), input_info("data0"), eltwise_mode::sum)); + topology.add(reshape("reshape", input_info("e1"), input_info("shape"), false, {-1, 12})); + topology.add(eltwise("e2", input_info("reshape"), input_info("data1"), eltwise_mode::sum)); + topology.add(reorder("reorder", input_info("reshape"), format::bfyx, data_types::f32)); + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + config.set_property(ov::intel_gpu::optimize_data(true)); + auto prog = program::build_program(engine, topology, config, false, true); + + layout_optimizer lo(true); + + program_wrapper::apply_opt_pass(*prog); + + ASSERT_NE(prog, nullptr); + ASSERT_TRUE(has_node_with_type(*prog)); + + auto& reshape_node = prog->get_node("reshape"); + ASSERT_TRUE(reshape_node.can_be_optimized()); + ASSERT_EQ(reshape_node.get_dependencies().size(), 2); + + auto& reshape_split_node = prog->get_node("reorder").get_dependency(0); + ASSERT_TRUE(reshape_split_node.is_type()); + ASSERT_EQ(reshape_split_node.get_dependencies().size(), 2); + + auto out_shape0 = prog->get_node("e2").get_output_layout().get_partial_shape(); + auto out_shape1 = prog->get_node("reorder").get_output_layout().get_partial_shape(); + + ov::PartialShape expected_out_shape{2, 12}; + + // handle_reshape may do reshape split, so ensure that output shape on all branches is correct + ASSERT_EQ(out_shape0, expected_out_shape); + ASSERT_EQ(out_shape1, expected_out_shape); +} + TEST(handle_reshape, reshape_input_reorder) { auto& engine = get_test_engine(); auto shape_memory = engine.allocate_memory({ ov::PartialShape{5}, data_types::i32, format::bfyx });