[GPU] Gather tree blocked layouts support (#12803)
* added gather tree blocked layout support * gather tree blocked layout support review cleanup * gather tree blocked layout support review cleanup * gather tree blocked layout support review cleanup * gather tree blocked layout support review cleanup * gather tree blocked layout support review cleanup * build fixed
This commit is contained in:
committed by
GitHub
parent
f7005ca297
commit
4324dcf695
@@ -40,7 +40,16 @@ gather_tree_inst::typed_primitive_inst(network& network, gather_tree_node const&
|
||||
"supported border primitive input formats",
|
||||
format::bfyx,
|
||||
format::yxfb,
|
||||
format::byxf);
|
||||
format::byxf,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv4_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv2,
|
||||
format::bs_fs_yx_bsv4_fsv2,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32);
|
||||
|
||||
auto dependencies = node.get_dependencies();
|
||||
|
||||
|
||||
@@ -44,14 +44,23 @@ struct gather_tree_impl : typed_primitive_impl_ocl<gather_tree> {
|
||||
};
|
||||
namespace detail {
|
||||
attach_gather_tree_impl::attach_gather_tree_impl() {
|
||||
implementation_map<gather_tree>::add(impl_types::ocl, gather_tree_impl::create, {
|
||||
std::make_tuple(data_types::i32, format::yxfb),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::byxf),
|
||||
std::make_tuple(data_types::f32, format::yxfb),
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::byxf),
|
||||
});
|
||||
auto types = {data_types::i32, data_types::f32};
|
||||
auto formats = {
|
||||
format::yxfb,
|
||||
format::bfyx,
|
||||
format::byxf,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv4_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv2,
|
||||
format::bs_fs_yx_bsv4_fsv2,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
};
|
||||
|
||||
implementation_map<gather_tree>::add(impl_types::ocl, gather_tree_impl::create, types, formats);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -1452,6 +1452,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::reverse::type_id() &&
|
||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
||||
prim.type() != cldnn::gather_tree::type_id() &&
|
||||
prim.type() != cldnn::experimental_detectron_detection_output::type_id()) {
|
||||
can_use_fsv16 = false;
|
||||
}
|
||||
@@ -1495,6 +1496,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::reverse::type_id() &&
|
||||
prim.type() != cldnn::reorg_yolo::type_id() &&
|
||||
prim.type() != cldnn::scatter_elements_update::type_id() &&
|
||||
prim.type() != cldnn::gather_tree::type_id() &&
|
||||
prim.type() != cldnn::experimental_detectron_detection_output::type_id() &&
|
||||
prim.type() != cldnn::deconvolution::type_id()) {
|
||||
can_use_bs_fs_yx_bsv16_fsv16 = false;
|
||||
|
||||
@@ -19,11 +19,41 @@ ParamsKey GatherTreeKernelRef::GetSupportedKey() const {
|
||||
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
|
||||
k.EnableInputLayout(DataLayout::yxfb);
|
||||
k.EnableOutputLayout(DataLayout::yxfb);
|
||||
|
||||
k.EnableOutputLayout(DataLayout::byxf);
|
||||
k.EnableInputLayout(DataLayout::byxf);
|
||||
|
||||
k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv32);
|
||||
|
||||
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv4_fsv4);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv4_fsv4);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv8_fsv4);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv8_fsv4);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv8_fsv2);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv8_fsv2);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv4_fsv2);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv4_fsv2);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv32);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv32_fsv16);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bs_fs_yx_bsv16_fsv16);
|
||||
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
|
||||
return k;
|
||||
|
||||
259
src/plugins/intel_gpu/tests/test_cases/gather_tree_gpu_test.cpp
Normal file
259
src/plugins/intel_gpu/tests/test_cases/gather_tree_gpu_test.cpp
Normal file
@@ -0,0 +1,259 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/gather_tree.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace {
|
||||
template<typename T>
|
||||
struct Params {
|
||||
tensor step_id_tensor;
|
||||
std::vector<T> step_id;
|
||||
tensor parent_id_tensor;
|
||||
std::vector<T> parent_id;
|
||||
tensor max_seq_len_tensor;
|
||||
std::vector<T> max_seq_len;
|
||||
tensor end_token_tensor;
|
||||
std::vector<T> end_token;
|
||||
tensor final_id_tensor;
|
||||
std::vector<T> final_id;
|
||||
std::string testcase_name;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
using ParamsWithLayout = std::tuple<
|
||||
Params<T>,
|
||||
format::type, // source (plain) layout - bfyx
|
||||
format::type // target (blocked) layout
|
||||
>;
|
||||
|
||||
const std::vector<format::type> layouts = {
|
||||
format::yxfb,
|
||||
format::bfyx,
|
||||
format::byxf,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv4_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv4,
|
||||
format::bs_fs_yx_bsv8_fsv2,
|
||||
format::bs_fs_yx_bsv4_fsv2,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
std::vector<Params<T>> generateParams() {
|
||||
static const std::vector<Params<T>> result = {
|
||||
{
|
||||
tensor(1, 1, 1, 10),
|
||||
std::vector<T>{1, 4, 9, 7, 9, 1, 2, 3, 9, 9},
|
||||
tensor(1, 1, 1, 10),
|
||||
std::vector<T>{1, 4, 9, 7, 9, 1, 2, 3, 9, 9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(1, 1, 1, 10),
|
||||
std::vector<T>{1, 4, 9, 7, 9, 1, 2, 3, 9, 9},
|
||||
"gather_tree_1",
|
||||
},
|
||||
{
|
||||
tensor(5, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
1, 4, 9, 7, 9, 1, 2, 3, 9, 2,
|
||||
3, 1, 4, 2, 4, 4, 7, 4, 9, 5,
|
||||
8, 4, 3, 7, 5, 2, 4, 8, 3, 1,
|
||||
5, 7, 9, 4, 5, 6, 4, 2, 9, 2,
|
||||
8, 8, 7, 9, 8, 3, 1, 7, 5, 9},
|
||||
tensor(5, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
1, 4, 9, 7, 9, 1, 2, 3, 9, 2,
|
||||
3, 1, 4, 2, 4, 4, 7, 4, 9, 5,
|
||||
8, 4, 3, 7, 5, 2, 4, 8, 3, 1,
|
||||
5, 7, 9, 4, 5, 6, 4, 2, 9, 2,
|
||||
8, 8, 7, 9, 8, 3, 1, 7, 5, 9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(5, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
4, 4, 9, 9, 4, 9, 2, 9, 9, 9,
|
||||
1, 1, 9, 9, 1, 9, 9, 9, 9, 9,
|
||||
1, 1, 9, 9, 1, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9},
|
||||
"gather_tree_5",
|
||||
},
|
||||
{
|
||||
tensor(20, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
1, 4, 9, 7, 9, 1, 2, 3, 9, 2, 3, 1, 4, 2, 4, 4, 7, 4, 9, 5,
|
||||
8, 4, 3, 7, 5, 2, 4, 8, 3, 1, 5, 7, 9, 4, 5, 6, 4, 2, 9, 2,
|
||||
8, 8, 7, 9, 8, 3, 1, 7, 5, 8, 8, 9, 8, 1, 8, 1, 3, 2, 1, 8,
|
||||
7, 1, 6, 4, 7, 9, 4, 5, 2, 7, 3, 3, 2, 7, 8, 8, 4, 1, 1, 7,
|
||||
6, 9, 6, 7, 3, 3, 5, 8, 2, 1, 1, 5, 5, 9, 1, 3, 9, 3, 2, 2,
|
||||
5, 1, 1, 7, 9, 2, 9, 3, 3, 5, 6, 1, 6, 6, 6, 2, 9, 6, 3, 7,
|
||||
3, 1, 5, 4, 9, 7, 5, 4, 5, 1, 7, 5, 1, 6, 2, 5, 8, 9, 1, 6,
|
||||
8, 9, 5, 2, 5, 2, 9, 8, 4, 4, 5, 2, 6, 9, 4, 4, 6, 7, 6, 7,
|
||||
2, 8, 7, 6, 6, 7, 4, 4, 7, 3, 4, 9, 7, 4, 8, 9, 1, 6, 5, 6,
|
||||
1, 2, 8, 9, 1, 5, 4, 6, 9, 4, 4, 3, 7, 9, 7, 6, 3, 1, 7, 9},
|
||||
tensor(20, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
1, 4, 9, 7, 9, 1, 2, 3, 9, 2, 3, 1, 4, 2, 4, 4, 7, 4, 9, 5,
|
||||
8, 4, 3, 7, 5, 2, 4, 8, 3, 1, 5, 7, 9, 4, 5, 6, 4, 2, 9, 2,
|
||||
8, 8, 7, 9, 8, 3, 1, 7, 5, 8, 8, 9, 8, 1, 8, 1, 3, 2, 1, 8,
|
||||
7, 1, 6, 4, 7, 9, 4, 5, 2, 7, 3, 3, 2, 7, 8, 8, 4, 1, 1, 7,
|
||||
6, 9, 6, 7, 3, 3, 5, 8, 2, 1, 1, 5, 5, 9, 1, 3, 9, 3, 2, 2,
|
||||
5, 1, 1, 7, 9, 2, 9, 3, 3, 5, 6, 1, 6, 6, 6, 2, 9, 6, 3, 7,
|
||||
3, 1, 5, 4, 9, 7, 5, 4, 5, 1, 7, 5, 1, 6, 2, 5, 8, 9, 1, 6,
|
||||
8, 9, 5, 2, 5, 2, 9, 8, 4, 4, 5, 2, 6, 9, 4, 4, 6, 7, 6, 7,
|
||||
2, 8, 7, 6, 6, 7, 4, 4, 7, 3, 4, 9, 7, 4, 8, 9, 1, 6, 5, 6,
|
||||
1, 2, 8, 9, 1, 5, 4, 6, 9, 4, 4, 3, 7, 9, 7, 6, 3, 1, 7, 9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(1, 1, 1, 1),
|
||||
std::vector<T>{9},
|
||||
tensor(20, 1, 1, 10),
|
||||
std::vector<T>{
|
||||
9, 4, 9, 4, 4, 4, 9, 4, 9, 9, 9, 1, 9, 1, 1, 1, 9, 1, 9, 9,
|
||||
9, 1, 9, 1, 1, 1, 9, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9},
|
||||
"gather_tree_10",
|
||||
},
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
struct PrintToStringParamName {
|
||||
template<class T>
|
||||
std::string operator()(const testing::TestParamInfo<ParamsWithLayout<T> > ¶m) {
|
||||
std::stringstream buf;
|
||||
Params<T> p;
|
||||
format::type plain_layout;
|
||||
format::type target_layout;
|
||||
std::tie(p, plain_layout, target_layout) = param.param;
|
||||
buf << " test case " << p.testcase_name
|
||||
<< " plain layout " << plain_layout
|
||||
<< " target layout " << target_layout;
|
||||
return buf.str();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct gather_tree_test
|
||||
: public ::testing::TestWithParam<ParamsWithLayout<T> > {
|
||||
public:
|
||||
void test() {
|
||||
const auto data_type = type_to_data_type<T>::value;
|
||||
Params<T> params;
|
||||
format::type plain_layout;
|
||||
format::type target_layout;
|
||||
|
||||
std::tie(params, plain_layout, target_layout) = this->GetParam();
|
||||
|
||||
auto &engine = get_test_engine();
|
||||
topology topology;
|
||||
|
||||
auto step_input = engine.allocate_memory({data_type, plain_layout, params.step_id_tensor});
|
||||
set_values(step_input, params.step_id);
|
||||
const std::string step_id = "step_id";
|
||||
topology.add(input_layout(step_id, step_input->get_layout()));
|
||||
const std::string reorder_step_id = step_id + "_reordered";
|
||||
topology.add(reorder(reorder_step_id, step_id, target_layout, data_type));
|
||||
|
||||
auto parent_input = engine.allocate_memory({data_type, plain_layout, params.parent_id_tensor});
|
||||
set_values(parent_input, params.parent_id);
|
||||
const std::string parent_id = "parent_id";
|
||||
topology.add(input_layout(parent_id, parent_input->get_layout()));
|
||||
const std::string reorder_parent_id = parent_id + "_reordered";
|
||||
topology.add(reorder(reorder_parent_id, parent_id, target_layout, data_type));
|
||||
|
||||
auto max_seq_len_input = engine.allocate_memory({data_type, plain_layout, params.max_seq_len_tensor});
|
||||
set_values(max_seq_len_input, params.max_seq_len);
|
||||
const std::string max_seq_len_id = "max_seq_len_id";
|
||||
topology.add(input_layout(max_seq_len_id, max_seq_len_input->get_layout()));
|
||||
const std::string reorder_max_seq_len_id = max_seq_len_id + "_reordered";
|
||||
topology.add(reorder(reorder_max_seq_len_id, max_seq_len_id, target_layout, data_type));
|
||||
|
||||
auto end_token_input = engine.allocate_memory({data_type, plain_layout, params.end_token_tensor});
|
||||
set_values(end_token_input, params.end_token);
|
||||
const std::string end_token_id = "end_token_id";
|
||||
topology.add(input_layout(end_token_id, end_token_input->get_layout()));
|
||||
const std::string reorder_end_token_id = end_token_id + "_reordered";
|
||||
topology.add(reorder(reorder_end_token_id, end_token_id, target_layout, data_type));
|
||||
|
||||
const std::string result_id = "result_id";
|
||||
topology.add(gather_tree(result_id, reorder_step_id, reorder_parent_id, reorder_max_seq_len_id, reorder_end_token_id));
|
||||
|
||||
const primitive_id reorder_result_id = result_id + "_reordered";
|
||||
topology.add(reorder(reorder_result_id, result_id, plain_layout, data_type));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data(step_id, step_input);
|
||||
network.set_input_data(parent_id, parent_input);
|
||||
network.set_input_data(max_seq_len_id, max_seq_len_input);
|
||||
network.set_input_data(end_token_id, end_token_input);
|
||||
|
||||
auto result = network.execute();
|
||||
|
||||
auto out_mem = result.at(reorder_result_id).get_memory();
|
||||
cldnn::mem_lock<T> out_ptr(out_mem, get_test_stream());
|
||||
|
||||
ASSERT_EQ(params.final_id_tensor.count(), out_ptr.size());
|
||||
|
||||
for (size_t i = 0; i < params.final_id.size(); ++i) {
|
||||
EXPECT_NEAR(params.final_id[i], out_ptr[i], 0.005) << "at i = " << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using gather_tree_test_f32 = gather_tree_test<float>;
|
||||
using gather_tree_test_int32 = gather_tree_test<int32_t>;
|
||||
|
||||
TEST_P(gather_tree_test_f32, test_case) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
TEST_P(gather_tree_test_int32, test_case) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(gather_tree,
|
||||
gather_tree_test_f32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateParams<float>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(layouts)),
|
||||
PrintToStringParamName());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(gather_tree,
|
||||
gather_tree_test_int32,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(generateParams<int32_t>()),
|
||||
::testing::Values(format::bfyx),
|
||||
::testing::ValuesIn(layouts)),
|
||||
PrintToStringParamName());
|
||||
Reference in New Issue
Block a user