[GPU] Fix issue of prepare_output (#7783)
This commit is contained in:
parent
50b0dc1182
commit
a93a67da78
@ -418,9 +418,11 @@ void network::set_output_memory(const primitive_id& id, memory::ptr mem_new) {
|
|||||||
|
|
||||||
for (auto& prim : o_iter->second) {
|
for (auto& prim : o_iter->second) {
|
||||||
prim->set_output_memory(eng.reinterpret_buffer(*mem_new, prim->output_memory().get_layout()), false);
|
prim->set_output_memory(eng.reinterpret_buffer(*mem_new, prim->output_memory().get_layout()), false);
|
||||||
if (!_reset_arguments)
|
if (!_reset_arguments &&
|
||||||
|
(!prim->get_node().is_type<data>() && !(prim->get_node().is_type<mutable_data>() && prim->get_node().get_dependencies().empty()))) {
|
||||||
prim->set_arguments();
|
prim->set_arguments();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cldnn::network::check_names() {
|
void cldnn::network::check_names() {
|
||||||
|
@ -343,3 +343,41 @@ TEST(set_output_memory_gpu, basic_opt) {
|
|||||||
EXPECT_TRUE(are_equal(output_ptr[i], output_vec[i])) << i;
|
EXPECT_TRUE(are_equal(output_ptr[i], output_vec[i])) << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(set_output_memory_gpu, mutable_output_data) {
|
||||||
|
static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2;
|
||||||
|
auto& engine = get_test_engine();
|
||||||
|
const int top_k = 2;
|
||||||
|
auto input = engine.allocate_memory({ data_types::f32, format::bfyx,{ batch_num, feature_num, x_size , y_size } });
|
||||||
|
auto second_input = engine.allocate_memory({ data_types::f32, format::bfyx, { top_k, feature_num, x_size , y_size } });
|
||||||
|
auto final_output = engine.allocate_memory({ data_types::f32, format::bfyx,{ 1, 1, 1 , 1 } });
|
||||||
|
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("Add_1396", input->get_layout()));
|
||||||
|
topology.add(cldnn::mutable_data("second_input", second_input));
|
||||||
|
topology.add(cldnn::mutable_data("12220_md_write", final_output));
|
||||||
|
topology.add(arg_max_min("arg_max", { "Add_1396", "12220_md_write", "second_input" }, arg_max_min::min, top_k, arg_max_min::batch));
|
||||||
|
topology.add(cldnn::mutable_data("pred/sink_port_0", {"arg_max"},final_output) );
|
||||||
|
|
||||||
|
std::vector<float> input_vec = {
|
||||||
|
//y0x0 y0x1 y1x0 y1x1
|
||||||
|
/*b0f0*/0.1f, -0.1f, 0.9f, 1.5f,
|
||||||
|
/*b0f1*/0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b0f2*/0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b0f3*/0.2f, 0.2f, -10.f, 4.2f,
|
||||||
|
|
||||||
|
/*b1f0*/3.f, 0.5f, 7.f, 10.f,
|
||||||
|
/*b1f1*/4.f, 0.5f, 8.f, 8.2f,
|
||||||
|
/*b1f2*/0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b1f3*/4.f, 0.5f, 8.f, 8.2f
|
||||||
|
};
|
||||||
|
set_values(input, input_vec);
|
||||||
|
auto prog = program::build_program(engine, topology, build_options());
|
||||||
|
network network(prog, 0);
|
||||||
|
network.set_input_data("Add_1396", input);
|
||||||
|
|
||||||
|
// to make _reset_arguments false
|
||||||
|
network.execute();
|
||||||
|
network.execute();
|
||||||
|
network.set_output_memory("pred/sink_port_0", final_output);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user