[GPU] Fix multiple output issue in get_output_layout(#19186) (#19186)

This commit is contained in:
Paul Youngsoo Ahn 2023-08-15 05:49:21 +09:00 committed by GitHub
parent a0d1b91a78
commit e2db808495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 7 deletions

View File

@ -139,7 +139,8 @@ public:
std::shared_ptr<ov::threading::IStreamsExecutor> task_executor,
bool is_internal);
explicit program(engine& engine);
explicit program(engine& engine,
const ExecutionConfig& config = {});
~program();
engine& get_engine() const { return _engine; }
const ExecutionConfig& get_config() const { return _config; }

View File

@ -191,10 +191,11 @@ program::program(engine& engine_ref,
build_program(is_internal);
}
program::program(engine& engine)
program::program(engine& engine,
const ExecutionConfig& config)
: _engine(engine),
_stream(_engine.create_stream({})),
_config(),
_config(config),
processing_order() {
init_primitives();
_config.apply_user_properties(_engine.get_device_info());

View File

@ -288,8 +288,8 @@ layout program_node::get_output_layout(bool invalidate_users_if_changed, size_t
if (valid_output_layouts[idx])
return output_layouts[idx];
auto new_layout = calc_output_layout();
set_output_layout(new_layout, invalidate_users_if_changed, idx);
auto new_layouts = calc_output_layouts();
set_output_layouts(new_layouts, invalidate_users_if_changed);
return output_layouts[idx];
}

View File

@ -8,11 +8,11 @@
#include <intel_gpu/primitives/permute.hpp>
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/mutable_data.hpp>
#include <arg_max_min_inst.h>
#include "test_utils.h"
#include "program_wrapper.h"
using namespace cldnn;
using namespace ::tests;
@ -930,3 +930,45 @@ TEST(arg_max_min_gpu, dynamic) {
ASSERT_FLOAT_EQ(output_ptr[i], i < (out_size / 2) ? 0 : 1);
}
}
TEST(arg_max_min_test, check_second_output_data_type) {
auto& engine = get_test_engine();
ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
cldnn::program prog(engine, config);
std::vector<std::shared_ptr<primitive>> input_prims;
std::vector<input_info> input_prim_ids;
{
auto prim_id = "input";
static const int32_t x_size = 1, y_size = 1, feature_num = 2000, batch_num = 1;
auto input_static = layout{{batch_num, feature_num, y_size, x_size}, data_types::f16, format::bfyx};
auto input_layout_prim = std::make_shared<input_layout>(prim_id, input_static);
input_prims.push_back(input_layout_prim);
input_prim_ids.push_back(input_info(prim_id));
}
{
auto prim_id = "top_k";
auto top_k_input = layout{{1,1,1,1}, data_types::f16, format::bfyx};
auto top_k_prim = std::make_shared<input_layout>(prim_id, top_k_input);
input_prims.push_back(top_k_prim);
input_prim_ids.push_back(input_info(prim_id));
}
auto arg_max_min_prim = std::make_shared<arg_max_min>("output", input_prim_ids,
ov::op::TopKMode::MAX, 400, 1,
ov::op::TopKSortType::SORT_VALUES, true, false, padding(),
data_types::f16, 2);
arg_max_min_prim->output_paddings = {padding(), padding()};
arg_max_min_prim->output_data_types = {data_types::f16, data_types::i32};
auto& arg_max_min_node = prog.get_or_create(arg_max_min_prim);
for (auto& prim : input_prims) {
auto& input_layout_node = prog.get_or_create(prim);
program_wrapper::add_connection(prog, input_layout_node, arg_max_min_node);
}
auto second_output_layout = arg_max_min_node.get_output_layout(false, 1);
ASSERT_EQ(second_output_layout.data_type, data_types::i32);
}