This case will check the layout of condition in these conditions. (#21335)

- it re-allocated at primitive_inst::realloc_if_needed().
- it can be skip subgraph.
This commit is contained in:
Sungeun Kim 2023-11-29 18:42:16 +09:00 committed by GitHub
parent c1a28d0942
commit 07bcb8b6ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -341,6 +341,75 @@ TEST(condition_gpu, dynamic_shapes) {
}
}
// This case will check the layout of condition in these conditions.
// - it re-allocated at primitive_inst::realloc_if_needed().
// - it can be skip subgraph.
TEST(condition_gpu, dynamic_shapes_skip_condition) {
auto& engine = get_test_engine();
ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
const int64_t d1 = 2;
const int64_t d2 = 4;
layout input_lay = {{-1, d1, -1, d2}, data_types::f32, format::bfyx};
auto predicate = engine.allocate_memory({{ 1 }, data_types::u8, format::bfyx });
const primitive_id condition_id = "condition";
const primitive_id condition_id_true = condition_id + "_when_true";
const primitive_id condition_id_false = condition_id + "_when_false";
const primitive_id branch_input_id = "branch_input";
const primitive_id model_input = "input";
const primitive_id predicate_input = "predicate";
const primitive_id reorder_id = "reorder";
const primitive_id tranpose = "transpose";
cldnn::topology topology;
topology.add(input_layout(model_input, input_lay));
topology.add(input_layout(predicate_input, predicate->get_layout()));
topology.add(permute(tranpose, model_input, {1, 0, 2, 3}));
auto generate_simple_branch = [&](bool branch_true_false, const primitive_id& input_id, const data_types dt) {
auto id = branch_true_false ? condition_id_true : condition_id_false;
cldnn::topology branch_topology(input_layout(input_id, { {d1, -1, -1, d2}, dt, format::bfyx }),
reorder(id, input_info(input_id), { {d1, -1, -1, d2}, dt, format::bfyx })
);
condition::branch branch;
branch.inner_program = program::build_program(engine, branch_topology, config, false, false, true);
branch.input_map.insert({tranpose, branch_input_id});
branch.output_map.insert({0, id});
return branch;
};
condition::branch branch_true = generate_simple_branch(true, branch_input_id, data_types::f32);
condition::branch branch_false = generate_simple_branch(false, branch_input_id, data_types::f32);
topology.add(reorder(reorder_id, input_info(predicate_input), { {d1, -1, -1, d2}, data_types::f32, format::bfyx }));
topology.add(condition(condition_id, { reorder_id, tranpose }, branch_true, branch_false));
tests::random_generator rg(GET_SUITE_NAME);
std::vector<uint8_t> predicate_data_true = { 1 };
std::vector<uint8_t> predicate_data_false = { 0 };
network net(engine, topology, config);
for (int i = 0; i < 10; i++) {
layout l = {{1, d1, 1 + static_cast<int64_t>(i), d2}, data_types::f32, format::bfyx};
std::vector<float> input_data = rg.generate_random_1d<float>(l.count(), -10, 10);
auto mem = engine.allocate_memory(l);
set_values(mem, input_data);
set_values(predicate, predicate_data_true);
net.set_input_data(model_input, mem);
net.set_input_data(predicate_input, predicate);
auto outputs = net.execute();
auto cond_layout = outputs.at(condition_id).get_layout();
ASSERT_TRUE(cond_layout.get_dim(2) == (i + 1));
}
}
TEST(condition_gpu, basic_stacked_ifs) {
/*
<prims...>