[GPU] Added missing second input to splitted reshape in handle reshape pass (#17254)

This commit is contained in:
Vladimir Paramuzov
2023-04-27 20:21:17 +04:00
committed by GitHub
parent bf59a67d94
commit b019868653
3 changed files with 54 additions and 2 deletions

View File

@@ -189,6 +189,8 @@ public:
bool connect_int_node_with_old_dep = true,
bool move_usrs_of_prev_to_node = false);
void add_connection(program_node& prev, program_node& next);
// removes a node from the graph and deletes it afterwards,
// prereq: node cannot be marked as output and has to have exactly one dependency
// returns if 'node' has been extracted and removed successfully
@@ -332,8 +334,6 @@ private:
// mark if the node is constant assuming that all dependencies are marked properly
void reverse_connection(program_node& dep_node, program_node& user_node);
void add_connection(program_node& prev, program_node& next);
void remove_connection(program_node& prev, program_node& next);
void remove_all_connections(program_node& node);

View File

@@ -157,6 +157,10 @@ void handle_reshape::run(program& p) {
auto& new_reshape_node = p.get_or_create(new_reshape);
user->replace_dependency(0, input_node);
p.add_intermediate(new_reshape_node, *user, 0);
if (new_reshape->input_size() == 2) {
p.add_connection(prim_node.get_dependency(1), new_reshape_node);
}
reorder_reshape_nodes.push_back(&new_reshape_node);
}
}

View File

@@ -132,6 +132,54 @@ TEST(handle_reshape, correct_parameters_propagation) {
ASSERT_EQ(out_shape1, expected_out_shape);
}
TEST(handle_reshape, correct_parameters_propagation_2_inputs) {
auto& engine = get_test_engine();
auto data0_mem = engine.allocate_memory({ ov::PartialShape{}, data_types::f16, format::bfyx });
auto data1_mem = engine.allocate_memory({ ov::PartialShape{1, 12}, data_types::f16, format::bfyx });
auto shape_mem = engine.allocate_memory({ ov::PartialShape{2}, data_types::i32, format::bfyx });
auto in_layout = layout{ ov::PartialShape{1, 2, 3, 4}, data_types::f16, format::bfyx };
set_values<int32_t>(shape_mem, {2, 12});
topology topology;
topology.add(input_layout("input", in_layout));
topology.add(data("data0", data0_mem));
topology.add(data("data1", data1_mem));
topology.add(data("shape", shape_mem));
topology.add(eltwise("e1", input_info("input"), input_info("data0"), eltwise_mode::sum));
topology.add(reshape("reshape", input_info("e1"), input_info("shape"), false, {-1, 12}));
topology.add(eltwise("e2", input_info("reshape"), input_info("data1"), eltwise_mode::sum));
topology.add(reorder("reorder", input_info("reshape"), format::bfyx, data_types::f32));
ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::intel_gpu::optimize_data(true));
auto prog = program::build_program(engine, topology, config, false, true);
layout_optimizer lo(true);
program_wrapper::apply_opt_pass<handle_reshape>(*prog);
ASSERT_NE(prog, nullptr);
ASSERT_TRUE(has_node_with_type<reshape>(*prog));
auto& reshape_node = prog->get_node("reshape");
ASSERT_TRUE(reshape_node.can_be_optimized());
ASSERT_EQ(reshape_node.get_dependencies().size(), 2);
auto& reshape_split_node = prog->get_node("reorder").get_dependency(0);
ASSERT_TRUE(reshape_split_node.is_type<reshape>());
ASSERT_EQ(reshape_split_node.get_dependencies().size(), 2);
auto out_shape0 = prog->get_node("e2").get_output_layout().get_partial_shape();
auto out_shape1 = prog->get_node("reorder").get_output_layout().get_partial_shape();
ov::PartialShape expected_out_shape{2, 12};
// handle_reshape may do reshape split, so ensure that output shape on all branches is correct
ASSERT_EQ(out_shape0, expected_out_shape);
ASSERT_EQ(out_shape1, expected_out_shape);
}
TEST(handle_reshape, reshape_input_reorder) {
auto& engine = get_test_engine();
auto shape_memory = engine.allocate_memory({ ov::PartialShape{5}, data_types::i32, format::bfyx });