[GPU] Fix issue of prepare_output (#7783)
This commit is contained in:
parent
50b0dc1182
commit
a93a67da78
@ -418,9 +418,11 @@ void network::set_output_memory(const primitive_id& id, memory::ptr mem_new) {
|
||||
|
||||
for (auto& prim : o_iter->second) {
|
||||
prim->set_output_memory(eng.reinterpret_buffer(*mem_new, prim->output_memory().get_layout()), false);
|
||||
if (!_reset_arguments)
|
||||
if (!_reset_arguments &&
|
||||
(!prim->get_node().is_type<data>() && !(prim->get_node().is_type<mutable_data>() && prim->get_node().get_dependencies().empty()))) {
|
||||
prim->set_arguments();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cldnn::network::check_names() {
|
||||
|
@ -343,3 +343,41 @@ TEST(set_output_memory_gpu, basic_opt) {
|
||||
EXPECT_TRUE(are_equal(output_ptr[i], output_vec[i])) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(set_output_memory_gpu, mutable_output_data) {
|
||||
static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2;
|
||||
auto& engine = get_test_engine();
|
||||
const int top_k = 2;
|
||||
auto input = engine.allocate_memory({ data_types::f32, format::bfyx,{ batch_num, feature_num, x_size , y_size } });
|
||||
auto second_input = engine.allocate_memory({ data_types::f32, format::bfyx, { top_k, feature_num, x_size , y_size } });
|
||||
auto final_output = engine.allocate_memory({ data_types::f32, format::bfyx,{ 1, 1, 1 , 1 } });
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Add_1396", input->get_layout()));
|
||||
topology.add(cldnn::mutable_data("second_input", second_input));
|
||||
topology.add(cldnn::mutable_data("12220_md_write", final_output));
|
||||
topology.add(arg_max_min("arg_max", { "Add_1396", "12220_md_write", "second_input" }, arg_max_min::min, top_k, arg_max_min::batch));
|
||||
topology.add(cldnn::mutable_data("pred/sink_port_0", {"arg_max"},final_output) );
|
||||
|
||||
std::vector<float> input_vec = {
|
||||
//y0x0 y0x1 y1x0 y1x1
|
||||
/*b0f0*/0.1f, -0.1f, 0.9f, 1.5f,
|
||||
/*b0f1*/0.2f, 0.2f, -10.f, 5.2f,
|
||||
/*b0f2*/0.2f, 0.2f, -10.f, 5.2f,
|
||||
/*b0f3*/0.2f, 0.2f, -10.f, 4.2f,
|
||||
|
||||
/*b1f0*/3.f, 0.5f, 7.f, 10.f,
|
||||
/*b1f1*/4.f, 0.5f, 8.f, 8.2f,
|
||||
/*b1f2*/0.2f, 0.2f, -10.f, 5.2f,
|
||||
/*b1f3*/4.f, 0.5f, 8.f, 8.2f
|
||||
};
|
||||
set_values(input, input_vec);
|
||||
auto prog = program::build_program(engine, topology, build_options());
|
||||
network network(prog, 0);
|
||||
network.set_input_data("Add_1396", input);
|
||||
|
||||
// to make _reset_arguments false
|
||||
network.execute();
|
||||
network.execute();
|
||||
network.set_output_memory("pred/sink_port_0", final_output);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user