[GPU] Allow to set empty tensor for inner network of condition primitive (#21415)
This commit is contained in:
parent
0a271c136a
commit
009951d969
@ -109,18 +109,25 @@ struct condition_impl : typed_primitive_impl<condition> {
|
||||
if (iter != branch.input_map.end()) {
|
||||
const primitive_id& input_internal_id = iter->second;
|
||||
auto mem_ptr = instance.input_memory_ptr(mem_idx);
|
||||
if (mem_ptr) {
|
||||
auto dep = instance.dependencies()[mem_idx];
|
||||
auto layout = dep.first->get_impl_params()->get_output_layout(dep.second);
|
||||
if (mem_ptr) {
|
||||
GPU_DEBUG_LOG << "Reshape input from " << mem_ptr->get_layout().to_short_string() << " to "
|
||||
<< layout.to_short_string() << std::endl;
|
||||
// Preallocation logic may allocate more memory than actually produced on current iteration, so
|
||||
// we need to adjust input buffers layout
|
||||
mem_ptr = instance.get_network().get_engine().reinterpret_buffer(*mem_ptr, layout);
|
||||
} else if (layout.count() == 0) {
|
||||
// Use dummy memory for empty tensor
|
||||
mem_ptr = std::make_shared<simple_attached_memory>(layout, nullptr);
|
||||
}
|
||||
OPENVINO_ASSERT(mem_ptr != nullptr, "[GPU] Can't assign nullptr memory buffer for condition primitive with id=", instance.id(), " ("
|
||||
"mem_idx=", mem_idx, ", "
|
||||
"external_id=", input_external_id, ", "
|
||||
"internal_id=", input_internal_id, ")");
|
||||
executed_net->set_input_data(input_internal_id, mem_ptr);
|
||||
GPU_DEBUG_LOG << "Inner net - Inputs[" << mem_idx << "]" << mem_ptr->get_layout().to_short_string()
|
||||
<< std::endl;
|
||||
GPU_DEBUG_LOG << "Inner net - Inputs[" << mem_idx << "]: layout=" << mem_ptr->get_layout().to_short_string() << ", "
|
||||
<< "allocation_type=" << mem_ptr->get_allocation_type() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,7 +43,8 @@ event::ptr input_layout_inst::set_data(memory::ptr mem) {
|
||||
auto& engine = get_network().get_engine();
|
||||
auto& stream = get_network().get_stream();
|
||||
|
||||
if (mem->is_allocated_by(engine)) {
|
||||
// Allow to set dummy simple_attached_memory empty tensor as network input
|
||||
if (mem->is_allocated_by(engine) || mem->get_layout().count() == 0) {
|
||||
OPENVINO_ASSERT(!_outputs.empty(), "[GPU] Can't set data for empty input memory");
|
||||
_outputs[0] = mem;
|
||||
ev = stream.create_user_event(true);
|
||||
|
@ -978,3 +978,72 @@ TEST(condition_gpu, empty_body_with_different_shapes) {
|
||||
auto out_pshape = output_layout.get_partial_shape();
|
||||
ASSERT_EQ(out_pshape, oned_pshape);
|
||||
}
|
||||
|
||||
TEST(condition_gpu, set_empty_tensor) {
|
||||
auto& engine = get_test_engine();
|
||||
ExecutionConfig config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
auto empty_mem = engine.allocate_memory({ { 1, 1, 1, 1 }, data_types::f16, format::bfyx });
|
||||
auto empty_input_mem = engine.reinterpret_buffer(*empty_mem, { { 1, 1, 0, 1 }, data_types::f16, format::bfyx });
|
||||
auto input_mem = engine.allocate_memory({ { 1, 1, 4, 1 }, data_types::f32, format::bfyx });
|
||||
auto predicate_mem = engine.allocate_memory({ { 1, 1, 1, 1 }, data_types::u8, format::bfyx });
|
||||
auto concat_data = engine.allocate_memory({ { 1, 1, 4, 1 }, data_types::f32, format::bfyx });
|
||||
|
||||
set_values(predicate_mem, {1});
|
||||
|
||||
primitive_id empty_input_id = "input1";
|
||||
primitive_id reorder_id = "reorder";
|
||||
primitive_id input_id = "input2";
|
||||
primitive_id pred_id = "predicate";
|
||||
primitive_id branch_input_id1 = "branch_input1";
|
||||
primitive_id branch_input_id2 = "branch_input2";
|
||||
primitive_id concat_data_id = "concat_data";
|
||||
primitive_id concat_id = "concat";
|
||||
primitive_id cond_id = "condi";
|
||||
|
||||
condition::branch branch_true;
|
||||
{
|
||||
topology branch_true_topology;
|
||||
branch_true_topology.add(
|
||||
input_layout(branch_input_id1, { { 1, 1, -1, 1 }, data_types::f32, format::bfyx }),
|
||||
data(concat_data_id, concat_data),
|
||||
concatenation(concat_id, { input_info(branch_input_id1), input_info(concat_data_id) }, 2)
|
||||
);
|
||||
branch_true.inner_program = program::build_program(engine, branch_true_topology, config, false, false, true);
|
||||
branch_true.input_map.insert({reorder_id, branch_input_id1});
|
||||
branch_true.output_map.insert({0, concat_id});
|
||||
}
|
||||
|
||||
condition::branch branch_false;
|
||||
{
|
||||
topology branch_false_topology;
|
||||
branch_false_topology.add(
|
||||
input_layout(branch_input_id2, { { 1, 1, 4, 1 }, data_types::f32, format::bfyx }),
|
||||
reorder("result", input_info(branch_input_id2), format::bfyx, data_types::f32)
|
||||
);
|
||||
branch_false.inner_program = program::build_program(engine, branch_false_topology, config, false, false, true);
|
||||
branch_false.input_map.insert({input_id, branch_input_id2});
|
||||
branch_false.output_map.insert({0, "result"});
|
||||
}
|
||||
|
||||
auto empty_input_layout = layout({ 1, 1, -1, 1 }, data_types::f32, format::bfyx);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout(pred_id, predicate_mem->get_layout()));
|
||||
topology.add(input_layout(empty_input_id, empty_input_layout));
|
||||
topology.add(input_layout(input_id, input_mem->get_layout()));
|
||||
topology.add(reorder(reorder_id, input_info(empty_input_id), format::bfyx, data_types::f32));
|
||||
topology.add(condition(cond_id, {input_info(pred_id), input_info(reorder_id), input_info(input_id)}, branch_true, branch_false));
|
||||
|
||||
network net(engine, topology, config);
|
||||
ASSERT_TRUE(net.get_primitive(cond_id)->get_node().as<condition>().get_branch_false().inner_program->can_be_optimized());
|
||||
ASSERT_FALSE(net.get_primitive(cond_id)->get_node().as<condition>().get_branch_true().inner_program->can_be_optimized());
|
||||
|
||||
net.set_input_data(pred_id, predicate_mem);
|
||||
net.set_input_data(empty_input_id, empty_input_mem);
|
||||
net.set_input_data(input_id, input_mem);
|
||||
|
||||
ASSERT_NO_THROW(net.execute());
|
||||
ASSERT_NO_THROW(net.get_output(cond_id).get_memory());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user