[GPU] Added missing second input to splitted reshape in handle reshape pass (#17254)
This commit is contained in:
committed by
GitHub
parent
bf59a67d94
commit
b019868653
@@ -189,6 +189,8 @@ public:
|
||||
bool connect_int_node_with_old_dep = true,
|
||||
bool move_usrs_of_prev_to_node = false);
|
||||
|
||||
void add_connection(program_node& prev, program_node& next);
|
||||
|
||||
// removes a node from the graph and deletes it afterwards,
|
||||
// prereq: node cannot be marked as output and has to have exactly one dependency
|
||||
// returns if 'node' has been extracted and removed successfully
|
||||
@@ -332,8 +334,6 @@ private:
|
||||
// mark if the node is constant assuming that all dependencies are marked properly
|
||||
void reverse_connection(program_node& dep_node, program_node& user_node);
|
||||
|
||||
void add_connection(program_node& prev, program_node& next);
|
||||
|
||||
void remove_connection(program_node& prev, program_node& next);
|
||||
|
||||
void remove_all_connections(program_node& node);
|
||||
|
||||
@@ -157,6 +157,10 @@ void handle_reshape::run(program& p) {
|
||||
auto& new_reshape_node = p.get_or_create(new_reshape);
|
||||
user->replace_dependency(0, input_node);
|
||||
p.add_intermediate(new_reshape_node, *user, 0);
|
||||
if (new_reshape->input_size() == 2) {
|
||||
p.add_connection(prim_node.get_dependency(1), new_reshape_node);
|
||||
}
|
||||
|
||||
reorder_reshape_nodes.push_back(&new_reshape_node);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +132,54 @@ TEST(handle_reshape, correct_parameters_propagation) {
|
||||
ASSERT_EQ(out_shape1, expected_out_shape);
|
||||
}
|
||||
|
||||
TEST(handle_reshape, correct_parameters_propagation_2_inputs) {
|
||||
auto& engine = get_test_engine();
|
||||
auto data0_mem = engine.allocate_memory({ ov::PartialShape{}, data_types::f16, format::bfyx });
|
||||
auto data1_mem = engine.allocate_memory({ ov::PartialShape{1, 12}, data_types::f16, format::bfyx });
|
||||
auto shape_mem = engine.allocate_memory({ ov::PartialShape{2}, data_types::i32, format::bfyx });
|
||||
auto in_layout = layout{ ov::PartialShape{1, 2, 3, 4}, data_types::f16, format::bfyx };
|
||||
set_values<int32_t>(shape_mem, {2, 12});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", in_layout));
|
||||
topology.add(data("data0", data0_mem));
|
||||
topology.add(data("data1", data1_mem));
|
||||
topology.add(data("shape", shape_mem));
|
||||
topology.add(eltwise("e1", input_info("input"), input_info("data0"), eltwise_mode::sum));
|
||||
topology.add(reshape("reshape", input_info("e1"), input_info("shape"), false, {-1, 12}));
|
||||
topology.add(eltwise("e2", input_info("reshape"), input_info("data1"), eltwise_mode::sum));
|
||||
topology.add(reorder("reorder", input_info("reshape"), format::bfyx, data_types::f32));
|
||||
|
||||
ExecutionConfig config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
auto prog = program::build_program(engine, topology, config, false, true);
|
||||
|
||||
layout_optimizer lo(true);
|
||||
|
||||
program_wrapper::apply_opt_pass<handle_reshape>(*prog);
|
||||
|
||||
ASSERT_NE(prog, nullptr);
|
||||
ASSERT_TRUE(has_node_with_type<reshape>(*prog));
|
||||
|
||||
auto& reshape_node = prog->get_node("reshape");
|
||||
ASSERT_TRUE(reshape_node.can_be_optimized());
|
||||
ASSERT_EQ(reshape_node.get_dependencies().size(), 2);
|
||||
|
||||
auto& reshape_split_node = prog->get_node("reorder").get_dependency(0);
|
||||
ASSERT_TRUE(reshape_split_node.is_type<reshape>());
|
||||
ASSERT_EQ(reshape_split_node.get_dependencies().size(), 2);
|
||||
|
||||
auto out_shape0 = prog->get_node("e2").get_output_layout().get_partial_shape();
|
||||
auto out_shape1 = prog->get_node("reorder").get_output_layout().get_partial_shape();
|
||||
|
||||
ov::PartialShape expected_out_shape{2, 12};
|
||||
|
||||
// handle_reshape may do reshape split, so ensure that output shape on all branches is correct
|
||||
ASSERT_EQ(out_shape0, expected_out_shape);
|
||||
ASSERT_EQ(out_shape1, expected_out_shape);
|
||||
}
|
||||
|
||||
TEST(handle_reshape, reshape_input_reorder) {
|
||||
auto& engine = get_test_engine();
|
||||
auto shape_memory = engine.allocate_memory({ ov::PartialShape{5}, data_types::i32, format::bfyx });
|
||||
|
||||
Reference in New Issue
Block a user