This case will check the layout of condition in these conditions. (#21335)
- it re-allocated at primitive_inst::realloc_if_needed(). - it can be skip subgraph.
This commit is contained in:
parent
c1a28d0942
commit
07bcb8b6ed
@ -341,6 +341,75 @@ TEST(condition_gpu, dynamic_shapes) {
|
||||
}
|
||||
}
|
||||
|
||||
// This case will check the layout of condition in these conditions.
|
||||
// - it re-allocated at primitive_inst::realloc_if_needed().
|
||||
// - it can be skip subgraph.
|
||||
TEST(condition_gpu, dynamic_shapes_skip_condition) {
|
||||
auto& engine = get_test_engine();
|
||||
ExecutionConfig config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
const int64_t d1 = 2;
|
||||
const int64_t d2 = 4;
|
||||
layout input_lay = {{-1, d1, -1, d2}, data_types::f32, format::bfyx};
|
||||
|
||||
auto predicate = engine.allocate_memory({{ 1 }, data_types::u8, format::bfyx });
|
||||
|
||||
const primitive_id condition_id = "condition";
|
||||
const primitive_id condition_id_true = condition_id + "_when_true";
|
||||
const primitive_id condition_id_false = condition_id + "_when_false";
|
||||
const primitive_id branch_input_id = "branch_input";
|
||||
const primitive_id model_input = "input";
|
||||
const primitive_id predicate_input = "predicate";
|
||||
const primitive_id reorder_id = "reorder";
|
||||
const primitive_id tranpose = "transpose";
|
||||
|
||||
cldnn::topology topology;
|
||||
topology.add(input_layout(model_input, input_lay));
|
||||
topology.add(input_layout(predicate_input, predicate->get_layout()));
|
||||
topology.add(permute(tranpose, model_input, {1, 0, 2, 3}));
|
||||
|
||||
auto generate_simple_branch = [&](bool branch_true_false, const primitive_id& input_id, const data_types dt) {
|
||||
auto id = branch_true_false ? condition_id_true : condition_id_false;
|
||||
cldnn::topology branch_topology(input_layout(input_id, { {d1, -1, -1, d2}, dt, format::bfyx }),
|
||||
reorder(id, input_info(input_id), { {d1, -1, -1, d2}, dt, format::bfyx })
|
||||
);
|
||||
condition::branch branch;
|
||||
branch.inner_program = program::build_program(engine, branch_topology, config, false, false, true);
|
||||
branch.input_map.insert({tranpose, branch_input_id});
|
||||
branch.output_map.insert({0, id});
|
||||
|
||||
return branch;
|
||||
};
|
||||
|
||||
condition::branch branch_true = generate_simple_branch(true, branch_input_id, data_types::f32);
|
||||
condition::branch branch_false = generate_simple_branch(false, branch_input_id, data_types::f32);
|
||||
|
||||
topology.add(reorder(reorder_id, input_info(predicate_input), { {d1, -1, -1, d2}, data_types::f32, format::bfyx }));
|
||||
topology.add(condition(condition_id, { reorder_id, tranpose }, branch_true, branch_false));
|
||||
|
||||
tests::random_generator rg(GET_SUITE_NAME);
|
||||
std::vector<uint8_t> predicate_data_true = { 1 };
|
||||
std::vector<uint8_t> predicate_data_false = { 0 };
|
||||
|
||||
network net(engine, topology, config);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
layout l = {{1, d1, 1 + static_cast<int64_t>(i), d2}, data_types::f32, format::bfyx};
|
||||
std::vector<float> input_data = rg.generate_random_1d<float>(l.count(), -10, 10);
|
||||
auto mem = engine.allocate_memory(l);
|
||||
|
||||
set_values(mem, input_data);
|
||||
set_values(predicate, predicate_data_true);
|
||||
net.set_input_data(model_input, mem);
|
||||
net.set_input_data(predicate_input, predicate);
|
||||
auto outputs = net.execute();
|
||||
|
||||
auto cond_layout = outputs.at(condition_id).get_layout();
|
||||
ASSERT_TRUE(cond_layout.get_dim(2) == (i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(condition_gpu, basic_stacked_ifs) {
|
||||
/*
|
||||
<prims...>
|
||||
|
Loading…
Reference in New Issue
Block a user