[GPU] Fix issue of prepare_output (#7783)

This commit is contained in:
Taylor Yeonbok Lee 2021-10-13 18:50:43 +09:00 committed by GitHub
parent 50b0dc1182
commit a93a67da78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 1 deletions

View File

@ -418,10 +418,12 @@ void network::set_output_memory(const primitive_id& id, memory::ptr mem_new) {
for (auto& prim : o_iter->second) {
prim->set_output_memory(eng.reinterpret_buffer(*mem_new, prim->output_memory().get_layout()), false);
if (!_reset_arguments)
if (!_reset_arguments &&
(!prim->get_node().is_type<data>() && !(prim->get_node().is_type<mutable_data>() && prim->get_node().get_dependencies().empty()))) {
prim->set_arguments();
}
}
}
void cldnn::network::check_names() {
for (auto const& prim : _primitives) {

View File

@ -343,3 +343,41 @@ TEST(set_output_memory_gpu, basic_opt) {
EXPECT_TRUE(are_equal(output_ptr[i], output_vec[i])) << i;
}
}
TEST(set_output_memory_gpu, mutable_output_data) {
static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2;
auto& engine = get_test_engine();
const int top_k = 2;
auto input = engine.allocate_memory({ data_types::f32, format::bfyx,{ batch_num, feature_num, x_size , y_size } });
auto second_input = engine.allocate_memory({ data_types::f32, format::bfyx, { top_k, feature_num, x_size , y_size } });
auto final_output = engine.allocate_memory({ data_types::f32, format::bfyx,{ 1, 1, 1 , 1 } });
topology topology;
topology.add(input_layout("Add_1396", input->get_layout()));
topology.add(cldnn::mutable_data("second_input", second_input));
topology.add(cldnn::mutable_data("12220_md_write", final_output));
topology.add(arg_max_min("arg_max", { "Add_1396", "12220_md_write", "second_input" }, arg_max_min::min, top_k, arg_max_min::batch));
topology.add(cldnn::mutable_data("pred/sink_port_0", {"arg_max"},final_output) );
std::vector<float> input_vec = {
//y0x0 y0x1 y1x0 y1x1
/*b0f0*/0.1f, -0.1f, 0.9f, 1.5f,
/*b0f1*/0.2f, 0.2f, -10.f, 5.2f,
/*b0f2*/0.2f, 0.2f, -10.f, 5.2f,
/*b0f3*/0.2f, 0.2f, -10.f, 4.2f,
/*b1f0*/3.f, 0.5f, 7.f, 10.f,
/*b1f1*/4.f, 0.5f, 8.f, 8.2f,
/*b1f2*/0.2f, 0.2f, -10.f, 5.2f,
/*b1f3*/4.f, 0.5f, 8.f, 8.2f
};
set_values(input, input_vec);
auto prog = program::build_program(engine, topology, build_options());
network network(prog, 0);
network.set_input_data("Add_1396", input);
// to make _reset_arguments false
network.execute();
network.execute();
network.set_output_memory("pred/sink_port_0", final_output);
}